mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement][chat] Move python code out of chat-core module
This commit is contained in:
156
chat/python/utils/chromadb_utils.py
Normal file
156
chat/python/utils/chromadb_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb.api import Collection
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def empty_chroma_collection_2(collection:Collection):
|
||||
collection_name = collection.name
|
||||
client = collection._client
|
||||
metadata = collection.metadata
|
||||
embedding_function = collection._embedding_function
|
||||
|
||||
client.delete_collection(collection_name)
|
||||
|
||||
new_collection = client.get_or_create_collection(name=collection_name,
|
||||
metadata=metadata,
|
||||
embedding_function=embedding_function)
|
||||
|
||||
size_of_new_collection = new_collection.count()
|
||||
|
||||
logger.info(f'Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}')
|
||||
|
||||
return new_collection
|
||||
|
||||
|
||||
def empty_chroma_collection(collection:Collection) -> None:
|
||||
collection.delete()
|
||||
|
||||
|
||||
def add_chroma_collection(collection:Collection,
|
||||
queries:List[str],
|
||||
query_ids:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None
|
||||
) -> None:
|
||||
|
||||
collection.add(documents=queries,
|
||||
ids=query_ids,
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings)
|
||||
|
||||
|
||||
def update_chroma_collection(collection:Collection,
|
||||
queries:List[str],
|
||||
query_ids:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None
|
||||
) -> None:
|
||||
|
||||
collection.update(documents=queries,
|
||||
ids=query_ids,
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings)
|
||||
|
||||
|
||||
def query_chroma_collection(collection:Collection, query_texts:List[str]=None, query_embeddings:Embeddings=None,
|
||||
filter_condition:Mapping[str,str]=None, n_results:int=10):
|
||||
outer_opt = '$and'
|
||||
inner_opt = '$eq'
|
||||
|
||||
if filter_condition is not None:
|
||||
if len(filter_condition)==1:
|
||||
outer_filter = filter_condition
|
||||
else:
|
||||
inner_filter = [{_k: {inner_opt:_v}} for _k, _v in filter_condition.items()]
|
||||
outer_filter = {outer_opt: inner_filter}
|
||||
else:
|
||||
outer_filter = None
|
||||
|
||||
logger.info('outer_filter: {}'.format(outer_filter))
|
||||
|
||||
res = collection.query(query_texts=query_texts, query_embeddings=query_embeddings,
|
||||
n_results=n_results, where=outer_filter)
|
||||
return res
|
||||
|
||||
|
||||
def parse_retrieval_chroma_collection_query(res:List[Mapping[str, Any]]):
|
||||
parsed_res = [[] for _ in range(0, len(res['ids']))]
|
||||
|
||||
retrieval_ids = res['ids']
|
||||
retrieval_distances = res['distances']
|
||||
retrieval_sentences = res['documents']
|
||||
retrieval_metadatas = res['metadatas']
|
||||
|
||||
for query_idx in range(0, len(retrieval_ids)):
|
||||
id_ls = retrieval_ids[query_idx]
|
||||
distance_ls = retrieval_distances[query_idx]
|
||||
sentence_ls = retrieval_sentences[query_idx]
|
||||
metadata_ls = retrieval_metadatas[query_idx]
|
||||
|
||||
for idx in range(0, len(id_ls)):
|
||||
id = id_ls[idx]
|
||||
distance = distance_ls[idx]
|
||||
sentence = sentence_ls[idx]
|
||||
metadata = metadata_ls[idx]
|
||||
|
||||
parsed_res[query_idx].append({
|
||||
'id': id,
|
||||
'distance': distance,
|
||||
'query': sentence,
|
||||
'metadata': metadata
|
||||
})
|
||||
|
||||
return parsed_res
|
||||
|
||||
def chroma_collection_query_retrieval_format(query_list:List[str], query_embeddings:Embeddings ,retrieval_list:List[Mapping[str, Any]]):
|
||||
res = []
|
||||
|
||||
if query_list is not None and query_embeddings is not None:
|
||||
raise Exception("query_list and query_embeddings are not None")
|
||||
if query_list is None and query_embeddings is None:
|
||||
raise Exception("query_list and query_embeddings are None")
|
||||
|
||||
if query_list is not None:
|
||||
for query_idx in range(0, len(query_list)):
|
||||
query = query_list[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({
|
||||
'query': query,
|
||||
'retrieval': retrieval
|
||||
})
|
||||
else:
|
||||
for query_idx in range(0, len(query_embeddings)):
|
||||
query_embedding = query_embeddings[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({
|
||||
'query_embedding': query_embedding,
|
||||
'retrieval': retrieval
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def delete_chroma_collection_by_ids(collection:Collection, query_ids:List[str]) -> None:
|
||||
collection.delete(ids=query_ids)
|
||||
|
||||
def get_chroma_collection_by_ids(collection:Collection, query_ids:List[str]):
|
||||
res = collection.get(ids=query_ids)
|
||||
|
||||
return res
|
||||
|
||||
def get_chroma_collection_size(collection:Collection) -> int:
|
||||
return collection.count()
|
||||
|
||||
Reference in New Issue
Block a user