mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
[improvement](chat) Support sqlGenerationMode in python interface (#521)
This commit is contained in:
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -35,7 +36,8 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ public class LLMRequestService {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode());
|
||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ public class LLMReq {
|
||||
|
||||
private String priorExts;
|
||||
|
||||
private SqlGenerationMode sqlGenerationMode;
|
||||
private String sqlGenerationMode;
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
|
||||
@@ -48,7 +48,7 @@ async def query2sql(query_body: Mapping[str, Any]):
|
||||
if 'sqlGenerationMode' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body")
|
||||
else:
|
||||
sql_generation_mode = query_body['sql_generation_mode']
|
||||
sql_generation_mode = query_body['sqlGenerationMode']
|
||||
|
||||
model_name = schema['modelName']
|
||||
fields_list = schema['fieldNameList']
|
||||
|
||||
Reference in New Issue
Block a user