diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java index 9ac38858a..a32979a54 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelper.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.common.jsqlparser; import java.util.List; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import org.apache.commons.collections.CollectionUtils; @@ -69,7 +68,7 @@ public class SqlValidHelper { try { CCJSqlParserUtil.parse(sql); return true; - } catch (JSQLParserException e) { + } catch (Exception e) { log.error("isValidSQL parse:{}", e); return false; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java similarity index 98% rename from headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelper.java rename to headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java index 0e990f31c..fa3bc2a9f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.headless.chat.utils; +package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeMode; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index 8adfa408d..230aeb0a3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -11,7 +11,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.QueryContext; -import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/InputFormat.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/InputFormat.java deleted file mode 100644 index 17155e481..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/InputFormat.java +++ /dev/null @@ -1,43 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - -import lombok.extern.slf4j.Slf4j; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -@Slf4j -public class InputFormat { - - public static final String SEPERATOR = "\n\n"; - - public static String format(String template, List templateKey, - List> toFormatList) { - List result = new ArrayList<>(); - - for (Map formatItem : toFormatList) { - Map retrievalMeta = subDict(formatItem, templateKey); - result.add(format(template, retrievalMeta)); - } - - return String.join(SEPERATOR, result); - } - - public static String format(String input, Map replacements) { - for (Map.Entry entry : replacements.entrySet()) { - input = input.replace(entry.getKey(), entry.getValue()); - } - return input; - } - - private static Map subDict(Map dict, List keys) { - Map subDict = new HashMap<>(); - for (String key : keys) { - if (dict.containsKey(key)) { - subDict.put(key, dict.get(key)); - } - } - return subDict; - } -} \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/JavaLLMProxy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/JavaLLMProxy.java deleted file mode 100644 index e6d1fe3f0..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/JavaLLMProxy.java +++ /dev/null @@ -1,26 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - - -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; -import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Component; - -/** - * LLMProxy based on langchain4j Java version. - */ -@Slf4j -@Component -public class JavaLLMProxy implements LLMProxy { - - public LLMResp text2sql(LLMReq llmReq) { - - SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType()); - String modelName = llmReq.getSchema().getDataSetName(); - LLMResp result = sqlGenStrategy.generate(llmReq); - result.setQuery(llmReq.getQueryText()); - result.setModelName(modelName); - return result; - } - -} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMProxy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMProxy.java deleted file mode 100644 index 3f7537d92..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMProxy.java +++ /dev/null @@ -1,16 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - - -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; - -/** - * LLMProxy encapsulates functions performed by LLMs so that multiple - * orchestration frameworks (e.g. LangChain in python, LangChain4j in java) - * could be used. - */ -public interface LLMProxy { - - LLMResp text2sql(LLMReq llmReq); - -} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 3252fe16e..565e5ca21 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -109,7 +109,12 @@ public class LLMRequestService { } public LLMResp runText2SQL(LLMReq llmReq) { - return ComponentFactory.getLLMProxy().text2sql(llmReq); + SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType()); + String modelName = llmReq.getSchema().getDataSetName(); + LLMResp result = sqlGenStrategy.generate(llmReq); + result.setQuery(llmReq.getQueryText()); + result.setModelName(modelName); + return result; } protected List getFieldNameList(QueryContext queryCtx, Long dataSetId, diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index f8166d15b..94d889389 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -65,14 +65,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { ); //4.format response. - Pair> sqlMapPair = OutputFormat.selfConsistencyVote( + Pair> sqlMapPair = ResponseHelper.selfConsistencyVote( Lists.newArrayList(prompt2Output.values())); LLMResp llmResp = new LLMResp(); llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq)); llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq)); llmResp.setSqlOutput(sqlMapPair.getLeft()); //TODO: should use the same few-shot exemplars as the one chose by self-consistency vote - llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight())); + llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight())); return llmResp; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OutputFormat.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java similarity index 52% rename from headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OutputFormat.java rename to headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java index 41c00390e..0b6d18655 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OutputFormat.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java @@ -5,8 +5,6 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -15,23 +13,7 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; @Slf4j -public class OutputFormat { - - public static String getSchemaLink(String schemaLink) { - String reult = ""; - try { - reult = schemaLink.trim(); - String pattern = "Schema_links:(.*)"; - Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL); - Matcher matcher = regexPattern.matcher(reult); - if (matcher.find()) { - return matcher.group(1).trim(); - } - } catch (Exception e) { - log.error("", e); - } - return reult; - } +public class ResponseHelper { public static String getSql(String sqlOutput) { String sql = ""; @@ -49,37 +31,15 @@ public class OutputFormat { return sql; } - public static String getSchemaLinks(String text) { - String schemaLinks = ""; - try { - text = text.trim(); - String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])"; - Pattern regexPattern = Pattern.compile(pattern); - Matcher matcher = regexPattern.matcher(text); - - if (matcher.find()) { - if (matcher.group(1) != null) { - schemaLinks = matcher.group(1); - } else if (matcher.group(2) != null) { - schemaLinks = matcher.group(2); - } - } - } catch (Exception e) { - log.error("", e); - } - - return schemaLinks; - } - - public static Pair> selfConsistencyVote(List inputList) { + public static Pair> selfConsistencyVote(List outputList) { Map inputCounts = new HashMap<>(); - for (String input : inputList) { + for (String input : outputList) { inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1); } String inputMax = null; int maxCount = 0; - int inputSize = inputList.size(); + int inputSize = outputList.size(); Map votePercentage = new HashMap<>(); for (Map.Entry entry : inputCounts.entrySet()) { String input = entry.getKey(); @@ -94,21 +54,6 @@ public class OutputFormat { return Pair.of(inputMax, votePercentage); } - public static List formatList(List toFormatList) { - List results = new ArrayList<>(); - for (String toFormat : toFormatList) { - List items = new ArrayList<>(); - String[] split = toFormat.replace("[", "").replace("]", "").split(","); - for (String item : split) { - items.add(item.trim()); - } - Collections.sort(items); - String result = "[" + String.join(",", items) + "]"; - results.add(result); - } - return results; - } - public static Map buildSqlRespMap(List sqlExamples, Map sqlMap) { if (sqlMap == null) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java index c8c9d04c5..1004e5d94 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java @@ -2,14 +2,10 @@ package com.tencent.supersonic.headless.chat.utils; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver; -import com.tencent.supersonic.headless.chat.parser.llm.JavaLLMProxy; -import com.tencent.supersonic.headless.chat.parser.llm.LLMProxy; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; import org.springframework.core.io.support.SpringFactoriesLoader; import java.util.List; -import java.util.Map; import java.util.Objects; /** @@ -18,27 +14,8 @@ import java.util.Objects; @Slf4j public class ComponentFactory { - private static LLMProxy llmProxy; private static DataSetResolver modelResolver; - public static LLMProxy getLLMProxy() { - //1.Preferentially retrieve from environment variables - String llmProxyEnv = System.getenv("llmProxy"); - if (StringUtils.isNotBlank(llmProxyEnv)) { - Map implementations = ContextUtils.getBeansOfType(LLMProxy.class); - llmProxy = implementations.entrySet().stream() - .filter(entry -> entry.getKey().equalsIgnoreCase(llmProxyEnv)) - .map(Map.Entry::getValue) - .findFirst() - .orElse(null); - } - //2.default JavaLLMProxy - if (Objects.isNull(llmProxy)) { - llmProxy = ContextUtils.getBean(JavaLLMProxy.class); - } - return llmProxy; - } - public static DataSetResolver getModelResolver() { if (Objects.isNull(modelResolver)) { modelResolver = init(DataSetResolver.class); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java index c1d51d577..fc344c511 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java @@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.corrector.S2SqlDateHelper; import org.apache.commons.lang3.tuple.Pair; import org.junit.Assert; import org.junit.jupiter.api.Disabled;