mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
(improvement)(Chat) Move python module from Chat To Headless (#823)
Co-authored-by: jolunoluo
This commit is contained in:
98
headless/python/services/query_retrieval/retriever.py
Normal file
98
headless/python/services/query_retrieval/retriever.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb import Client
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_collection,
|
||||
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
||||
get_chroma_collection_by_ids, get_chroma_collection_size,
|
||||
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
|
||||
empty_chroma_collection_2)
|
||||
|
||||
from utils.text2vec import Text2VecEmbeddingFunction
|
||||
|
||||
class ChromaCollectionRetriever(object):
|
||||
def __init__(self, collection:Collection):
|
||||
self.collection = collection
|
||||
|
||||
def retrieval_query_run(self, query_texts_list:List[str]=None, query_embeddings:Embeddings=None,
|
||||
filter_condition:Mapping[str,str]=None, n_results:int=5):
|
||||
|
||||
retrieval_res = query_chroma_collection(self.collection, query_texts_list, query_embeddings,
|
||||
filter_condition, n_results)
|
||||
|
||||
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
|
||||
logger.debug('parsed_retrieval_res: {}', parsed_retrieval_res)
|
||||
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, query_embeddings, parsed_retrieval_res)
|
||||
logger.debug('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
|
||||
def get_query_by_ids(self, query_ids:List[str]):
|
||||
queries = get_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return queries
|
||||
|
||||
def get_query_size(self):
|
||||
return get_chroma_collection_size(self.collection)
|
||||
|
||||
def add_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None):
|
||||
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
||||
return True
|
||||
|
||||
def update_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None):
|
||||
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
||||
return True
|
||||
|
||||
def delete_queries_by_ids(self, query_ids:List[str]):
|
||||
delete_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return True
|
||||
|
||||
def empty_query_collection(self):
|
||||
self.collection = empty_chroma_collection_2(self.collection)
|
||||
|
||||
return True
|
||||
|
||||
class CollectionManager(object):
|
||||
def __init__(self, chroma_client:Client, embedding_func: Text2VecEmbeddingFunction, collection_meta: Optional[CollectionMetadata] = None):
|
||||
self.chroma_client = chroma_client
|
||||
self.embedding_func = embedding_func
|
||||
self.collection_meta = collection_meta
|
||||
|
||||
def list_collections(self):
|
||||
collection_list = self.chroma_client.list_collections()
|
||||
return collection_list
|
||||
|
||||
def get_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.get_collection(name=collection_name, embedding_function=self.embedding_func)
|
||||
return collection
|
||||
|
||||
def create_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
||||
return collection
|
||||
|
||||
def get_or_create_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
||||
return collection
|
||||
|
||||
def delete_collection(self, collection_name:str):
|
||||
self.chroma_client.delete_collection(collection_name)
|
||||
return True
|
||||
37
headless/python/services/query_retrieval/run.py
Normal file
37
headless/python/services/query_retrieval/run.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
from utils.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
|
||||
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
||||
from retriever import ChromaCollectionRetriever, CollectionManager
|
||||
|
||||
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
|
||||
collection_manager = CollectionManager(chroma_client=client, embedding_func=emb_func
|
||||
,collection_meta={"hnsw:space": "cosine"})
|
||||
|
||||
solved_query_collection = collection_manager.get_or_create_collection(collection_name=SOLVED_QUERY_COLLECTION_NAME)
|
||||
preset_query_collection = collection_manager.get_or_create_collection(collection_name=PRESET_QUERY_COLLECTION_NAME)
|
||||
|
||||
|
||||
solved_query_retriever = ChromaCollectionRetriever(solved_query_collection)
|
||||
preset_query_retriever = ChromaCollectionRetriever(preset_query_collection)
|
||||
|
||||
logger.info("init_solved_query_collection_size: {}".format(solved_query_retriever.get_query_size()))
|
||||
logger.info("init_preset_query_collection_size: {}".format(preset_query_retriever.get_query_size()))
|
||||
Reference in New Issue
Block a user