mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
[improvement](chat) Support sqlGenerationMode in python interface (#521)
This commit is contained in:
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
|||||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
||||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -35,7 +36,8 @@ public class JavaLLMProxy implements LLMProxy {
|
|||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
|
||||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(llmReq.getSqlGenerationMode());
|
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||||
|
SqlGenerationMode.valueOf(llmReq.getSqlGenerationMode()));
|
||||||
String modelName = llmReq.getSchema().getModelName();
|
String modelName = llmReq.getSchema().getModelName();
|
||||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||||
|
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ public class LLMRequestService {
|
|||||||
currentDate = DateUtils.getBeforeDate(0);
|
currentDate = DateUtils.getBeforeDate(0);
|
||||||
}
|
}
|
||||||
llmReq.setCurrentDate(currentDate);
|
llmReq.setCurrentDate(currentDate);
|
||||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode());
|
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
|
||||||
return llmReq;
|
return llmReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ public class LLMReq {
|
|||||||
|
|
||||||
private String priorExts;
|
private String priorExts;
|
||||||
|
|
||||||
private SqlGenerationMode sqlGenerationMode;
|
private String sqlGenerationMode;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class ElementValue {
|
public static class ElementValue {
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ async def query2sql(query_body: Mapping[str, Any]):
|
|||||||
if 'sqlGenerationMode' not in query_body:
|
if 'sqlGenerationMode' not in query_body:
|
||||||
raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body")
|
raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body")
|
||||||
else:
|
else:
|
||||||
sql_generation_mode = query_body['sql_generation_mode']
|
sql_generation_mode = query_body['sqlGenerationMode']
|
||||||
|
|
||||||
model_name = schema['modelName']
|
model_name = schema['modelName']
|
||||||
fields_list = schema['fieldNameList']
|
fields_list = schema['fieldNameList']
|
||||||
|
|||||||
Reference in New Issue
Block a user