mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless)Remove LLMProxy abstraction as not needed any more.
This commit is contained in:
@@ -2,7 +2,6 @@ package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@@ -69,7 +68,7 @@ public class SqlValidHelper {
|
||||
try {
|
||||
CCJSqlParserUtil.parse(sql);
|
||||
return true;
|
||||
} catch (JSQLParserException e) {
|
||||
} catch (Exception e) {
|
||||
log.error("isValidSQL parse:{}", e);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.chat.utils;
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
@@ -11,7 +11,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class InputFormat {
|
||||
|
||||
public static final String SEPERATOR = "\n\n";
|
||||
|
||||
public static String format(String template, List<String> templateKey,
|
||||
List<Map<String, String>> toFormatList) {
|
||||
List<String> result = new ArrayList<>();
|
||||
|
||||
for (Map<String, String> formatItem : toFormatList) {
|
||||
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
|
||||
result.add(format(template, retrievalMeta));
|
||||
}
|
||||
|
||||
return String.join(SEPERATOR, result);
|
||||
}
|
||||
|
||||
public static String format(String input, Map<String, String> replacements) {
|
||||
for (Map.Entry<String, String> entry : replacements.entrySet()) {
|
||||
input = input.replace(entry.getKey(), entry.getValue());
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
|
||||
Map<String, String> subDict = new HashMap<>();
|
||||
for (String key : keys) {
|
||||
if (dict.containsKey(key)) {
|
||||
subDict.put(key, dict.get(key));
|
||||
}
|
||||
}
|
||||
return subDict;
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
|
||||
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGenStrategy.generate(llmReq);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* LLMProxy encapsulates functions performed by LLMs so that multiple
|
||||
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
|
||||
* could be used.
|
||||
*/
|
||||
public interface LLMProxy {
|
||||
|
||||
LLMResp text2sql(LLMReq llmReq);
|
||||
|
||||
}
|
||||
@@ -109,7 +109,12 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
public LLMResp runText2SQL(LLMReq llmReq) {
|
||||
return ComponentFactory.getLLMProxy().text2sql(llmReq);
|
||||
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGenStrategy.generate(llmReq);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||
|
||||
@@ -65,14 +65,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
);
|
||||
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
|
||||
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
|
||||
Lists.newArrayList(prompt2Output.values()));
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq));
|
||||
llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq));
|
||||
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
||||
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
||||
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -15,23 +13,7 @@ import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class OutputFormat {
|
||||
|
||||
public static String getSchemaLink(String schemaLink) {
|
||||
String reult = "";
|
||||
try {
|
||||
reult = schemaLink.trim();
|
||||
String pattern = "Schema_links:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
|
||||
Matcher matcher = regexPattern.matcher(reult);
|
||||
if (matcher.find()) {
|
||||
return matcher.group(1).trim();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return reult;
|
||||
}
|
||||
public class ResponseHelper {
|
||||
|
||||
public static String getSql(String sqlOutput) {
|
||||
String sql = "";
|
||||
@@ -49,37 +31,15 @@ public class OutputFormat {
|
||||
return sql;
|
||||
}
|
||||
|
||||
public static String getSchemaLinks(String text) {
|
||||
String schemaLinks = "";
|
||||
try {
|
||||
text = text.trim();
|
||||
String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])";
|
||||
Pattern regexPattern = Pattern.compile(pattern);
|
||||
Matcher matcher = regexPattern.matcher(text);
|
||||
|
||||
if (matcher.find()) {
|
||||
if (matcher.group(1) != null) {
|
||||
schemaLinks = matcher.group(1);
|
||||
} else if (matcher.group(2) != null) {
|
||||
schemaLinks = matcher.group(2);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
|
||||
return schemaLinks;
|
||||
}
|
||||
|
||||
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> inputList) {
|
||||
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> outputList) {
|
||||
Map<String, Integer> inputCounts = new HashMap<>();
|
||||
for (String input : inputList) {
|
||||
for (String input : outputList) {
|
||||
inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1);
|
||||
}
|
||||
|
||||
String inputMax = null;
|
||||
int maxCount = 0;
|
||||
int inputSize = inputList.size();
|
||||
int inputSize = outputList.size();
|
||||
Map<String, Double> votePercentage = new HashMap<>();
|
||||
for (Map.Entry<String, Integer> entry : inputCounts.entrySet()) {
|
||||
String input = entry.getKey();
|
||||
@@ -94,21 +54,6 @@ public class OutputFormat {
|
||||
return Pair.of(inputMax, votePercentage);
|
||||
}
|
||||
|
||||
public static List<String> formatList(List<String> toFormatList) {
|
||||
List<String> results = new ArrayList<>();
|
||||
for (String toFormat : toFormatList) {
|
||||
List<String> items = new ArrayList<>();
|
||||
String[] split = toFormat.replace("[", "").replace("]", "").split(",");
|
||||
for (String item : split) {
|
||||
items.add(item.trim());
|
||||
}
|
||||
Collections.sort(items);
|
||||
String result = "[" + String.join(",", items) + "]";
|
||||
results.add(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
|
||||
Map<String, Double> sqlMap) {
|
||||
if (sqlMap == null) {
|
||||
@@ -2,14 +2,10 @@ package com.tencent.supersonic.headless.chat.utils;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.JavaLLMProxy;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.LLMProxy;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
@@ -18,27 +14,8 @@ import java.util.Objects;
|
||||
@Slf4j
|
||||
public class ComponentFactory {
|
||||
|
||||
private static LLMProxy llmProxy;
|
||||
private static DataSetResolver modelResolver;
|
||||
|
||||
public static LLMProxy getLLMProxy() {
|
||||
//1.Preferentially retrieve from environment variables
|
||||
String llmProxyEnv = System.getenv("llmProxy");
|
||||
if (StringUtils.isNotBlank(llmProxyEnv)) {
|
||||
Map<String, LLMProxy> implementations = ContextUtils.getBeansOfType(LLMProxy.class);
|
||||
llmProxy = implementations.entrySet().stream()
|
||||
.filter(entry -> entry.getKey().equalsIgnoreCase(llmProxyEnv))
|
||||
.map(Map.Entry::getValue)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
}
|
||||
//2.default JavaLLMProxy
|
||||
if (Objects.isNull(llmProxy)) {
|
||||
llmProxy = ContextUtils.getBean(JavaLLMProxy.class);
|
||||
}
|
||||
return llmProxy;
|
||||
}
|
||||
|
||||
public static DataSetResolver getModelResolver() {
|
||||
if (Objects.isNull(modelResolver)) {
|
||||
modelResolver = init(DataSetResolver.class);
|
||||
|
||||
@@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.corrector.S2SqlDateHelper;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
|
||||
Reference in New Issue
Block a user