mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query. * fix bug. * upgrade the config module, now support config llm suppoted by langchain. * fix conflicts. * update text2sql config reload to be compitable with new config format. * modify default config. * 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules. * merege with upstream master * add general retrieve service. * add api service for sql_agent for crud opereations of few-shots examples. * modify requirements * add auto-cot feature --------- Co-authored-by: shaoweigong <shaoweigong@tencent.com>
98 lines
4.5 KiB
Python
98 lines
4.5 KiB
Python
# -*- coding:utf-8 -*-
|
|
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from typing import Any, List, Mapping, Optional, Union
|
|
|
|
import chromadb
|
|
from chromadb import Client
|
|
from chromadb.config import Settings
|
|
from chromadb.api import Collection, Documents, Embeddings
|
|
from chromadb.api.types import CollectionMetadata
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from instances.logging_instance import logger
|
|
from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_collection,
|
|
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
|
get_chroma_collection_by_ids, get_chroma_collection_size,
|
|
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
|
|
empty_chroma_collection_2)
|
|
|
|
from utils.text2vec import Text2VecEmbeddingFunction
|
|
|
|
class ChromaCollectionRetriever(object):
|
|
def __init__(self, collection:Collection):
|
|
self.collection = collection
|
|
|
|
def retrieval_query_run(self, query_texts_list:List[str]=None, query_embeddings:Embeddings=None,
|
|
filter_condition:Mapping[str,str]=None, n_results:int=5):
|
|
|
|
retrieval_res = query_chroma_collection(self.collection, query_texts_list, query_embeddings,
|
|
filter_condition, n_results)
|
|
|
|
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
|
|
logger.debug('parsed_retrieval_res: {}', parsed_retrieval_res)
|
|
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, query_embeddings, parsed_retrieval_res)
|
|
logger.debug('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
|
|
|
|
return parsed_retrieval_res_format
|
|
|
|
def get_query_by_ids(self, query_ids:List[str]):
|
|
queries = get_chroma_collection_by_ids(self.collection, query_ids)
|
|
return queries
|
|
|
|
def get_query_size(self):
|
|
return get_chroma_collection_size(self.collection)
|
|
|
|
def add_queries(self, query_text_list:List[str],
|
|
query_id_list:List[str],
|
|
metadatas:List[Mapping[str, str]]=None,
|
|
embeddings:Embeddings=None):
|
|
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
|
return True
|
|
|
|
def update_queries(self, query_text_list:List[str],
|
|
query_id_list:List[str],
|
|
metadatas:List[Mapping[str, str]]=None,
|
|
embeddings:Embeddings=None):
|
|
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
|
return True
|
|
|
|
def delete_queries_by_ids(self, query_ids:List[str]):
|
|
delete_chroma_collection_by_ids(self.collection, query_ids)
|
|
return True
|
|
|
|
def empty_query_collection(self):
|
|
self.collection = empty_chroma_collection_2(self.collection)
|
|
|
|
return True
|
|
|
|
class CollectionManager(object):
|
|
def __init__(self, chroma_client:Client, embedding_func: Text2VecEmbeddingFunction, collection_meta: Optional[CollectionMetadata] = None):
|
|
self.chroma_client = chroma_client
|
|
self.embedding_func = embedding_func
|
|
self.collection_meta = collection_meta
|
|
|
|
def list_collections(self):
|
|
collection_list = self.chroma_client.list_collections()
|
|
return collection_list
|
|
|
|
def get_collection(self, collection_name:str):
|
|
collection = self.chroma_client.get_collection(name=collection_name, embedding_function=self.embedding_func)
|
|
return collection
|
|
|
|
def create_collection(self, collection_name:str):
|
|
collection = self.chroma_client.create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
|
return collection
|
|
|
|
def get_or_create_collection(self, collection_name:str):
|
|
collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
|
return collection
|
|
|
|
def delete_collection(self, collection_name:str):
|
|
self.chroma_client.delete_collection(collection_name)
|
|
return True |