[improvement](chat) Support sqlGenerationMode in python interface (#521)

This commit is contained in:
lexluo09
2023-12-16 17:24:24 +08:00
committed by GitHub
parent f03da53d6f
commit 276b224c13
4 changed files with 6 additions and 4 deletions

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration; import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory; import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -35,7 +36,8 @@ public class JavaLLMProxy implements LLMProxy {
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) { public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode()); SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getModelName(); String modelName = llmReq.getSchema().getModelName();
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey); Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);

View File

@@ -132,7 +132,7 @@ public class LLMRequestService {
currentDate = DateUtils.getBeforeDate(0); currentDate = DateUtils.getBeforeDate(0);
} }
llmReq.setCurrentDate(currentDate); llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode()); llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
return llmReq; return llmReq;
} }

View File

@@ -19,7 +19,7 @@ public class LLMReq {
private String priorExts; private String priorExts;
private SqlGenerationMode sqlGenerationMode; private String sqlGenerationMode;
@Data @Data
public static class ElementValue { public static class ElementValue {

View File

@@ -48,7 +48,7 @@ async def query2sql(query_body: Mapping[str, Any]):
if 'sqlGenerationMode' not in query_body: if 'sqlGenerationMode' not in query_body:
raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body") raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body")
else: else:
sql_generation_mode = query_body['sql_generation_mode'] sql_generation_mode = query_body['sqlGenerationMode']
model_name = schema['modelName'] model_name = schema['modelName']
fields_list = schema['fieldNameList'] fields_list = schema['fieldNameList']