Files
supersonic/chat/python/utils/chromadb_utils.py

157 lines
5.2 KiB
Python

# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union
import chromadb
from chromadb.api import Collection
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
def empty_chroma_collection_2(collection:Collection):
collection_name = collection.name
client = collection._client
metadata = collection.metadata
embedding_function = collection._embedding_function
client.delete_collection(collection_name)
new_collection = client.get_or_create_collection(name=collection_name,
metadata=metadata,
embedding_function=embedding_function)
size_of_new_collection = new_collection.count()
logger.info(f'Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}')
return new_collection
def empty_chroma_collection(collection:Collection) -> None:
collection.delete()
def add_chroma_collection(collection:Collection,
queries:List[str],
query_ids:List[str],
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None
) -> None:
collection.add(documents=queries,
ids=query_ids,
metadatas=metadatas,
embeddings=embeddings)
def update_chroma_collection(collection:Collection,
queries:List[str],
query_ids:List[str],
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None
) -> None:
collection.update(documents=queries,
ids=query_ids,
metadatas=metadatas,
embeddings=embeddings)
def query_chroma_collection(collection:Collection, query_texts:List[str]=None, query_embeddings:Embeddings=None,
filter_condition:Mapping[str,str]=None, n_results:int=10):
outer_opt = '$and'
inner_opt = '$eq'
if filter_condition is not None:
if len(filter_condition)==1:
outer_filter = filter_condition
else:
inner_filter = [{_k: {inner_opt:_v}} for _k, _v in filter_condition.items()]
outer_filter = {outer_opt: inner_filter}
else:
outer_filter = None
logger.info('outer_filter: {}'.format(outer_filter))
res = collection.query(query_texts=query_texts, query_embeddings=query_embeddings,
n_results=n_results, where=outer_filter)
return res
def parse_retrieval_chroma_collection_query(res:List[Mapping[str, Any]]):
parsed_res = [[] for _ in range(0, len(res['ids']))]
retrieval_ids = res['ids']
retrieval_distances = res['distances']
retrieval_sentences = res['documents']
retrieval_metadatas = res['metadatas']
for query_idx in range(0, len(retrieval_ids)):
id_ls = retrieval_ids[query_idx]
distance_ls = retrieval_distances[query_idx]
sentence_ls = retrieval_sentences[query_idx]
metadata_ls = retrieval_metadatas[query_idx]
for idx in range(0, len(id_ls)):
id = id_ls[idx]
distance = distance_ls[idx]
sentence = sentence_ls[idx]
metadata = metadata_ls[idx]
parsed_res[query_idx].append({
'id': id,
'distance': distance,
'query': sentence,
'metadata': metadata
})
return parsed_res
def chroma_collection_query_retrieval_format(query_list:List[str], query_embeddings:Embeddings ,retrieval_list:List[Mapping[str, Any]]):
res = []
if query_list is not None and query_embeddings is not None:
raise Exception("query_list and query_embeddings are not None")
if query_list is None and query_embeddings is None:
raise Exception("query_list and query_embeddings are None")
if query_list is not None:
for query_idx in range(0, len(query_list)):
query = query_list[query_idx]
retrieval = retrieval_list[query_idx]
res.append({
'query': query,
'retrieval': retrieval
})
else:
for query_idx in range(0, len(query_embeddings)):
query_embedding = query_embeddings[query_idx]
retrieval = retrieval_list[query_idx]
res.append({
'query_embedding': query_embedding,
'retrieval': retrieval
})
return res
def delete_chroma_collection_by_ids(collection:Collection, query_ids:List[str]) -> None:
collection.delete(ids=query_ids)
def get_chroma_collection_by_ids(collection:Collection, query_ids:List[str]):
res = collection.get(ids=query_ids)
return res
def get_chroma_collection_size(collection:Collection) -> int:
return collection.count()