mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
add auto-CoT feature (#483)
* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query. * fix bug. * upgrade the config module, now support config llm suppoted by langchain. * fix conflicts. * update text2sql config reload to be compitable with new config format. * modify default config. * 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules. * merege with upstream master * add general retrieve service. * add api service for sql_agent for crud opereations of few-shots examples. * modify requirements * add auto-cot feature --------- Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
@@ -8,7 +8,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.sql.run import text2sql_agent_router
|
||||
from services.s2ql.run import text2sql_agent_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -45,12 +45,17 @@ async def query2sql(query_body: Mapping[str, Any]):
|
||||
else:
|
||||
filter_condition = query_body['filterCondition']
|
||||
|
||||
if 'sql_generation_mode' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sql_generation_mode is not in query_body")
|
||||
else:
|
||||
sql_generation_mode = query_body['sql_generation_mode']
|
||||
|
||||
model_name = schema['modelName']
|
||||
fields_list = schema['fieldNameList']
|
||||
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
||||
|
||||
resp = await text2sql_agent_router.async_query2sql(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list,
|
||||
data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
resp = await text2sql_agent_router.async_query2sql(question=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list,
|
||||
data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, sql_generation_mode=sql_generation_mode)
|
||||
|
||||
return resp
|
||||
|
||||
@@ -82,18 +87,7 @@ def query2sql_setting_update(query_body: Mapping[str, Any]):
|
||||
else:
|
||||
self_consistency_nums = query_body['selfConsistencyNums']
|
||||
|
||||
if 'isShortcut' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
|
||||
else:
|
||||
is_shortcut = query_body['isShortcut']
|
||||
|
||||
if 'isSelfConsistency' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="isSelfConsistency is not in query_body")
|
||||
else:
|
||||
is_self_consistency = query_body['isSelfConsistency']
|
||||
|
||||
text2sql_agent_router.update_configs(is_shortcut=is_shortcut, is_self_consistency=is_self_consistency,
|
||||
sql_example_ids=sql_ids, sql_example_units=sql_examplars,
|
||||
text2sql_agent_router.update_configs(sql_example_ids=sql_ids, sql_example_units=sql_examplars,
|
||||
num_examples=example_nums, num_fewshots=fewshot_nums, num_self_consistency=self_consistency_nums)
|
||||
|
||||
return "success"
|
||||
@@ -112,8 +106,7 @@ def query2sql_add_examples(query_body: Mapping[str, Any]):
|
||||
else:
|
||||
sql_examplars = query_body['sqlExamplars']
|
||||
|
||||
text2sql_agent_router.sql_agent.add_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
text2sql_agent_router.sql_agent_cs.add_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
text2sql_agent_router.add_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -131,8 +124,7 @@ def query2sql_update_examples(query_body: Mapping[str, Any]):
|
||||
else:
|
||||
sql_examplars = query_body['sqlExamplars']
|
||||
|
||||
text2sql_agent_router.sql_agent.update_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
text2sql_agent_router.sql_agent_cs.update_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
text2sql_agent_router.update_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -144,18 +136,24 @@ def query2sql_delete_examples(query_body: Mapping[str, Any]):
|
||||
else:
|
||||
sql_ids = query_body['sqlIds']
|
||||
|
||||
text2sql_agent_router.sql_agent.delete_examples(sql_example_ids=sql_ids)
|
||||
text2sql_agent_router.sql_agent_cs.delete_examples(sql_example_ids=sql_ids)
|
||||
text2sql_agent_router.delete_examples(sql_example_ids=sql_ids)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/query2sql_get_examples")
|
||||
def query2sql_get_examples(query_body: Mapping[str, Any]):
|
||||
if 'sqlIds' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
|
||||
else:
|
||||
sql_ids = query_body['sqlIds']
|
||||
|
||||
examples = text2sql_agent_router.get_examples(sql_example_ids=sql_ids)
|
||||
|
||||
return examples
|
||||
|
||||
@router.get("/query2sql_count_examples")
|
||||
def query2sql_count_examples():
|
||||
sql_agent_examples_cnt = text2sql_agent_router.sql_agent.count_examples()
|
||||
sql_agent_cs_examples_cnt = text2sql_agent_router.sql_agent_cs.count_examples()
|
||||
examples_cnt = text2sql_agent_router.count_examples()
|
||||
|
||||
assert sql_agent_examples_cnt == sql_agent_cs_examples_cnt
|
||||
|
||||
return sql_agent_examples_cnt
|
||||
return examples_cnt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user