(improvement)(headless)Remove LLMProxy abstraction as not needed any more.

This commit is contained in:
jerryjzhang
2024-07-05 10:09:23 +08:00
parent 93ea7a618c
commit 14b9086d83
11 changed files with 15 additions and 174 deletions

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.common.jsqlparser;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
@@ -69,7 +68,7 @@ public class SqlValidHelper {
try {
CCJSqlParserUtil.parse(sql);
return true;
} catch (JSQLParserException e) {
} catch (Exception e) {
log.error("isValidSQL parse:{}", e);
return false;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.chat.utils;
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;

View File

@@ -11,7 +11,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;

View File

@@ -1,43 +0,0 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Slf4j
public class InputFormat {
public static final String SEPERATOR = "\n\n";
public static String format(String template, List<String> templateKey,
List<Map<String, String>> toFormatList) {
List<String> result = new ArrayList<>();
for (Map<String, String> formatItem : toFormatList) {
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
result.add(format(template, retrievalMeta));
}
return String.join(SEPERATOR, result);
}
public static String format(String input, Map<String, String> replacements) {
for (Map.Entry<String, String> entry : replacements.entrySet()) {
input = input.replace(entry.getKey(), entry.getValue());
}
return input;
}
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
Map<String, String> subDict = new HashMap<>();
for (String key : keys) {
if (dict.containsKey(key)) {
subDict.put(key, dict.get(key));
}
}
return subDict;
}
}

View File

@@ -1,26 +0,0 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
/**
* LLMProxy based on langchain4j Java version.
*/
@Slf4j
@Component
public class JavaLLMProxy implements LLMProxy {
public LLMResp text2sql(LLMReq llmReq) {
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
String modelName = llmReq.getSchema().getDataSetName();
LLMResp result = sqlGenStrategy.generate(llmReq);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
return result;
}
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
/**
* LLMProxy encapsulates functions performed by LLMs so that multiple
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
* could be used.
*/
public interface LLMProxy {
LLMResp text2sql(LLMReq llmReq);
}

View File

@@ -109,7 +109,12 @@ public class LLMRequestService {
}
public LLMResp runText2SQL(LLMReq llmReq) {
return ComponentFactory.getLLMProxy().text2sql(llmReq);
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
String modelName = llmReq.getSchema().getDataSetName();
LLMResp result = sqlGenStrategy.generate(llmReq);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
return result;
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,

View File

@@ -65,14 +65,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
);
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
Lists.newArrayList(prompt2Output.values()));
LLMResp llmResp = new LLMResp();
llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq));
llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq));
llmResp.setSqlOutput(sqlMapPair.getLeft());
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
return llmResp;
}

View File

@@ -5,8 +5,6 @@ import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -15,23 +13,7 @@ import java.util.regex.Pattern;
import java.util.stream.Collectors;
@Slf4j
public class OutputFormat {
public static String getSchemaLink(String schemaLink) {
String reult = "";
try {
reult = schemaLink.trim();
String pattern = "Schema_links:(.*)";
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher matcher = regexPattern.matcher(reult);
if (matcher.find()) {
return matcher.group(1).trim();
}
} catch (Exception e) {
log.error("", e);
}
return reult;
}
public class ResponseHelper {
public static String getSql(String sqlOutput) {
String sql = "";
@@ -49,37 +31,15 @@ public class OutputFormat {
return sql;
}
public static String getSchemaLinks(String text) {
String schemaLinks = "";
try {
text = text.trim();
String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])";
Pattern regexPattern = Pattern.compile(pattern);
Matcher matcher = regexPattern.matcher(text);
if (matcher.find()) {
if (matcher.group(1) != null) {
schemaLinks = matcher.group(1);
} else if (matcher.group(2) != null) {
schemaLinks = matcher.group(2);
}
}
} catch (Exception e) {
log.error("", e);
}
return schemaLinks;
}
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> inputList) {
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> outputList) {
Map<String, Integer> inputCounts = new HashMap<>();
for (String input : inputList) {
for (String input : outputList) {
inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1);
}
String inputMax = null;
int maxCount = 0;
int inputSize = inputList.size();
int inputSize = outputList.size();
Map<String, Double> votePercentage = new HashMap<>();
for (Map.Entry<String, Integer> entry : inputCounts.entrySet()) {
String input = entry.getKey();
@@ -94,21 +54,6 @@ public class OutputFormat {
return Pair.of(inputMax, votePercentage);
}
public static List<String> formatList(List<String> toFormatList) {
List<String> results = new ArrayList<>();
for (String toFormat : toFormatList) {
List<String> items = new ArrayList<>();
String[] split = toFormat.replace("[", "").replace("]", "").split(",");
for (String item : split) {
items.add(item.trim());
}
Collections.sort(items);
String result = "[" + String.join(",", items) + "]";
results.add(result);
}
return results;
}
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
Map<String, Double> sqlMap) {
if (sqlMap == null) {

View File

@@ -2,14 +2,10 @@ package com.tencent.supersonic.headless.chat.utils;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
import com.tencent.supersonic.headless.chat.parser.llm.JavaLLMProxy;
import com.tencent.supersonic.headless.chat.parser.llm.LLMProxy;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
@@ -18,27 +14,8 @@ import java.util.Objects;
@Slf4j
public class ComponentFactory {
private static LLMProxy llmProxy;
private static DataSetResolver modelResolver;
public static LLMProxy getLLMProxy() {
//1.Preferentially retrieve from environment variables
String llmProxyEnv = System.getenv("llmProxy");
if (StringUtils.isNotBlank(llmProxyEnv)) {
Map<String, LLMProxy> implementations = ContextUtils.getBeansOfType(LLMProxy.class);
llmProxy = implementations.entrySet().stream()
.filter(entry -> entry.getKey().equalsIgnoreCase(llmProxyEnv))
.map(Map.Entry::getValue)
.findFirst()
.orElse(null);
}
//2.default JavaLLMProxy
if (Objects.isNull(llmProxy)) {
llmProxy = ContextUtils.getBean(JavaLLMProxy.class);
}
return llmProxy;
}
public static DataSetResolver getModelResolver() {
if (Objects.isNull(modelResolver)) {
modelResolver = init(DataSetResolver.class);

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.corrector.S2SqlDateHelper;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.jupiter.api.Disabled;