From 276b224c13bbb180c4538ff498bc7caf2097972e Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sat, 16 Dec 2023 17:24:24 +0800 Subject: [PATCH] [improvement](chat) Support sqlGenerationMode in python interface (#521) --- .../java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java | 4 +++- .../supersonic/chat/parser/sql/llm/LLMRequestService.java | 2 +- .../com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java | 2 +- chat/python/services_router/query2sql_service.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java index d02ee0f26..57083c176 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/JavaLLMProxy.java @@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat; import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration; import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory; import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -35,7 +36,8 @@ public class JavaLLMProxy implements LLMProxy { public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) { - SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode()); + SqlGeneration sqlGeneration = SqlGenerationFactory.get( + SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode())); String modelName = llmReq.getSchema().getModelName(); Map sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java index 391d86893..9388184a8 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java @@ -132,7 +132,7 @@ public class LLMRequestService { currentDate = DateUtils.getBeforeDate(0); } llmReq.setCurrentDate(currentDate); - llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode()); + llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName()); return llmReq; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java index f4a78ab2a..c11676201 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java @@ -19,7 +19,7 @@ public class LLMReq { private String priorExts; - private SqlGenerationMode sqlGenerationMode; + private String sqlGenerationMode; @Data public static class ElementValue { diff --git a/chat/python/services_router/query2sql_service.py b/chat/python/services_router/query2sql_service.py index 75c6262d2..76bada405 100644 --- a/chat/python/services_router/query2sql_service.py +++ b/chat/python/services_router/query2sql_service.py @@ -48,7 +48,7 @@ async def query2sql(query_body: Mapping[str, Any]): if 'sqlGenerationMode' not in query_body: raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body") else: - sql_generation_mode = query_body['sql_generation_mode'] + sql_generation_mode = query_body['sqlGenerationMode'] model_name = schema['modelName'] fields_list = schema['fieldNameList']