(improvement)(chat) Add support for Chroma metastore and remove additional configurations. (#1199)

This commit is contained in:
lexluo09
2024-06-23 21:53:44 +08:00
committed by GitHub
parent 4d6cbf31f7
commit 6a66db7c0e
9 changed files with 77 additions and 127 deletions

View File

@@ -11,9 +11,6 @@ public class EmbeddingConfig {
@Value("${s2.embedding.url:}")
private String url;
@Value("${s2.embedding.persistent.path:/tmp}")
private String embeddingStorePersistentPath;
@Value("${s2.embedding.recognize.path:/preset_query_retrival}")
private String recognizePath;

View File

@@ -0,0 +1,21 @@
package dev.langchain4j.chroma.spring;
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class ChromaAutoConfig {
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
EmbeddingStoreFactory milvusChatModel(Properties properties) {
return new ChromaEmbeddingStoreFactory(properties);
}
}

View File

@@ -0,0 +1,24 @@
package dev.langchain4j.chroma.spring;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
private Properties properties;
public ChromaEmbeddingStoreFactory(Properties properties) {
this.properties = properties;
}
@Override
public EmbeddingStore create(String collectionName) {
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
return ChromaEmbeddingStore.builder()
.baseUrl(embeddingStore.getBaseUrl())
.collectionName(embeddingStore.getCollectionName())
.timeout(embeddingStore.getTimeout())
.build();
}
}

View File

@@ -0,0 +1,14 @@
package dev.langchain4j.chroma.spring;
import java.time.Duration;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
class EmbeddingStoreProperties {
private String baseUrl;
private String collectionName;
private Duration timeout;
}

View File

@@ -0,0 +1,17 @@
package dev.langchain4j.chroma.spring;
import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@Getter
@Setter
@ConfigurationProperties(prefix = Properties.PREFIX)
public class Properties {
static final String PREFIX = "langchain4j.chroma";
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}

View File

@@ -1,112 +0,0 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/***
* Implementation of calling the Python service S2EmbeddingStore.
*/
@Slf4j
public class PythonServiceEmbeddingService implements EmbeddingService {
private RestTemplate restTemplate = new RestTemplate();
public void addCollection(String collectionName) {
List<String> collections = getCollectionList();
if (collections.contains(collectionName)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/create_collection?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, null, HttpMethod.GET);
}
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
if (CollectionUtils.isEmpty(queries)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/add_query?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, JSONObject.toJSONString(queries, SerializerFeature.WriteMapNullValue), HttpMethod.POST);
}
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
if (CollectionUtils.isEmpty(queries)) {
return;
}
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
List<String> queryIds = queries.stream().map(EmbeddingQuery::getQueryId).collect(Collectors.toList());
String url = String.format("%s/delete_query_by_ids?collection_name=%s",
embeddingConfig.getUrl(), collectionName);
doRequest(url, JSONObject.toJSONString(queryIds), HttpMethod.POST);
}
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = String.format("%s/retrieve_query?collection_name=%s&n_results=%s",
embeddingConfig.getUrl(), collectionName, num);
ResponseEntity<String> responseEntity = doRequest(url, JSONObject.toJSONString(retrieveQuery,
SerializerFeature.WriteMapNullValue), HttpMethod.POST);
if (!responseEntity.hasBody()) {
return Lists.newArrayList();
}
return JSONObject.parseArray(responseEntity.getBody(), RetrieveQueryResult.class);
}
private List<String> getCollectionList() {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String url = embeddingConfig.getUrl() + "/list_collections";
ResponseEntity<String> responseEntity = doRequest(url, null, HttpMethod.GET);
if (!responseEntity.hasBody()) {
return Lists.newArrayList();
}
List<EmbeddingCollection> embeddingCollections = JSONObject.parseArray(responseEntity.getBody(),
EmbeddingCollection.class);
return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList());
}
public ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(headers);
if (jsonBody != null) {
log.info("[embedding] request body :{}", jsonBody);
entity = new HttpEntity<>(jsonBody, headers);
}
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
httpMethod, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] url :{} result body:{}", url, responseEntity);
return responseEntity;
} catch (Throwable e) {
log.warn("doRequest service failed, url:" + url, e);
}
return ResponseEntity.of(Optional.empty());
}
}

View File

@@ -40,8 +40,4 @@ com.tencent.supersonic.chat.postprocessor.PostProcessor=\
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
com.tencent.supersonic.chat.server.processor.execute.DimensionRecommendProcessor,\
com.tencent.supersonic.chat.server.processor.execute.MetricRatioProcessor
com.tencent.supersonic.common.service.EmbeddingService=\
dev.langchain4j.inmemory.spring.InMemoryEmbeddingService
com.tencent.supersonic.chat.server.processor.execute.MetricRatioProcessor

View File

@@ -4,10 +4,6 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.common.service.EmbeddingService=\
dev.langchain4j.inmemory.spring.InMemoryEmbeddingService
com.tencent.supersonic.headless.core.parser.converter.HeadlessConverter=\
com.tencent.supersonic.headless.core.parser.converter.DefaultDimValueConverter,\
com.tencent.supersonic.headless.core.parser.converter.SqlVariableParseConverter,\

View File

@@ -86,9 +86,6 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
### common SPIs
com.tencent.supersonic.common.service.EmbeddingService=\
dev.langchain4j.inmemory.spring.InMemoryEmbeddingService
org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
dev.langchain4j.spring.LangChain4jAutoConfig,\
dev.langchain4j.openai.spring.AutoConfig,\