[improvement](chat) Change llmparser to pyllm, retrieve LLMProxy from environment variables, defaulting to JavaLLMProxy. (#497)

This commit is contained in:
lexluo09
2023-12-12 14:29:35 +08:00
committed by GitHub
parent 49bb2c6d8b
commit 73899e3174
17 changed files with 84 additions and 72 deletions

View File

@@ -4,7 +4,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import lombok.extern.slf4j.Slf4j;
/**
* Perform SQL corrections on the "From" section in S2SQL.
*/
@Slf4j
public class FromCorrector extends BaseSemanticCorrector {

View File

@@ -18,7 +18,7 @@ import java.util.Set;
import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "group by" section in S2SQL.
* Perform SQL corrections on the "Group by" section in S2SQL.
*/
@Slf4j
public class GroupByCorrector extends BaseSemanticCorrector {

View File

@@ -13,11 +13,13 @@ import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
/**
* LLMProxy based on langchain4j Java version.
*/
@Slf4j
@Component
public class JavaLLMProxy implements LLMProxy {
@Override

View File

@@ -19,6 +19,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@@ -26,6 +27,7 @@ import org.springframework.web.util.UriComponentsBuilder;
* PythonLLMProxy sends requests to LangChain-based python service.
*/
@Slf4j
@Component
public class PythonLLMProxy implements LLMProxy {
@Override

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
@@ -27,12 +26,10 @@ import org.springframework.util.CollectionUtils;
@Slf4j
public class EmbeddingRecallParser extends PluginParser {
protected LLMProxy llmInterpreter = ComponentFactory.getLLMProxy();
@Override
public boolean checkPreCondition(QueryContext queryContext) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (StringUtils.isBlank(embeddingConfig.getUrl()) && llmInterpreter instanceof PythonLLMProxy) {
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
return false;
}
List<Plugin> plugins = getPluginList(queryContext);

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin;
@@ -27,13 +26,11 @@ import org.springframework.util.CollectionUtils;
@Slf4j
public class FunctionCallParser extends PluginParser {
protected LLMProxy llmInterpreter = ComponentFactory.getLLMProxy();
@Override
public boolean checkPreCondition(QueryContext queryContext) {
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl) && llmInterpreter instanceof PythonLLMProxy) {
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryContext.getRequest().getQueryText());
return false;
@@ -84,7 +81,7 @@ public class FunctionCallParser extends PluginParser {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryContext.getRequest().getQueryText())
.pluginConfigs(pluginToFunctionCall).build();
functionResp = llmInterpreter.requestFunction(functionReq);
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
}
return functionResp;
}

View File

@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
@@ -26,13 +25,6 @@ import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
@@ -42,13 +34,17 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Slf4j
@Service
public class LLMRequestService {
protected LLMProxy llmProxy = ComponentFactory.getLLMProxy();
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Autowired
private LLMParserConfig llmParserConfig;
@@ -60,7 +56,7 @@ public class LLMRequestService {
private OptimizationConfig optimizationConfig;
public boolean isSkip(QueryContext queryCtx) {
if (llmProxy.isSkip(queryCtx)) {
if (ComponentFactory.getLLMProxy().isSkip(queryCtx)) {
return true;
}
if (SatisfactionChecker.isSkip(queryCtx)) {
@@ -140,7 +136,7 @@ public class LLMRequestService {
}
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
return llmProxy.query2sql(llmReq, modelClusterKey);
return ComponentFactory.getLLMProxy().query2sql(llmReq, modelClusterKey);
}
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,

View File

@@ -4,16 +4,22 @@ import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.parser.JavaLLMProxy;
import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.parser.sql.llm.ModelResolver;
import com.tencent.supersonic.chat.processor.ParseResultProcessor;
import com.tencent.supersonic.chat.query.QueryResponder;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
@Slf4j
public class ComponentFactory {
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
@@ -57,9 +63,21 @@ public class ComponentFactory {
}
public static LLMProxy getLLMProxy() {
if (Objects.isNull(llmProxy)) {
llmProxy = init(LLMProxy.class);
//1.Preferentially retrieve from environment variables
String llmProxyEnv = System.getenv("llmProxy");
if (StringUtils.isNotBlank(llmProxyEnv)) {
Map<String, LLMProxy> implementations = ContextUtils.getBeansOfType(LLMProxy.class);
llmProxy = implementations.entrySet().stream()
.filter(entry -> entry.getKey().equalsIgnoreCase(llmProxyEnv))
.map(Map.Entry::getValue)
.findFirst()
.orElse(null);
}
//2.default JavaLLMProxy
if (Objects.isNull(llmProxy)) {
llmProxy = ContextUtils.getBean(JavaLLMProxy.class);
}
log.info("llmProxy:{}", llmProxy);
return llmProxy;
}