mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
Add SqlGeneration abstraction and implementation, optimize LLMSqlParser skip. (#487)
This commit is contained in:
@@ -1,68 +1,42 @@
|
|||||||
package com.tencent.supersonic.chat.parser;
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator;
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
|
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
|
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
|
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import java.util.Objects;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbedLLMProxy implements LLMProxy {
|
public class EmbedLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkip(QueryContext queryContext) {
|
||||||
|
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
|
if (Objects.isNull(chatLanguageModel)) {
|
||||||
|
log.warn("chatLanguageModel is null, skip EmbedLLMProxy");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
|
||||||
|
|
||||||
SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
|
|
||||||
|
|
||||||
OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
|
|
||||||
|
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
|
||||||
config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum());
|
|
||||||
|
|
||||||
String queryText = llmReq.getQueryText();
|
|
||||||
String modelName = llmReq.getSchema().getModelName();
|
String modelName = llmReq.getSchema().getModelName();
|
||||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
String sql = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||||
List<ElementValue> linking = llmReq.getLinking();
|
|
||||||
|
|
||||||
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
|
|
||||||
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
|
|
||||||
linking, sqlExamples);
|
|
||||||
|
|
||||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
|
||||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
|
||||||
|
|
||||||
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
|
|
||||||
|
|
||||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
|
|
||||||
llmReq.getCurrentDate(), sqlExamples);
|
|
||||||
|
|
||||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
|
||||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
|
||||||
|
|
||||||
LLMResp result = new LLMResp();
|
LLMResp result = new LLMResp();
|
||||||
result.setQuery(queryText);
|
result.setQuery(llmReq.getQueryText());
|
||||||
result.setSchemaLinkingOutput(linkingPromptStr);
|
|
||||||
result.setSchemaLinkStr(schemaLinkStr);
|
|
||||||
result.setModelName(modelName);
|
result.setModelName(modelName);
|
||||||
result.setSqlOutput(sqlResult.content().text());
|
result.setSqlOutput(sql);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.parser;
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
@@ -12,6 +13,8 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
|||||||
*/
|
*/
|
||||||
public interface LLMProxy {
|
public interface LLMProxy {
|
||||||
|
|
||||||
|
boolean isSkip(QueryContext queryContext);
|
||||||
|
|
||||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||||
|
|
||||||
FunctionResp requestFunction(FunctionReq functionReq);
|
FunctionResp requestFunction(FunctionReq functionReq);
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.parser;
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSON;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||||
@@ -12,6 +13,7 @@ import com.tencent.supersonic.common.util.JsonUtil;
|
|||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.http.HttpEntity;
|
import org.springframework.http.HttpEntity;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
@@ -26,6 +28,16 @@ import org.springframework.web.util.UriComponentsBuilder;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class PythonLLMProxy implements LLMProxy {
|
public class PythonLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkip(QueryContext queryContext) {
|
||||||
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
|
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||||
|
log.warn("llmParserUrl is empty, skip PythonLLMProxy, config:{}", llmParserConfig);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
|
|||||||
@@ -60,13 +60,11 @@ public class LLMRequestService {
|
|||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
public boolean isSkip(QueryContext queryCtx) {
|
public boolean isSkip(QueryContext queryCtx) {
|
||||||
QueryReq request = queryCtx.getRequest();
|
if (llmProxy.isSkip(queryCtx)) {
|
||||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
|
||||||
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMSqlParser.class, llmParserConfig);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (SatisfactionChecker.isSkip(queryCtx)) {
|
if (SatisfactionChecker.isSkip(queryCtx)) {
|
||||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, request.getQueryText());
|
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@@ -104,7 +102,7 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
||||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||||
String queryText = queryCtx.getRequest().getQueryText();
|
String queryText = queryCtx.getRequest().getQueryText();
|
||||||
|
|
||||||
@@ -146,7 +144,7 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
||||||
LLMParserConfig llmParserConfig) {
|
LLMParserConfig llmParserConfig) {
|
||||||
|
|
||||||
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
|
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
|
||||||
|
public interface SqlGeneration {
|
||||||
|
|
||||||
|
String generation(LLMReq llmReq, String modelClusterKey);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sql generation factory
|
||||||
|
*/
|
||||||
|
public class SqlGenerationFactory {
|
||||||
|
|
||||||
|
private static Map<SqlGenerationMode, SqlGeneration> sqlGenerationMap = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
public static SqlGeneration get(SqlGenerationMode strategyType) {
|
||||||
|
return sqlGenerationMap.get(strategyType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) {
|
||||||
|
sqlGenerationMap.put(strategy, sqlGeneration);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.InitializingBean;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class TwoStepsSqlGeneration implements SqlGeneration, InitializingBean {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ChatLanguageModel chatLanguageModel;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private SqlExampleLoader sqlExampleLoader;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String generation(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
String text2sqlCollectionName = optimizationConfig.getText2sqlCollectionName();
|
||||||
|
int text2sqlFewShotsNum = optimizationConfig.getText2sqlFewShotsNum();
|
||||||
|
String queryText = llmReq.getQueryText();
|
||||||
|
|
||||||
|
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(queryText, text2sqlCollectionName,
|
||||||
|
text2sqlFewShotsNum);
|
||||||
|
|
||||||
|
String modelName = llmReq.getSchema().getModelName();
|
||||||
|
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||||
|
List<ElementValue> linking = llmReq.getLinking();
|
||||||
|
|
||||||
|
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
|
||||||
|
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
|
||||||
|
linking, sqlExamples);
|
||||||
|
|
||||||
|
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||||
|
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||||
|
|
||||||
|
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
|
||||||
|
|
||||||
|
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
|
||||||
|
llmReq.getCurrentDate(), sqlExamples);
|
||||||
|
|
||||||
|
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||||
|
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||||
|
return sqlResult.content().text();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void afterPropertiesSet() {
|
||||||
|
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEPS, this);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,6 +18,8 @@ public class LLMReq {
|
|||||||
|
|
||||||
private String priorExts;
|
private String priorExts;
|
||||||
|
|
||||||
|
private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_STEPS;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class ElementValue {
|
public static class ElementValue {
|
||||||
|
|
||||||
@@ -43,4 +45,25 @@ public class LLMReq {
|
|||||||
|
|
||||||
private String tableName;
|
private String tableName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public enum SqlGenerationMode {
|
||||||
|
|
||||||
|
ONE_STEP("ONE_STEP"),
|
||||||
|
|
||||||
|
TWO_STEPS("TWO_STEPS"),
|
||||||
|
|
||||||
|
TWO_STEPS_WITH_CS("TWO_STEPS_WITH_CS");
|
||||||
|
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
SqlGenerationMode(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,4 +23,4 @@ public class WordBuilderFactory {
|
|||||||
public static BaseWordBuilder get(DictWordType strategyType) {
|
public static BaseWordBuilder get(DictWordType strategyType) {
|
||||||
return wordNatures.get(strategyType);
|
return wordNatures.get(strategyType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -53,20 +53,23 @@ langchain4j:
|
|||||||
temperature: 0.0
|
temperature: 0.0
|
||||||
timeout: PT60S
|
timeout: PT60S
|
||||||
#2.embedding-model
|
#2.embedding-model
|
||||||
|
#2.1 in_memory
|
||||||
embedding-model:
|
embedding-model:
|
||||||
provider: in_memory
|
provider: in_memory
|
||||||
# embedding-model:
|
#2.2 open_ai
|
||||||
# hugging-face:
|
|
||||||
# access-token: hg_access_token
|
|
||||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
# timeout: 1h
|
|
||||||
|
|
||||||
# embedding-model:
|
# embedding-model:
|
||||||
# provider: open_ai
|
# provider: open_ai
|
||||||
# openai:
|
# openai:
|
||||||
# api-key: api_key
|
# api-key: api_key
|
||||||
# modelName: all-minilm-l6-v2.onnx
|
# modelName: all-minilm-l6-v2.onnx
|
||||||
|
|
||||||
|
#2.2 hugging_face
|
||||||
|
# embedding-model:
|
||||||
|
# provider: hugging_face
|
||||||
|
# hugging-face:
|
||||||
|
# access-token: hg_access_token
|
||||||
|
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
# timeout: 1h
|
||||||
|
|
||||||
#langchain4j log
|
#langchain4j log
|
||||||
logging:
|
logging:
|
||||||
|
|||||||
Reference in New Issue
Block a user