diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlReplaceHelper.java index 0f1a4aa3e..4d1964652 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlReplaceHelper.java @@ -631,5 +631,30 @@ public class SqlReplaceHelper { } } } + + public static String dealAliasToOrderBy(String querySql) { + Select selectStatement = SqlSelectHelper.getSelect(querySql); + PlainSelect plainSelect = selectStatement.getPlainSelect(); + List> selectItemList = plainSelect.getSelectItems(); + List orderByElementList = plainSelect.getOrderByElements(); + if (CollectionUtils.isEmpty(orderByElementList)) { + return querySql; + } + Map map = new HashMap<>(); + for (int i = 0; i < selectItemList.size(); i++) { + if (!Objects.isNull(selectItemList.get(i).getAlias())) { + map.put(selectItemList.get(i).getAlias().getName(), selectItemList.get(i).getExpression()); + selectItemList.get(i).setAlias(null); + } + } + for (OrderByElement orderByElement : orderByElementList) { + if (map.containsKey(orderByElement.getExpression().toString())) { + orderByElement.setExpression(map.get(orderByElement.getExpression().toString())); + } + } + plainSelect.setOrderByElements(orderByElementList); + return plainSelect.toString(); + } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java index cf01ced2c..14c8b2f4e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlSelectHelper.java @@ -46,6 +46,7 @@ import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SetOperationList; import net.sf.jsqlparser.statement.select.WithItem; +import net.sf.jsqlparser.statement.select.Limit; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -612,6 +613,17 @@ public class SqlSelectHelper { } } + public static Boolean hasLimit(String querySql) { + Select selectStatement = SqlSelectHelper.getSelect(querySql); + PlainSelect plainSelect = selectStatement.getPlainSelect(); + Limit limit = plainSelect.getLimit(); + if (Objects.nonNull(limit)) { + return true; + } else { + return false; + } + } + public static Map> getFieldsWithSubQuery(String sql) { List plainSelects = getPlainSelects(getPlainSelect(sql)); Map> results = new HashMap<>(); diff --git a/headless/core/pom.xml b/headless/core/pom.xml index 2a0006bad..50ef311d1 100644 --- a/headless/core/pom.xml +++ b/headless/core/pom.xml @@ -172,6 +172,12 @@ duckdb_jdbc ${duckdb_jdbc.version} + + com.tencent.supersonic + launchers-common + 0.9.2-SNAPSHOT + compile + diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java index 796e7bd2a..63d5e25af 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/SelectCorrector.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.core.chat.corrector; +import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.core.pojo.QueryContext; @@ -27,5 +28,7 @@ public class SelectCorrector extends BaseSemanticCorrector { return; } addFieldsToSelect(semanticParseInfo, correctS2SQL); + String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/DefaultQueryParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/DefaultQueryParser.java index 2b653201c..4edd255ae 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/DefaultQueryParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/parser/DefaultQueryParser.java @@ -2,10 +2,10 @@ package com.tencent.supersonic.headless.core.parser; import com.google.common.base.Strings; import com.tencent.supersonic.common.util.StringUtil; +import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.MetricTable; import com.tencent.supersonic.headless.api.pojo.QueryParam; import com.tencent.supersonic.headless.api.pojo.enums.AggOption; -import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq; import com.tencent.supersonic.headless.core.parser.converter.HeadlessConverter; import com.tencent.supersonic.headless.core.pojo.DataSetQueryParam; import com.tencent.supersonic.headless.core.pojo.MetricQueryParam; @@ -54,12 +54,10 @@ public class DefaultQueryParser implements QueryParser { || Strings.isNullOrEmpty(queryStatement.getSourceId())) { throw new RuntimeException("parse Exception: " + queryStatement.getErrMsg()); } - String querySql = - Objects.nonNull(queryStatement.getLimit()) && queryStatement.getLimit() > 0 - ? String.format(SqlExecuteReq.LIMIT_WRAPPER, - queryStatement.getSql(), queryStatement.getLimit()) - : queryStatement.getSql(); - queryStatement.setSql(querySql); + if (!SqlSelectHelper.hasLimit(queryStatement.getSql())) { + String querySql = queryStatement.getSql() + " limit " + queryStatement.getLimit().toString(); + queryStatement.setSql(querySql); + } } public QueryStatement parser(DataSetQueryParam dataSetQueryParam, QueryStatement queryStatement) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java index 23b82204c..1d1ed44f9 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/S2ChatModelProvider.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.LLMConfig; import com.tencent.supersonic.common.pojo.enums.S2ModelProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.localai.LocalAiChatModel; -import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.FullOpenAiChatModel; import org.apache.commons.lang3.StringUtils; import java.time.Duration; @@ -18,7 +18,7 @@ public class S2ChatModelProvider { return chatLanguageModel; } if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) { - return OpenAiChatModel + return FullOpenAiChatModel .builder() .baseUrl(llmConfig.getBaseUrl()) .modelName(llmConfig.getModelName()) diff --git a/headless/core/src/test/java/com/tencent/supersonic/chat/core/corrector/TimeCorrectorTest.java b/headless/core/src/test/java/com/tencent/supersonic/chat/core/corrector/TimeCorrectorTest.java index d85f89223..69681c0a5 100644 --- a/headless/core/src/test/java/com/tencent/supersonic/chat/core/corrector/TimeCorrectorTest.java +++ b/headless/core/src/test/java/com/tencent/supersonic/chat/core/corrector/TimeCorrectorTest.java @@ -98,4 +98,4 @@ class TimeCorrectorTest { corrector.doCorrect(queryContext, semanticParseInfo); Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL()); } -} \ No newline at end of file +} diff --git a/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java index 331c01876..403e7b108 100644 --- a/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java +++ b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java @@ -16,7 +16,7 @@ import dev.langchain4j.model.localai.LocalAiChatModel; import dev.langchain4j.model.localai.LocalAiEmbeddingModel; import dev.langchain4j.model.localai.LocalAiLanguageModel; import dev.langchain4j.model.moderation.ModerationModel; -import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.FullOpenAiChatModel; import dev.langchain4j.model.openai.OpenAiEmbeddingModel; import dev.langchain4j.model.openai.OpenAiLanguageModel; import dev.langchain4j.model.openai.OpenAiModerationModel; @@ -53,7 +53,7 @@ public class S2LangChain4jAutoConfiguration { if (openAi == null || isNullOrBlank(openAi.getApiKey())) { throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.openai.api-key' property"); } - return OpenAiChatModel.builder() + return FullOpenAiChatModel.builder() .baseUrl(openAi.getBaseUrl()) .apiKey(openAi.getApiKey()) .modelName(openAi.getModelName()) diff --git a/launchers/common/src/main/java/dev/langchain4j/model/ChatModel.java b/launchers/common/src/main/java/dev/langchain4j/model/ChatModel.java new file mode 100644 index 000000000..8562df8c0 --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/model/ChatModel.java @@ -0,0 +1,30 @@ +package dev.langchain4j.model; + +public enum ChatModel { + ZHIPU("glm"), + ALI("qwen"); + + private final String modelName; + + private ChatModel(String modelName) { + this.modelName = modelName; + } + + public String toString() { + return this.modelName; + } + + public static ChatModel from(String stringValue) { + ChatModel[] var1 = values(); + int var2 = var1.length; + + for (int var3 = 0; var3 < var2; ++var3) { + ChatModel model = var1[var3]; + if (model.modelName.equals(stringValue)) { + return model; + } + } + + throw new IllegalArgumentException("Unknown role: '" + stringValue + "'"); + } +} diff --git a/launchers/common/src/main/java/dev/langchain4j/model/openai/FullOpenAiChatModel.java b/launchers/common/src/main/java/dev/langchain4j/model/openai/FullOpenAiChatModel.java new file mode 100644 index 000000000..933312bc5 --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/model/openai/FullOpenAiChatModel.java @@ -0,0 +1,229 @@ +package dev.langchain4j.model.openai; + +import dev.ai4j.openai4j.OpenAiClient; +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.chat.ChatCompletionRequest.Builder; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.internal.RetryUtils; +import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.ChatModel; +import dev.langchain4j.model.Tokenizer; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.TokenCountEstimator; +import dev.langchain4j.model.output.Response; + +import java.net.Proxy; +import java.time.Duration; +import java.util.Collections; +import java.util.List; + +public class FullOpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { + private final OpenAiClient client; + private final String modelName; + private final Double temperature; + private final Double topP; + private final List stop; + private final Integer maxTokens; + private final Double presencePenalty; + private final Double frequencyPenalty; + private final Integer maxRetries; + private final Tokenizer tokenizer; + + public FullOpenAiChatModel(String baseUrl, String apiKey, String modelName, Double temperature, + Double topP, List stop, Integer maxTokens, Double presencePenalty, + Double frequencyPenalty, Duration timeout, Integer maxRetries, Proxy proxy, + Boolean logRequests, Boolean logResponses, Tokenizer tokenizer) { + baseUrl = (String) Utils.getOrDefault(baseUrl, "https://api.openai.com/v1"); + if ("demo".equals(apiKey)) { + baseUrl = "http://langchain4j.dev/demo/openai/v1"; + } + + timeout = (Duration) Utils.getOrDefault(timeout, Duration.ofSeconds(60L)); + this.client = OpenAiClient.builder().openAiApiKey(apiKey) + .baseUrl(baseUrl).callTimeout(timeout).connectTimeout(timeout) + .readTimeout(timeout).writeTimeout(timeout).proxy(proxy) + .logRequests(logRequests).logResponses(logResponses).build(); + this.modelName = (String) Utils.getOrDefault(modelName, "gpt-3.5-turbo"); + this.temperature = (Double) Utils.getOrDefault(temperature, 0.7D); + this.topP = topP; + this.stop = stop; + this.maxTokens = maxTokens; + this.presencePenalty = presencePenalty; + this.frequencyPenalty = frequencyPenalty; + this.maxRetries = (Integer) Utils.getOrDefault(maxRetries, 3); + this.tokenizer = (Tokenizer) Utils.getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName)); + } + + public Response generate(List messages) { + return this.generate(messages, (List) null, (ToolSpecification) null); + } + + public Response generate(List messages, List toolSpecifications) { + return this.generate(messages, toolSpecifications, (ToolSpecification) null); + } + + public Response generate(List messages, ToolSpecification toolSpecification) { + return this.generate(messages, Collections.singletonList(toolSpecification), toolSpecification); + } + + private Response generate(List messages, + List toolSpecifications, + ToolSpecification toolThatMustBeExecuted) { + Builder requestBuilder = null; + if (modelName.contains(ChatModel.ZHIPU.toString()) || modelName.contains(ChatModel.ALI.toString())) { + requestBuilder = ChatCompletionRequest.builder() + .model(this.modelName) + .messages(ImproveInternalOpenAiHelper.toOpenAiMessages(messages, this.modelName)); + } else { + requestBuilder = ChatCompletionRequest.builder() + .model(this.modelName) + .messages(ImproveInternalOpenAiHelper.toOpenAiMessages(messages, this.modelName)) + .temperature(this.temperature).topP(this.topP).stop(this.stop).maxTokens(this.maxTokens) + .presencePenalty(this.presencePenalty).frequencyPenalty(this.frequencyPenalty); + } + if (toolSpecifications != null && !toolSpecifications.isEmpty()) { + requestBuilder.functions(InternalOpenAiHelper.toFunctions(toolSpecifications)); + } + + if (toolThatMustBeExecuted != null) { + requestBuilder.functionCall(toolThatMustBeExecuted.name()); + } + + ChatCompletionRequest request = requestBuilder.build(); + ChatCompletionResponse response = (ChatCompletionResponse) RetryUtils.withRetry(() -> { + return (ChatCompletionResponse) this.client.chatCompletion(request).execute(); + }, this.maxRetries); + return Response.from(InternalOpenAiHelper.aiMessageFrom(response), + InternalOpenAiHelper.tokenUsageFrom(response.usage()), + InternalOpenAiHelper.finishReasonFrom( + ((ChatCompletionChoice) response.choices().get(0)).finishReason())); + } + + public int estimateTokenCount(List messages) { + return this.tokenizer.estimateTokenCountInMessages(messages); + } + + public static FullOpenAiChatModel withApiKey(String apiKey) { + return builder().apiKey(apiKey).build(); + } + + public static FullOpenAiChatModel.FullOpenAiChatModelBuilder builder() { + return new FullOpenAiChatModel.FullOpenAiChatModelBuilder(); + } + + public static class FullOpenAiChatModelBuilder { + private String baseUrl; + private String apiKey; + private String modelName; + private Double temperature; + private Double topP; + private List stop; + private Integer maxTokens; + private Double presencePenalty; + private Double frequencyPenalty; + private Duration timeout; + private Integer maxRetries; + private Proxy proxy; + private Boolean logRequests; + private Boolean logResponses; + private Tokenizer tokenizer; + + FullOpenAiChatModelBuilder() { + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder topP(Double topP) { + this.topP = topP; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder stop(List stop) { + this.stop = stop; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder presencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder frequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder proxy(Proxy proxy) { + this.proxy = proxy; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder logRequests(Boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder logResponses(Boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public FullOpenAiChatModel.FullOpenAiChatModelBuilder tokenizer(Tokenizer tokenizer) { + this.tokenizer = tokenizer; + return this; + } + + public FullOpenAiChatModel build() { + return new FullOpenAiChatModel(this.baseUrl, this.apiKey, this.modelName, this.temperature, + this.topP, this.stop, this.maxTokens, this.presencePenalty, this.frequencyPenalty, + this.timeout, this.maxRetries, this.proxy, this.logRequests, this.logResponses, this.tokenizer); + } + + public String toString() { + return "FullOpenAiChatModel.FullOpenAiChatModelBuilder(baseUrl=" + this.baseUrl + + ", apiKey=" + this.apiKey + ", modelName=" + this.modelName + ", temperature=" + + this.temperature + ", topP=" + this.topP + ", stop=" + this.stop + ", maxTokens=" + + this.maxTokens + ", presencePenalty=" + this.presencePenalty + ", frequencyPenalty=" + + this.frequencyPenalty + ", timeout=" + this.timeout + ", maxRetries=" + this.maxRetries + + ", proxy=" + this.proxy + ", logRequests=" + this.logRequests + ", logResponses=" + + this.logResponses + ", tokenizer=" + this.tokenizer + ")"; + } + } +} diff --git a/launchers/common/src/main/java/dev/langchain4j/model/openai/ImproveInternalOpenAiHelper.java b/launchers/common/src/main/java/dev/langchain4j/model/openai/ImproveInternalOpenAiHelper.java new file mode 100644 index 000000000..0bdcc27ec --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/model/openai/ImproveInternalOpenAiHelper.java @@ -0,0 +1,111 @@ +package dev.langchain4j.model.openai; + +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.chat.Function; +import dev.ai4j.openai4j.chat.FunctionCall; +import dev.ai4j.openai4j.chat.Message; +import dev.ai4j.openai4j.chat.Parameters; +import dev.ai4j.openai4j.chat.Role; +import dev.ai4j.openai4j.shared.Usage; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolParameters; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.ChatModel; +import dev.langchain4j.model.output.TokenUsage; + +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +public class ImproveInternalOpenAiHelper { + static final String OPENAI_URL = "https://api.openai.com/v1"; + static final String OPENAI_DEMO_API_KEY = "demo"; + static final String OPENAI_DEMO_URL = "http://langchain4j.dev/demo/openai/v1"; + + public ImproveInternalOpenAiHelper() { + } + + public static List toOpenAiMessages(List messages, String modelName) { + List messageList = messages.stream() + .map(message -> toOpenAiMessage(message, modelName)).collect(Collectors.toList()); + return messageList; + } + + public static Message toOpenAiMessage(ChatMessage message, String modelName) { + return Message.builder().role(roleFrom(message, modelName)) + .name(nameFrom(message)).content(message.text()) + .functionCall(functionCallFrom(message)).build(); + } + + private static String nameFrom(ChatMessage message) { + if (message instanceof UserMessage) { + return ((UserMessage) message).name(); + } else { + return message instanceof ToolExecutionResultMessage + ? ((ToolExecutionResultMessage) message).toolName() : null; + } + } + + private static FunctionCall functionCallFrom(ChatMessage message) { + if (message instanceof AiMessage) { + AiMessage aiMessage = (AiMessage) message; + if (aiMessage.toolExecutionRequest() != null) { + return FunctionCall.builder().name(aiMessage.toolExecutionRequest().name()) + .arguments(aiMessage.toolExecutionRequest().arguments()).build(); + } + } + + return null; + } + + public static Role roleFrom(ChatMessage message, String modelName) { + if (modelName.contains(ChatModel.ZHIPU.toString()) || modelName.contains(ChatModel.ALI.toString())) { + return Role.USER; + } + if (message instanceof AiMessage) { + return Role.ASSISTANT; + } else if (message instanceof ToolExecutionResultMessage) { + return Role.FUNCTION; + } else { + return message instanceof SystemMessage ? Role.SYSTEM : Role.USER; + } + } + + public static List toFunctions(Collection toolSpecifications) { + return (List) toolSpecifications.stream().map(ImproveInternalOpenAiHelper::toFunction) + .collect(Collectors.toList()); + } + + private static Function toFunction(ToolSpecification toolSpecification) { + return Function.builder().name(toolSpecification.name()) + .description(toolSpecification.description()) + .parameters(toOpenAiParameters(toolSpecification.parameters())).build(); + } + + private static Parameters toOpenAiParameters(ToolParameters toolParameters) { + return toolParameters == null ? Parameters.builder().build() : Parameters.builder() + .properties(toolParameters.properties()).required(toolParameters.required()).build(); + } + + public static AiMessage aiMessageFrom(ChatCompletionResponse response) { + if (response.content() != null) { + return AiMessage.aiMessage(response.content()); + } else { + FunctionCall functionCall = ((ChatCompletionChoice) response.choices().get(0)).message().functionCall(); + ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .name(functionCall.name()).arguments(functionCall.arguments()).build(); + return AiMessage.aiMessage(toolExecutionRequest); + } + } + + public static TokenUsage tokenUsageFrom(Usage openAiUsage) { + return openAiUsage == null ? null : new TokenUsage(openAiUsage.promptTokens(), + openAiUsage.completionTokens(), openAiUsage.totalTokens()); + } +} diff --git a/launchers/standalone/src/main/resources/supersonic-env.sh b/launchers/standalone/src/main/resources/supersonic-env.sh index c3f0a88df..f6949a28a 100644 --- a/launchers/standalone/src/main/resources/supersonic-env.sh +++ b/launchers/standalone/src/main/resources/supersonic-env.sh @@ -7,4 +7,4 @@ OPENAI_API_BASE=http://langchain4j.dev/demo/openai/v1 OPENAI_API_KEY=demo OPENAI_MODEL_NAME=gpt-3.5-turbo OPENAI_TEMPERATURE=0.0 -OPENAI_TIMEOUT=PT60S \ No newline at end of file +OPENAI_TIMEOUT=PT60S