[improvement][chat] Fix the issue with the DatabaseMatchStrategy variable under multi-threading (#1963)

This commit is contained in:
lexluo09
2024-12-19 10:04:17 +08:00
committed by GitHub
parent 91856ddebd
commit 8b69d57c4b
7 changed files with 42 additions and 40 deletions

View File

@@ -81,7 +81,7 @@ public class Configuration {
.setUnquotedCasing(Casing.TO_UPPER).setConformance(sqlDialect.getConformance()) .setUnquotedCasing(Casing.TO_UPPER).setConformance(sqlDialect.getConformance())
.setLex(Lex.BIG_QUERY); .setLex(Lex.BIG_QUERY);
if (EngineType.HANADB.equals(engineType)) { if (EngineType.HANADB.equals(engineType)) {
parserConfig = parserConfig.setQuoting(Quoting.DOUBLE_QUOTE); parserConfig = parserConfig.setQuoting(Quoting.DOUBLE_QUOTE);
} }
parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED); parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED);
parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED); parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED);

View File

@@ -21,10 +21,11 @@ public class SqlDialectFactory {
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'") .withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) .withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false); .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
public static final Context HANADB_CONTEXT = SqlDialect.EMPTY_CONTEXT public static final Context HANADB_CONTEXT =
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'") SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withIdentifierQuoteString("\"").withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) .withLiteralQuoteString("'").withIdentifierQuoteString("\"")
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true); .withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap; private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;
static { static {

View File

@@ -1,15 +1,13 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter; import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import net.sf.jsqlparser.expression.Alias;
public class FieldAliasReplaceNameVisitor extends SelectItemVisitorAdapter { public class FieldAliasReplaceNameVisitor extends SelectItemVisitorAdapter {
private Map<String, String> fieldNameMap; private Map<String, String> fieldNameMap;

View File

@@ -465,6 +465,7 @@ public class SqlReplaceHelper {
} }
return selectStatement.toString(); return selectStatement.toString();
} }
public static String replaceAlias(String sql) { public static String replaceAlias(String sql) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {

View File

@@ -9,6 +9,7 @@ import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel; import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import static java.time.Duration.ofSeconds; import static java.time.Duration.ofSeconds;
@Service @Service
@@ -32,8 +33,8 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl()) return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName()) .apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60)) .maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60)).readTimeout(ofSeconds(60)) .connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
.logRequests(embeddingModelConfig.getLogRequests()) .readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses()).build(); .logResponses(embeddingModelConfig.getLogResponses()).build();
} }

View File

@@ -326,34 +326,35 @@ class SqlReplaceHelperTest {
@Test @Test
void testReplaceAliasFieldName() { void testReplaceAliasFieldName() {
Map<String, String> map = new HashMap<>(); Map<String, String> map = new HashMap<>();
map.put("总访问次数", "\"总访问次数\""); map.put("总访问次数", "\"总访问次数\"");
map.put("访问次数", "\"访问次数\""); map.put("访问次数", "\"访问次数\"");
String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10";
String replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); String replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql); System.out.println(replaceSql);
Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE "
+ "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10",
replaceSql); replaceSql);
sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
+ "group by 部门 order by 总访问次数 desc limit 10"; + "group by 部门 order by 总访问次数 desc limit 10";
replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql); System.out.println(replaceSql);
Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE " Assert.assertEquals("SELECT 部门, sum(访问次数) AS \"总访问次数\" FROM 超音数 WHERE "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' "
+ "GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", replaceSql); + "GROUP BY 部门 ORDER BY \"总访问次数\" DESC LIMIT 10", replaceSql);
sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where " sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where "
+ "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' "
+ "group by 部门 order by 访问次数 desc limit 10"; + "group by 部门 order by 访问次数 desc limit 10";
replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map); replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
System.out.println(replaceSql); System.out.println(replaceSql);
Assert.assertEquals("SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, " Assert.assertEquals(
+ "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10", "SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, "
replaceSql); + "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10",
replaceSql);
} }
@Test @Test

View File

@@ -27,12 +27,12 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult> { public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult> {
private List<SchemaElement> allElements; private ThreadLocal<List<SchemaElement>> allElements = ThreadLocal.withInitial(ArrayList::new);
@Override @Override
public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext, public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext,
List<S2Term> terms, Set<Long> detectDataSetIds) { List<S2Term> terms, Set<Long> detectDataSetIds) {
this.allElements = getSchemaElements(chatQueryContext); allElements.set(getSchemaElements(chatQueryContext));
return super.match(chatQueryContext, terms, detectDataSetIds); return super.match(chatQueryContext, terms, detectDataSetIds);
} }
@@ -43,7 +43,7 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
} }
Double metricDimensionThresholdConfig = getThreshold(chatQueryContext); Double metricDimensionThresholdConfig = getThreshold(chatQueryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements); Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements.get());
List<DatabaseMapResult> results = new ArrayList<>(); List<DatabaseMapResult> results = new ArrayList<>();
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) { for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
String name = entry.getKey(); String name = entry.getKey();