diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java index bb371c247..4be84f15d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java @@ -1,68 +1,42 @@ package com.tencent.supersonic.chat.parser; -import com.tencent.supersonic.chat.config.OptimizationConfig; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator; -import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat; -import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader; -import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator; import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq; import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; +import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration; +import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.JsonUtil; -import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.input.Prompt; -import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.model.output.Response; +import java.util.Objects; import lombok.extern.slf4j.Slf4j; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - @Slf4j public class EmbedLLMProxy implements LLMProxy { + @Override + public boolean isSkip(QueryContext queryContext) { + ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); + if (Objects.isNull(chatLanguageModel)) { + log.warn("chatLanguageModel is null, skip EmbedLLMProxy"); + return true; + } + return false; + } + public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) { - ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); - - SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class); - - OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class); - - List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), - config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum()); - - String queryText = llmReq.getQueryText(); + SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode()); String modelName = llmReq.getSchema().getModelName(); - List fieldNameList = llmReq.getSchema().getFieldNameList(); - List linking = llmReq.getLinking(); - - SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class); - String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList, - linking, sqlExamples); - - Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>()); - Response linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage()); - - String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text()); - - String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr, - llmReq.getCurrentDate(), sqlExamples); - - Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>()); - Response sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage()); + String sql = sqlGeneration.generation(llmReq, modelClusterKey); LLMResp result = new LLMResp(); - result.setQuery(queryText); - result.setSchemaLinkingOutput(linkingPromptStr); - result.setSchemaLinkStr(schemaLinkStr); + result.setQuery(llmReq.getQueryText()); result.setModelName(modelName); - result.setSqlOutput(sqlResult.content().text()); + result.setSqlOutput(sql); return result; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/LLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/LLMProxy.java index 8f242f102..ba47ea5c2 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/LLMProxy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/LLMProxy.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.parser; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq; import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; @@ -12,6 +13,8 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; */ public interface LLMProxy { + boolean isSkip(QueryContext queryContext); + LLMResp query2sql(LLMReq llmReq, String modelClusterKey); FunctionResp requestFunction(FunctionReq functionReq); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java index 4b457499e..78d573b97 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/PythonLLMProxy.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.parser; import com.alibaba.fastjson.JSON; +import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.config.LLMParserConfig; import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig; import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq; @@ -12,6 +13,7 @@ import com.tencent.supersonic.common.util.JsonUtil; import java.net.URI; import java.net.URL; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -26,6 +28,16 @@ import org.springframework.web.util.UriComponentsBuilder; @Slf4j public class PythonLLMProxy implements LLMProxy { + @Override + public boolean isSkip(QueryContext queryContext) { + LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class); + if (StringUtils.isEmpty(llmParserConfig.getUrl())) { + log.warn("llmParserUrl is empty, skip PythonLLMProxy, config:{}", llmParserConfig); + return true; + } + return false; + } + public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) { long startTime = System.currentTimeMillis(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java index 2df434f09..eab9d75f8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java @@ -60,13 +60,11 @@ public class LLMRequestService { private OptimizationConfig optimizationConfig; public boolean isSkip(QueryContext queryCtx) { - QueryReq request = queryCtx.getRequest(); - if (StringUtils.isEmpty(llmParserConfig.getUrl())) { - log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMSqlParser.class, llmParserConfig); + if (llmProxy.isSkip(queryCtx)) { return true; } if (SatisfactionChecker.isSkip(queryCtx)) { - log.info("skip {}, queryText:{}", LLMSqlParser.class, request.getQueryText()); + log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText()); return true; } return false; @@ -104,7 +102,7 @@ public class LLMRequestService { } public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema, - ModelCluster modelCluster, List linkingValues) { + ModelCluster modelCluster, List linkingValues) { Map modelIdToName = semanticSchema.getModelIdToName(); String queryText = queryCtx.getRequest().getQueryText(); @@ -146,7 +144,7 @@ public class LLMRequestService { } protected List getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster, - LLMParserConfig llmParserConfig) { + LLMParserConfig llmParserConfig) { Set results = getTopNFieldNames(modelCluster, llmParserConfig); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGeneration.java new file mode 100644 index 000000000..7970828d0 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGeneration.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.chat.parser.sql.llm; + + +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; + +public interface SqlGeneration { + + String generation(LLMReq llmReq, String modelClusterKey); + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGenerationFactory.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGenerationFactory.java new file mode 100644 index 000000000..cd6fd0bc2 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGenerationFactory.java @@ -0,0 +1,22 @@ +package com.tencent.supersonic.chat.parser.sql.llm; + + +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Sql generation factory + */ +public class SqlGenerationFactory { + + private static Map sqlGenerationMap = new ConcurrentHashMap<>(); + + public static SqlGeneration get(SqlGenerationMode strategyType) { + return sqlGenerationMap.get(strategyType); + } + + public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) { + sqlGenerationMap.put(strategy, sqlGeneration); + } +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoStepsSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoStepsSqlGeneration.java new file mode 100644 index 000000000..54839bbd5 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/TwoStepsSqlGeneration.java @@ -0,0 +1,73 @@ +package com.tencent.supersonic.chat.parser.sql.llm; + + +import com.tencent.supersonic.chat.config.OptimizationConfig; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.JsonUtil; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.Response; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +@Service +@Slf4j +public class TwoStepsSqlGeneration implements SqlGeneration, InitializingBean { + + @Autowired + private ChatLanguageModel chatLanguageModel; + + @Autowired + private SqlExampleLoader sqlExampleLoader; + + @Autowired + private OptimizationConfig optimizationConfig; + + @Override + public String generation(LLMReq llmReq, String modelClusterKey) { + String text2sqlCollectionName = optimizationConfig.getText2sqlCollectionName(); + int text2sqlFewShotsNum = optimizationConfig.getText2sqlFewShotsNum(); + String queryText = llmReq.getQueryText(); + + List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(queryText, text2sqlCollectionName, + text2sqlFewShotsNum); + + String modelName = llmReq.getSchema().getModelName(); + List fieldNameList = llmReq.getSchema().getFieldNameList(); + List linking = llmReq.getLinking(); + + SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class); + String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList, + linking, sqlExamples); + + Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>()); + Response linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage()); + + String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text()); + + String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr, + llmReq.getCurrentDate(), sqlExamples); + + Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>()); + Response sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage()); + return sqlResult.content().text(); + } + + @Override + public void afterPropertiesSet() { + SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEPS, this); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java index 78984f9c8..cb5c215ce 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java @@ -18,6 +18,8 @@ public class LLMReq { private String priorExts; + private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_STEPS; + @Data public static class ElementValue { @@ -43,4 +45,25 @@ public class LLMReq { private String tableName; } + + public enum SqlGenerationMode { + + ONE_STEP("ONE_STEP"), + + TWO_STEPS("TWO_STEPS"), + + TWO_STEPS_WITH_CS("TWO_STEPS_WITH_CS"); + + + private String name; + + SqlGenerationMode(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + } } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java index cc6b6fdf8..eb8c12354 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/WordBuilderFactory.java @@ -23,4 +23,4 @@ public class WordBuilderFactory { public static BaseWordBuilder get(DictWordType strategyType) { return wordNatures.get(strategyType); } -} +} \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index 142ad652e..096bc7226 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -53,20 +53,23 @@ langchain4j: temperature: 0.0 timeout: PT60S #2.embedding-model + #2.1 in_memory embedding-model: provider: in_memory -# embedding-model: -# hugging-face: -# access-token: hg_access_token -# model-id: sentence-transformers/all-MiniLM-L6-v2 -# timeout: 1h - + #2.2 open_ai # embedding-model: # provider: open_ai # openai: # api-key: api_key # modelName: all-minilm-l6-v2.onnx + #2.2 hugging_face +# embedding-model: +# provider: hugging_face +# hugging-face: +# access-token: hg_access_token +# model-id: sentence-transformers/all-MiniLM-L6-v2 +# timeout: 1h #langchain4j log logging: