mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(pyllm)Use HTTP parameter llm_config in place of the default llm_config
This commit is contained in:
@@ -4,11 +4,13 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -50,7 +52,7 @@ public class MultiTurnParser implements ChatParser {
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class);
|
||||
if (Boolean.FALSE.equals(multiTurn)) {
|
||||
if (!Boolean.TRUE.equals(multiTurn)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -73,6 +75,7 @@ public class MultiTurnParser implements ChatParser {
|
||||
.histQuestion(lastParseResult.getQueryText())
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.llmConfig(queryReq.getLlmConfig())
|
||||
.build());
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
@@ -80,7 +83,6 @@ public class MultiTurnParser implements ChatParser {
|
||||
}
|
||||
|
||||
private String rewriteQuery(RewriteContext context) {
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("curtQuestion", context.getCurtQuestion());
|
||||
variables.put("histQuestion", context.getHistQuestion());
|
||||
@@ -89,14 +91,13 @@ public class MultiTurnParser implements ChatParser {
|
||||
|
||||
Prompt prompt = promptTemplate.apply(variables);
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
//3.format response.
|
||||
String rewriteQuery = response.content().text();
|
||||
|
||||
return rewriteQuery;
|
||||
return response.content().text();
|
||||
}
|
||||
|
||||
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
|
||||
@@ -142,5 +143,6 @@ public class MultiTurnParser implements ChatParser {
|
||||
private String histQuestion;
|
||||
private String curtSchema;
|
||||
private String histSchema;
|
||||
private LLMConfig llmConfig;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,8 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getDataSetElementMatches());
|
||||
log.debug("before {},mapInfo:{}", simpleName,
|
||||
queryContext.getMapInfo().getDataSetElementMatches());
|
||||
|
||||
try {
|
||||
doMap(queryContext);
|
||||
@@ -37,7 +38,8 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getDataSetElementMatches());
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
|
||||
queryContext.getMapInfo().getDataSetElementMatches());
|
||||
}
|
||||
|
||||
private void filter(QueryContext queryContext) {
|
||||
@@ -130,7 +132,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
SchemaElement elementDb = dataSetSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
log.warn("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
return null;
|
||||
}
|
||||
BeanUtils.copyProperties(elementDb, element);
|
||||
|
||||
@@ -30,7 +30,7 @@ import java.util.ArrayList;
|
||||
@Component
|
||||
public class PythonLLMProxy implements LLMProxy {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger(PythonLLMProxy.class);
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
|
||||
@@ -9,13 +9,13 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from config.config_parse import LLM_PROVIDER_NAME, llm_config_dict
|
||||
|
||||
|
||||
def get_llm_provider(llm_provider_name: str, llm_config_dict: dict):
|
||||
if llm_provider_name in llms.type_to_cls_dict:
|
||||
llm_provider = llms.type_to_cls_dict[llm_provider_name]
|
||||
llm = llm_provider(**llm_config_dict)
|
||||
def get_llm(llm_config: dict):
|
||||
if LLM_PROVIDER_NAME in llms.type_to_cls_dict:
|
||||
llm_provider = llms.type_to_cls_dict[LLM_PROVIDER_NAME]
|
||||
if llm_config is None:
|
||||
llm = llm_provider(**llm_config_dict)
|
||||
else:
|
||||
llm = llm_provider(**llm_config)
|
||||
return llm
|
||||
else:
|
||||
raise Exception("llm_provider_name is not supported: {}".format(llm_provider_name))
|
||||
|
||||
|
||||
llm = get_llm_provider(LLM_PROVIDER_NAME, llm_config_dict)
|
||||
raise Exception("llm_provider_name is not supported: {}".format(LLM_PROVIDER_NAME))
|
||||
@@ -12,17 +12,15 @@ from plugin_call.prompt_construct import (
|
||||
construct_task_prompt,
|
||||
plugin_selection_output_parse,
|
||||
)
|
||||
from instances.llm_instance import llm
|
||||
|
||||
# def plugin_selection_run(
|
||||
# query_text: str, plugin_configs: List[Mapping[str, Any]]
|
||||
# ) -> Union[Mapping[str, str], None]:
|
||||
|
||||
def plugin_selection_run(
|
||||
query_text: str, plugin_configs: List[Mapping[str, Any]]
|
||||
) -> Union[Mapping[str, str], None]:
|
||||
# tools_prompt = construct_plugin_pool_prompt(plugin_configs)
|
||||
|
||||
tools_prompt = construct_plugin_pool_prompt(plugin_configs)
|
||||
# task_prompt = construct_task_prompt(query_text, tools_prompt)
|
||||
# llm_output = llm(task_prompt)
|
||||
# parsed_output = plugin_selection_output_parse(llm_output)
|
||||
|
||||
task_prompt = construct_task_prompt(query_text, tools_prompt)
|
||||
llm_output = llm(task_prompt)
|
||||
parsed_output = plugin_selection_output_parse(llm_output)
|
||||
|
||||
return parsed_output
|
||||
# return parsed_output
|
||||
|
||||
@@ -14,7 +14,6 @@ import json
|
||||
from s2sql.constructor import FewShotPromptTemplate2
|
||||
from s2sql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper
|
||||
|
||||
from instances.llm_instance import llm
|
||||
from instances.chromadb_instance import client as chromadb_client
|
||||
from instances.logging_instance import logger
|
||||
from instances.text2vec_instance import emb_func
|
||||
@@ -40,9 +39,9 @@ text2dsl_agent_act_example_prompter = FewShotPromptTemplate2(collection=text2dsl
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
|
||||
sql_example_prompter=text2dsl_agent_example_prompter)
|
||||
text2sql_agent_autoCoT = Text2DSLAgentAutoCoT(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agent_act_example_prompter, llm=llm,
|
||||
sql_example_prompter=text2dsl_agent_act_example_prompter,
|
||||
auto_cot_min_window_size=ACT_MIN_WINDOWN_SIZE, auto_cot_max_window_size=ACT_MAX_WINDOWN_SIZE)
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_exemplars))]
|
||||
|
||||
@@ -17,17 +17,17 @@ from instances.logging_instance import logger
|
||||
from s2sql.constructor import FewShotPromptTemplate2
|
||||
from s2sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
from s2sql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
|
||||
from instances.llm_instance import get_llm
|
||||
|
||||
|
||||
class Text2DSLAgentBase(object):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
|
||||
sql_example_prompter:FewShotPromptTemplate2) -> None:
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
|
||||
self.llm = llm
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
|
||||
def get_examples_candidates(self, question: str, filter_condition: Mapping[str, str], num_examples: int)->List[Mapping[str, str]]:
|
||||
@@ -82,9 +82,9 @@ class Text2DSLAgentBase(object):
|
||||
|
||||
class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,
|
||||
sql_example_prompter:FewShotPromptTemplate2,
|
||||
auto_cot_min_window_size: int, auto_cot_max_window_size: int):
|
||||
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
|
||||
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter)
|
||||
|
||||
assert auto_cot_min_window_size <= auto_cot_max_window_size
|
||||
self.auto_cot_min_window_size = auto_cot_min_window_size
|
||||
@@ -218,7 +218,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
llm_config:dict):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -230,7 +231,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
llm = get_llm(llm_config)
|
||||
schema_link_output = await llm._call_async(schema_linking_prompt)
|
||||
logger.debug("schema_link_output->{}".format(schema_link_output))
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
@@ -238,7 +240,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
sql_output = await llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
@@ -261,7 +263,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
llm_config:dict):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -273,7 +276,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
logger.debug("schema_linking_sql_shortcut_output->{}".format(schema_linking_sql_shortcut_output))
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
@@ -298,11 +302,13 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
return resp
|
||||
|
||||
async def generate_schema_linking_tasks(self, question: str, model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug("schema_linking_prompt_pool->{}".format(schema_linking_prompt_pool))
|
||||
schema_linking_output_pool = await asyncio.gather(*[self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool])
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_output_pool = await asyncio.gather(*[llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool])
|
||||
logger.debug("schema_linking_output_pool->{}".format(schema_linking_output_pool))
|
||||
|
||||
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
|
||||
@@ -310,19 +316,22 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
|
||||
|
||||
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
|
||||
sql_output_pool = await asyncio.gather(*[self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
|
||||
llm = get_llm(llm_config)
|
||||
sql_output_pool = await asyncio.gather(*[llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
|
||||
logger.debug("sql_output_pool->{}".format(sql_output_pool))
|
||||
|
||||
return sql_output_pool, sql_prompt_pool
|
||||
|
||||
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict):
|
||||
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
||||
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
||||
logger.debug("schema_linking_sql_output_res_pool->{}".format(schema_linking_sql_output_res_pool))
|
||||
|
||||
@@ -330,7 +339,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -342,14 +351,14 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
@@ -374,7 +383,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
return resp
|
||||
|
||||
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -420,8 +429,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
class Text2DSLAgent(Text2DSLAgentBase):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,) -> None:
|
||||
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
|
||||
sql_example_prompter:FewShotPromptTemplate2) -> None:
|
||||
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter)
|
||||
|
||||
def reload_setting(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
||||
self.num_fewshots = num_fewshots
|
||||
@@ -554,12 +563,13 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
||||
|
||||
async def generate_schema_linking_tasks(self, question: str, domain_name: str,
|
||||
fields_list: List[str], prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, domain_name,
|
||||
fields_list, prior_schema_links,
|
||||
fewshot_example_list_combo)
|
||||
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_output_task_pool = [llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
||||
schema_linking_output_pool = await asyncio.gather(*schema_linking_output_task_pool)
|
||||
logger.debug(f'schema_linking_output_pool:{schema_linking_output_pool}')
|
||||
|
||||
@@ -568,25 +578,29 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
||||
return schema_linking_str_pool
|
||||
|
||||
async def generate_sql_tasks(self, question: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]],
|
||||
llm_config: dict):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(question, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
|
||||
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
||||
llm = get_llm(llm_config)
|
||||
sql_output_task_pool = [llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
||||
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
|
||||
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
|
||||
|
||||
return sql_output_res_pool
|
||||
|
||||
async def generate_schema_linking_sql_tasks(self, question: str, domain_name: str, fields_list: List[str], data_date: str,
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_combo:List[List[Mapping[str, str]]],
|
||||
llm_config: dict):
|
||||
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, domain_name, fields_list, data_date, prior_schema_links, fewshot_example_list_combo)
|
||||
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
||||
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
||||
logger.debug(f'schema_linking_sql_output_res_pool:{schema_linking_sql_output_res_pool}')
|
||||
|
||||
return schema_linking_sql_output_res_pool
|
||||
|
||||
async def tasks_run(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
||||
async def tasks_run(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str, llm_config: dict):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("domain_name: {}".format(domain_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
@@ -601,7 +615,7 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
|
||||
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo, llm_config)
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
@@ -675,7 +689,7 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
||||
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
@@ -690,13 +704,14 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(question, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
llm = get_llm(llm_config)
|
||||
schema_link_output = await llm._call_async(schema_linking_prompt)
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(question, model_name, schema_link_str, data_date, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
sql_output = await llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
@@ -716,7 +731,8 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
||||
|
||||
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
llm_config: dict):
|
||||
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -732,7 +748,8 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
||||
@@ -764,26 +781,26 @@ class Text2DSLAgentWrapper(object):
|
||||
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str):
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict):
|
||||
|
||||
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
|
||||
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
|
||||
|
||||
if sql_generation_mode == '1_pass_auto_cot':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
return resp
|
||||
elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
return resp
|
||||
elif sql_generation_mode == '2_pass_auto_cot':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
return resp
|
||||
elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
return resp
|
||||
else:
|
||||
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
import ast
|
||||
from typing import Any, Mapping
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from services.s2sql.run import text2sql_agent_router
|
||||
|
||||
@@ -50,12 +51,20 @@ async def query2sql(query_body: Mapping[str, Any]):
|
||||
else:
|
||||
sql_generation_mode = query_body['sqlGenerationMode']
|
||||
|
||||
if 'llmConfig' in query_body:
|
||||
llm_config = ast.literal_eval(str(query_body['llmConfig']))
|
||||
else:
|
||||
llm_config = None
|
||||
|
||||
dataset_name = schema['dataSetName']
|
||||
fields_list = schema['fieldNameList']
|
||||
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
||||
|
||||
resp = await text2sql_agent_router.async_query2sql(question=query_text, filter_condition=filter_condition, model_name=dataset_name, fields_list=fields_list,
|
||||
data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, sql_generation_mode=sql_generation_mode)
|
||||
resp = await text2sql_agent_router.async_query2sql(question=query_text, filter_condition=filter_condition,
|
||||
model_name=dataset_name, fields_list=fields_list,
|
||||
data_date=current_date, prior_schema_links=prior_schema_links,
|
||||
prior_exts=prior_exts, sql_generation_mode=sql_generation_mode,
|
||||
llm_config=llm_config)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@@ -7,14 +7,12 @@ import uvicorn
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI
|
||||
|
||||
from config.config_parse import LLMPARSER_HOST, LLMPARSER_PORT
|
||||
|
||||
from services_router import (query2sql_service, preset_query_service,
|
||||
solved_query_service, plugin_call_service, retriever_service)
|
||||
solved_query_service, retriever_service)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
@@ -26,7 +24,7 @@ def read_health():
|
||||
app.include_router(preset_query_service.router)
|
||||
app.include_router(solved_query_service.router)
|
||||
app.include_router(query2sql_service.router)
|
||||
app.include_router(plugin_call_service.router)
|
||||
#app.include_router(plugin_call_service.router)
|
||||
app.include_router(retriever_service.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -22,8 +22,6 @@ import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -177,10 +175,7 @@ public class ChatDemoLoader implements CommandLineRunner {
|
||||
agentConfig.getTools().add(llmParserTool);
|
||||
}
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
LLMConfig llmConfig = new LLMConfig(S2ModelProvider.OPEN_AI.name(),
|
||||
"", "your_key", "gpt-3.5-turbo");
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false);
|
||||
agent.setLlmConfig(llmConfig);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
agentService.createAgent(agent, User.getFakeUser());
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Replace with your LLM configs
|
||||
# Note: The default API key `demo` is provided by langchain4j community
|
||||
# which limits 1000 tokens per request.
|
||||
OPENAI_API_BASE=https://api.openai.com/v1
|
||||
OPENAI_API_BASE=http://langchain4j.dev/demo/openai/v1
|
||||
OPENAI_API_KEY=demo
|
||||
OPENAI_MODEL_NAME=gpt-3.5-turbo
|
||||
OPENAI_TEMPERATURE=0.0
|
||||
|
||||
Reference in New Issue
Block a user