From abbe8c84a15e4b067f5ea8d949ed251132a294ad Mon Sep 17 00:00:00 2001
From: lexluo09 <39718951+lexluo09@users.noreply.github.com>
Date: Fri, 8 Dec 2023 19:24:58 +0800
Subject: [PATCH] [improvement](python) LLM related services support Java
service invocation (#484)
---
chat/core/pom.xml | 13 +
.../chat/mapper/EmbeddingMatchStrategy.java | 11 +-
.../supersonic/chat/parser/EmbedLLMProxy.java | 84 ++++++
.../function/FunctionCallPromptGenerator.java | 44 +++
.../parser/sql/llm/prompt/InputFormat.java | 42 +++
.../parser/sql/llm/prompt/OutputFormat.java | 53 ++++
.../parser/sql/llm/prompt/SqlExample.java | 32 ++
.../sql/llm/prompt/SqlExampleLoader.java | 76 +++++
.../sql/llm/prompt/SqlPromptGenerator.java | 65 ++++
.../supersonic/chat/plugin/PluginManager.java | 19 +-
.../query/SimilarMetricQueryResponder.java | 14 +-
.../llm/analytics/MetricAnalyzeQuery.java | 24 +-
.../chat/utils/ComponentFactory.java | 4 +-
.../chat/utils/SolvedQueryManager.java | 19 +-
common/pom.xml | 14 +
.../common/config/EmbeddingConfig.java | 3 +-
.../common/util/ComponentFactory.java | 23 ++
.../common/util/embedding/EmbeddingQuery.java | 2 +-
.../embedding/InMemoryS2EmbeddingStore.java | 83 +++++
...Utils.java => PythonS2EmbeddingStore.java} | 13 +-
.../common/util/embedding/Retrieval.java | 2 +-
.../util/embedding/S2EmbeddingStore.java | 19 ++
launchers/common/pom.xml | 10 +
.../java/dev/langchain4j/ModelProvider.java | 8 +
.../S2LangChain4jAutoConfiguration.java | 283 ++++++++++++++++++
.../supersonic/EmbeddingInitListener.java | 43 +++
.../supersonic/StandaloneLauncher.java | 3 +
.../main/resources/META-INF/spring.factories | 7 +-
.../src/main/resources/application-local.yaml | 36 ++-
.../integration/MetricInterpretTest.java | 8 -
.../integration/MockConfiguration.java | 6 -
pom.xml | 52 ++++
.../listener/MetaEmbeddingListener.java | 17 +-
33 files changed, 1037 insertions(+), 95 deletions(-)
create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java
create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/FunctionCallPromptGenerator.java
create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/InputFormat.java
create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/OutputFormat.java
create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExample.java
create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExampleLoader.java
create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlPromptGenerator.java
create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/ComponentFactory.java
create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java
rename common/src/main/java/com/tencent/supersonic/common/util/embedding/{EmbeddingUtils.java => PythonS2EmbeddingStore.java} (98%)
create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/embedding/S2EmbeddingStore.java
create mode 100644 launchers/common/src/main/java/dev/langchain4j/ModelProvider.java
create mode 100644 launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java
create mode 100644 launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java
diff --git a/chat/core/pom.xml b/chat/core/pom.xml
index 24806171f..f169048c2 100644
--- a/chat/core/pom.xml
+++ b/chat/core/pom.xml
@@ -104,6 +104,19 @@
${mockito-inline.version}
test
+
+
+ dev.langchain4j
+ langchain4j-open-ai
+
+
+ dev.langchain4j
+ langchain4j
+
+
+ dev.langchain4j
+ langchain4j-chroma
+
diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java
index 8163af681..739dadeab 100644
--- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java
+++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java
@@ -4,10 +4,11 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
-import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
+import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
+import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
import java.util.Comparator;
@@ -32,8 +33,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy {
@Autowired
private OptimizationConfig optimizationConfig;
- @Autowired
- private EmbeddingUtils embeddingUtils;
+
+ private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Override
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
@@ -83,7 +84,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy {
.queryEmbeddings(null)
.build();
// step2. retrieveQuery by detectSegment
- List retrieveQueryResults = embeddingUtils.retrieveQuery(
+ List retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
@@ -97,7 +98,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy {
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
if (CollectionUtils.isNotEmpty(detectModelIds)) {
retrievals.removeIf(retrieval -> {
- String modelIdStr = retrieval.getMetadata().get("modelId");
+ String modelIdStr = retrieval.getMetadata().get("modelId").toString();
if (StringUtils.isBlank(modelIdStr)) {
return true;
}
diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java
new file mode 100644
index 000000000..bb371c247
--- /dev/null
+++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java
@@ -0,0 +1,84 @@
+package com.tencent.supersonic.chat.parser;
+
+import com.tencent.supersonic.chat.config.OptimizationConfig;
+import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator;
+import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
+import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
+import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
+import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
+import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
+import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
+import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
+import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
+import com.tencent.supersonic.common.util.ContextUtils;
+import com.tencent.supersonic.common.util.JsonUtil;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.input.Prompt;
+import dev.langchain4j.model.input.PromptTemplate;
+import dev.langchain4j.model.output.Response;
+import lombok.extern.slf4j.Slf4j;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+@Slf4j
+public class EmbedLLMProxy implements LLMProxy {
+
+ public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
+
+ ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
+
+ SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
+
+ OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
+
+ List