Add SqlGeneration abstraction and implementation, optimize LLMSqlParser skip. (#487)

This commit is contained in:
lexluo09
2023-12-10 11:34:48 +08:00
committed by GitHub
parent 0e0ba51750
commit 6af661459c
10 changed files with 176 additions and 58 deletions

View File

@@ -1,68 +1,42 @@
package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Slf4j
public class EmbedLLMProxy implements LLMProxy {
@Override
public boolean isSkip(QueryContext queryContext) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
if (Objects.isNull(chatLanguageModel)) {
log.warn("chatLanguageModel is null, skip EmbedLLMProxy");
return true;
}
return false;
}
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum());
String queryText = llmReq.getQueryText();
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
String modelName = llmReq.getSchema().getModelName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
linking, sqlExamples);
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
llmReq.getCurrentDate(), sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
String sql = sqlGeneration.generation(llmReq, modelClusterKey);
LLMResp result = new LLMResp();
result.setQuery(queryText);
result.setSchemaLinkingOutput(linkingPromptStr);
result.setSchemaLinkStr(schemaLinkStr);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
result.setSqlOutput(sqlResult.content().text());
result.setSqlOutput(sql);
return result;
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
@@ -12,6 +13,8 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
*/
public interface LLMProxy {
boolean isSkip(QueryContext queryContext);
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
FunctionResp requestFunction(FunctionReq functionReq);

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.parser;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
@@ -12,6 +13,7 @@ import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.net.URL;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@@ -26,6 +28,16 @@ import org.springframework.web.util.UriComponentsBuilder;
@Slf4j
public class PythonLLMProxy implements LLMProxy {
@Override
public boolean isSkip(QueryContext queryContext) {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
log.warn("llmParserUrl is empty, skip PythonLLMProxy, config:{}", llmParserConfig);
return true;
}
return false;
}
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
long startTime = System.currentTimeMillis();

View File

@@ -60,13 +60,11 @@ public class LLMRequestService {
private OptimizationConfig optimizationConfig;
public boolean isSkip(QueryContext queryCtx) {
QueryReq request = queryCtx.getRequest();
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMSqlParser.class, llmParserConfig);
if (llmProxy.isSkip(queryCtx)) {
return true;
}
if (SatisfactionChecker.isSkip(queryCtx)) {
log.info("skip {}, queryText:{}", LLMSqlParser.class, request.getQueryText());
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText());
return true;
}
return false;
@@ -104,7 +102,7 @@ public class LLMRequestService {
}
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
ModelCluster modelCluster, List<ElementValue> linkingValues) {
ModelCluster modelCluster, List<ElementValue> linkingValues) {
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText();
@@ -146,7 +144,7 @@ public class LLMRequestService {
}
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
LLMParserConfig llmParserConfig) {
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
public interface SqlGeneration {
String generation(LLMReq llmReq, String modelClusterKey);
}

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Sql generation factory
*/
public class SqlGenerationFactory {
private static Map<SqlGenerationMode, SqlGeneration> sqlGenerationMap = new ConcurrentHashMap<>();
public static SqlGeneration get(SqlGenerationMode strategyType) {
return sqlGenerationMap.get(strategyType);
}
public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) {
sqlGenerationMap.put(strategy, sqlGeneration);
}
}

View File

@@ -0,0 +1,73 @@
package com.tencent.supersonic.chat.parser.sql.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class TwoStepsSqlGeneration implements SqlGeneration, InitializingBean {
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExampleLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Override
public String generation(LLMReq llmReq, String modelClusterKey) {
String text2sqlCollectionName = optimizationConfig.getText2sqlCollectionName();
int text2sqlFewShotsNum = optimizationConfig.getText2sqlFewShotsNum();
String queryText = llmReq.getQueryText();
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(queryText, text2sqlCollectionName,
text2sqlFewShotsNum);
String modelName = llmReq.getSchema().getModelName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
linking, sqlExamples);
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
llmReq.getCurrentDate(), sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
return sqlResult.content().text();
}
@Override
public void afterPropertiesSet() {
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_STEPS, this);
}
}

View File

@@ -18,6 +18,8 @@ public class LLMReq {
private String priorExts;
private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_STEPS;
@Data
public static class ElementValue {
@@ -43,4 +45,25 @@ public class LLMReq {
private String tableName;
}
public enum SqlGenerationMode {
ONE_STEP("ONE_STEP"),
TWO_STEPS("TWO_STEPS"),
TWO_STEPS_WITH_CS("TWO_STEPS_WITH_CS");
private String name;
SqlGenerationMode(String name) {
this.name = name;
}
public String getName() {
return name;
}
}
}

View File

@@ -23,4 +23,4 @@ public class WordBuilderFactory {
public static BaseWordBuilder get(DictWordType strategyType) {
return wordNatures.get(strategyType);
}
}
}