mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
Add SqlGeneration abstraction and implementation, optimize LLMSqlParser skip. (#487)
This commit is contained in:
@@ -1,68 +1,42 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class EmbedLLMProxy implements LLMProxy {
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
if (Objects.isNull(chatLanguageModel)) {
|
||||
log.warn("chatLanguageModel is null, skip EmbedLLMProxy");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
|
||||
SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
|
||||
|
||||
OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum());
|
||||
|
||||
String queryText = llmReq.getQueryText();
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
|
||||
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
|
||||
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
|
||||
linking, sqlExamples);
|
||||
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
|
||||
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
|
||||
|
||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
|
||||
llmReq.getCurrentDate(), sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||
String sql = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(queryText);
|
||||
result.setSchemaLinkingOutput(linkingPromptStr);
|
||||
result.setSchemaLinkStr(schemaLinkStr);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
result.setSqlOutput(sqlResult.content().text());
|
||||
result.setSqlOutput(sql);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
@@ -12,6 +13,8 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
*/
|
||||
public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
@@ -12,6 +13,7 @@ import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
@@ -26,6 +28,16 @@ import org.springframework.web.util.UriComponentsBuilder;
|
||||
@Slf4j
|
||||
public class PythonLLMProxy implements LLMProxy {
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.warn("llmParserUrl is empty, skip PythonLLMProxy, config:{}", llmParserConfig);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
@@ -60,13 +60,11 @@ public class LLMRequestService {
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
public boolean isSkip(QueryContext queryCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMSqlParser.class, llmParserConfig);
|
||||
if (llmProxy.isSkip(queryCtx)) {
|
||||
return true;
|
||||
}
|
||||
if (SatisfactionChecker.isSkip(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, request.getQueryText());
|
||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
|
||||
public interface SqlGeneration {
|
||||
|
||||
String generation(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* Sql generation factory
|
||||
*/
|
||||
public class SqlGenerationFactory {
|
||||
|
||||
private static Map<SqlGenerationMode, SqlGeneration> sqlGenerationMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static SqlGeneration get(SqlGenerationMode strategyType) {
|
||||
return sqlGenerationMap.get(strategyType);
|
||||
}
|
||||
|
||||
public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) {
|
||||
sqlGenerationMap.put(strategy, sqlGeneration);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoStepsSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public String generation(LLMReq llmReq, String modelClusterKey) {
|
||||
String text2sqlCollectionName = optimizationConfig.getText2sqlCollectionName();
|
||||
int text2sqlFewShotsNum = optimizationConfig.getText2sqlFewShotsNum();
|
||||
String queryText = llmReq.getQueryText();
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(queryText, text2sqlCollectionName,
|
||||
text2sqlFewShotsNum);
|
||||
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
|
||||
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
|
||||
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
|
||||
linking, sqlExamples);
|
||||
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
|
||||
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
|
||||
|
||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
|
||||
llmReq.getCurrentDate(), sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||
return sqlResult.content().text();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEPS, this);
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,8 @@ public class LLMReq {
|
||||
|
||||
private String priorExts;
|
||||
|
||||
private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_STEPS;
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
|
||||
@@ -43,4 +45,25 @@ public class LLMReq {
|
||||
|
||||
private String tableName;
|
||||
}
|
||||
|
||||
public enum SqlGenerationMode {
|
||||
|
||||
ONE_STEP("ONE_STEP"),
|
||||
|
||||
TWO_STEPS("TWO_STEPS"),
|
||||
|
||||
TWO_STEPS_WITH_CS("TWO_STEPS_WITH_CS");
|
||||
|
||||
|
||||
private String name;
|
||||
|
||||
SqlGenerationMode(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,20 +53,23 @@ langchain4j:
|
||||
temperature: 0.0
|
||||
timeout: PT60S
|
||||
#2.embedding-model
|
||||
#2.1 in_memory
|
||||
embedding-model:
|
||||
provider: in_memory
|
||||
# embedding-model:
|
||||
# hugging-face:
|
||||
# access-token: hg_access_token
|
||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||
# timeout: 1h
|
||||
|
||||
#2.2 open_ai
|
||||
# embedding-model:
|
||||
# provider: open_ai
|
||||
# openai:
|
||||
# api-key: api_key
|
||||
# modelName: all-minilm-l6-v2.onnx
|
||||
|
||||
#2.2 hugging_face
|
||||
# embedding-model:
|
||||
# provider: hugging_face
|
||||
# hugging-face:
|
||||
# access-token: hg_access_token
|
||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||
# timeout: 1h
|
||||
|
||||
#langchain4j log
|
||||
logging:
|
||||
|
||||
Reference in New Issue
Block a user