[improvement](python) LLM related services support Java service invocation (#484)

This commit is contained in:
lexluo09
2023-12-08 19:24:58 +08:00
committed by GitHub
parent 6c0f88d8b5
commit abbe8c84a1
33 changed files with 1037 additions and 95 deletions

View File

@@ -43,6 +43,16 @@
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<!--langchain4j-->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,8 @@
package dev.langchain4j;
enum ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_MEMORY
}

View File

@@ -0,0 +1,283 @@
package dev.langchain4j;
import static dev.langchain4j.ModelProvider.OPEN_AI;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
import dev.langchain4j.model.localai.LocalAiLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiLanguageModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Lazy;
import org.springframework.context.annotation.Primary;
@Configuration
@EnableConfigurationProperties(LangChain4jProperties.class)
public class S2LangChain4jAutoConfiguration {
@Autowired
private LangChain4jProperties properties;
@Bean
@Lazy
@ConditionalOnMissingBean
ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
if (properties.getChatModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
+ "langchain4j.chat-model.provider = openai\n"
+ "langchain4j.chat-model.openai.api-key = sk-...\n");
}
switch (properties.getChatModel().getProvider()) {
case OPEN_AI:
OpenAi openAi = properties.getChatModel().getOpenAi();
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.openai.api-key' property");
}
return OpenAiChatModel.builder()
.baseUrl(openAi.getBaseUrl())
.apiKey(openAi.getApiKey())
.modelName(openAi.getModelName())
.temperature(openAi.getTemperature())
.topP(openAi.getTopP())
.maxTokens(openAi.getMaxTokens())
.presencePenalty(openAi.getPresencePenalty())
.frequencyPenalty(openAi.getFrequencyPenalty())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
.logRequests(openAi.getLogRequests())
.logResponses(openAi.getLogResponses())
.build();
case HUGGING_FACE:
HuggingFace huggingFace = properties.getChatModel().getHuggingFace();
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.chat-model.huggingface.access-token' property");
}
return HuggingFaceChatModel.builder()
.accessToken(huggingFace.getAccessToken())
.modelId(huggingFace.getModelId())
.timeout(huggingFace.getTimeout())
.temperature(huggingFace.getTemperature())
.maxNewTokens(huggingFace.getMaxNewTokens())
.returnFullText(huggingFace.getReturnFullText())
.waitForModel(huggingFace.getWaitForModel())
.build();
case LOCAL_AI:
LocalAi localAi = properties.getChatModel().getLocalAi();
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.localai.base-url' property");
}
if (isNullOrBlank(localAi.getModelName())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.chat-model.localai.model-name' property");
}
return LocalAiChatModel.builder()
.baseUrl(localAi.getBaseUrl())
.modelName(localAi.getModelName())
.temperature(localAi.getTemperature())
.topP(localAi.getTopP())
.maxTokens(localAi.getMaxTokens())
.timeout(localAi.getTimeout())
.maxRetries(localAi.getMaxRetries())
.logRequests(localAi.getLogRequests())
.logResponses(localAi.getLogResponses())
.build();
default:
throw illegalConfiguration("Unsupported chat model provider: %s",
properties.getChatModel().getProvider());
}
}
@Bean
@Lazy
@ConditionalOnMissingBean
LanguageModel languageModel(LangChain4jProperties properties) {
if (properties.getLanguageModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
+ "langchain4j.language-model.provider = openai\n"
+ "langchain4j.language-model.openai.api-key = sk-...\n");
}
switch (properties.getLanguageModel().getProvider()) {
case OPEN_AI:
OpenAi openAi = properties.getLanguageModel().getOpenAi();
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.language-model.openai.api-key' property");
}
return OpenAiLanguageModel.builder()
.apiKey(openAi.getApiKey())
.modelName(openAi.getModelName())
.temperature(openAi.getTemperature())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
.logRequests(openAi.getLogRequests())
.logResponses(openAi.getLogResponses())
.build();
case HUGGING_FACE:
HuggingFace huggingFace = properties.getLanguageModel().getHuggingFace();
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.language-model.huggingface.access-token' property");
}
return HuggingFaceLanguageModel.builder()
.accessToken(huggingFace.getAccessToken())
.modelId(huggingFace.getModelId())
.timeout(huggingFace.getTimeout())
.temperature(huggingFace.getTemperature())
.maxNewTokens(huggingFace.getMaxNewTokens())
.returnFullText(huggingFace.getReturnFullText())
.waitForModel(huggingFace.getWaitForModel())
.build();
case LOCAL_AI:
LocalAi localAi = properties.getLanguageModel().getLocalAi();
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.language-model.localai.base-url' property");
}
if (isNullOrBlank(localAi.getModelName())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.language-model.localai.model-name' property");
}
return LocalAiLanguageModel.builder()
.baseUrl(localAi.getBaseUrl())
.modelName(localAi.getModelName())
.temperature(localAi.getTemperature())
.topP(localAi.getTopP())
.maxTokens(localAi.getMaxTokens())
.timeout(localAi.getTimeout())
.maxRetries(localAi.getMaxRetries())
.logRequests(localAi.getLogRequests())
.logResponses(localAi.getLogResponses())
.build();
default:
throw illegalConfiguration("Unsupported language model provider: %s",
properties.getLanguageModel().getProvider());
}
}
@Bean
@Lazy
@ConditionalOnMissingBean
@Primary
EmbeddingModel embeddingModel(LangChain4jProperties properties) {
if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n"
+ "langchain4j.embedding-model.provider = openai\n"
+ "langchain4j.embedding-model.openai.api-key = sk-...\n");
}
switch (properties.getEmbeddingModel().getProvider()) {
case OPEN_AI:
OpenAi openAi = properties.getEmbeddingModel().getOpenAi();
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.embedding-model.openai.api-key' property");
}
return OpenAiEmbeddingModel.builder()
.apiKey(openAi.getApiKey())
.modelName(openAi.getModelName())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
.logRequests(openAi.getLogRequests())
.logResponses(openAi.getLogResponses())
.build();
case HUGGING_FACE:
HuggingFace huggingFace = properties.getEmbeddingModel().getHuggingFace();
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.embedding-model.huggingface.access-token' property");
}
return HuggingFaceEmbeddingModel.builder()
.accessToken(huggingFace.getAccessToken())
.modelId(huggingFace.getModelId())
.waitForModel(huggingFace.getWaitForModel())
.timeout(huggingFace.getTimeout())
.build();
case LOCAL_AI:
LocalAi localAi = properties.getEmbeddingModel().getLocalAi();
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.embedding-model.localai.base-url' property");
}
if (isNullOrBlank(localAi.getModelName())) {
throw illegalConfiguration(
"\n\nPlease define 'langchain4j.embedding-model.localai.model-name' property");
}
return LocalAiEmbeddingModel.builder()
.baseUrl(localAi.getBaseUrl())
.modelName(localAi.getModelName())
.timeout(localAi.getTimeout())
.maxRetries(localAi.getMaxRetries())
.logRequests(localAi.getLogRequests())
.logResponses(localAi.getLogResponses())
.build();
case IN_MEMORY:
return new AllMiniLmL6V2EmbeddingModel();
default:
throw illegalConfiguration("Unsupported embedding model provider: %s",
properties.getEmbeddingModel().getProvider());
}
}
@Bean
@Lazy
@ConditionalOnMissingBean
ModerationModel moderationModel(LangChain4jProperties properties) {
if (properties.getModerationModel() == null) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
+ "langchain4j.moderation-model.provider = openai\n"
+ "langchain4j.moderation-model.openai.api-key = sk-...\n");
}
if (properties.getModerationModel().getProvider() != OPEN_AI) {
throw illegalConfiguration("Unsupported moderation model provider: %s",
properties.getModerationModel().getProvider());
}
OpenAi openAi = properties.getModerationModel().getOpenAi();
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model.openai.api-key' property");
}
return OpenAiModerationModel.builder()
.apiKey(openAi.getApiKey())
.modelName(openAi.getModelName())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
.logRequests(openAi.getLogRequests())
.logResponses(openAi.getLogResponses())
.build();
}
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExample;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
import com.tencent.supersonic.chat.parser.EmbedLLMProxy;
import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
@Slf4j
@Component
@Order(4)
public class EmbeddingInitListener implements CommandLineRunner {
protected LLMProxy llmProxy = ComponentFactory.getLLMProxy();
@Autowired
private SqlExampleLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Override
public void run(String... args) {
initSqlExamples();
}
public void initSqlExamples() {
try {
if (llmProxy instanceof EmbedLLMProxy) {
List<SqlExample> sqlExamples = sqlExampleLoader.getSqlExamples();
String collectionName = optimizationConfig.getText2sqlCollectionName();
sqlExampleLoader.addEmbeddingStore(sqlExamples, collectionName);
}
} catch (Exception e) {
log.error("initSqlExamples error", e);
}
}
}

View File

@@ -1,9 +1,11 @@
package com.tencent.supersonic;
import dev.langchain4j.S2LangChain4jAutoConfiguration;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
import org.springframework.context.annotation.Import;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@@ -11,6 +13,7 @@ import org.springframework.scheduling.annotation.EnableScheduling;
exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class})
@EnableScheduling
@EnableAsync
@Import(S2LangChain4jAutoConfiguration.class)
public class StandaloneLauncher {
public static void main(String[] args) {

View File

@@ -31,7 +31,7 @@ com.tencent.supersonic.chat.processor.ParseResultProcessor=\
com.tencent.supersonic.chat.processor.RespBuildProcessor
com.tencent.supersonic.chat.parser.LLMProxy=\
com.tencent.supersonic.chat.parser.PythonLLMProxy
com.tencent.supersonic.chat.parser.EmbedLLMProxy
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
@@ -46,4 +46,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.query.QueryResponder=\
com.tencent.supersonic.chat.query.SimilarMetricQueryResponder
com.tencent.supersonic.chat.query.SimilarMetricQueryResponder
com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\
com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore

View File

@@ -36,8 +36,40 @@ mybatis:
llm:
parser:
url: http://127.0.0.1:9092
url:
embedding:
url: http://127.0.0.1:9092
functionCall:
url: http://127.0.0.1:9092
url: http://127.0.0.1:9092
#langchain4j config
langchain4j:
#1.chat-model
chat-model:
provider: open_ai
openai:
api-key: api_key
model-name: gpt-3.5-turbo
temperature: 0.0
timeout: PT60S
#2.embedding-model
embedding-model:
provider: in_memory
# embedding-model:
# hugging-face:
# access-token: hg_access_token
# model-id: sentence-transformers/all-MiniLM-L6-v2
# timeout: 1h
# embedding-model:
# provider: open_ai
# openai:
# api-key: api_key
# modelName: all-minilm-l6-v2.onnx
#langchain4j log
logging:
level:
dev.langchain4j: DEBUG
dev.ai4j.openai4j: DEBUG

View File

@@ -1,16 +1,13 @@
package com.tencent.supersonic.integration;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@@ -21,13 +18,9 @@ public class MetricInterpretTest {
@MockBean
private AgentService agentService;
@MockBean
private EmbeddingConfig embeddingConfig;
@MockBean
private EmbeddingUtils embeddingUtils;
@Test
public void testMetricInterpret() throws Exception {
MockConfiguration.mockAgent(agentService);
@@ -36,7 +29,6 @@ public class MetricInterpretTest {
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
}
}

View File

@@ -1,20 +1,17 @@
package com.tencent.supersonic.integration;
import static org.mockito.ArgumentMatchers.anyObject;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity;
@Configuration
@Slf4j
@@ -38,7 +35,4 @@ public class MockConfiguration {
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
}
public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity<String> responseEntity) {
when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity);
}
}