(improvement)(chat) Support configuring embeddingModel or embeddingStore at the agent level. (#1361)

This commit is contained in:
lexluo09
2024-07-06 20:44:23 +08:00
committed by GitHub
parent d39db734c4
commit 6db6aaf98d
42 changed files with 669 additions and 299 deletions

View File

@@ -3,12 +3,12 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data;
@@ -27,7 +27,7 @@ public class QueryReq {
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig;
private ChatModelConfig llmConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
}

View File

@@ -2,10 +2,11 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
@@ -50,7 +51,8 @@ public class QueryContext {
@JsonIgnore
private WorkflowState workflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig;
private ModelConfig modelConfig;
private ChatModelConfig llmConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;

View File

@@ -1,8 +1,5 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -17,6 +14,12 @@ import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@@ -26,11 +29,9 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
@Slf4j
@Service
@@ -101,6 +102,7 @@ public class LLMRequestService {
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig());
llmReq.setLlmConfig(queryCtx.getLlmConfig());
llmReq.setPromptConfig(queryCtx.getPromptConfig());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
@@ -118,7 +120,7 @@ public class LLMRequestService {
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) {
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import dev.langchain4j.provider.ModelProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
@@ -23,8 +23,8 @@ public abstract class SqlGenStrategy implements InitializingBean {
@Autowired
protected PromptHelper promptHelper;
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
return ChatLanguageModelProvider.provide(llmConfig);
protected ChatLanguageModel getChatLanguageModel(ChatModelConfig llmConfig) {
return ModelProvider.provideChatModel(llmConfig);
}
abstract LLMResp generate(LLMReq llmReq);

View File

@@ -2,7 +2,8 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.SqlExemplar;
@@ -27,8 +28,8 @@ public class LLMReq {
private SqlGenType sqlGenType;
private LLMConfig llmConfig;
private ModelConfig modelConfig;
private ChatModelConfig llmConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;