[improvement](chat) Support sqlGenerationMode in python interface (#521)

This commit is contained in:
lexluo09
2023-12-16 17:24:24 +08:00
committed by GitHub
parent f03da53d6f
commit 276b224c13
4 changed files with 6 additions and 4 deletions

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -35,7 +36,8 @@ public class JavaLLMProxy implements LLMProxy {
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getModelName();
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);

View File

@@ -132,7 +132,7 @@ public class LLMRequestService {
currentDate = DateUtils.getBeforeDate(0);
}
llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode());
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
return llmReq;
}

View File

@@ -19,7 +19,7 @@ public class LLMReq {
private String priorExts;
private SqlGenerationMode sqlGenerationMode;
private String sqlGenerationMode;
@Data
public static class ElementValue {

View File

@@ -48,7 +48,7 @@ async def query2sql(query_body: Mapping[str, Any]):
if 'sqlGenerationMode' not in query_body:
raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body")
else:
sql_generation_mode = query_body['sql_generation_mode']
sql_generation_mode = query_body['sqlGenerationMode']
model_name = schema['modelName']
fields_list = schema['fieldNameList']