[improvement][chat] Move python code out of chat-core module

This commit is contained in:
jerryjzhang
2023-11-16 09:58:25 +08:00
parent 13d8b9cff5
commit 8688c8c2b3
24 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from fastapi import APIRouter, Depends, HTTPException
from services.plugin_call.run import plugin_selection_run
router = APIRouter()
@router.post("/plugin_selection/")
async def tool_selection(query_body: Mapping[str, Any]):
if "queryText" not in query_body:
raise HTTPException(status_code=400, detail="query_text is not in query_body")
else:
query_text = query_body["queryText"]
if "pluginConfigs" not in query_body:
raise HTTPException(
status_code=400, detail="pluginConfigs is not in query_body"
)
else:
plugin_configs = query_body["pluginConfigs"]
resp = plugin_selection_run(query_text=query_text, plugin_configs=plugin_configs)
return resp

View File

@@ -0,0 +1,71 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from fastapi import APIRouter, Depends, HTTPException
from services.query_retrieval.run import preset_query_retriever
router = APIRouter()
@router.post("/preset_query_retrival")
def preset_query_retrival(query_text_list: List[str], n_results: int = 5):
parsed_retrieval_res_format = preset_query_retriever.retrieval_query_run(query_texts_list=query_text_list, filter_condition=None, n_results=n_results)
return parsed_retrieval_res_format
@router.post("/preset_query_add")
def preset_query_add(preset_info_list: List[Mapping[str, str]]):
preset_queries = []
preset_query_ids = []
for preset_info in preset_info_list:
preset_queries.append(preset_info['preset_query'])
preset_query_ids.append(preset_info['preset_query_id'])
preset_query_retriever.add_queries(query_text_list=preset_queries, query_id_list=preset_query_ids, metadatas=None)
return "success"
@router.post("/preset_query_update")
def preset_query_update(preset_info_list: List[Mapping[str, str]]):
preset_queries = []
preset_query_ids = []
for preset_info in preset_info_list:
preset_queries.append(preset_info['preset_query'])
preset_query_ids.append(preset_info['preset_query_id'])
preset_query_retriever.update_queries(query_text_list=preset_queries, query_id_list=preset_query_ids, metadatas=None)
return "success"
@router.get("/preset_query_empty")
def preset_query_empty():
preset_query_retriever.empty_query_collection()
return "success"
@router.post("/preset_delete_by_ids")
def preset_delete_by_ids(preset_query_ids: List[str]):
preset_query_retriever.delete_queries_by_ids(preset_query_ids)
return "success"
@router.post("/preset_get_by_ids")
def preset_get_by_ids(preset_query_ids: List[str]):
preset_queries = preset_query_retriever.get_query_by_ids(preset_query_ids)
return preset_queries
@router.get("/preset_query_size")
def preset_query_size():
size = preset_query_retriever.get_query_size()
return size

View File

@@ -0,0 +1,161 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from fastapi import APIRouter, Depends, HTTPException
from services.sql.run import text2sql_agent_router
router = APIRouter()
@router.post("/query2sql")
async def query2sql(query_body: Mapping[str, Any]):
if 'queryText' not in query_body:
raise HTTPException(status_code=400, detail="query_text is not in query_body")
else:
query_text = query_body['queryText']
if 'schema' not in query_body:
raise HTTPException(status_code=400, detail="schema is not in query_body")
else:
schema = query_body['schema']
if 'currentDate' not in query_body:
raise HTTPException(status_code=400, detail="currentDate is not in query_body")
else:
current_date = query_body['currentDate']
if 'linking' not in query_body:
raise HTTPException(status_code=400, detail="linking is not in query_body")
else:
linking = query_body['linking']
if 'priorExts' not in query_body:
raise HTTPException(status_code=400, detail="prior_exts is not in query_body")
else:
prior_exts = query_body['priorExts']
if 'filterCondition' not in query_body:
raise HTTPException(status_code=400, detail="filterCondition is not in query_body")
else:
filter_condition = query_body['filterCondition']
model_name = schema['modelName']
fields_list = schema['fieldNameList']
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
resp = await text2sql_agent_router.async_query2sql(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list,
data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
@router.post("/query2sql_setting_update")
def query2sql_setting_update(query_body: Mapping[str, Any]):
if 'sqlExamplars' not in query_body:
raise HTTPException(status_code=400, detail="sqlExamplars is not in query_body")
else:
sql_examplars = query_body['sqlExamplars']
if 'sqlIds' not in query_body:
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
else:
sql_ids = query_body['sqlIds']
if 'exampleNums' not in query_body:
raise HTTPException(status_code=400, detail="exampleNums is not in query_body")
else:
example_nums = query_body['exampleNums']
if 'fewshotNums' not in query_body:
raise HTTPException(status_code=400, detail="fewshotNums is not in query_body")
else:
fewshot_nums = query_body['fewshotNums']
if 'selfConsistencyNums' not in query_body:
raise HTTPException(status_code=400, detail="selfConsistencyNums is not in query_body")
else:
self_consistency_nums = query_body['selfConsistencyNums']
if 'isShortcut' not in query_body:
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
else:
is_shortcut = query_body['isShortcut']
if 'isSelfConsistency' not in query_body:
raise HTTPException(status_code=400, detail="isSelfConsistency is not in query_body")
else:
is_self_consistency = query_body['isSelfConsistency']
text2sql_agent_router.update_configs(is_shortcut=is_shortcut, is_self_consistency=is_self_consistency,
sql_example_ids=sql_ids, sql_example_units=sql_examplars,
num_examples=example_nums, num_fewshots=fewshot_nums, num_self_consistency=self_consistency_nums)
return "success"
@router.post("/query2sql_add_examples")
def query2sql_add_examples(query_body: Mapping[str, Any]):
if 'sqlIds' not in query_body:
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
else:
sql_ids = query_body['sqlIds']
if 'sqlExamplars' not in query_body:
raise HTTPException(status_code=400,
detail="sqlExamplars is not in query_body")
else:
sql_examplars = query_body['sqlExamplars']
text2sql_agent_router.sql_agent.add_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
text2sql_agent_router.sql_agent_cs.add_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
return "success"
@router.post("/query2sql_update_examples")
def query2sql_update_examples(query_body: Mapping[str, Any]):
if 'sqlIds' not in query_body:
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
else:
sql_ids = query_body['sqlIds']
if 'sqlExamplars' not in query_body:
raise HTTPException(status_code=400,
detail="sqlExamplars is not in query_body")
else:
sql_examplars = query_body['sqlExamplars']
text2sql_agent_router.sql_agent.update_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
text2sql_agent_router.sql_agent_cs.update_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
return "success"
@router.post("/query2sql_delete_examples")
def query2sql_delete_examples(query_body: Mapping[str, Any]):
if 'sqlIds' not in query_body:
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
else:
sql_ids = query_body['sqlIds']
text2sql_agent_router.sql_agent.delete_examples(sql_example_ids=sql_ids)
text2sql_agent_router.sql_agent_cs.delete_examples(sql_example_ids=sql_ids)
return "success"
@router.get("/query2sql_count_examples")
def query2sql_count_examples():
sql_agent_examples_cnt = text2sql_agent_router.sql_agent.count_examples()
sql_agent_cs_examples_cnt = text2sql_agent_router.sql_agent_cs.count_examples()
assert sql_agent_examples_cnt == sql_agent_cs_examples_cnt
return sql_agent_examples_cnt

View File

@@ -0,0 +1,156 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from fastapi import APIRouter, Depends, HTTPException
from services.query_retrieval.run import collection_manager
from services.query_retrieval.retriever import ChromaCollectionRetriever
router = APIRouter()
@router.get("/list_collections")
def list_collections():
collections = collection_manager.list_collections()
return collections
@router.get("/create_collection")
def create_collection(collection_name: str):
collection_manager.create_collection(collection_name)
return "success"
@router.get("/delete_collection")
def delete_collection(collection_name: str):
collection_manager.delete_collection(collection_name)
return "success"
@router.get("/get_collection")
def get_collection(collection_name: str):
collection = collection_manager.get_collection(collection_name)
return collection
@router.get("/get_or_create_collection")
def get_or_create_collection(collection_name: str):
collection = collection_manager.get_or_create_collection(collection_name)
return collection
@router.post("/add_query")
def query_add(collection_name:str, query_info_list: List[Mapping[str, Any]]):
queries = []
query_ids = []
metadatas = []
embeddings = []
for query_info in query_info_list:
queries.append(query_info['query'])
query_ids.append(query_info['queryId'])
metadatas.append(query_info['metadata'])
embeddings.append(query_info['queryEmbedding'])
if None in embeddings:
embeddings = None
if None in queries:
queries = None
if embeddings is None and queries is None:
raise HTTPException(status_code=400, detail="query and queryEmbedding are None")
if embeddings is not None and queries is not None:
raise HTTPException(status_code=400, detail="query and queryEmbedding are not None")
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
query_retriever = ChromaCollectionRetriever(collection=query_collection)
query_retriever.add_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas, embeddings=embeddings)
return "success"
@router.post("/update_query")
def update_query(collection_name:str, query_info_list: List[Mapping[str, Any]]):
queries = []
query_ids = []
metadatas = []
embeddings = []
for query_info in query_info_list:
queries.append(query_info['query'])
query_ids.append(query_info['queryId'])
metadatas.append(query_info['metadata'])
embeddings.append(query_info['queryEmbedding'])
if None in embeddings:
embeddings = None
if None in queries:
queries = None
if embeddings is None and queries is None:
raise HTTPException(status_code=400, detail="query and queryEmbedding are None")
if embeddings is not None and queries is not None:
raise HTTPException(status_code=400, detail="query and queryEmbedding are not None")
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
query_retriever = ChromaCollectionRetriever(collection=query_collection)
query_retriever.update_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas, embeddings=embeddings)
return "success"
@router.get("/empty_query")
def empty_query(collection_name:str):
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
query_retriever = ChromaCollectionRetriever(collection=query_collection)
query_retriever.empty_query_collection()
return "success"
@router.post("/delete_query_by_ids")
def delete_query_by_ids(collection_name:str, query_ids: List[str]):
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
query_retriever = ChromaCollectionRetriever(collection=query_collection)
query_retriever.delete_queries_by_ids(query_ids=query_ids)
return "success"
@router.post("/get_query_by_ids")
def get_query_by_ids(collection_name:str, query_ids: List[str]):
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
query_retriever = ChromaCollectionRetriever(collection=query_collection)
queries = query_retriever.get_query_by_ids(query_ids=query_ids)
return queries
@router.get("/query_size")
def query_size(collection_name:str):
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
query_retriever = ChromaCollectionRetriever(collection=query_collection)
size = query_retriever.get_query_size()
return size
@router.post("/retrieve_query")
def retrieve_query(collection_name:str, query_info: Mapping[str, Any], n_results:int=10):
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
query_retriever = ChromaCollectionRetriever(collection=query_collection)
query_texts_list = query_info['queryTextsList']
qeuery_embeddings = query_info['queryEmbeddings']
filter_condition = query_info['filterCondition']
if query_texts_list is None and qeuery_embeddings is None:
raise HTTPException(status_code=400, detail="query and queryEmbedding are None")
if query_texts_list is not None and qeuery_embeddings is not None:
raise HTTPException(status_code=400, detail="query and queryEmbedding are not None")
parsed_retrieval_res_format = query_retriever.retrieval_query_run(query_texts_list=query_texts_list,
query_embeddings=qeuery_embeddings,
filter_condition=filter_condition,
n_results=n_results)
return parsed_retrieval_res_format

View File

@@ -0,0 +1,80 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from fastapi import APIRouter, Depends, HTTPException
from services.query_retrieval.run import solved_query_retriever
router = APIRouter()
@router.post("/solved_query_retrival")
def solved_query_retrival(query_info: Mapping[str, Any], n_results: int = 5):
query_texts_list = query_info['queryTextsList']
filter_condition = query_info['filterCondition']
parsed_retrieval_res_format = solved_query_retriever.retrieval_query_run(query_texts_list=query_texts_list,
filter_condition=filter_condition,
n_results=n_results)
return parsed_retrieval_res_format
@router.post("/solved_query_add")
def add_solved_queries(sovled_query_info_list: List[Mapping[str, Any]]):
queries = []
query_ids = []
metadatas = []
for sovled_query_info in sovled_query_info_list:
queries.append(sovled_query_info['query'])
query_ids.append(sovled_query_info['query_id'])
metadatas.append(sovled_query_info['metadata'])
solved_query_retriever.add_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas)
return "success"
@router.post("/solved_query_update")
def solved_query_update(sovled_query_info_list: List[Mapping[str, Any]]):
queries = []
query_ids = []
metadatas = []
for sovled_query_info in sovled_query_info_list:
queries.append(sovled_query_info['query'])
query_ids.append(sovled_query_info['query_id'])
metadatas.append(sovled_query_info['metadata'])
solved_query_retriever.update_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas)
return "success"
@router.get("/solved_query_empty")
def solved_query_empty():
solved_query_retriever.empty_query_collection()
return "success"
@router.post("/solved_query_delete_by_ids")
def solved_query_delete_by_ids(query_ids: List[str]):
solved_query_retriever.delete_queries_by_ids(query_ids=query_ids)
return "success"
@router.post("/solved_query_get_by_ids")
def solved_query_get_by_ids(query_ids: List[str]):
queries = solved_query_retriever.get_query_by_ids(query_ids=query_ids)
return queries
@router.get("/solved_query_size")
def solved_query_size():
size = solved_query_retriever.get_query_size()
return size