(improvement)(Headless) supports glm-4、glm-3-turbo、qwen (#1058)

This commit is contained in:
mainmain
2024-05-30 21:51:35 +08:00
committed by GitHub
parent b4bc92e586
commit 4e6c076481
12 changed files with 427 additions and 13 deletions

View File

@@ -172,6 +172,12 @@
<artifactId>duckdb_jdbc</artifactId>
<version>${duckdb_jdbc.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>launchers-common</artifactId>
<version>0.9.2-SNAPSHOT</version>
<scope>compile</scope>
</dependency>
</dependencies>

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.core.chat.corrector;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
@@ -27,5 +28,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
return;
}
addFieldsToSelect(semanticParseInfo, correctS2SQL);
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
}
}

View File

@@ -2,10 +2,10 @@ package com.tencent.supersonic.headless.core.parser;
import com.google.common.base.Strings;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.MetricTable;
import com.tencent.supersonic.headless.api.pojo.QueryParam;
import com.tencent.supersonic.headless.api.pojo.enums.AggOption;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
import com.tencent.supersonic.headless.core.parser.converter.HeadlessConverter;
import com.tencent.supersonic.headless.core.pojo.DataSetQueryParam;
import com.tencent.supersonic.headless.core.pojo.MetricQueryParam;
@@ -54,12 +54,10 @@ public class DefaultQueryParser implements QueryParser {
|| Strings.isNullOrEmpty(queryStatement.getSourceId())) {
throw new RuntimeException("parse Exception: " + queryStatement.getErrMsg());
}
String querySql =
Objects.nonNull(queryStatement.getLimit()) && queryStatement.getLimit() > 0
? String.format(SqlExecuteReq.LIMIT_WRAPPER,
queryStatement.getSql(), queryStatement.getLimit())
: queryStatement.getSql();
queryStatement.setSql(querySql);
if (!SqlSelectHelper.hasLimit(queryStatement.getSql())) {
String querySql = queryStatement.getSql() + " limit " + queryStatement.getLimit().toString();
queryStatement.setSql(querySql);
}
}
public QueryStatement parser(DataSetQueryParam dataSetQueryParam, QueryStatement queryStatement) {

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.FullOpenAiChatModel;
import org.apache.commons.lang3.StringUtils;
import java.time.Duration;
@@ -18,7 +18,7 @@ public class S2ChatModelProvider {
return chatLanguageModel;
}
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return OpenAiChatModel
return FullOpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())

View File

@@ -98,4 +98,4 @@ class TimeCorrectorTest {
corrector.doCorrect(queryContext, semanticParseInfo);
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
}
}
}