mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[improvement][chat] Fix the issue with the DatabaseMatchStrategy variable under multi-threading (#1963)
This commit is contained in:
@@ -21,9 +21,10 @@ public class SqlDialectFactory {
|
||||
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
|
||||
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
|
||||
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
|
||||
public static final Context HANADB_CONTEXT = SqlDialect.EMPTY_CONTEXT
|
||||
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
|
||||
.withIdentifierQuoteString("\"").withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
|
||||
public static final Context HANADB_CONTEXT =
|
||||
SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
|
||||
.withLiteralQuoteString("'").withIdentifierQuoteString("\"")
|
||||
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
|
||||
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
|
||||
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;
|
||||
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import net.sf.jsqlparser.expression.Alias;
|
||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import net.sf.jsqlparser.expression.Alias;
|
||||
|
||||
public class FieldAliasReplaceNameVisitor extends SelectItemVisitorAdapter {
|
||||
private Map<String, String> fieldNameMap;
|
||||
|
||||
|
||||
@@ -465,6 +465,7 @@ public class SqlReplaceHelper {
|
||||
}
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
public static String replaceAlias(String sql) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
if (!(selectStatement instanceof PlainSelect)) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import static java.time.Duration.ofSeconds;
|
||||
|
||||
@Service
|
||||
@@ -32,8 +33,8 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
|
||||
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60)).readTimeout(ofSeconds(60))
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
|
||||
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
|
||||
@@ -351,7 +351,8 @@ class SqlReplaceHelperTest {
|
||||
+ "group by 部门 order by 访问次数 desc limit 10";
|
||||
replaceSql = SqlReplaceHelper.replaceAliasFieldName(sql, map);
|
||||
System.out.println(replaceSql);
|
||||
Assert.assertEquals("SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, "
|
||||
Assert.assertEquals(
|
||||
"SELECT 部门, sum(\"访问次数\") AS \"访问次数\" FROM 超音数 WHERE (datediff('day', 数据日期, "
|
||||
+ "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY \"访问次数\" DESC LIMIT 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@@ -27,12 +27,12 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult> {
|
||||
|
||||
private List<SchemaElement> allElements;
|
||||
private ThreadLocal<List<SchemaElement>> allElements = ThreadLocal.withInitial(ArrayList::new);
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext,
|
||||
List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
this.allElements = getSchemaElements(chatQueryContext);
|
||||
allElements.set(getSchemaElements(chatQueryContext));
|
||||
return super.match(chatQueryContext, terms, detectDataSetIds);
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
|
||||
}
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(chatQueryContext);
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements.get());
|
||||
List<DatabaseMapResult> results = new ArrayList<>();
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
|
||||
Reference in New Issue
Block a user