(improvement)[build] Use Spotless to customize the code formatting (#1750)

This commit is contained in:
lexluo09
2024-10-04 00:05:04 +08:00
committed by GitHub
parent 44d1cde34f
commit 71a9954be5
521 changed files with 7811 additions and 13046 deletions

View File

@@ -24,13 +24,12 @@ public class LoadRemoveService {
}
List<String> resultList = new ArrayList<>(value);
if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) {
resultList.removeIf(
nature -> {
if (Objects.isNull(nature)) {
return false;
}
return nature.startsWith(mapperRemoveNaturePrefix);
});
resultList.removeIf(nature -> {
if (Objects.isNull(nature)) {
return false;
}
return nature.startsWith(mapperRemoveNaturePrefix);
});
}
return resultList;
}

View File

@@ -253,19 +253,8 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
@Override
public String toString() {
return "BaseNode{"
+ "child="
+ Arrays.toString(child)
+ ", status="
+ status
+ ", c="
+ c
+ ", value="
+ value
+ ", prefix='"
+ prefix
+ '\''
+ '}';
return "BaseNode{" + "child=" + Arrays.toString(child) + ", status=" + status + ", c=" + c
+ ", value=" + value + ", prefix='" + prefix + '\'' + '}';
}
public void walkNode(Set<Map.Entry<String, V>> entrySet) {

View File

@@ -34,13 +34,8 @@ public class CoreDictionary {
if (!load(PATH)) {
throw new IllegalArgumentException("核心词典" + PATH + "加载失败");
} else {
Predefine.logger.info(
PATH
+ "加载成功,"
+ trie.size()
+ "个词条,耗时"
+ (System.currentTimeMillis() - start)
+ "ms");
Predefine.logger.info(PATH + "加载成功," + trie.size() + "个词条,耗时"
+ (System.currentTimeMillis() - start) + "ms");
}
}
@@ -77,22 +72,14 @@ public class CoreDictionary {
map.put(param[0], attribute);
totalFrequency += attribute.totalFrequency;
}
Predefine.logger.info(
"核心词典读入词条"
+ map.size()
+ " 全部频次"
+ totalFrequency
+ ",耗时"
+ (System.currentTimeMillis() - start)
+ "ms");
Predefine.logger.info("核心词典读入词条" + map.size() + " 全部频次" + totalFrequency + ",耗时"
+ (System.currentTimeMillis() - start) + "ms");
br.close();
trie.build(map);
Predefine.logger.info("核心词典加载成功:" + trie.size() + "个词条,下面将写入缓存……");
try {
DataOutputStream out =
new DataOutputStream(
new BufferedOutputStream(
IOUtil.newOutputStream(path + Predefine.BIN_EXT)));
DataOutputStream out = new DataOutputStream(
new BufferedOutputStream(IOUtil.newOutputStream(path + Predefine.BIN_EXT)));
Collection<Attribute> attributeList = map.values();
out.writeInt(attributeList.size());
for (Attribute attribute : attributeList) {
@@ -278,11 +265,8 @@ public class CoreDictionary {
}
return attribute;
} catch (Exception e) {
Predefine.logger.warning(
"使用字符串"
+ natureWithFrequency
+ "创建词条属性失败!"
+ TextUtility.exceptionToString(e));
Predefine.logger.warning("使用字符串" + natureWithFrequency + "创建词条属性失败!"
+ TextUtility.exceptionToString(e));
return null;
}
}
@@ -409,9 +393,7 @@ public class CoreDictionary {
if (originals == null || originals.length == 0) {
return null;
}
return Arrays.stream(originals)
.filter(o -> o != null)
.distinct()
return Arrays.stream(originals).filter(o -> o != null).distinct()
.collect(Collectors.toList());
}
}

View File

@@ -47,8 +47,7 @@ public abstract class WordBasedSegment extends Segment {
}
vertex = (Vertex) var1.next();
} while (!vertex.realWord.equals("")
&& !vertex.realWord.equals("")
} while (!vertex.realWord.equals("") && !vertex.realWord.equals("")
&& !vertex.realWord.equals("-"));
vertex.confirmNature(Nature.w);
@@ -66,8 +65,7 @@ public abstract class WordBasedSegment extends Segment {
if (currentNature == Nature.nx
&& (next.hasNature(Nature.q) || next.hasNature(Nature.n))) {
String[] param = current.realWord.split("-", 1);
if (param.length == 2
&& TextUtility.isAllNum(param[0])
if (param.length == 2 && TextUtility.isAllNum(param[0])
&& TextUtility.isAllNum(param[1])) {
current = current.copy();
current.realWord = param[0];
@@ -112,10 +110,8 @@ public abstract class WordBasedSegment extends Segment {
current.confirmNature(Nature.m, true);
} else if (current.realWord.length() > 1) {
char last = current.realWord.charAt(current.realWord.length() - 1);
current =
Vertex.newNumberInstance(
current.realWord.substring(
0, current.realWord.length() - 1));
current = Vertex.newNumberInstance(
current.realWord.substring(0, current.realWord.length() - 1));
listIterator.previous();
listIterator.previous();
listIterator.set(current);
@@ -162,9 +158,7 @@ public abstract class WordBasedSegment extends Segment {
charTypeArray[i] = CharType.get(c);
if (c == '.' && i < charArray.length - 1 && CharType.get(charArray[i + 1]) == 9) {
charTypeArray[i] = 9;
} else if (c == '.'
&& i < charArray.length - 1
&& charArray[i + 1] >= '0'
} else if (c == '.' && i < charArray.length - 1 && charArray[i + 1] >= '0'
&& charArray[i + 1] <= '9') {
charTypeArray[i] = 5;
} else if (charTypeArray[i] == 8) {
@@ -227,7 +221,7 @@ public abstract class WordBasedSegment extends Segment {
while (listIterator.hasNext()) {
next = (Vertex) listIterator.next();
if (!TextUtility.isAllNum(current.realWord)
&& !TextUtility.isAllChineseNum(current.realWord)
&& !TextUtility.isAllChineseNum(current.realWord)
|| !TextUtility.isAllNum(next.realWord)
&& !TextUtility.isAllChineseNum(next.realWord)) {
current = next;
@@ -252,21 +246,16 @@ public abstract class WordBasedSegment extends Segment {
DoubleArrayTrie.Searcher searcher = CoreDictionary.trie.getSearcher(charArray, 0);
while (searcher.next()) {
wordNetStorage.add(
searcher.begin + 1,
new Vertex(
new String(charArray, searcher.begin, searcher.length),
(CoreDictionary.Attribute) searcher.value,
searcher.index));
wordNetStorage.add(searcher.begin + 1,
new Vertex(new String(charArray, searcher.begin, searcher.length),
(CoreDictionary.Attribute) searcher.value, searcher.index));
}
if (this.config.forceCustomDictionary) {
this.customDictionary.parseText(
charArray,
this.customDictionary.parseText(charArray,
new AhoCorasickDoubleArrayTrie.IHit<CoreDictionary.Attribute>() {
public void hit(int begin, int end, CoreDictionary.Attribute value) {
wordNetStorage.add(
begin + 1,
wordNetStorage.add(begin + 1,
new Vertex(new String(charArray, begin, end - begin), value));
}
});
@@ -279,11 +268,9 @@ public abstract class WordBasedSegment extends Segment {
while (i < vertexes.length) {
if (vertexes[i].isEmpty()) {
int j;
for (j = i + 1;
j < vertexes.length - 1
&& (vertexes[j].isEmpty()
|| CharType.get(charArray[j - 1]) == 11);
++j) {}
for (j = i + 1; j < vertexes.length - 1 && (vertexes[j].isEmpty()
|| CharType.get(charArray[j - 1]) == 11); ++j) {
}
wordNetStorage.add(i, Segment.quickAtomSegment(charArray, i - 1, j - 1));
i = j;
@@ -310,10 +297,8 @@ public abstract class WordBasedSegment extends Segment {
addTerms(termList, vertex, line - 1);
termMain.offset = line - 1;
if (vertex.realWord.length() > 2) {
label43:
for (int currentLine = line;
currentLine < line + vertex.realWord.length();
++currentLine) {
label43: for (int currentLine = line; currentLine < line
+ vertex.realWord.length(); ++currentLine) {
Iterator iterator = wordNetAll.descendingIterator(currentLine);
while (true) {
@@ -327,8 +312,8 @@ public abstract class WordBasedSegment extends Segment {
&& smallVertex.realWord.length() < this.config.indexMode);
if (smallVertex != vertex
&& currentLine + smallVertex.realWord.length()
<= line + vertex.realWord.length()) {
&& currentLine + smallVertex.realWord.length() <= line
+ vertex.realWord.length()) {
listIterator.add(smallVertex);
// Term termSub = convert(smallVertex);
// termSub.offset = currentLine - 1;
@@ -346,8 +331,8 @@ public abstract class WordBasedSegment extends Segment {
}
protected static void speechTagging(List<Vertex> vertexList) {
Viterbi.compute(
vertexList, CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary);
Viterbi.compute(vertexList,
CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary);
}
protected void addTerms(List<Term> terms, Vertex vertex, int offset) {

View File

@@ -42,19 +42,13 @@ public class Term {
}
// todo opt
/*
String wordOri = word.toLowerCase();
CoreDictionary.Attribute attribute = getDynamicCustomDictionary().get(wordOri);
if (attribute == null) {
attribute = CoreDictionary.get(wordOri);
if (attribute == null) {
attribute = CustomDictionary.get(wordOri);
}
}
if (attribute != null && nature != null && attribute.hasNature(nature)) {
return attribute.getNatureFrequency(nature);
}
return attribute == null ? 0 : attribute.totalFrequency;
*/
* String wordOri = word.toLowerCase(); CoreDictionary.Attribute attribute =
* getDynamicCustomDictionary().get(wordOri); if (attribute == null) { attribute =
* CoreDictionary.get(wordOri); if (attribute == null) { attribute =
* CustomDictionary.get(wordOri); } } if (attribute != null && nature != null &&
* attribute.hasNature(nature)) { return attribute.getNatureFrequency(nature); } return
* attribute == null ? 0 : attribute.totalFrequency;
*/
return 0;
}

View File

@@ -51,19 +51,18 @@ public class Configuration {
public static SqlValidator.Config getValidatorConfig(EngineType engineType) {
SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType);
return SqlValidator.Config.DEFAULT
.withConformance(sqlDialect.getConformance())
return SqlValidator.Config.DEFAULT.withConformance(sqlDialect.getConformance())
.withDefaultNullCollation(config.defaultNullCollation())
.withLenientOperatorLookup(true);
}
static {
configProperties.put(
CalciteConnectionProperty.CASE_SENSITIVE.camelName(), Boolean.TRUE.toString());
configProperties.put(
CalciteConnectionProperty.UNQUOTED_CASING.camelName(), Casing.UNCHANGED.toString());
configProperties.put(
CalciteConnectionProperty.QUOTED_CASING.camelName(), Casing.TO_LOWER.toString());
configProperties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
Boolean.TRUE.toString());
configProperties.put(CalciteConnectionProperty.UNQUOTED_CASING.camelName(),
Casing.UNCHANGED.toString());
configProperties.put(CalciteConnectionProperty.QUOTED_CASING.camelName(),
Casing.TO_LOWER.toString());
}
public static SqlParser.Config getParserConfig(EngineType engineType) {
@@ -76,15 +75,10 @@ public class Configuration {
parserConfig.setQuotedCasing(config.quotedCasing());
parserConfig.setConformance(config.conformance());
parserConfig.setLex(Lex.BIG_QUERY);
parserConfig
.setParserFactory(SqlParserImpl.FACTORY)
.setCaseSensitive(false)
.setIdentifierMaxLength(Integer.MAX_VALUE)
.setQuoting(Quoting.BACK_TICK)
.setQuoting(Quoting.SINGLE_QUOTE)
.setQuotedCasing(Casing.TO_UPPER)
.setUnquotedCasing(Casing.TO_UPPER)
.setConformance(sqlDialect.getConformance())
parserConfig.setParserFactory(SqlParserImpl.FACTORY).setCaseSensitive(false)
.setIdentifierMaxLength(Integer.MAX_VALUE).setQuoting(Quoting.BACK_TICK)
.setQuoting(Quoting.SINGLE_QUOTE).setQuotedCasing(Casing.TO_UPPER)
.setUnquotedCasing(Casing.TO_UPPER).setConformance(sqlDialect.getConformance())
.setLex(Lex.BIG_QUERY);
parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED);
parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED);
@@ -96,61 +90,39 @@ public class Configuration {
tables.add(SqlStdOperatorTable.instance());
SqlOperatorTable operatorTable = new ChainedSqlOperatorTable(tables);
// operatorTable.
Prepare.CatalogReader catalogReader =
new CalciteCatalogReader(
rootSchema,
Collections.singletonList(rootSchema.getName()),
typeFactory,
config);
return SqlValidatorUtil.newValidator(
operatorTable,
catalogReader,
typeFactory,
Prepare.CatalogReader catalogReader = new CalciteCatalogReader(rootSchema,
Collections.singletonList(rootSchema.getName()), typeFactory, config);
return SqlValidatorUtil.newValidator(operatorTable, catalogReader, typeFactory,
Configuration.getValidatorConfig(engineType));
}
public static SqlValidatorWithHints getSqlValidatorWithHints(
CalciteSchema rootSchema, EngineType engineTyp) {
return new SqlAdvisorValidator(
SqlStdOperatorTable.instance(),
new CalciteCatalogReader(
rootSchema,
Collections.singletonList(rootSchema.getName()),
typeFactory,
config),
typeFactory,
SqlValidator.Config.DEFAULT);
public static SqlValidatorWithHints getSqlValidatorWithHints(CalciteSchema rootSchema,
EngineType engineTyp) {
return new SqlAdvisorValidator(SqlStdOperatorTable.instance(),
new CalciteCatalogReader(rootSchema,
Collections.singletonList(rootSchema.getName()), typeFactory, config),
typeFactory, SqlValidator.Config.DEFAULT);
}
public static SqlToRelConverter.Config getConverterConfig() {
HintStrategyTable strategies = HintStrategyTable.builder().build();
return SqlToRelConverter.config()
.withHintStrategyTable(strategies)
.withTrimUnusedFields(true)
.withExpand(true)
return SqlToRelConverter.config().withHintStrategyTable(strategies)
.withTrimUnusedFields(true).withExpand(true)
.addRelBuilderConfigTransform(c -> c.withSimplify(false));
}
public static SqlToRelConverter getSqlToRelConverter(
SqlValidatorScope scope,
SqlValidator sqlValidator,
RelOptPlanner relOptPlanner,
EngineType engineType) {
public static SqlToRelConverter getSqlToRelConverter(SqlValidatorScope scope,
SqlValidator sqlValidator, RelOptPlanner relOptPlanner, EngineType engineType) {
RexBuilder rexBuilder = new RexBuilder(typeFactory);
RelOptCluster cluster = RelOptCluster.create(relOptPlanner, rexBuilder);
FrameworkConfig fromworkConfig =
Frameworks.newConfigBuilder()
.parserConfig(getParserConfig(engineType))
Frameworks.newConfigBuilder().parserConfig(getParserConfig(engineType))
.defaultSchema(
scope.getValidator().getCatalogReader().getRootSchema().plus())
.build();
return new SqlToRelConverter(
new ViewExpanderImpl(),
sqlValidator,
(CatalogReader) scope.getValidator().getCatalogReader(),
cluster,
fromworkConfig.getConvertletTable(),
getConverterConfig());
return new SqlToRelConverter(new ViewExpanderImpl(), sqlValidator,
(CatalogReader) scope.getValidator().getCatalogReader(), cluster,
fromworkConfig.getConvertletTable(), getConverterConfig());
}
public static SqlAdvisor getSqlAdvisor(SqlValidatorWithHints validator, EngineType engineType) {
@@ -159,15 +131,10 @@ public class Configuration {
public static SqlWriterConfig getSqlWriterConfig(EngineType engineType) {
SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType);
SqlWriterConfig config =
SqlPrettyWriter.config()
.withDialect(sqlDialect)
.withKeywordsLowerCase(false)
.withClauseEndsLine(true)
.withAlwaysUseParentheses(false)
.withSelectListItemsOnSeparateLines(false)
.withUpdateSetListNewline(false)
.withIndentation(0);
SqlWriterConfig config = SqlPrettyWriter.config().withDialect(sqlDialect)
.withKeywordsLowerCase(false).withClauseEndsLine(true)
.withAlwaysUseParentheses(false).withSelectListItemsOnSeparateLines(false)
.withUpdateSetListNewline(false).withIndentation(0);
if (EngineType.MYSQL.equals(engineType)) {
// no backticks around function name
config = config.withQuoteAllIdentifiers(false);

View File

@@ -17,8 +17,8 @@ public class SemanticSqlDialect extends SqlDialect {
super(context);
}
public static void unparseFetchUsingAnsi(
SqlWriter writer, @Nullable SqlNode offset, @Nullable SqlNode fetch) {
public static void unparseFetchUsingAnsi(SqlWriter writer, @Nullable SqlNode offset,
@Nullable SqlNode fetch) {
Preconditions.checkArgument(fetch != null || offset != null);
SqlWriter.Frame fetchFrame;
writer.newlineAndIndent();
@@ -74,11 +74,11 @@ public class SemanticSqlDialect extends SqlDialect {
return true;
}
public void unparseSqlIntervalLiteral(
SqlWriter writer, SqlIntervalLiteral literal, int leftPrec, int rightPrec) {}
public void unparseSqlIntervalLiteral(SqlWriter writer, SqlIntervalLiteral literal,
int leftPrec, int rightPrec) {}
public void unparseOffsetFetch(
SqlWriter writer, @Nullable SqlNode offset, @Nullable SqlNode fetch) {
public void unparseOffsetFetch(SqlWriter writer, @Nullable SqlNode offset,
@Nullable SqlNode fetch) {
unparseFetchUsingAnsi(writer, offset, fetch);
}
}

View File

@@ -13,22 +13,14 @@ import java.util.Objects;
public class SqlDialectFactory {
public static final Context DEFAULT_CONTEXT =
SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withLiteralQuoteString("'")
.withLiteralEscapedQuoteString("''")
.withIdentifierQuoteString("`")
.withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED)
.withCaseSensitive(false);
public static final Context POSTGRESQL_CONTEXT =
SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withLiteralQuoteString("'")
.withLiteralEscapedQuoteString("''")
.withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED)
.withCaseSensitive(false);
SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withLiteralQuoteString("'").withLiteralEscapedQuoteString("''")
.withIdentifierQuoteString("`").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
public static final Context POSTGRESQL_CONTEXT = SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;
static {

View File

@@ -20,12 +20,8 @@ import java.util.List;
@Slf4j
public class SqlMergeWithUtils {
public static String mergeWith(
EngineType engineType,
String sql,
List<String> parentSqlList,
List<String> parentWithNameList)
throws SqlParseException {
public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
List<String> parentWithNameList) throws SqlParseException {
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
// Parse the main SQL statement
@@ -45,14 +41,12 @@ public class SqlMergeWithUtils {
SqlNode sqlNode2 = parser.parseQuery();
// Create a new WITH item for parentWithName without quotes
SqlWithItem withItem =
new SqlWithItem(
SqlParserPos.ZERO,
new SqlIdentifier(
parentWithName, SqlParserPos.ZERO), // false to avoid quotes
null,
sqlNode2,
SqlLiteral.createBoolean(false, SqlParserPos.ZERO));
SqlWithItem withItem = new SqlWithItem(SqlParserPos.ZERO,
new SqlIdentifier(parentWithName, SqlParserPos.ZERO), // false
// to
// avoid
// quotes
null, sqlNode2, SqlLiteral.createBoolean(false, SqlParserPos.ZERO));
// Add the new WITH item to the list
withItemList.add(withItem);
@@ -66,11 +60,8 @@ public class SqlMergeWithUtils {
}
// Create a new SqlWith node
SqlWith finalSqlNode =
new SqlWith(
SqlParserPos.ZERO,
new SqlNodeList(withItemList, SqlParserPos.ZERO),
sqlNode1);
SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
// Custom SqlPrettyWriter configuration to avoid quoting identifiers
SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType);
// Pretty print the final SQL

View File

@@ -45,10 +45,8 @@ public class SqlParseUtils {
sqlParserInfo.setAllFields(
sqlParserInfo.getAllFields().stream().distinct().collect(Collectors.toList()));
sqlParserInfo.setSelectFields(
sqlParserInfo.getSelectFields().stream()
.distinct()
.collect(Collectors.toList()));
sqlParserInfo.setSelectFields(sqlParserInfo.getSelectFields().stream().distinct()
.collect(Collectors.toList()));
return sqlParserInfo;
} catch (SqlParseException e) {
@@ -108,13 +106,10 @@ public class SqlParseUtils {
SqlSelect sqlSelect = (SqlSelect) select;
SqlNodeList selectList = sqlSelect.getSelectList();
selectList
.getList()
.forEach(
list -> {
Set<String> selectFields = handlerField(list);
sqlParserInfo.getSelectFields().addAll(selectFields);
});
selectList.getList().forEach(list -> {
Set<String> selectFields = handlerField(list);
sqlParserInfo.getSelectFields().addAll(selectFields);
});
String tableName = handlerFrom(sqlSelect.getFrom());
sqlParserInfo.setTableName(tableName);
@@ -129,14 +124,10 @@ public class SqlParseUtils {
results.addAll(formFields);
}
sqlSelect
.getSelectList()
.getList()
.forEach(
list -> {
Set<String> selectFields = handlerField(list);
results.addAll(selectFields);
});
sqlSelect.getSelectList().getList().forEach(list -> {
Set<String> selectFields = handlerField(list);
results.addAll(selectFields);
});
if (sqlSelect.hasWhere()) {
Set<String> whereFields = handlerField(sqlSelect.getWhere());
@@ -148,11 +139,10 @@ public class SqlParseUtils {
}
SqlNodeList group = sqlSelect.getGroup();
if (group != null) {
group.forEach(
groupField -> {
Set<String> groupByFields = handlerField(groupField);
results.addAll(groupByFields);
});
group.forEach(groupField -> {
Set<String> groupByFields = handlerField(groupField);
results.addAll(groupByFields);
});
}
return results;
}
@@ -213,12 +203,9 @@ public class SqlParseUtils {
}
}
if (field instanceof SqlNodeList) {
((SqlNodeList) field)
.getList()
.forEach(
node -> {
fields.addAll(handlerField(node));
});
((SqlNodeList) field).getList().forEach(node -> {
fields.addAll(handlerField(node));
});
}
break;
}
@@ -243,12 +230,9 @@ public class SqlParseUtils {
SqlIdentifier sqlIdentifier = (SqlIdentifier) operandList.get(0);
String simple = sqlIdentifier.getSimple();
SqlBasicCall aliasedNode =
new SqlBasicCall(
SqlStdOperatorTable.AS,
new SqlNode[] {
sqlBasicCall,
new SqlIdentifier(simple.toLowerCase(), SqlParserPos.ZERO)
},
new SqlBasicCall(SqlStdOperatorTable.AS,
new SqlNode[] {sqlBasicCall, new SqlIdentifier(
simple.toLowerCase(), SqlParserPos.ZERO)},
SqlParserPos.ZERO);
selectList.set(selectList.indexOf(node), aliasedNode);
}

View File

@@ -11,10 +11,7 @@ public class ViewExpanderImpl implements RelOptTable.ViewExpander {
public ViewExpanderImpl() {}
@Override
public RelRoot expandView(
RelDataType rowType,
String queryString,
List<String> schemaPath,
public RelRoot expandView(RelDataType rowType, String queryString, List<String> schemaPath,
List<String> dataSetPath) {
return null;
}

View File

@@ -20,98 +20,37 @@ import java.util.List;
@Slf4j
public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER =
new Parameter(
"s2.chat.model.provider",
OpenAiModelFactory.PROVIDER,
"接口协议",
"",
"list",
"对话模型配置",
getCandidateValues());
public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider",
OpenAiModelFactory.PROVIDER, "接口协议", "", "list", "对话模型配置", getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter(
"s2.chat.model.base.url",
OpenAiModelFactory.DEFAULT_BASE_URL,
"BaseUrl",
"",
"string",
"对话模型配置",
null,
getBaseUrlDependency());
public static final Parameter CHAT_MODEL_ENDPOINT =
new Parameter(
"s2.chat.model.endpoint",
"llama_2_70b",
"Endpoint",
"",
"string",
"对话模型配置",
null,
getEndpointDependency());
public static final Parameter CHAT_MODEL_API_KEY =
new Parameter(
"s2.chat.model.api.key",
DEMO,
"ApiKey",
"",
"password",
"对话模型配置",
null,
getApiKeyDependency());
public static final Parameter CHAT_MODEL_SECRET_KEY =
new Parameter(
"s2.chat.model.secretKey",
"demo",
"SecretKey",
"",
"password",
"对话模型配置",
null,
getSecretKeyDependency());
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL, "BaseUrl",
"", "string", "对话模型配置", null, getBaseUrlDependency());
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("s2.chat.model.endpoint",
"llama_2_70b", "Endpoint", "", "string", "对话模型配置", null, getEndpointDependency());
public static final Parameter CHAT_MODEL_API_KEY = new Parameter("s2.chat.model.api.key", DEMO,
"ApiKey", "", "password", "对话模型配置", null, getApiKeyDependency());
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("s2.chat.model.secretKey",
"demo", "SecretKey", "", "password", "对话模型配置", null, getSecretKeyDependency());
public static final Parameter CHAT_MODEL_NAME =
new Parameter(
"s2.chat.model.name",
"gpt-4o-mini",
"ModelName",
"",
"string",
"对话模型配置",
null,
getModelNameDependency());
public static final Parameter CHAT_MODEL_NAME = new Parameter("s2.chat.model.name",
"gpt-4o-mini", "ModelName", "", "string", "对话模型配置", null, getModelNameDependency());
public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
new Parameter(
"s2.chat.model.enableSearch",
"false",
"是否启用搜索增强功能设为false表示不启用",
"",
"bool",
"对话模型配置",
null,
getEnableSearchDependency());
new Parameter("s2.chat.model.enableSearch", "false", "是否启用搜索增强功能设为false表示不启用", "",
"bool", "对话模型配置", null, getEnableSearchDependency());
public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter(
"s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置");
public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter(
"s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置");
public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60", "超时时间(秒)", "", "number", "对话模型配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
CHAT_MODEL_PROVIDER,
CHAT_MODEL_BASE_URL,
CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY,
CHAT_MODEL_SECRET_KEY,
CHAT_MODEL_NAME,
CHAT_MODEL_ENABLE_SEARCH,
CHAT_MODEL_TEMPERATURE,
CHAT_MODEL_TIMEOUT);
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
}
public ChatModelConfig convert() {
@@ -125,36 +64,24 @@ public class ChatModelParameterConfig extends ParameterConfig {
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
return ChatModelConfig.builder()
.provider(chatModelProvider)
.baseUrl(chatModelBaseUrl)
.apiKey(chatModelApiKey)
.modelName(chatModelName)
return ChatModelConfig.builder().provider(chatModelProvider).baseUrl(chatModelBaseUrl)
.apiKey(chatModelApiKey).modelName(chatModelName)
.enableSearch(Boolean.valueOf(enableSearch))
.temperature(Double.valueOf(chatModelTemperature))
.timeOut(Long.valueOf(chatModelTimeout))
.endpoint(endpoint)
.secretKey(secretKey)
.timeOut(Long.valueOf(chatModelTimeout)).endpoint(endpoint).secretKey(secretKey)
.build();
}
private static List<String> getCandidateValues() {
return Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(
CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
@@ -164,30 +91,18 @@ public class ChatModelParameterConfig extends ParameterConfig {
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(
CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, DEMO,
QianfanModelFactory.PROVIDER, DEMO,
ZhipuModelFactory.PROVIDER, DEMO,
LocalAiModelFactory.PROVIDER, DEMO,
AzureModelFactory.PROVIDER, DEMO,
DashscopeModelFactory.PROVIDER, DEMO));
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER,
DEMO, ZhipuModelFactory.PROVIDER, DEMO, LocalAiModelFactory.PROVIDER, DEMO,
AzureModelFactory.PROVIDER, DEMO, DashscopeModelFactory.PROVIDER, DEMO));
}
private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency(
CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
@@ -197,23 +112,19 @@ public class ChatModelParameterConfig extends ParameterConfig {
}
private static List<Parameter.Dependency> getEndpointDependency() {
return getDependency(
CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
}
private static List<Parameter.Dependency> getEnableSearchDependency() {
return getDependency(
CHAT_MODEL_PROVIDER.getName(),
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
}
private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(
CHAT_MODEL_PROVIDER.getName(),
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
}

View File

@@ -22,89 +22,35 @@ import java.util.List;
@Slf4j
public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter(
"s2.embedding.model.provider",
InMemoryModelFactory.PROVIDER,
"接口协议",
"",
"list",
"向量模型配置",
getCandidateValues());
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "",
"list", "向量模型配置", getCandidateValues());
public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter(
"s2.embedding.model.base.url",
"",
"BaseUrl",
"",
"string",
"向量模型配置",
null,
getBaseUrlDependency());
new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", "向量模型配置",
null, getBaseUrlDependency());
public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter(
"s2.embedding.model.api.key",
"",
"ApiKey",
"",
"password",
"向量模型配置",
null,
getApiKeyDependency());
new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", "向量模型配置",
null, getApiKeyDependency());
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
new Parameter(
"s2.embedding.model.secretKey",
"demo",
"SecretKey",
"",
"password",
"向量模型配置",
null,
getSecretKeyDependency());
new Parameter("s2.embedding.model.secretKey", "demo", "SecretKey", "", "password",
"向量模型配置", null, getSecretKeyDependency());
public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter(
"s2.embedding.model.name",
EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName",
"",
"string",
"向量模型配置",
null,
getModelNameDependency());
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName", "", "string", "向量模型配置", null, getModelNameDependency());
public static final Parameter EMBEDDING_MODEL_PATH =
new Parameter(
"s2.embedding.model.path",
"",
"模型路径",
"",
"string",
"向量模型配置",
null,
getModelPathDependency());
public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path",
"", "模型路径", "", "string", "向量模型配置", null, getModelPathDependency());
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter(
"s2.embedding.model.vocabulary.path",
"",
"词汇表路径",
"",
"string",
"向量模型配置",
null,
getModelPathDependency());
new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string", "向量模型配置",
null, getModelPathDependency());
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
EMBEDDING_MODEL_PROVIDER,
EMBEDDING_MODEL_BASE_URL,
EMBEDDING_MODEL_API_KEY,
EMBEDDING_MODEL_SECRET_KEY,
EMBEDDING_MODEL_NAME,
EMBEDDING_MODEL_PATH,
EMBEDDING_MODEL_VOCABULARY_PATH);
return Lists.newArrayList(EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL,
EMBEDDING_MODEL_API_KEY, EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME,
EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH);
}
public EmbeddingModelConfig convert() {
@@ -115,40 +61,24 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY);
return EmbeddingModelConfig.builder()
.provider(provider)
.baseUrl(baseUrl)
.apiKey(apiKey)
.secretKey(secretKey)
.modelName(modelName)
.modelPath(modelPath)
.vocabularyPath(vocabularyPath)
.build();
return EmbeddingModelConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey)
.secretKey(secretKey).modelName(modelName).modelPath(modelPath)
.vocabularyPath(vocabularyPath).build();
}
private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(
EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
@@ -157,63 +87,43 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(
EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER,
DEMO,
AzureModelFactory.PROVIDER,
DEMO,
DashscopeModelFactory.PROVIDER,
DEMO,
QianfanModelFactory.PROVIDER,
DEMO,
ZhipuModelFactory.PROVIDER,
DEMO));
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, AzureModelFactory.PROVIDER, DEMO,
DashscopeModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, DEMO,
ZhipuModelFactory.PROVIDER, DEMO));
}
private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency(
EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER,
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
OllamaModelFactory.PROVIDER,
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.PROVIDER,
AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
DashscopeModelFactory.PROVIDER,
DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
QianfanModelFactory.PROVIDER,
QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
ZhipuModelFactory.PROVIDER,
ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
}
private static List<Parameter.Dependency> getModelPathDependency() {
return getDependency(
EMBEDDING_MODEL_PROVIDER.getName(),
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(InMemoryModelFactory.PROVIDER),
ImmutableMap.of(InMemoryModelFactory.PROVIDER, ""));
}
private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(
EMBEDDING_MODEL_PROVIDER.getName(),
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
}

View File

@@ -15,83 +15,38 @@ import java.util.List;
@Service("EmbeddingStoreParameterConfig")
@Slf4j
public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER =
new Parameter(
"s2.embedding.store.provider",
EmbeddingStoreType.IN_MEMORY.name(),
"向量库类型",
"目前支持三种类型IN_MEMORY、MILVUS、CHROMA",
"list",
"向量库配置",
getCandidateValues());
public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
"s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
"目前支持三种类型IN_MEMORY、MILVUS、CHROMA", "list", "向量库配置", getCandidateValues());
public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter(
"s2.embedding.store.base.url",
"",
"BaseUrl",
"",
"string",
"向量库配置",
null,
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", "向量库配置", null,
getBaseUrlDependency());
public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter(
"s2.embedding.store.api.key",
"",
"ApiKey",
"",
"password",
"向量库配置",
null,
new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", "向量库配置", null,
getApiKeyDependency());
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter(
"s2.embedding.store.persist.path",
"",
"持久化路径",
"默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径",
"string",
"向量库配置",
null,
getPathDependency());
new Parameter("s2.embedding.store.persist.path", "", "持久化路径",
"默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string",
"向量库配置", null, getPathDependency());
public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", "向量库配置");
public static final Parameter EMBEDDING_STORE_DIMENSION =
new Parameter(
"s2.embedding.store.dimension",
"",
"纬度",
"",
"number",
"向量库配置",
null,
new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", "向量库配置", null,
getDimensionDependency());
public static final Parameter EMBEDDING_STORE_DATABASE_NAME =
new Parameter(
"s2.embedding.store.databaseName",
"",
"DatabaseName",
"",
"string",
"向量库配置",
null,
getDatabaseNameDependency());
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
"向量库配置", null, getDatabaseNameDependency());
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
EMBEDDING_STORE_PROVIDER,
EMBEDDING_STORE_BASE_URL,
EMBEDDING_STORE_API_KEY,
EMBEDDING_STORE_DATABASE_NAME,
EMBEDDING_STORE_PERSIST_PATH,
EMBEDDING_STORE_TIMEOUT,
EMBEDDING_STORE_DIMENSION);
return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL,
EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME,
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
}
public EmbeddingStoreConfig convert() {
@@ -105,58 +60,44 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
}
return EmbeddingStoreConfig.builder()
.provider(provider)
.baseUrl(baseUrl)
.apiKey(apiKey)
.persistPath(persistPath)
.databaseName(databaseName)
.timeOut(Long.valueOf(timeOut))
.dimension(dimension)
.build();
return EmbeddingStoreConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey)
.persistPath(persistPath).databaseName(databaseName).timeOut(Long.valueOf(timeOut))
.dimension(dimension).build();
}
private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(
EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name());
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name());
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(
EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"));
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(
EMBEDDING_STORE_PROVIDER.getName(),
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO));
}
private static List<Parameter.Dependency> getPathDependency() {
return getDependency(
EMBEDDING_STORE_PROVIDER.getName(),
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()),
ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), ""));
}
private static List<Parameter.Dependency> getDimensionDependency() {
return getDependency(
EMBEDDING_STORE_PROVIDER.getName(),
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384"));
}
private static List<Parameter.Dependency> getDatabaseNameDependency() {
return getDependency(
EMBEDDING_STORE_PROVIDER.getName(),
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), ""));
}

View File

@@ -15,9 +15,11 @@ import java.util.Map;
@Service
public abstract class ParameterConfig {
public static final String DEMO = "demo";
@Autowired private SystemConfigService sysConfigService;
@Autowired
private SystemConfigService sysConfigService;
@Autowired private Environment environment;
@Autowired
private Environment environment;
/** @return system parameters to be set with user interface */
protected List<Parameter> getSysParameters() {
@@ -46,10 +48,8 @@ public abstract class ParameterConfig {
return value;
}
protected static List<Parameter.Dependency> getDependency(
String dependencyParameterName,
List<String> includesValue,
Map<String, String> setDefaultValue) {
protected static List<Parameter.Dependency> getDependency(String dependencyParameterName,
List<String> includesValue, Map<String, String> setDefaultValue) {
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
show.setIncludesValue(includesValue);

View File

@@ -38,11 +38,8 @@ public class SystemConfig {
if (StringUtils.isBlank(name)) {
return "";
}
Map<String, String> nameToValue =
getParameters().stream()
.collect(
Collectors.toMap(
Parameter::getName, Parameter::getValue, (k1, k2) -> k1));
Map<String, String> nameToValue = getParameters().stream()
.collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (k1, k2) -> k1));
return nameToValue.get(name);
}
@@ -69,15 +66,11 @@ public class SystemConfig {
if (CollectionUtils.isEmpty(parameters)) {
return defaultParameters;
}
Map<String, String> parameterNameValueMap =
parameters.stream()
.collect(
Collectors.toMap(
Parameter::getName, Parameter::getValue, (v1, v2) -> v2));
Map<String, String> parameterNameValueMap = parameters.stream()
.collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (v1, v2) -> v2));
for (Parameter parameter : defaultParameters) {
parameter.setValue(
parameterNameValueMap.getOrDefault(
parameter.getName(), parameter.getDefaultValue()));
parameter.setValue(parameterNameValueMap.getOrDefault(parameter.getName(),
parameter.getDefaultValue()));
}
return defaultParameters;
}

View File

@@ -14,8 +14,8 @@ import org.springframework.web.servlet.ModelAndView;
@Slf4j
public class LogInterceptor implements HandlerInterceptor {
@Override
public boolean preHandle(
HttpServletRequest request, HttpServletResponse response, Object handler) {
public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
Object handler) {
// use previous traceId
String traceId = request.getHeader(TraceIdUtil.TRACE_ID);
if (StringUtils.isBlank(traceId)) {
@@ -27,17 +27,12 @@ public class LogInterceptor implements HandlerInterceptor {
}
@Override
public void postHandle(
HttpServletRequest request,
HttpServletResponse response,
Object handler,
ModelAndView modelAndView)
throws Exception {}
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
ModelAndView modelAndView) throws Exception {}
@Override
public void afterCompletion(
HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
throws Exception {
public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
Object handler, Exception ex) throws Exception {
// remove after Completing
TraceIdUtil.remove();
}

View File

@@ -5,13 +5,9 @@ import java.util.Map;
import java.util.stream.Collectors;
public enum AggregateEnum {
MOST("最多", "max"),
HIGHEST("", "max"),
MAXIMUN("最大", "max"),
LEAST("最少", "min"),
SMALLEST("最小", "min"),
LOWEST("最低", "min"),
AVERAGE("平均", "avg");
MOST("最多", "max"), HIGHEST("最高", "max"), MAXIMUN("最大", "max"), LEAST("最少",
"min"), SMALLEST("", "min"), LOWEST("最低", "min"), AVERAGE("平均", "avg");
private String aggregateCh;
private String aggregateEN;
@@ -29,9 +25,7 @@ public enum AggregateEnum {
}
public static Map<String, String> getAggregateEnum() {
return Arrays.stream(AggregateEnum.values())
.collect(
Collectors.toMap(
AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
return Arrays.stream(AggregateEnum.values()).collect(
Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
}
}

View File

@@ -15,8 +15,8 @@ public class CustomExpressionDeParser extends ExpressionDeParser {
private boolean dealNull;
private boolean dealNotNull;
public CustomExpressionDeParser(
Set<String> removeFieldNames, boolean dealNull, boolean dealNotNull) {
public CustomExpressionDeParser(Set<String> removeFieldNames, boolean dealNull,
boolean dealNotNull) {
this.removeFieldNames = removeFieldNames;
this.dealNull = dealNull;
this.dealNotNull = dealNotNull;
@@ -45,12 +45,10 @@ public class CustomExpressionDeParser extends ExpressionDeParser {
Expression leftExpression = ((AndExpression) binaryExpression).getLeftExpression();
Expression rightExpression = ((AndExpression) binaryExpression).getRightExpression();
boolean leftIsNull =
leftExpression instanceof IsNullExpression
&& shouldSkip((IsNullExpression) leftExpression);
boolean rightIsNull =
rightExpression instanceof IsNullExpression
&& shouldSkip((IsNullExpression) rightExpression);
boolean leftIsNull = leftExpression instanceof IsNullExpression
&& shouldSkip((IsNullExpression) leftExpression);
boolean rightIsNull = rightExpression instanceof IsNullExpression
&& shouldSkip((IsNullExpression) rightExpression);
if (leftIsNull && rightIsNull) {
// Skip both expressions

View File

@@ -13,8 +13,8 @@ import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
@Slf4j
public class DateFunctionHelper {
public static String getStartDateStr(
ComparisonOperator minorThanEquals, ExpressionList<?> expressions) {
public static String getStartDateStr(ComparisonOperator minorThanEquals,
ExpressionList<?> expressions) {
String unitValue = getUnit(expressions);
String dateValue = getEndDateValue(expressions);
String dateStr = "";

View File

@@ -23,9 +23,8 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter {
expr.getWhenExpression().accept(this);
if (expr.getThenExpression() instanceof Column) {
Column column = (Column) expr.getThenExpression();
Expression expression =
QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr(column, fieldExprMap));
Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr(column, fieldExprMap));
if (Objects.nonNull(expression)) {
expr.setThenExpression(expression);
}
@@ -52,20 +51,16 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter {
}
}
if (left instanceof Column) {
Expression expression =
QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr(
(Column) left, fieldExprMap));
Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr((Column) left, fieldExprMap));
if (Objects.nonNull(expression)) {
expr.setLeftExpression(expression);
leftVisited = true;
}
}
if (right instanceof Column) {
Expression expression =
QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr(
(Column) right, fieldExprMap));
Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr((Column) right, fieldExprMap));
if (Objects.nonNull(expression)) {
expr.setRightExpression(expression);
rightVisited = true;
@@ -81,9 +76,8 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter {
private boolean visitFunction(Function function) {
if (function.getParameters().getExpressions().get(0) instanceof Column) {
Expression expression =
QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr(function, fieldExprMap));
Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getReplaceExpr(function, fieldExprMap));
if (Objects.nonNull(expression)) {
ExpressionList<Expression> expressions = new ExpressionList<>();
expressions.add(expression);

View File

@@ -130,8 +130,8 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
Arrays.stream(DatePeriodEnum.values()).collect(Collectors.toList());
DatePeriodEnum periodEnum = DatePeriodEnum.get(functionName);
if (Objects.nonNull(periodEnum) && collect.contains(periodEnum)) {
fieldExpression.setFieldValue(
getFieldValue(rightExpression) + periodEnum.getChName());
fieldExpression
.setFieldValue(getFieldValue(rightExpression) + periodEnum.getChName());
return fieldExpression;
} else {
// deal with aggregate function

View File

@@ -31,8 +31,8 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
private boolean exactReplace;
private Map<String, Map<String, String>> filedNameToValueMap;
public FieldValueReplaceVisitor(
boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
public FieldValueReplaceVisitor(boolean exactReplace,
Map<String, Map<String, String>> filedNameToValueMap) {
this.exactReplace = exactReplace;
this.filedNameToValueMap = filedNameToValueMap;
}
@@ -67,24 +67,20 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
ExpressionList rightItemsList = (ExpressionList) inExpression.getRightExpression();
List<Expression> expressions = rightItemsList.getExpressions();
List<String> values = new ArrayList<>();
expressions.stream()
.forEach(
o -> {
if (o instanceof StringValue) {
values.add(((StringValue) o).getValue());
}
});
expressions.stream().forEach(o -> {
if (o instanceof StringValue) {
values.add(((StringValue) o).getValue());
}
});
if (valueMap == null || CollectionUtils.isEmpty(values)) {
return;
}
List<Expression> newExpressions = new ArrayList<>();
values.stream()
.forEach(
o -> {
String replaceValue = valueMap.getOrDefault(o, o);
StringValue stringValue = new StringValue(replaceValue);
newExpressions.add(stringValue);
});
values.stream().forEach(o -> {
String replaceValue = valueMap.getOrDefault(o, o);
StringValue stringValue = new StringValue(replaceValue);
newExpressions.add(stringValue);
});
rightItemsList.setExpressions(newExpressions);
inExpression.setRightExpression(rightItemsList);
}

View File

@@ -34,11 +34,9 @@ public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
Expression leftExpression = expr.getLeftExpression();
Expression rightExpression = expr.getRightExpression();
if (!(rightExpression instanceof StringValue)
|| !(leftExpression instanceof Column)
if (!(rightExpression instanceof StringValue) || !(leftExpression instanceof Column)
|| CollectionUtils.isEmpty(fieldValueToFieldNames)
|| Objects.isNull(rightExpression)
|| Objects.isNull(leftExpression)) {
|| Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
return;
}

View File

@@ -21,8 +21,8 @@ public class FunctionAliasReplaceVisitor extends SelectItemVisitorAdapter {
// 2.alias's fieldName not equal. "sum(pv) as pv" cannot be replaced.
if (Objects.nonNull(selectExpressionItem.getAlias())
&& !selectExpressionItem.getAlias().getName().equalsIgnoreCase(columnName)) {
aliasToActualExpression.put(
selectExpressionItem.getAlias().getName(), function.toString());
aliasToActualExpression.put(selectExpressionItem.getAlias().getName(),
function.toString());
selectExpressionItem.setAlias(null);
}
}

View File

@@ -16,8 +16,8 @@ public class FunctionNameReplaceVisitor extends ExpressionVisitorAdapter {
private Map<String, String> functionMap;
private Map<String, UnaryOperator> functionCallMap;
public FunctionNameReplaceVisitor(
Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
public FunctionNameReplaceVisitor(Map<String, String> functionMap,
Map<String, UnaryOperator> functionCallMap) {
this.functionMap = functionMap;
this.functionCallMap = functionCallMap;
}

View File

@@ -19,8 +19,8 @@ public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
private Map<String, String> functionMap;
private Map<String, UnaryOperator> functionCallMap;
public GroupByFunctionReplaceVisitor(
Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
public GroupByFunctionReplaceVisitor(Map<String, String> functionMap,
Map<String, UnaryOperator> functionCallMap) {
this.functionMap = functionMap;
this.functionCallMap = functionCallMap;
}

View File

@@ -53,11 +53,8 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
return expression.toString();
}
private void replaceExpression(
List<Expression> groupByExpressions,
int index,
Expression expression,
String replaceColumn) {
private void replaceExpression(List<Expression> groupByExpressions, int index,
Expression expression, String replaceColumn) {
if (expression instanceof Column) {
groupByExpressions.set(index, new Column(replaceColumn));
} else if (expression instanceof Function) {
@@ -68,8 +65,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
Function function = (Function) expression;
if (function.getParameters().size() > 1) {
function.getParameters().stream()
.skip(1)
function.getParameters().stream().skip(1)
.forEach(e -> newExpressionList.add((Function) e));
}
function.setParameters(newExpressionList);

View File

@@ -27,26 +27,14 @@ public class JsqlConstants {
public static final String IN_CONSTANT = " 1 in (1) ";
public static final String LIKE_CONSTANT = "1 like 1";
public static final String IN = "IN";
public static final Map<String, String> rightMap =
Stream.of(
new AbstractMap.SimpleEntry<>("<=", "<="),
new AbstractMap.SimpleEntry<>("<", "<"),
new AbstractMap.SimpleEntry<>(">=", "<="),
new AbstractMap.SimpleEntry<>(">", "<"),
new AbstractMap.SimpleEntry<>("=", "<="))
.collect(
toMap(
AbstractMap.SimpleEntry::getKey,
AbstractMap.SimpleEntry::getValue));
public static final Map<String, String> leftMap =
Stream.of(
new AbstractMap.SimpleEntry<>("<=", ">="),
new AbstractMap.SimpleEntry<>("<", ">"),
new AbstractMap.SimpleEntry<>(">=", "<="),
new AbstractMap.SimpleEntry<>(">", "<"),
new AbstractMap.SimpleEntry<>("=", ">="))
.collect(
toMap(
AbstractMap.SimpleEntry::getKey,
AbstractMap.SimpleEntry::getValue));
public static final Map<String, String> rightMap = Stream.of(
new AbstractMap.SimpleEntry<>("<=", "<="), new AbstractMap.SimpleEntry<>("<", "<"),
new AbstractMap.SimpleEntry<>(">=", "<="), new AbstractMap.SimpleEntry<>(">", "<"),
new AbstractMap.SimpleEntry<>("=", "<="))
.collect(toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
public static final Map<String, String> leftMap = Stream.of(
new AbstractMap.SimpleEntry<>("<=", ">="), new AbstractMap.SimpleEntry<>("<", ">"),
new AbstractMap.SimpleEntry<>(">=", "<="), new AbstractMap.SimpleEntry<>(">", "<"),
new AbstractMap.SimpleEntry<>("=", ">="))
.collect(toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
}

View File

@@ -13,8 +13,8 @@ import java.util.stream.Collectors;
@Slf4j
public class ParseVisitorHelper {
public void replaceColumn(
Column column, Map<String, String> fieldNameMap, boolean exactReplace) {
public void replaceColumn(Column column, Map<String, String> fieldNameMap,
boolean exactReplace) {
String columnName = StringUtil.replaceBackticks(column.getColumnName());
String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) {
@@ -22,8 +22,8 @@ public class ParseVisitorHelper {
}
}
public String getReplaceValue(
String beforeValue, Map<String, String> valueMap, boolean exactReplace) {
public String getReplaceValue(String beforeValue, Map<String, String> valueMap,
boolean exactReplace) {
String value = valueMap.get(beforeValue);
if (StringUtils.isNotBlank(value)) {
return value;
@@ -31,19 +31,13 @@ public class ParseVisitorHelper {
if (exactReplace) {
return null;
}
Optional<Entry<String, String>> first =
valueMap.entrySet().stream()
.sorted(
(k1, k2) -> {
String k1Value = k1.getKey();
String k2Value = k2.getKey();
Double k1Similarity = getSimilarity(beforeValue, k1Value);
Double k2Similarity = getSimilarity(beforeValue, k2Value);
return k2Similarity.compareTo(k1Similarity);
})
.collect(Collectors.toList())
.stream()
.findFirst();
Optional<Entry<String, String>> first = valueMap.entrySet().stream().sorted((k1, k2) -> {
String k1Value = k1.getKey();
String k2Value = k2.getKey();
Double k1Similarity = getSimilarity(beforeValue, k1Value);
Double k2Similarity = getSimilarity(beforeValue, k2Value);
return k2Similarity.compareTo(k1Similarity);
}).collect(Collectors.toList()).stream().findFirst();
if (first.isPresent()) {
return first.get().getValue();
@@ -68,16 +62,12 @@ public class ParseVisitorHelper {
char cj = word2.charAt(j - 1);
if (ci == cj) {
dp[i][j] = dp[i - 1][j - 1];
} else if (i > 1
&& j > 1
&& ci == word2.charAt(j - 2)
} else if (i > 1 && j > 1 && ci == word2.charAt(j - 2)
&& cj == word1.charAt(i - 2)) {
dp[i][j] = 1 + Math.min(dp[i - 2][j - 2], Math.min(dp[i][j - 1], dp[i - 1][j]));
} else {
dp[i][j] =
Math.min(
dp[i - 1][j - 1] + 1,
Math.min(dp[i][j - 1] + 1, dp[i - 1][j] + 1));
dp[i][j] = Math.min(dp[i - 1][j - 1] + 1,
Math.min(dp[i][j - 1] + 1, dp[i - 1][j] + 1));
}
}
}

View File

@@ -43,32 +43,21 @@ public class SqlAddHelper {
}
if (selectStatement instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectStatement;
fields.stream()
.filter(Objects::nonNull)
.forEach(
field -> {
SelectItem<Column> selectExpressionItem =
new SelectItem(new Column(field));
plainSelect.addSelectItems(selectExpressionItem);
});
fields.stream().filter(Objects::nonNull).forEach(field -> {
SelectItem<Column> selectExpressionItem = new SelectItem(new Column(field));
plainSelect.addSelectItems(selectExpressionItem);
});
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
fields.stream()
.forEach(
field -> {
SelectItem<Column> selectExpressionItem =
new SelectItem(new Column(field));
subPlainSelect.addSelectItems(
selectExpressionItem);
});
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
fields.stream().forEach(field -> {
SelectItem<Column> selectExpressionItem = new SelectItem(new Column(field));
subPlainSelect.addSelectItems(selectExpressionItem);
});
});
}
}
return selectStatement.toString();
@@ -88,13 +77,10 @@ public class SqlAddHelper {
SetOperationList setOperationList =
(SetOperationList) selectStatement.getSetOperationList();
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
});
}
}
@@ -238,18 +224,15 @@ public class SqlAddHelper {
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
selectStatement.accept(
new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
addAggregateToSelectItems(
plainSelect.getSelectItems(), fieldNameToAggregate);
addAggregateToOrderByItems(
plainSelect.getOrderByElements(), fieldNameToAggregate);
addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate);
addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate);
}
});
selectStatement.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate);
addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate);
addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate);
addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate);
}
});
return selectStatement.toString();
}
@@ -276,8 +259,8 @@ public class SqlAddHelper {
return selectStatement.toString();
}
private static void addAggregateToSelectItems(
List<SelectItem<?>> selectItems, Map<String, String> fieldNameToAggregate) {
private static void addAggregateToSelectItems(List<SelectItem<?>> selectItems,
Map<String, String> fieldNameToAggregate) {
for (SelectItem selectItem : selectItems) {
Expression expression = selectItem.getExpression();
Function function =
@@ -289,8 +272,8 @@ public class SqlAddHelper {
}
}
private static void addAggregateToOrderByItems(
List<OrderByElement> orderByElements, Map<String, String> fieldNameToAggregate) {
private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
Map<String, String> fieldNameToAggregate) {
if (orderByElements == null) {
return;
}
@@ -305,8 +288,8 @@ public class SqlAddHelper {
}
}
private static void addAggregateToGroupByItems(
GroupByElement groupByElement, Map<String, String> fieldNameToAggregate) {
private static void addAggregateToGroupByItems(GroupByElement groupByElement,
Map<String, String> fieldNameToAggregate) {
if (groupByElement == null) {
return;
}
@@ -321,16 +304,16 @@ public class SqlAddHelper {
}
}
private static void addAggregateToWhereItems(
Expression whereExpression, Map<String, String> fieldNameToAggregate) {
private static void addAggregateToWhereItems(Expression whereExpression,
Map<String, String> fieldNameToAggregate) {
if (whereExpression == null) {
return;
}
modifyWhereExpression(whereExpression, fieldNameToAggregate);
}
private static void modifyWhereExpression(
Expression whereExpression, Map<String, String> fieldNameToAggregate) {
private static void modifyWhereExpression(Expression whereExpression,
Map<String, String> fieldNameToAggregate) {
if (SqlSelectHelper.isLogicExpression(whereExpression)) {
if (whereExpression instanceof AndExpression) {
AndExpression andExpression = (AndExpression) whereExpression;
@@ -347,15 +330,15 @@ public class SqlAddHelper {
modifyWhereExpression(rightExpression, fieldNameToAggregate);
}
} else if (whereExpression instanceof Parenthesis) {
modifyWhereExpression(
((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate);
modifyWhereExpression(((Parenthesis) whereExpression).getExpression(),
fieldNameToAggregate);
} else {
setAggToFunction(whereExpression, fieldNameToAggregate);
}
}
private static void setAggToFunction(
Expression expression, Map<String, String> fieldNameToAggregate) {
private static void setAggToFunction(Expression expression,
Map<String, String> fieldNameToAggregate) {
if (!(expression instanceof ComparisonOperator)) {
return;
}
@@ -363,20 +346,16 @@ public class SqlAddHelper {
if (comparisonOperator.getRightExpression() instanceof Column) {
String columnName =
((Column) (comparisonOperator).getRightExpression()).getColumnName();
Function function =
SqlSelectFunctionHelper.getFunction(
comparisonOperator.getRightExpression(),
fieldNameToAggregate.get(columnName));
Function function = SqlSelectFunctionHelper.getFunction(
comparisonOperator.getRightExpression(), fieldNameToAggregate.get(columnName));
if (Objects.nonNull(function)) {
comparisonOperator.setRightExpression(function);
}
}
if (comparisonOperator.getLeftExpression() instanceof Column) {
String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName();
Function function =
SqlSelectFunctionHelper.getFunction(
comparisonOperator.getLeftExpression(),
fieldNameToAggregate.get(columnName));
Function function = SqlSelectFunctionHelper.getFunction(
comparisonOperator.getLeftExpression(), fieldNameToAggregate.get(columnName));
if (Objects.nonNull(function)) {
comparisonOperator.setLeftExpression(function);
}

View File

@@ -27,18 +27,17 @@ public class SqlAsHelper {
if (plainSelect instanceof Select) {
Select select = plainSelect;
Select selectBody = select.getSelectBody();
selectBody.accept(
new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
extractAliasesFromSelect(plainSelect, aliases);
}
selectBody.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
extractAliasesFromSelect(plainSelect, aliases);
}
@Override
public void visit(WithItem withItem) {
withItem.getSelectBody().accept(this);
}
});
@Override
public void visit(WithItem withItem) {
withItem.getSelectBody().accept(this);
}
});
}
}
return new ArrayList<>(aliases);

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.common.jsqlparser;
public enum SqlEditEnum {
NUMBER_FILTER,
DATEDIFF
NUMBER_FILTER, DATEDIFF
}

View File

@@ -67,15 +67,14 @@ public class SqlRemoveHelper {
}
List<SelectItem<?>> selectItems = ((PlainSelect) selectStatement).getSelectItems();
Set<String> fields = new HashSet<>();
selectItems.removeIf(
selectItem -> {
String field = selectItem.getExpression().toString();
if (fields.contains(field)) {
return true;
}
fields.add(field);
return false;
});
selectItems.removeIf(selectItem -> {
String field = selectItem.getExpression().toString();
if (fields.contains(field)) {
return true;
}
fields.add(field);
return false;
});
((PlainSelect) selectStatement).setSelectItems(selectItems);
return selectStatement.toString();
}
@@ -85,18 +84,17 @@ public class SqlRemoveHelper {
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
selectStatement.accept(
new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
removeWhereCondition(plainSelect.getWhere(), removeFieldNames);
}
});
selectStatement.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
removeWhereCondition(plainSelect.getWhere(), removeFieldNames);
}
});
return removeNumberFilter(selectStatement.toString());
}
private static void removeWhereCondition(
Expression whereExpression, Set<String> removeFieldNames) {
private static void removeWhereCondition(Expression whereExpression,
Set<String> removeFieldNames) {
if (whereExpression == null) {
return;
}
@@ -121,8 +119,8 @@ public class SqlRemoveHelper {
return selectStatement.toString();
}
private static void removeWhereExpression(
Expression whereExpression, Set<String> removeFieldNames) {
private static void removeWhereExpression(Expression whereExpression,
Set<String> removeFieldNames) {
if (SqlSelectHelper.isLogicExpression(whereExpression)) {
BinaryExpression binaryExpression = (BinaryExpression) whereExpression;
Expression leftExpression = binaryExpression.getLeftExpression();
@@ -131,8 +129,8 @@ public class SqlRemoveHelper {
removeWhereExpression(leftExpression, removeFieldNames);
removeWhereExpression(rightExpression, removeFieldNames);
} else if (whereExpression instanceof Parenthesis) {
removeWhereExpression(
((Parenthesis) whereExpression).getExpression(), removeFieldNames);
removeWhereExpression(((Parenthesis) whereExpression).getExpression(),
removeFieldNames);
} else {
removeExpressionWithConstant(whereExpression, removeFieldNames);
}
@@ -152,8 +150,8 @@ public class SqlRemoveHelper {
return constant;
}
private static void removeExpressionWithConstant(
Expression expression, Set<String> removeFieldNames) {
private static void removeExpressionWithConstant(Expression expression,
Set<String> removeFieldNames) {
try {
if (expression instanceof ComparisonOperator) {
handleComparisonOperator((ComparisonOperator) expression, removeFieldNames);
@@ -167,13 +165,10 @@ public class SqlRemoveHelper {
}
}
private static void handleComparisonOperator(
ComparisonOperator comparisonOperator, Set<String> removeFieldNames)
throws JSQLParserException {
String columnName =
SqlSelectHelper.getColumnName(
comparisonOperator.getLeftExpression(),
comparisonOperator.getRightExpression());
private static void handleComparisonOperator(ComparisonOperator comparisonOperator,
Set<String> removeFieldNames) throws JSQLParserException {
String columnName = SqlSelectHelper.getColumnName(comparisonOperator.getLeftExpression(),
comparisonOperator.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
@@ -185,9 +180,8 @@ public class SqlRemoveHelper {
private static void handleInExpression(InExpression inExpression, Set<String> removeFieldNames)
throws JSQLParserException {
String columnName =
SqlSelectHelper.getColumnName(
inExpression.getLeftExpression(), inExpression.getRightExpression());
String columnName = SqlSelectHelper.getColumnName(inExpression.getLeftExpression(),
inExpression.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
@@ -196,12 +190,10 @@ public class SqlRemoveHelper {
updateInExpression(inExpression, constantExpression);
}
private static void handleLikeExpression(
LikeExpression likeExpression, Set<String> removeFieldNames)
throws JSQLParserException {
String columnName =
SqlSelectHelper.getColumnName(
likeExpression.getLeftExpression(), likeExpression.getRightExpression());
private static void handleLikeExpression(LikeExpression likeExpression,
Set<String> removeFieldNames) throws JSQLParserException {
String columnName = SqlSelectHelper.getColumnName(likeExpression.getLeftExpression(),
likeExpression.getRightExpression());
if (!removeFieldNames.contains(columnName)) {
return;
}
@@ -210,8 +202,8 @@ public class SqlRemoveHelper {
updateLikeExpression(likeExpression, constantExpression);
}
private static void updateComparisonOperator(
ComparisonOperator original, ComparisonOperator constantExpression) {
private static void updateComparisonOperator(ComparisonOperator original,
ComparisonOperator constantExpression) {
original.setLeftExpression(constantExpression.getLeftExpression());
original.setRightExpression(constantExpression.getRightExpression());
original.setASTNode(constantExpression.getASTNode());
@@ -223,8 +215,8 @@ public class SqlRemoveHelper {
original.setASTNode(constantExpression.getASTNode());
}
private static void updateLikeExpression(
LikeExpression original, LikeExpression constantExpression) {
private static void updateLikeExpression(LikeExpression original,
LikeExpression constantExpression) {
original.setLeftExpression(constantExpression.getLeftExpression());
original.setRightExpression(constantExpression.getRightExpression());
}
@@ -234,13 +226,12 @@ public class SqlRemoveHelper {
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
selectStatement.accept(
new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
removeWhereCondition(plainSelect.getHaving(), removeFieldNames);
}
});
selectStatement.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
removeWhereCondition(plainSelect.getHaving(), removeFieldNames);
}
});
return removeNumberFilter(selectStatement.toString());
}
@@ -254,16 +245,13 @@ public class SqlRemoveHelper {
return sql;
}
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
groupByExpressionList
.getExpressions()
.removeIf(
expression -> {
if (expression instanceof Column) {
Column column = (Column) expression;
return fields.contains(column.getColumnName());
}
return false;
});
groupByExpressionList.getExpressions().removeIf(expression -> {
if (expression instanceof Column) {
Column column = (Column) expression;
return fields.contains(column.getColumnName());
}
return false;
});
if (CollectionUtils.isEmpty(groupByExpressionList.getExpressions())) {
((PlainSelect) selectStatement).setGroupByElement(null);
}
@@ -279,15 +267,14 @@ public class SqlRemoveHelper {
Iterator<SelectItem<?>> iterator = selectItems.iterator();
while (iterator.hasNext()) {
SelectItem selectItem = iterator.next();
selectItem.accept(
new SelectItemVisitorAdapter() {
@Override
public void visit(SelectItem item) {
if (fields.contains(item.getExpression().toString())) {
iterator.remove();
}
}
});
selectItem.accept(new SelectItemVisitorAdapter() {
@Override
public void visit(SelectItem item) {
if (fields.contains(item.getExpression().toString())) {
iterator.remove();
}
}
});
}
if (selectItems.isEmpty()) {
selectItems.add(new SelectItem(new AllColumns()));
@@ -345,17 +332,14 @@ public class SqlRemoveHelper {
}
}
private static Expression dealComparisonOperatorFilter(
Expression expression, SqlEditEnum sqlEditEnum) {
private static Expression dealComparisonOperatorFilter(Expression expression,
SqlEditEnum sqlEditEnum) {
if (Objects.isNull(expression)) {
return null;
}
if (expression instanceof GreaterThanEquals
|| expression instanceof GreaterThan
|| expression instanceof MinorThan
|| expression instanceof MinorThanEquals
|| expression instanceof EqualsTo
|| expression instanceof NotEqualsTo) {
if (expression instanceof GreaterThanEquals || expression instanceof GreaterThan
|| expression instanceof MinorThan || expression instanceof MinorThanEquals
|| expression instanceof EqualsTo || expression instanceof NotEqualsTo) {
return removeSingleFilter((ComparisonOperator) expression, sqlEditEnum);
} else if (expression instanceof InExpression) {
InExpression inExpression = (InExpression) expression;
@@ -369,14 +353,14 @@ public class SqlRemoveHelper {
return expression;
}
private static Expression removeSingleFilter(
ComparisonOperator comparisonExpression, SqlEditEnum sqlEditEnum) {
private static Expression removeSingleFilter(ComparisonOperator comparisonExpression,
SqlEditEnum sqlEditEnum) {
Expression leftExpression = comparisonExpression.getLeftExpression();
return recursionBase(leftExpression, comparisonExpression, sqlEditEnum);
}
private static Expression recursionBase(
Expression leftExpression, Expression expression, SqlEditEnum sqlEditEnum) {
private static Expression recursionBase(Expression leftExpression, Expression expression,
SqlEditEnum sqlEditEnum) {
if (sqlEditEnum.equals(SqlEditEnum.NUMBER_FILTER)) {
return distinguishNumberFilter(leftExpression, expression);
}
@@ -386,8 +370,8 @@ public class SqlRemoveHelper {
return expression;
}
private static Expression distinguishNumberFilter(
Expression leftExpression, Expression expression) {
private static Expression distinguishNumberFilter(Expression leftExpression,
Expression expression) {
if (leftExpression instanceof LongValue) {
return null;
} else {
@@ -403,8 +387,8 @@ public class SqlRemoveHelper {
return removeIsNullOrNotNullInWhere(false, true, sql, removeFieldNames);
}
public static String removeIsNullOrNotNullInWhere(
boolean dealNull, boolean dealNotNull, String sql, Set<String> removeFieldNames) {
public static String removeIsNullOrNotNullInWhere(boolean dealNull, boolean dealNotNull,
String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;

View File

@@ -46,57 +46,46 @@ import java.util.function.UnaryOperator;
/** Sql Parser replace Helper */
@Slf4j
public class SqlReplaceHelper {
public static String replaceAggFields(
String sql, Map<String, Pair<String, String>> fieldNameToAggMap) {
public static String replaceAggFields(String sql,
Map<String, Pair<String, String>> fieldNameToAggMap) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
((PlainSelect) selectStatement)
.getSelectItems().stream()
.forEach(
o -> {
SelectItem selectExpressionItem = (SelectItem) o;
if (selectExpressionItem.getExpression() instanceof Function) {
Function function =
(Function) selectExpressionItem.getExpression();
Column column =
(Column)
function.getParameters()
.getExpressions()
.get(0);
if (fieldNameToAggMap.containsKey(column.getColumnName())) {
Pair<String, String> agg =
fieldNameToAggMap.get(column.getColumnName());
String field = agg.getKey();
String func = agg.getRight();
if (AggOperatorEnum.isCountDistinct(func)) {
function.setName("count");
function.setDistinct(true);
} else {
function.setName(func);
}
function.withParameters(new Column(field));
if (Objects.nonNull(selectExpressionItem.getAlias())
&& StringUtils.isNotBlank(field)) {
selectExpressionItem.getAlias().setName(field);
}
}
}
});
((PlainSelect) selectStatement).getSelectItems().stream().forEach(o -> {
SelectItem selectExpressionItem = (SelectItem) o;
if (selectExpressionItem.getExpression() instanceof Function) {
Function function = (Function) selectExpressionItem.getExpression();
Column column = (Column) function.getParameters().getExpressions().get(0);
if (fieldNameToAggMap.containsKey(column.getColumnName())) {
Pair<String, String> agg = fieldNameToAggMap.get(column.getColumnName());
String field = agg.getKey();
String func = agg.getRight();
if (AggOperatorEnum.isCountDistinct(func)) {
function.setName("count");
function.setDistinct(true);
} else {
function.setName(func);
}
function.withParameters(new Column(field));
if (Objects.nonNull(selectExpressionItem.getAlias())
&& StringUtils.isNotBlank(field)) {
selectExpressionItem.getAlias().setName(field);
}
}
}
});
return selectStatement.toString();
}
public static String replaceValue(
String sql, Map<String, Map<String, String>> filedNameToValueMap) {
public static String replaceValue(String sql,
Map<String, Map<String, String>> filedNameToValueMap) {
return replaceValue(sql, filedNameToValueMap, true);
}
public static String replaceValue(
String sql,
Map<String, Map<String, String>> filedNameToValueMap,
boolean exactReplace) {
public static String replaceValue(String sql,
Map<String, Map<String, String>> filedNameToValueMap, boolean exactReplace) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
@@ -113,8 +102,8 @@ public class SqlReplaceHelper {
return selectStatement.toString();
}
public static String replaceFieldNameByValue(
String sql, Map<String, Set<String>> fieldValueToFieldNames) {
public static String replaceFieldNameByValue(String sql,
Map<String, Set<String>> fieldValueToFieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
@@ -145,14 +134,11 @@ public class SqlReplaceHelper {
} else if (select instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) select;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
});
}
}
}
@@ -161,8 +147,8 @@ public class SqlReplaceHelper {
return replaceFields(sql, fieldNameMap, false);
}
public static String replaceFields(
String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
public static String replaceFields(String sql, Map<String, String> fieldNameMap,
boolean exactReplace) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
if (selectStatement instanceof PlainSelect) {
@@ -172,14 +158,11 @@ public class SqlReplaceHelper {
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
});
}
List<OrderByElement> orderByElements = setOperationList.getOrderByElements();
if (!CollectionUtils.isEmpty(orderByElements)) {
@@ -197,8 +180,8 @@ public class SqlReplaceHelper {
return selectStatement.toString();
}
private static void replaceFieldsInPlainOneSelect(
Map<String, String> fieldNameMap, boolean exactReplace, PlainSelect plainSelect) {
private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameMap,
boolean exactReplace, PlainSelect plainSelect) {
// 1. replace where fields
Expression where = plainSelect.getWhere();
FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace);
@@ -220,14 +203,10 @@ public class SqlReplaceHelper {
} else if (select instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) select;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
replaceFieldsInPlainOneSelect(
fieldNameMap, exactReplace, subPlainSelect);
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect);
});
}
}
}
@@ -253,11 +232,9 @@ public class SqlReplaceHelper {
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
if (!CollectionUtils.isEmpty(join.getOnExpressions())) {
join.getOnExpressions().stream()
.forEach(
onExpression -> {
onExpression.accept(visitor);
});
join.getOnExpressions().stream().forEach(onExpression -> {
onExpression.accept(visitor);
});
}
if (!(join.getRightItem() instanceof ParenthesedSelect)) {
continue;
@@ -278,8 +255,8 @@ public class SqlReplaceHelper {
return replaceFunction(sql, functionMap, null);
}
public static String replaceFunction(
String sql, Map<String, String> functionMap, Map<String, UnaryOperator> functionCall) {
public static String replaceFunction(String sql, Map<String, String> functionMap,
Map<String, UnaryOperator> functionCall) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
@@ -293,10 +270,8 @@ public class SqlReplaceHelper {
return selectStatement.toString();
}
private static void replaceFunction(
Map<String, String> functionMap,
Map<String, UnaryOperator> functionCall,
PlainSelect selectBody) {
private static void replaceFunction(Map<String, String> functionMap,
Map<String, UnaryOperator> functionCall, PlainSelect selectBody) {
PlainSelect plainSelect = selectBody;
// 1. replace where dataDiff function
Expression where = plainSelect.getWhere();
@@ -356,8 +331,8 @@ public class SqlReplaceHelper {
}
}
private static void replaceComparisonOperatorFunction(
Map<String, String> functionMap, Expression expression) {
private static void replaceComparisonOperatorFunction(Map<String, String> functionMap,
Expression expression) {
if (Objects.isNull(expression)) {
return;
}
@@ -376,8 +351,8 @@ public class SqlReplaceHelper {
}
}
private static void replaceOrderByFunction(
Map<String, String> functionMap, List<OrderByElement> orderByElementList) {
private static void replaceOrderByFunction(Map<String, String> functionMap,
List<OrderByElement> orderByElementList) {
if (Objects.isNull(orderByElementList)) {
return;
}
@@ -410,25 +385,23 @@ public class SqlReplaceHelper {
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
if (!CollectionUtils.isEmpty(plainSelectList)) {
List<String> withNameList = SqlSelectHelper.getWithName(sql);
plainSelectList.stream()
.forEach(
plainSelect -> {
if (plainSelect.getFromItem() instanceof Table) {
Table table = (Table) plainSelect.getFromItem();
if (!withNameList.contains(table.getName())) {
replaceSingleTable(plainSelect, tableName);
}
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect =
(ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
Table table = (Table) subPlainSelect.getFromItem();
if (!withNameList.contains(table.getName())) {
replaceSingleTable(subPlainSelect, tableName);
}
}
});
plainSelectList.stream().forEach(plainSelect -> {
if (plainSelect.getFromItem() instanceof Table) {
Table table = (Table) plainSelect.getFromItem();
if (!withNameList.contains(table.getName())) {
replaceSingleTable(plainSelect, tableName);
}
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect =
(ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
Table table = (Table) subPlainSelect.getFromItem();
if (!withNameList.contains(table.getName())) {
replaceSingleTable(subPlainSelect, tableName);
}
}
});
return selectStatement.toString();
}
if (selectStatement instanceof PlainSelect) {
@@ -438,14 +411,11 @@ public class SqlReplaceHelper {
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
replaceSingleTable(subPlainSelect, tableName);
replaceSubTable(subPlainSelect, tableName);
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
replaceSingleTable(subPlainSelect, tableName);
replaceSubTable(subPlainSelect, tableName);
});
}
}
@@ -476,15 +446,12 @@ public class SqlReplaceHelper {
plainSelects.add(plainSelect);
List<PlainSelect> painSelects = SqlSelectHelper.getPlainSelects(plainSelects);
for (PlainSelect painSelect : painSelects) {
painSelect.accept(
new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
plainSelect
.getFromItem()
.accept(new TableNameReplaceVisitor(tableName));
}
});
painSelect.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName));
}
});
List<Join> joins = painSelect.getJoins();
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
@@ -494,8 +461,7 @@ public class SqlReplaceHelper {
List<PlainSelect> subPlainSelects =
SqlSelectHelper.getPlainSelects(plainSelectList);
for (PlainSelect subPlainSelect : subPlainSelects) {
subPlainSelect
.getFromItem()
subPlainSelect.getFromItem()
.accept(new TableNameReplaceVisitor(tableName));
}
} else if (join.getRightItem() instanceof Table) {
@@ -524,8 +490,8 @@ public class SqlReplaceHelper {
return selectStatement.toString();
}
public static String replaceHavingValue(
String sql, Map<String, Map<String, String>> filedNameToValueMap) {
public static String replaceHavingValue(String sql,
Map<String, Map<String, String>> filedNameToValueMap) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
@@ -539,8 +505,8 @@ public class SqlReplaceHelper {
return selectStatement.toString();
}
public static Expression distinguishDateDiffFilter(
Expression leftExpression, Expression expression) {
public static Expression distinguishDateDiffFilter(Expression leftExpression,
Expression expression) {
if (leftExpression instanceof Function) {
Function function = (Function) leftExpression;
if (function.getName().equals(JsqlConstants.DATE_FUNCTION)) {
@@ -558,17 +524,14 @@ public class SqlReplaceHelper {
String endDateCondExpr =
columnName + endDateOperator + StringUtil.getCommaWrap(endDateValue);
ComparisonOperator rightExpression =
(ComparisonOperator)
CCJSqlParserUtil.parseCondExpression(endDateCondExpr);
ComparisonOperator rightExpression = (ComparisonOperator) CCJSqlParserUtil
.parseCondExpression(endDateCondExpr);
String startDateCondExpr =
columnName
+ StringUtil.getSpaceWrap(startDateOperator)
columnName + StringUtil.getSpaceWrap(startDateOperator)
+ StringUtil.getCommaWrap(startDateValue);
ComparisonOperator newLeftExpression =
(ComparisonOperator)
CCJSqlParserUtil.parseCondExpression(startDateCondExpr);
ComparisonOperator newLeftExpression = (ComparisonOperator) CCJSqlParserUtil
.parseCondExpression(startDateCondExpr);
AndExpression andExpression =
new AndExpression(newLeftExpression, rightExpression);
@@ -576,8 +539,8 @@ public class SqlReplaceHelper {
|| JsqlConstants.GREATER_THAN_EQUALS.equals(dateOperator)) {
return newLeftExpression;
} else {
return CCJSqlParserUtil.parseCondExpression(
"(" + andExpression.toString() + ")");
return CCJSqlParserUtil
.parseCondExpression("(" + andExpression.toString() + ")");
}
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
@@ -608,30 +571,24 @@ public class SqlReplaceHelper {
}
}
}
plainSelect.getOrderByElements().stream()
.forEach(
o -> {
if (o.getExpression() instanceof Function) {
Function function = (Function) o.getExpression();
if (function.getParameters().size() == 1
&& function.getParameters().get(0)
instanceof Column) {
Column column =
(Column) function.getParameters().get(0);
if (selectNames.containsKey(column.getColumnName())) {
o.setExpression(
new LongValue(
selectNames.get(
column.getColumnName())));
}
}
}
});
plainSelect.getOrderByElements().stream().forEach(o -> {
if (o.getExpression() instanceof Function) {
Function function = (Function) o.getExpression();
if (function.getParameters().size() == 1
&& function.getParameters().get(0) instanceof Column) {
Column column = (Column) function.getParameters().get(0);
if (selectNames.containsKey(column.getColumnName())) {
o.setExpression(
new LongValue(selectNames.get(column.getColumnName())));
}
}
}
});
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
parenthesedSelect.setSelect(
replaceAggAliasOrderItem(parenthesedSelect.getSelect()));
parenthesedSelect
.setSelect(replaceAggAliasOrderItem(parenthesedSelect.getSelect()));
}
return selectStatement;
}
@@ -665,13 +622,10 @@ public class SqlReplaceHelper {
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
});
}
} else {
return sql;
@@ -683,8 +637,8 @@ public class SqlReplaceHelper {
return selectStatement.toString();
}
private static void replacePlainSelectByExpr(
PlainSelect plainSelect, Map<String, String> replace) {
private static void replacePlainSelectByExpr(PlainSelect plainSelect,
Map<String, String> replace) {
QueryExpressionReplaceVisitor expressionReplaceVisitor =
new QueryExpressionReplaceVisitor(replace);
for (SelectItem selectItem : plainSelect.getSelectItems()) {
@@ -703,9 +657,8 @@ public class SqlReplaceHelper {
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
if (!CollectionUtils.isEmpty(orderByElements)) {
for (OrderByElement orderByElement : orderByElements) {
orderByElement.setExpression(
QueryExpressionReplaceVisitor.replace(
orderByElement.getExpression(), replace));
orderByElement.setExpression(QueryExpressionReplaceVisitor
.replace(orderByElement.getExpression(), replace));
}
}
}

View File

@@ -58,8 +58,8 @@ public class SqlSelectFunctionHelper {
return visitor.getFunctionNames();
}
public static Function getFunction(
Expression expression, Map<String, String> fieldNameToAggregate) {
public static Function getFunction(Expression expression,
Map<String, String> fieldNameToAggregate) {
if (!(expression instanceof Column)) {
return null;
}
@@ -100,8 +100,7 @@ public class SqlSelectFunctionHelper {
FunctionVisitor visitor = new FunctionVisitor();
expression.accept(visitor);
Set<String> functions = visitor.getFunctionNames();
return functions.stream()
.filter(t -> aggregateFunctionName.contains(t.toUpperCase()))
return functions.stream().filter(t -> aggregateFunctionName.contains(t.toUpperCase()))
.collect(Collectors.toList());
}
return new ArrayList<>();

View File

@@ -70,12 +70,9 @@ public class SqlSelectHelper {
having.accept(new FieldAndValueAcquireVisitor(result));
}
}
result =
result.stream()
.filter(
fieldExpression ->
StringUtils.isNotBlank(fieldExpression.getFieldName()))
.collect(Collectors.toSet());
result = result.stream()
.filter(fieldExpression -> StringUtils.isNotBlank(fieldExpression.getFieldName()))
.collect(Collectors.toSet());
return new ArrayList<>(result);
}
@@ -90,31 +87,27 @@ public class SqlSelectHelper {
}
public static void getWhereFields(List<PlainSelect> plainSelectList, Set<String> result) {
plainSelectList.stream()
.forEach(
plainSelect -> {
Expression where = plainSelect.getWhere();
if (Objects.nonNull(where)) {
where.accept(new FieldAcquireVisitor(result));
}
});
plainSelectList.stream().forEach(plainSelect -> {
Expression where = plainSelect.getWhere();
if (Objects.nonNull(where)) {
where.accept(new FieldAcquireVisitor(result));
}
});
}
public static List<String> gePureSelectFields(String sql) {
List<PlainSelect> plainSelectList = getPlainSelect(sql);
Set<String> result = new HashSet<>();
plainSelectList.stream()
.forEach(
plainSelect -> {
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
if (!(selectItem.getExpression() instanceof Column)) {
continue;
}
Column column = (Column) selectItem.getExpression();
result.add(column.getColumnName());
}
});
plainSelectList.stream().forEach(plainSelect -> {
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
if (!(selectItem.getExpression() instanceof Column)) {
continue;
}
Column column = (Column) selectItem.getExpression();
result.add(column.getColumnName());
}
});
return new ArrayList<>(result);
}
@@ -128,14 +121,12 @@ public class SqlSelectHelper {
public static Set<String> getSelectFields(List<PlainSelect> plainSelectList) {
Set<String> result = new HashSet<>();
plainSelectList.stream()
.forEach(
plainSelect -> {
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
selectItem.accept(new FieldAcquireVisitor(result));
}
});
plainSelectList.stream().forEach(plainSelect -> {
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
for (SelectItem selectItem : selectItems) {
selectItem.accept(new FieldAcquireVisitor(result));
}
});
return result;
}
@@ -152,13 +143,10 @@ public class SqlSelectHelper {
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList
.getSelects()
.forEach(
subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
getSubPlainSelect(subPlainSelect, plainSelectList);
});
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
getSubPlainSelect(subPlainSelect, plainSelectList);
});
}
}
return plainSelectList;
@@ -235,39 +223,37 @@ public class SqlSelectHelper {
List<PlainSelect> plainSelects = new ArrayList<>();
for (PlainSelect plainSelect : plainSelectList) {
plainSelects.add(plainSelect);
ExpressionVisitorAdapter expressionVisitor =
new ExpressionVisitorAdapter() {
@Override
public void visit(Select subSelect) {
if (subSelect instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) subSelect;
if (parenthesedSelect.getSelect() instanceof PlainSelect) {
plainSelects.add(parenthesedSelect.getPlainSelect());
}
}
ExpressionVisitorAdapter expressionVisitor = new ExpressionVisitorAdapter() {
@Override
public void visit(Select subSelect) {
if (subSelect instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) subSelect;
if (parenthesedSelect.getSelect() instanceof PlainSelect) {
plainSelects.add(parenthesedSelect.getPlainSelect());
}
};
}
}
};
plainSelect.accept(
new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
Expression whereExpression = plainSelect.getWhere();
if (whereExpression != null) {
whereExpression.accept(expressionVisitor);
}
Expression having = plainSelect.getHaving();
if (Objects.nonNull(having)) {
having.accept(expressionVisitor);
}
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
if (!CollectionUtils.isEmpty(selectItems)) {
for (SelectItem selectItem : selectItems) {
selectItem.accept(expressionVisitor);
}
}
plainSelect.accept(new SelectVisitorAdapter() {
@Override
public void visit(PlainSelect plainSelect) {
Expression whereExpression = plainSelect.getWhere();
if (whereExpression != null) {
whereExpression.accept(expressionVisitor);
}
Expression having = plainSelect.getHaving();
if (Objects.nonNull(having)) {
having.accept(expressionVisitor);
}
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
if (!CollectionUtils.isEmpty(selectItems)) {
for (SelectItem selectItem : selectItems) {
selectItem.accept(expressionVisitor);
}
});
}
}
});
}
return plainSelects;
}
@@ -313,14 +299,11 @@ public class SqlSelectHelper {
private static void getLateralViewsFields(PlainSelect plainSelect, Set<String> result) {
List<LateralView> lateralViews = plainSelect.getLateralViews();
if (!CollectionUtils.isEmpty(lateralViews)) {
lateralViews.stream()
.forEach(
l -> {
if (Objects.nonNull(l.getGeneratorFunction())) {
l.getGeneratorFunction()
.accept(new FieldAcquireVisitor(result));
}
});
lateralViews.stream().forEach(l -> {
if (Objects.nonNull(l.getGeneratorFunction())) {
l.getGeneratorFunction().accept(new FieldAcquireVisitor(result));
}
});
}
}
@@ -425,11 +408,9 @@ public class SqlSelectHelper {
private static void getOrderByFields(PlainSelect plainSelect, Set<String> result) {
Set<FieldExpression> orderByFieldExpressions = getOrderByFields(plainSelect);
Set<String> collect =
orderByFieldExpressions.stream()
.map(fieldExpression -> fieldExpression.getFieldName())
.filter(Objects::nonNull)
.collect(Collectors.toSet());
Set<String> collect = orderByFieldExpressions.stream()
.map(fieldExpression -> fieldExpression.getFieldName()).filter(Objects::nonNull)
.collect(Collectors.toSet());
result.addAll(collect);
}
@@ -487,9 +468,8 @@ public class SqlSelectHelper {
if (selectItem.getExpression() instanceof Function) {
Function function = (Function) selectItem.getExpression();
if (Objects.nonNull(function.getParameters())
&& !CollectionUtils.isEmpty(
function.getParameters().getExpressions())) {
if (Objects.nonNull(function.getParameters()) && !CollectionUtils
.isEmpty(function.getParameters().getExpressions())) {
String columnName =
function.getParameters().getExpressions().get(0).toString();
result.add(columnName);
@@ -516,9 +496,8 @@ public class SqlSelectHelper {
if (alias != null && StringUtils.isNotBlank(alias.getName())) {
result.add(alias.getName());
} else {
if (Objects.nonNull(function.getParameters())
&& !CollectionUtils.isEmpty(
function.getParameters().getExpressions())) {
if (Objects.nonNull(function.getParameters()) && !CollectionUtils
.isEmpty(function.getParameters().getExpressions())) {
String columnName =
function.getParameters().getExpressions().get(0).toString();
result.add(columnName);
@@ -552,9 +531,8 @@ public class SqlSelectHelper {
}
public static boolean isLogicExpression(Expression whereExpression) {
return whereExpression instanceof AndExpression
|| (whereExpression instanceof OrExpression
|| (whereExpression instanceof XorExpression));
return whereExpression instanceof AndExpression || (whereExpression instanceof OrExpression
|| (whereExpression instanceof XorExpression));
}
public static String getColumnName(Expression leftExpression, Expression rightExpression) {
@@ -789,8 +767,8 @@ public class SqlSelectHelper {
return results;
}
private static void getFieldsWithSubQuery(
PlainSelect plainSelect, Map<String, Set<String>> fields) {
private static void getFieldsWithSubQuery(PlainSelect plainSelect,
Map<String, Set<String>> fields) {
if (plainSelect.getFromItem() instanceof Table) {
List<String> withAlias = new ArrayList<>();
if (!CollectionUtils.isEmpty(plainSelect.getWithItemsList())) {
@@ -807,10 +785,8 @@ public class SqlSelectHelper {
if (!fields.containsKey(table.getFullyQualifiedName())) {
fields.put(tableName, new HashSet<>());
}
List<String> sqlFields =
getFieldsByPlainSelect(plainSelect).stream()
.map(f -> f.replaceAll("`", ""))
.collect(Collectors.toList());
List<String> sqlFields = getFieldsByPlainSelect(plainSelect).stream()
.map(f -> f.replaceAll("`", "")).collect(Collectors.toList());
fields.get(tableName).addAll(sqlFields);
}
}
@@ -826,8 +802,8 @@ public class SqlSelectHelper {
((ParenthesedSelect) join.getRightItem()).getPlainSelect(), fields);
}
if (join.getFromItem() instanceof ParenthesedSelect) {
getFieldsWithSubQuery(
((ParenthesedSelect) join.getFromItem()).getPlainSelect(), fields);
getFieldsWithSubQuery(((ParenthesedSelect) join.getFromItem()).getPlainSelect(),
fields);
}
}
}

View File

@@ -5,4 +5,5 @@ import com.tencent.supersonic.common.persistence.dataobject.SystemConfigDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface SystemConfigMapper extends BaseMapper<SystemConfigDO> {}
public interface SystemConfigMapper extends BaseMapper<SystemConfigDO> {
}

View File

@@ -35,23 +35,15 @@ public class Criterion {
public boolean isNeedApostrophe() {
return Arrays.stream(StringDataType.values())
.filter(value -> this.dataType.equalsIgnoreCase(value.getType()))
.findFirst()
.filter(value -> this.dataType.equalsIgnoreCase(value.getType())).findFirst()
.isPresent();
}
public enum NumericDataType {
TINYINT("TINYINT"),
SMALLINT("SMALLINT"),
MEDIUMINT("MEDIUMINT"),
INT("INT"),
INTEGER("INTEGER"),
BIGINT("BIGINT"),
FLOAT("FLOAT"),
DOUBLE("DOUBLE"),
DECIMAL("DECIMAL"),
NUMERIC("NUMERIC"),
;
TINYINT("TINYINT"), SMALLINT("SMALLINT"), MEDIUMINT("MEDIUMINT"), INT("INT"), INTEGER(
"INTEGER"), BIGINT("BIGINT"), FLOAT(
"FLOAT"), DOUBLE("DOUBLE"), DECIMAL("DECIMAL"), NUMERIC("NUMERIC"),;
private String type;
NumericDataType(String type) {
@@ -64,9 +56,8 @@ public class Criterion {
}
public enum StringDataType {
VARCHAR("VARCHAR"),
STRING("STRING"),
;
VARCHAR("VARCHAR"), STRING("STRING"),;
private String type;
StringDataType(String type) {

View File

@@ -11,8 +11,8 @@ public class DataUpdateEvent extends ApplicationEvent {
private Long id;
private TypeEnums type;
public DataUpdateEvent(
Object source, String name, String newName, Long modelId, Long id, TypeEnums type) {
public DataUpdateEvent(Object source, String name, String newName, Long modelId, Long id,
TypeEnums type) {
super(source);
this.name = name;
this.newName = newName;

View File

@@ -70,10 +70,8 @@ public class DateConf {
return false;
}
DateConf dateConf = (DateConf) o;
return dateMode == dateConf.dateMode
&& Objects.equals(startDate, dateConf.startDate)
&& Objects.equals(endDate, dateConf.endDate)
&& Objects.equals(unit, dateConf.unit)
return dateMode == dateConf.dateMode && Objects.equals(startDate, dateConf.startDate)
&& Objects.equals(endDate, dateConf.endDate) && Objects.equals(unit, dateConf.unit)
&& Objects.equals(period, dateConf.period);
}
@@ -89,11 +87,7 @@ public class DateConf {
* the element, [unit, period] 4 - AVAILABLE, dynamic time which guaranteed to query some
* data, [startDate, endDate] 5 - ALL, all table data
*/
BETWEEN,
LIST,
RECENT,
AVAILABLE,
ALL
BETWEEN, LIST, RECENT, AVAILABLE, ALL
}
@Override

View File

@@ -47,8 +47,6 @@ public class Filter {
}
public enum Relation {
FILTER,
OR,
AND
FILTER, OR, AND
}
}

View File

@@ -13,22 +13,28 @@ import java.util.Map;
/**
* 1.Password Field:
*
* <p>dataType: string name: password require: true/false or any value/empty placeholder: 'Please
* enter the relevant configuration information' value: initial value Text Input Field:
* <p>
* dataType: string name: password require: true/false or any value/empty placeholder: 'Please enter
* the relevant configuration information' value: initial value Text Input Field:
*
* <p>2.dataType: string require: true/false or any value/empty placeholder: 'Please enter the
* relevant configuration information' value: initial value Long Text Input Field:
* <p>
* 2.dataType: string require: true/false or any value/empty placeholder: 'Please enter the relevant
* configuration information' value: initial value Long Text Input Field:
*
* <p>3.dataType: longText require: true/false or any value/empty placeholder: 'Please enter the
* <p>
* 3.dataType: longText require: true/false or any value/empty placeholder: 'Please enter the
* relevant configuration information' value: initial value Number Input Field:
*
* <p>4.dataType: number require: true/false or any value/empty placeholder: 'Please enter the
* relevant configuration information' value: initial value Switch Component:
* <p>
* 4.dataType: number require: true/false or any value/empty placeholder: 'Please enter the relevant
* configuration information' value: initial value Switch Component:
*
* <p>5.dataType: bool require: true/false or any value/empty value: initial value Select Dropdown
* <p>
* 5.dataType: bool require: true/false or any value/empty value: initial value Select Dropdown
* Component:
*
* <p>6.dataType: list candidateValues: ["OPEN_AI", "OLLAMA"] or [{label: 'Model Name 1', value:
* <p>
* 6.dataType: list candidateValues: ["OPEN_AI", "OLLAMA"] or [{label: 'Model Name 1', value:
* 'OPEN_AI'}, {label: 'Model Name 2', value: 'OLLAMA'}] require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information' value: initial value
*/
@@ -43,35 +49,18 @@ public class Parameter {
private List<String> candidateValues;
private List<Dependency> dependencies;
public Parameter(
String name,
String defaultValue,
String comment,
String description,
String dataType,
String module) {
public Parameter(String name, String defaultValue, String comment, String description,
String dataType, String module) {
this(name, defaultValue, comment, description, dataType, module, null, null);
}
public Parameter(
String name,
String defaultValue,
String comment,
String description,
String dataType,
String module,
List<String> candidateValues) {
public Parameter(String name, String defaultValue, String comment, String description,
String dataType, String module, List<String> candidateValues) {
this(name, defaultValue, comment, description, dataType, module, candidateValues, null);
}
public Parameter(
String name,
String defaultValue,
String comment,
String description,
String dataType,
String module,
List<String> candidateValues,
public Parameter(String name, String defaultValue, String comment, String description,
String dataType, String module, List<String> candidateValues,
List<Dependency> dependencies) {
this.name = name;
this.defaultValue = defaultValue;

View File

@@ -9,16 +9,13 @@ public enum AggOperatorEnum {
SUM("SUM"),
COUNT("COUNT"),
COUNT_DISTINCT("COUNT_DISTINCT"),
DISTINCT("DISTINCT"),
COUNT("COUNT"), COUNT_DISTINCT("COUNT_DISTINCT"), DISTINCT("DISTINCT"),
TOPN("TOPN"),
PERCENTILE("PERCENTILE"),
RATIO_ROLL("RATIO_ROLL"),
RATIO_OVER("RATIO_OVER"),
RATIO_ROLL("RATIO_ROLL"), RATIO_OVER("RATIO_OVER"),
UNKNOWN("UNKNOWN");

View File

@@ -1,14 +1,7 @@
package com.tencent.supersonic.common.pojo.enums;
public enum AggregateTypeEnum {
SUM,
AVG,
MAX,
MIN,
TOPN,
DISTINCT,
COUNT,
NONE;
SUM, AVG, MAX, MIN, TOPN, DISTINCT, COUNT, NONE;
public static AggregateTypeEnum of(String agg) {
for (AggregateTypeEnum aggEnum : AggregateTypeEnum.values()) {

View File

@@ -1,7 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
public enum ApiItemType {
METRIC,
TAG,
DIMENSION
METRIC, TAG, DIMENSION
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
public enum AuthType {
VISIBLE,
ADMIN
VISIBLE, ADMIN
}

View File

@@ -1,9 +1,7 @@
package com.tencent.supersonic.common.pojo.enums;
public enum ConfigMode {
DETAIL("DETAIL"),
AGG("AGG"),
UNKNOWN("UNKNOWN");
DETAIL("DETAIL"), AGG("AGG"), UNKNOWN("UNKNOWN");
private String mode;

View File

@@ -1,11 +1,8 @@
package com.tencent.supersonic.common.pojo.enums;
public enum DatePeriodEnum {
DAY(""),
WEEK(""),
MONTH(""),
QUARTER("季度"),
YEAR("");
DAY(""), WEEK(""), MONTH(""), QUARTER("季度"), YEAR("");
private String chName;
DatePeriodEnum(String chName) {

View File

@@ -51,8 +51,7 @@ public enum DictWordType {
return DATASET;
}
// dimension value
if (natures.length == 3
&& StringUtils.isNumeric(natures[1])
if (natures.length == 3 && StringUtils.isNumeric(natures[1])
&& StringUtils.isNumeric(natures[2])) {
return VALUE;
}

View File

@@ -1,14 +1,8 @@
package com.tencent.supersonic.common.pojo.enums;
public enum EngineType {
TDW(0, "tdw"),
MYSQL(1, "mysql"),
DORIS(2, "doris"),
CLICKHOUSE(3, "clickhouse"),
KAFKA(4, "kafka"),
H2(5, "h2"),
POSTGRESQL(6, "postgresql"),
OTHER(7, "other");
TDW(0, "tdw"), MYSQL(1, "mysql"), DORIS(2, "doris"), CLICKHOUSE(3, "clickhouse"), KAFKA(4,
"kafka"), H2(5, "h2"), POSTGRESQL(6, "postgresql"), OTHER(7, "other");
private Integer code;

View File

@@ -1,7 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
public enum EventType {
ADD,
UPDATE,
DELETE
ADD, UPDATE, DELETE
}

View File

@@ -10,20 +10,10 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
public enum FilterOperatorEnum {
IN("IN"),
NOT_IN("NOT_IN"),
EQUALS("="),
BETWEEN("BETWEEN"),
GREATER_THAN(">"),
GREATER_THAN_EQUALS(">="),
IS_NULL("IS_NULL"),
IS_NOT_NULL("IS_NOT_NULL"),
LIKE("LIKE"),
MINOR_THAN("<"),
MINOR_THAN_EQUALS("<="),
NOT_EQUALS("!="),
SQL_PART("SQL_PART"),
EXISTS("EXISTS");
IN("IN"), NOT_IN("NOT_IN"), EQUALS("="), BETWEEN("BETWEEN"), GREATER_THAN(
">"), GREATER_THAN_EQUALS(">="), IS_NULL("IS_NULL"), IS_NOT_NULL("IS_NOT_NULL"), LIKE(
"LIKE"), MINOR_THAN("<"), MINOR_THAN_EQUALS(
"<="), NOT_EQUALS("!="), SQL_PART("SQL_PART"), EXISTS("EXISTS");
private String value;
@@ -48,8 +38,7 @@ public enum FilterOperatorEnum {
}
public static boolean isValueCompare(FilterOperatorEnum filterOperatorEnum) {
return EQUALS.equals(filterOperatorEnum)
|| GREATER_THAN.equals(filterOperatorEnum)
return EQUALS.equals(filterOperatorEnum) || GREATER_THAN.equals(filterOperatorEnum)
|| GREATER_THAN_EQUALS.equals(filterOperatorEnum)
|| MINOR_THAN.equals(filterOperatorEnum)
|| MINOR_THAN_EQUALS.equals(filterOperatorEnum)

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.common.pojo.enums;
public enum PublishEnum {
UN_PUBLISHED(0),
PUBLISHED(1);
UN_PUBLISHED(0), PUBLISHED(1);
private Integer code;

View File

@@ -1,13 +1,8 @@
package com.tencent.supersonic.common.pojo.enums;
public enum RatioOverType {
DAY_ON_DAY("日环比"),
WEEK_ON_DAY("环比"),
WEEK_ON_WEEK("周环比"),
MONTH_ON_WEEK("月环比"),
MONTH_ON_MONTH("月环比"),
YEAR_ON_MONTH("年同比"),
YEAR_ON_YEAR("年环比");
DAY_ON_DAY("日环比"), WEEK_ON_DAY("周环比"), WEEK_ON_WEEK("周环比"), MONTH_ON_WEEK(
"月环比"), MONTH_ON_MONTH("月环比"), YEAR_ON_MONTH("年同比"), YEAR_ON_YEAR("环比");
private String showName;

View File

@@ -1,11 +1,9 @@
package com.tencent.supersonic.common.pojo.enums;
public enum ReturnCode {
SUCCESS(200, "success"),
INVALID_REQUEST(400, "invalid request"),
INVALID_PERMISSION(401, "invalid permission"),
ACCESS_ERROR(403, "access denied"),
SYSTEM_ERROR(500, "system error");
SUCCESS(200, "success"), INVALID_REQUEST(400, "invalid request"), INVALID_PERMISSION(401,
"invalid permission"), ACCESS_ERROR(403,
"access denied"), SYSTEM_ERROR(500, "system error");
private final int code;
private final String message;

View File

@@ -1,9 +1,7 @@
package com.tencent.supersonic.common.pojo.enums;
public enum SensitiveLevelEnum {
LOW(0),
MID(1),
HIGH(2);
LOW(0), MID(1), HIGH(2);
private Integer code;

View File

@@ -1,12 +1,8 @@
package com.tencent.supersonic.common.pojo.enums;
public enum StatusEnum {
INITIALIZED("INITIALIZED", 0),
ONLINE("ONLINE", 1),
OFFLINE("OFFLINE", 2),
DELETED("DELETED", 3),
UNAVAILABLE("UNAVAILABLE", 4),
UNKNOWN("UNKNOWN", -1);
INITIALIZED("INITIALIZED", 0), ONLINE("ONLINE", 1), OFFLINE("OFFLINE", 2), DELETED("DELETED",
3), UNAVAILABLE("UNAVAILABLE", 4), UNKNOWN("UNKNOWN", -1);
private String status;
private Integer code;

View File

@@ -1,9 +1,7 @@
package com.tencent.supersonic.common.pojo.enums;
public enum Text2SQLType {
ONLY_RULE,
ONLY_LLM,
RULE_AND_LLM;
ONLY_RULE, ONLY_LLM, RULE_AND_LLM;
public boolean enableRule() {
return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM);

View File

@@ -31,33 +31,23 @@ public enum TimeDimensionEnum {
}
public static List<String> getNameList() {
return Arrays.stream(TimeDimensionEnum.values())
.map(TimeDimensionEnum::getName)
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName)
.collect(Collectors.toList());
}
public static List<String> getChNameList() {
return Arrays.stream(TimeDimensionEnum.values())
.map(TimeDimensionEnum::getChName)
return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName)
.collect(Collectors.toList());
}
public static Map<String, String> getChNameToNameMap() {
return Arrays.stream(TimeDimensionEnum.values())
.collect(
Collectors.toMap(
TimeDimensionEnum::getChName,
TimeDimensionEnum::getName,
(k1, k2) -> k1));
return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors
.toMap(TimeDimensionEnum::getChName, TimeDimensionEnum::getName, (k1, k2) -> k1));
}
public static Map<String, String> getNameToNameMap() {
return Arrays.stream(TimeDimensionEnum.values())
.collect(
Collectors.toMap(
TimeDimensionEnum::getName,
TimeDimensionEnum::getName,
(k1, k2) -> k1));
return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors
.toMap(TimeDimensionEnum::getName, TimeDimensionEnum::getName, (k1, k2) -> k1));
}
public String getName() {

View File

@@ -1,13 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
public enum TypeEnums {
METRIC,
DIMENSION,
TAG_OBJECT,
TAG,
DOMAIN,
ENTITY,
DATASET,
MODEL,
UNKNOWN
METRIC, DIMENSION, TAG_OBJECT, TAG, DOMAIN, ENTITY, DATASET, MODEL, UNKNOWN
}

View File

@@ -13,7 +13,8 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping({"/api/semantic/parameter"})
public class SystemConfigController {
@Autowired private SystemConfigService sysConfigService;
@Autowired
private SystemConfigService sysConfigService;
@PostMapping
public Boolean save(@RequestBody SystemConfig systemConfig) {

View File

@@ -16,8 +16,8 @@ public interface EmbeddingService {
void deleteQuery(String collectionName, List<TextSegment> queries);
List<RetrieveQueryResult> retrieveQuery(
String collectionName, RetrieveQuery retrieveQuery, int num);
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery,
int num);
void removeAll();
}

View File

@@ -37,11 +37,8 @@ import java.util.stream.Collectors;
@Slf4j
public class EmbeddingServiceImpl implements EmbeddingService {
private Cache<String, Boolean> cache =
CacheBuilder.newBuilder()
.maximumSize(10000)
.expireAfterWrite(10, TimeUnit.HOURS)
.build();
private Cache<String, Boolean> cache = CacheBuilder.newBuilder().maximumSize(10000)
.expireAfterWrite(10, TimeUnit.HOURS).build();
@Override
public void addQuery(String collectionName, List<TextSegment> queries) {
@@ -59,17 +56,14 @@ public class EmbeddingServiceImpl implements EmbeddingService {
embeddingStore.add(embedding, query);
cache.put(TextSegmentConvert.getQueryId(query), true);
} catch (Exception e) {
log.error(
"embeddingModel embed error question: {}, embeddingStore: {}",
question,
embeddingStore.getClass().getSimpleName(),
e);
log.error("embeddingModel embed error question: {}, embeddingStore: {}", question,
embeddingStore.getClass().getSimpleName(), e);
}
}
}
private boolean existSegment(
EmbeddingStore embeddingStore, TextSegment query, Embedding embedding) {
private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query,
Embedding embedding) {
String queryId = TextSegmentConvert.getQueryId(query);
if (queryId == null) {
return false;
@@ -82,13 +76,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put(TextSegmentConvert.QUERY_ID, queryId);
Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request =
EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.filter(filter)
.minScore(1.0d)
.maxResults(1)
.build();
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder().queryEmbedding(embedding)
.filter(filter).minScore(1.0d).maxResults(1).build();
EmbeddingSearchResult result = embeddingStore.search(request);
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
@@ -104,10 +93,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
try {
List<String> queryIds =
queries.stream()
.map(textSegment -> TextSegmentConvert.getQueryId(textSegment))
.filter(Objects::nonNull)
.collect(Collectors.toList());
queries.stream().map(textSegment -> TextSegmentConvert.getQueryId(textSegment))
.filter(Objects::nonNull).collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(queryIds)) {
MetadataFilterBuilder filterBuilder =
new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID);
@@ -122,21 +109,15 @@ public class EmbeddingServiceImpl implements EmbeddingService {
}
@Override
public List<RetrieveQueryResult> retrieveQuery(
String collectionName, RetrieveQuery retrieveQuery, int num) {
public List<RetrieveQueryResult> retrieveQuery(String collectionName,
RetrieveQuery retrieveQuery, int num) {
EmbeddingStore embeddingStore =
EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
Map<String, Object> filterCondition = retrieveQuery.getFilterCondition();
return retrieveQuery.getQueryTextsList().stream()
.map(
queryText ->
retrieveSingleQuery(
queryText,
embeddingModel,
embeddingStore,
filterCondition,
num))
return retrieveQuery
.getQueryTextsList().stream().map(queryText -> retrieveSingleQuery(queryText,
embeddingModel, embeddingStore, filterCondition, num))
.collect(Collectors.toList());
}
@@ -152,28 +133,17 @@ public class EmbeddingServiceImpl implements EmbeddingService {
cache.invalidateAll();
}
private RetrieveQueryResult retrieveSingleQuery(
String queryText,
EmbeddingModel embeddingModel,
EmbeddingStore embeddingStore,
Map<String, Object> filterCondition,
int num) {
private RetrieveQueryResult retrieveSingleQuery(String queryText, EmbeddingModel embeddingModel,
EmbeddingStore embeddingStore, Map<String, Object> filterCondition, int num) {
Embedding embeddedText = embeddingModel.embed(queryText).content();
Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request =
EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText)
.filter(filter)
.maxResults(num)
.build();
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
List<Retrieval> retrievals =
result.matches().stream()
.map(this::convertToRetrieval)
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity))
.limit(num)
.collect(Collectors.toList());
List<Retrieval> retrievals = result.matches().stream().map(this::convertToRetrieval)
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity)).limit(num)
.collect(Collectors.toList());
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText);
@@ -209,10 +179,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
// Create an OR filter for each value in the list
for (String value : (List<String>) fieldValue) {
IsEqualTo equalToFilter = new IsEqualTo(fieldName, value);
fieldFilter =
(fieldFilter == null)
? equalToFilter
: Filter.or(fieldFilter, equalToFilter);
fieldFilter = (fieldFilter == null) ? equalToFilter
: Filter.or(fieldFilter, equalToFilter);
}
} else if (fieldValue instanceof String) {
// Create a simple equality filter
@@ -220,10 +188,8 @@ public class EmbeddingServiceImpl implements EmbeddingService {
}
// Combine the current field filter with the overall filter using AND logic
if (fieldFilter != null) {
combinedFilter =
(combinedFilter == null)
? fieldFilter
: Filter.and(combinedFilter, fieldFilter);
combinedFilter = (combinedFilter == null) ? fieldFilter
: Filter.and(combinedFilter, fieldFilter);
}
}
return combinedFilter;

View File

@@ -35,14 +35,15 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper();
@Autowired private EmbeddingConfig embeddingConfig;
@Autowired
private EmbeddingConfig embeddingConfig;
@Autowired private EmbeddingService embeddingService;
@Autowired
private EmbeddingService embeddingService;
public void storeExemplar(String collection, Text2SQLExemplar exemplar) {
Metadata metadata =
Metadata.from(
JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class));
Metadata metadata = Metadata
.from(JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class));
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
TextSegmentConvert.addQueryId(segment, exemplar.getQuestion());
@@ -50,9 +51,8 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
}
public void removeExemplar(String collection, Text2SQLExemplar exemplar) {
Metadata metadata =
Metadata.from(
JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class));
Metadata metadata = Metadata
.from(JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class));
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
TextSegmentConvert.addQueryId(segment, exemplar.getQuestion());
@@ -70,18 +70,11 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
RetrieveQuery.builder().queryTextsList(Lists.newArrayList(query)).build();
List<RetrieveQueryResult> results =
embeddingService.retrieveQuery(collection, retrieveQuery, num);
results.stream()
.forEach(
ret -> {
ret.getRetrieval().stream()
.forEach(
r -> {
exemplars.add(
JsonUtil.mapToObject(
r.getMetadata(),
Text2SQLExemplar.class));
});
});
results.stream().forEach(ret -> {
ret.getRetrieval().stream().forEach(r -> {
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
});
});
return exemplars;
}

View File

@@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicReference;
public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, SystemConfigDO>
implements SystemConfigService {
@Autowired private Environment environment;
@Autowired
private Environment environment;
// Cache field to store the system configuration
private AtomicReference<SystemConfig> cachedSystemConfig = new AtomicReference<>();
@@ -44,13 +45,11 @@ public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, Sys
systemConfig.setId(1);
systemConfig.init();
// use system property to initialize system parameter
systemConfig.getParameters().stream()
.forEach(
p -> {
if (environment.containsProperty(p.getName())) {
p.setValue(environment.getProperty(p.getName()));
}
});
systemConfig.getParameters().stream().forEach(p -> {
if (environment.containsProperty(p.getName())) {
p.setValue(environment.getProperty(p.getName()));
}
});
save(systemConfig);
return systemConfig;
}
@@ -68,9 +67,8 @@ public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, Sys
private SystemConfig convert(SystemConfigDO systemConfigDO) {
SystemConfig sysParameter = new SystemConfig();
sysParameter.setId(systemConfigDO.getId());
List<Parameter> parameters =
JsonUtil.toObject(
systemConfigDO.getParameters(), new TypeReference<List<Parameter>>() {});
List<Parameter> parameters = JsonUtil.toObject(systemConfigDO.getParameters(),
new TypeReference<List<Parameter>>() {});
sysParameter.setParameters(parameters);
sysParameter.setAdminList(systemConfigDO.getAdmin());
return sysParameter;

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.common.util;
import lombok.extern.slf4j.Slf4j;
import javax.crypto.Cipher;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import lombok.extern.slf4j.Slf4j;
import java.security.MessageDigest;
import java.security.spec.KeySpec;
import java.util.Arrays;
@@ -121,10 +120,8 @@ public class AESEncryptionUtil {
int len = hexString.length();
byte[] byteArray = new byte[len / 2];
for (int i = 0; i < len; i += 2) {
byteArray[i / 2] =
(byte)
((Character.digit(hexString.charAt(i), 16) << 4)
+ Character.digit(hexString.charAt(i + 1), 16));
byteArray[i / 2] = (byte) ((Character.digit(hexString.charAt(i), 16) << 4)
+ Character.digit(hexString.charAt(i + 1), 16));
}
return byteArray;
}

View File

@@ -62,12 +62,10 @@ public class DateModeUtils {
* @return
*/
public String hasDataModeStr(ItemDateResp dateDate, DateConf dateInfo) {
if (Objects.isNull(dateDate)
|| StringUtils.isEmpty(dateDate.getStartDate())
if (Objects.isNull(dateDate) || StringUtils.isEmpty(dateDate.getStartDate())
|| StringUtils.isEmpty(dateDate.getStartDate())) {
return String.format(
"(%s >= '%s' and %s <= '%s')",
sysDateCol, dateInfo.getStartDate(), sysDateCol, dateInfo.getEndDate());
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(),
sysDateCol, dateInfo.getEndDate());
} else {
log.info("dateDate:{}", dateDate);
}
@@ -81,31 +79,22 @@ public class DateModeUtils {
if (endReq.isAfter(endData)) {
if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) {
Long unit =
getInterval(
dateInfo.getStartDate(),
dateInfo.getEndDate(),
dateFormatStr,
ChronoUnit.DAYS);
Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(),
dateFormatStr, ChronoUnit.DAYS);
LocalDate dateMax = endData;
LocalDate dateMin = dateMax.minusDays(unit - 1);
return String.format(
"(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, dateMax);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol,
dateMax);
}
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
Long unit =
getInterval(
dateInfo.getStartDate(),
dateInfo.getEndDate(),
dateFormatStr,
ChronoUnit.MONTHS);
Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(),
dateFormatStr, ChronoUnit.MONTHS);
return generateMonthSql(endData, unit, dateFormatStr);
}
}
return String.format(
"(%s >= '%s' and %s <= '%s')",
sysDateCol, dateInfo.getStartDate(), sysDateCol, dateInfo.getEndDate());
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(),
sysDateCol, dateInfo.getEndDate());
}
public String generateMonthSql(LocalDate endData, Long unit, String dateFormatStr) {
@@ -131,9 +120,8 @@ public class DateModeUtils {
public String recentDayStr(ItemDateResp dateDate, DateConf dateInfo) {
ImmutablePair<String, String> dayRange = recentDay(dateDate, dateInfo);
return String.format(
"(%s >= '%s' and %s <= '%s')",
sysDateCol, dayRange.left, sysDateCol, dayRange.right);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dayRange.left, sysDateCol,
dayRange.right);
}
public ImmutablePair<String, String> recentDay(ItemDateResp dateDate, DateConf dateInfo) {
@@ -143,7 +131,7 @@ public class DateModeUtils {
}
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormatStr);
LocalDate end = LocalDate.parse(dateDate.getEndDate(), formatter);
// todo unavailableDateList logic
// todo unavailableDateList logic
Integer unit = dateInfo.getUnit() - 1;
String start = end.minusDays(unit).format(formatter);
@@ -154,16 +142,15 @@ public class DateModeUtils {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormatStr);
String endStr = endData.format(formatter);
String start = endData.minusMonths(unit).format(formatter);
return String.format(
"(%s >= '%s' and %s <= '%s')", sysDateMonthCol, start, sysDateMonthCol, endStr);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, start, sysDateMonthCol,
endStr);
}
public String recentMonthStr(ItemDateResp dateDate, DateConf dateInfo) {
List<ImmutablePair<String, String>> range = recentMonth(dateDate, dateInfo);
if (range.size() == 1) {
return String.format(
"(%s >= '%s' and %s <= '%s')",
sysDateMonthCol, range.get(0).left, sysDateMonthCol, range.get(0).right);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, range.get(0).left,
sysDateMonthCol, range.get(0).right);
}
if (range.size() > 0) {
StringJoiner joiner = new StringJoiner(",");
@@ -173,21 +160,15 @@ public class DateModeUtils {
return "";
}
public List<ImmutablePair<String, String>> recentMonth(
ItemDateResp dateDate, DateConf dateInfo) {
LocalDate endData =
LocalDate.parse(
dateDate.getEndDate(),
DateTimeFormatter.ofPattern(dateDate.getDateFormat()));
public List<ImmutablePair<String, String>> recentMonth(ItemDateResp dateDate,
DateConf dateInfo) {
LocalDate endData = LocalDate.parse(dateDate.getEndDate(),
DateTimeFormatter.ofPattern(dateDate.getDateFormat()));
List<ImmutablePair<String, String>> ret = new ArrayList<>();
if (dateDate.getDatePeriod() != null
&& DatePeriodEnum.MONTH.equals(dateDate.getDatePeriod())) {
Long unit =
getInterval(
dateInfo.getStartDate(),
dateInfo.getEndDate(),
dateDate.getDateFormat(),
ChronoUnit.MONTHS);
Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(),
dateDate.getDateFormat(), ChronoUnit.MONTHS);
LocalDate dateMax = endData;
List<String> months = generateMonthStr(dateMax, unit, dateDate.getDateFormat());
if (!CollectionUtils.isEmpty(months)) {
@@ -207,16 +188,14 @@ public class DateModeUtils {
public String recentWeekStr(LocalDate endData, Long unit) {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DAY_FORMAT);
String start = endData.minusDays(unit * 7).format(formatter);
return String.format(
"(%s >= '%s' and %s <= '%s')",
sysDateWeekCol, start, sysDateWeekCol, endData.format(formatter));
return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, start, sysDateWeekCol,
endData.format(formatter));
}
public String recentWeekStr(ItemDateResp dateDate, DateConf dateInfo) {
ImmutablePair<String, String> dayRange = recentWeek(dateDate, dateInfo);
return String.format(
"(%s >= '%s' and %s <= '%s')",
sysDateWeekCol, dayRange.left, sysDateWeekCol, dayRange.right);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, dayRange.left,
sysDateWeekCol, dayRange.right);
}
public ImmutablePair<String, String> recentWeek(ItemDateResp dateDate, DateConf dateInfo) {
@@ -231,8 +210,8 @@ public class DateModeUtils {
return ImmutablePair.of(start, end.format(formatter));
}
private Long getInterval(
String startDate, String endDate, String dateFormat, ChronoUnit chronoUnit) {
private Long getInterval(String startDate, String endDate, String dateFormat,
ChronoUnit chronoUnit) {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormat);
try {
LocalDate start = LocalDate.parse(startDate, formatter);
@@ -270,34 +249,23 @@ public class DateModeUtils {
if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) {
// startDate YYYYMM
if (!dateInfo.getStartDate().contains(Constants.MINUS)) {
return String.format(
"%s >= '%s' and %s <= '%s'",
sysDateMonthCol,
dateInfo.getStartDate(),
sysDateMonthCol,
dateInfo.getEndDate());
return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol,
dateInfo.getStartDate(), sysDateMonthCol, dateInfo.getEndDate());
}
LocalDate endData =
LocalDate.parse(dateInfo.getEndDate(), DateTimeFormatter.ofPattern(DAY_FORMAT));
LocalDate startData =
LocalDate.parse(
dateInfo.getStartDate(), DateTimeFormatter.ofPattern(DAY_FORMAT));
LocalDate startData = LocalDate.parse(dateInfo.getStartDate(),
DateTimeFormatter.ofPattern(DAY_FORMAT));
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT);
return String.format(
"%s >= '%s' and %s <= '%s'",
sysDateMonthCol,
startData.format(formatter),
sysDateMonthCol,
endData.format(formatter));
return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol,
startData.format(formatter), sysDateMonthCol, endData.format(formatter));
}
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
return String.format(
"%s >= '%s' and %s <= '%s'",
sysDateWeekCol, dateInfo.getStartDate(), sysDateWeekCol, dateInfo.getEndDate());
return String.format("%s >= '%s' and %s <= '%s'", sysDateWeekCol,
dateInfo.getStartDate(), sysDateWeekCol, dateInfo.getEndDate());
}
return String.format(
"%s >= '%s' and %s <= '%s'",
sysDateCol, dateInfo.getStartDate(), sysDateCol, dateInfo.getEndDate());
return String.format("%s >= '%s' and %s <= '%s'", sysDateCol, dateInfo.getStartDate(),
sysDateCol, dateInfo.getEndDate());
}
/**
@@ -335,8 +303,8 @@ public class DateModeUtils {
if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) {
LocalDate dateMax = LocalDate.now().minusDays(1);
LocalDate dateMin = dateMax.minusDays(unit - 1);
return String.format(
"(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, dateMax);
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol,
dateMax);
}
if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) {
@@ -352,9 +320,8 @@ public class DateModeUtils {
return recentMonthStr(dateMax, unit.longValue() * 12, MONTH_FORMAT);
}
return String.format(
"(%s >= '%s' and %s <= '%s')",
sysDateCol, LocalDate.now().minusDays(2), sysDateCol, LocalDate.now().minusDays(1));
return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol,
LocalDate.now().minusDays(2), sysDateCol, LocalDate.now().minusDays(1));
}
public String getDateWhereStr(DateConf dateInfo) {

View File

@@ -90,8 +90,8 @@ public class DateUtils {
return startDate.format(DEFAULT_DATE_FORMATTER2);
}
public static String getBeforeDate(
String currentDate, int intervalDay, DatePeriodEnum datePeriodEnum) {
public static String getBeforeDate(String currentDate, int intervalDay,
DatePeriodEnum datePeriodEnum) {
LocalDate specifiedDate = LocalDate.parse(currentDate, DEFAULT_DATE_FORMATTER2);
LocalDate result = null;
switch (datePeriodEnum) {
@@ -101,9 +101,8 @@ public class DateUtils {
case WEEK:
result = specifiedDate.minusWeeks(intervalDay);
if (intervalDay == 0) {
result =
result.with(
TemporalAdjusters.previousOrSame(java.time.DayOfWeek.MONDAY));
result = result
.with(TemporalAdjusters.previousOrSame(java.time.DayOfWeek.MONDAY));
}
break;
case MONTH:
@@ -115,14 +114,13 @@ public class DateUtils {
case QUARTER:
result = specifiedDate.minusMonths(intervalDay * 3L);
if (intervalDay == 0) {
TemporalAdjuster firstDayOfQuarter =
temporal -> {
LocalDate tempDate = LocalDate.from(temporal);
int month = tempDate.get(ChronoField.MONTH_OF_YEAR);
int firstMonthOfQuarter = ((month - 1) / 3) * 3 + 1;
return tempDate.with(ChronoField.MONTH_OF_YEAR, firstMonthOfQuarter)
.with(TemporalAdjusters.firstDayOfMonth());
};
TemporalAdjuster firstDayOfQuarter = temporal -> {
LocalDate tempDate = LocalDate.from(temporal);
int month = tempDate.get(ChronoField.MONTH_OF_YEAR);
int firstMonthOfQuarter = ((month - 1) / 3) * 3 + 1;
return tempDate.with(ChronoField.MONTH_OF_YEAR, firstMonthOfQuarter)
.with(TemporalAdjusters.firstDayOfMonth());
};
result = result.with(firstDayOfQuarter);
}
break;
@@ -162,8 +160,8 @@ public class DateUtils {
return !timeString.equals("00:00:00");
}
public static List<String> getDateList(
String startDateStr, String endDateStr, DatePeriodEnum period) {
public static List<String> getDateList(String startDateStr, String endDateStr,
DatePeriodEnum period) {
try {
LocalDate startDate = LocalDate.parse(startDateStr);
LocalDate endDate = LocalDate.parse(endDateStr);

View File

@@ -22,12 +22,8 @@ public class FileUtils {
return -1;
}
File file = new File(path);
Optional<Long> lastModified =
Arrays.stream(file.listFiles())
.filter(f -> f.isFile())
.map(f -> f.lastModified())
.sorted(Collections.reverseOrder())
.findFirst();
Optional<Long> lastModified = Arrays.stream(file.listFiles()).filter(f -> f.isFile())
.map(f -> f.lastModified()).sorted(Collections.reverseOrder()).findFirst();
if (lastModified.isPresent()) {
return lastModified.get();
@@ -42,8 +38,8 @@ public class FileUtils {
return null;
}
public static void scanDirectory(
File file, int maxLevel, Map<Integer, List<File>> directories) {
public static void scanDirectory(File file, int maxLevel,
Map<Integer, List<File>> directories) {
if (maxLevel < 0) {
return;
}

View File

@@ -77,34 +77,25 @@ public class HttpClientUtils {
private static void init() {
try {
SSLConnectionSocketFactory sslConnectionSocketFactory =
new SSLConnectionSocketFactory(
SSLContexts.custom()
.loadTrustMaterial((chain, authType) -> true)
.build(),
new String[] {"SSLv2Hello", "SSLv3", "TLSv1", "TLSv1.1", "TLSv1.2"},
null,
NoopHostnameVerifier.INSTANCE);
SSLConnectionSocketFactory sslConnectionSocketFactory = new SSLConnectionSocketFactory(
SSLContexts.custom().loadTrustMaterial((chain, authType) -> true).build(),
new String[] {"SSLv2Hello", "SSLv3", "TLSv1", "TLSv1.1", "TLSv1.2"}, null,
NoopHostnameVerifier.INSTANCE);
PoolingHttpClientConnectionManager connManager =
new PoolingHttpClientConnectionManager(
RegistryBuilder.<ConnectionSocketFactory>create()
.register(
"http", PlainConnectionSocketFactory.getSocketFactory())
.register("https", sslConnectionSocketFactory)
.build());
PoolingHttpClientConnectionManager connManager = new PoolingHttpClientConnectionManager(
RegistryBuilder.<ConnectionSocketFactory>create()
.register("http", PlainConnectionSocketFactory.getSocketFactory())
.register("https", sslConnectionSocketFactory).build());
connManager.setMaxTotal(DEFAULT_MAX_TOTAL_CONN);
connManager.setDefaultMaxPerRoute(DEFAULT_MAX_CONN_PERHOST);
RequestConfig requestConfig =
RequestConfig.custom()
// 请求超时时间
.setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT)
// 等待数据超时时间
.setSocketTimeout(DEFAULT_READ_TIMEOUT)
// 连接不够用时等待超时时间
.setConnectionRequestTimeout(DEFAULT_CONN_REQUEST_TIMEOUT)
.build();
RequestConfig requestConfig = RequestConfig.custom()
// 请求超时时间
.setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT)
// 等待数据超时时间
.setSocketTimeout(DEFAULT_READ_TIMEOUT)
// 连接不够用时等待超时时间
.setConnectionRequestTimeout(DEFAULT_CONN_REQUEST_TIMEOUT).build();
HttpRequestRetryHandler httpRequestRetryHandler =
(exception, executionCount, context) -> {
@@ -116,49 +107,39 @@ public class HttpClientUtils {
}
if (exception instanceof NoHttpResponseException) {
// 如果服务器丢掉了连接,那么就重试
log.warn(
"Retry, No response from server on {} error: {}",
executionCount,
exception.getMessage());
log.warn("Retry, No response from server on {} error: {}",
executionCount, exception.getMessage());
return true;
} else if (exception instanceof SocketException) {
// 如果服务器断开了连接,那么就重试
log.warn(
"Retry, No connection from server on {} error: {}",
executionCount,
exception.getMessage());
log.warn("Retry, No connection from server on {} error: {}",
executionCount, exception.getMessage());
return true;
}
return false;
};
httpClient =
HttpClients.custom()
// 设置连接池
.setConnectionManager(connManager)
// 设置超时时间
.setDefaultRequestConfig(requestConfig)
// 设置连接存活时间
.setKeepAliveStrategy(
new DefaultConnectionKeepAliveStrategy() {
@Override
public long getKeepAliveDuration(
final HttpResponse response,
final HttpContext context) {
long keepAlive =
super.getKeepAliveDuration(response, context);
if (keepAlive == -1) {
keepAlive = 5000;
}
return keepAlive;
}
})
.setRetryHandler(httpRequestRetryHandler)
// 设置连接存活时间
.setConnectionTimeToLive(5000L, TimeUnit.MILLISECONDS)
// 关闭无效和空闲的连接
.evictIdleConnections(5L, TimeUnit.SECONDS)
.build();
httpClient = HttpClients.custom()
// 设置连接池
.setConnectionManager(connManager)
// 设置超时时间
.setDefaultRequestConfig(requestConfig)
// 设置连接存活时间
.setKeepAliveStrategy(new DefaultConnectionKeepAliveStrategy() {
@Override
public long getKeepAliveDuration(final HttpResponse response,
final HttpContext context) {
long keepAlive = super.getKeepAliveDuration(response, context);
if (keepAlive == -1) {
keepAlive = 5000;
}
return keepAlive;
}
}).setRetryHandler(httpRequestRetryHandler)
// 设置连接存活时间
.setConnectionTimeToLive(5000L, TimeUnit.MILLISECONDS)
// 关闭无效和空闲的连接
.evictIdleConnections(5L, TimeUnit.SECONDS).build();
} catch (Exception e) {
log.error(e.getMessage(), e);
throw new RuntimeException(e);
@@ -193,45 +174,34 @@ public class HttpClientUtils {
*
* @return
*/
public static HttpClientResult doPost(
String url,
String proxyHost,
Integer proxyPort,
Map<String, String> headers,
Map<String, String> params) {
return RetryUtils.exec(
() -> {
HttpPost httpPost = null;
CloseableHttpResponse response = null;
try {
httpPost = new HttpPost(url);
setProxy(httpPost, proxyHost, proxyPort);
public static HttpClientResult doPost(String url, String proxyHost, Integer proxyPort,
Map<String, String> headers, Map<String, String> params) {
return RetryUtils.exec(() -> {
HttpPost httpPost = null;
CloseableHttpResponse response = null;
try {
httpPost = new HttpPost(url);
setProxy(httpPost, proxyHost, proxyPort);
// 封装header参数
packageHeader(headers, httpPost);
// 封装请求参数
packageParam(params, httpPost);
// 封装header参数
packageHeader(headers, httpPost);
// 封装请求参数
packageParam(params, httpPost);
response = httpClient.execute(httpPost);
// 获取返回结果
HttpClientResult result = getHttpClientResult(response);
log.info(
"uri:{}, req:{}, resp:{}",
url,
"headers:" + getHeaders(httpPost) + "------params:" + params,
result);
return result;
} catch (Exception e) {
log.error(
"uri:{}, req:{}",
url,
"headers:" + headers + "------params:" + params,
e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpPost, response);
}
});
response = httpClient.execute(httpPost);
// 获取返回结果
HttpClientResult result = getHttpClientResult(response);
log.info("uri:{}, req:{}, resp:{}", url,
"headers:" + getHeaders(httpPost) + "------params:" + params, result);
return result;
} catch (Exception e) {
log.error("uri:{}, req:{}", url, "headers:" + headers + "------params:" + params,
e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpPost, response);
}
});
}
/**
@@ -242,8 +212,8 @@ public class HttpClientUtils {
* @param params
* @return
*/
public static HttpClientResult doPost(
String url, Map<String, String> header, Map<String, String> params) {
public static HttpClientResult doPost(String url, Map<String, String> header,
Map<String, String> params) {
return doPost(url, null, null, header, params);
}
@@ -279,53 +249,42 @@ public class HttpClientUtils {
* @return
* @throws Exception
*/
public static HttpClientResult doGet(
String url,
String proxyHost,
Integer proxyPort,
Map<String, String> headers,
Map<String, String> params) {
return RetryUtils.exec(
() -> {
HttpGet httpGet = null;
CloseableHttpResponse response = null;
try {
// 创建访问的地址
URIBuilder uriBuilder = new URIBuilder(url);
if (params != null) {
Set<Map.Entry<String, String>> entrySet = params.entrySet();
for (Map.Entry<String, String> entry : entrySet) {
uriBuilder.setParameter(entry.getKey(), entry.getValue());
}
}
httpGet = new HttpGet(uriBuilder.build());
setProxy(httpGet, proxyHost, proxyPort);
// 设置请求头
packageHeader(headers, httpGet);
response = httpClient.execute(httpGet);
// 获取返回结果
HttpClientResult res = getHttpClientResult(response);
log.debug(
"GET uri:{}, req:{}, resp:{}",
url,
"headers:" + getHeaders(httpGet) + "------params:" + params,
res);
return res;
} catch (Exception e) {
log.error(
"GET error! uri:{}, req:{}",
url,
"headers:" + headers + "------params:" + params,
e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpGet, response);
public static HttpClientResult doGet(String url, String proxyHost, Integer proxyPort,
Map<String, String> headers, Map<String, String> params) {
return RetryUtils.exec(() -> {
HttpGet httpGet = null;
CloseableHttpResponse response = null;
try {
// 创建访问的地址
URIBuilder uriBuilder = new URIBuilder(url);
if (params != null) {
Set<Map.Entry<String, String>> entrySet = params.entrySet();
for (Map.Entry<String, String> entry : entrySet) {
uriBuilder.setParameter(entry.getKey(), entry.getValue());
}
});
}
httpGet = new HttpGet(uriBuilder.build());
setProxy(httpGet, proxyHost, proxyPort);
// 设置请求头
packageHeader(headers, httpGet);
response = httpClient.execute(httpGet);
// 获取返回结果
HttpClientResult res = getHttpClientResult(response);
log.debug("GET uri:{}, req:{}, resp:{}", url,
"headers:" + getHeaders(httpGet) + "------params:" + params, res);
return res;
} catch (Exception e) {
log.error("GET error! uri:{}, req:{}", url,
"headers:" + headers + "------params:" + params, e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpGet, response);
}
});
}
/**
@@ -336,8 +295,8 @@ public class HttpClientUtils {
* @param params
* @return
*/
public static HttpClientResult doGet(
String url, Map<String, String> header, Map<String, String> params) {
public static HttpClientResult doGet(String url, Map<String, String> header,
Map<String, String> params) {
return doGet(url, null, null, header, params);
}
@@ -399,9 +358,8 @@ public class HttpClientUtils {
* @param httpMethod
* @throws UnsupportedEncodingException
*/
public static void packageParam(
Map<String, String> params, HttpEntityEnclosingRequestBase httpMethod)
throws UnsupportedEncodingException {
public static void packageParam(Map<String, String> params,
HttpEntityEnclosingRequestBase httpMethod) throws UnsupportedEncodingException {
if (params != null) {
List<NameValuePair> nvps = new ArrayList<NameValuePair>();
Set<Map.Entry<String, String>> entrySet = params.entrySet();
@@ -416,13 +374,9 @@ public class HttpClientUtils {
public static void setProxy(HttpRequestBase httpMethod, String proxyHost, Integer proxyPort) {
if (!StringUtils.isEmpty(proxyHost) && proxyPort != null) {
RequestConfig config =
RequestConfig.custom()
.setProxy(new HttpHost(proxyHost, proxyPort))
.setConnectTimeout(10000)
.setSocketTimeout(10000)
.setConnectionRequestTimeout(3000)
.build();
RequestConfig config = RequestConfig.custom()
.setProxy(new HttpHost(proxyHost, proxyPort)).setConnectTimeout(10000)
.setSocketTimeout(10000).setConnectionRequestTimeout(3000).build();
httpMethod.setConfig(config);
}
}
@@ -437,50 +391,39 @@ public class HttpClientUtils {
*
* @return
*/
public static HttpClientResult doPostJSON(
String url,
String proxyHost,
Integer proxyPort,
Map<String, String> headers,
String req) {
return RetryUtils.exec(
() -> {
HttpPost httpPost = null;
CloseableHttpResponse response = null;
try {
httpPost = new HttpPost(url);
setProxy(httpPost, proxyHost, proxyPort);
public static HttpClientResult doPostJSON(String url, String proxyHost, Integer proxyPort,
Map<String, String> headers, String req) {
return RetryUtils.exec(() -> {
HttpPost httpPost = null;
CloseableHttpResponse response = null;
try {
httpPost = new HttpPost(url);
setProxy(httpPost, proxyHost, proxyPort);
// 封装header参数
packageHeader(headers, httpPost);
httpPost.setHeader("Content-Type", "application/json;charset=UTF-8");
// 封装header参数
packageHeader(headers, httpPost);
httpPost.setHeader("Content-Type", "application/json;charset=UTF-8");
// 封装请求参数
StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题
stringEntity.setContentEncoding("UTF-8");
// 封装请求参数
StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题
stringEntity.setContentEncoding("UTF-8");
httpPost.setEntity(stringEntity);
httpPost.setEntity(stringEntity);
response = httpClient.execute(httpPost);
// 获取返回结果
HttpClientResult res = getHttpClientResult(response);
log.info(
"doPostJSON uri:{}, req:{}, resp:{}",
url,
"headers:" + getHeaders(httpPost) + "------req:" + req,
res);
return res;
} catch (Exception e) {
log.error(
"doPostJSON error! uri:{}, req:{}",
url,
"headers:" + headers + "------req:" + req,
e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpPost, response);
}
});
response = httpClient.execute(httpPost);
// 获取返回结果
HttpClientResult res = getHttpClientResult(response);
log.info("doPostJSON uri:{}, req:{}, resp:{}", url,
"headers:" + getHeaders(httpPost) + "------req:" + req, res);
return res;
} catch (Exception e) {
log.error("doPostJSON error! uri:{}, req:{}", url,
"headers:" + headers + "------req:" + req, e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpPost, response);
}
});
}
public static HttpClientResult doPostJSON(String url, String req) {
@@ -488,56 +431,45 @@ public class HttpClientUtils {
}
/** get json */
public static HttpClientResult doGetJSON(
String url,
String proxyHost,
Integer proxyPort,
Map<String, String> headers,
Map<String, String> params) {
return RetryUtils.exec(
() -> {
HttpGet httpGet = null;
CloseableHttpResponse response = null;
try {
// 创建访问的地址
URIBuilder uriBuilder = new URIBuilder(url);
if (params != null) {
Set<Map.Entry<String, String>> entrySet = params.entrySet();
for (Map.Entry<String, String> entry : entrySet) {
uriBuilder.setParameter(entry.getKey(), entry.getValue());
}
}
httpGet = new HttpGet(uriBuilder.build());
setProxy(httpGet, proxyHost, proxyPort);
// 设置请求头
packageHeader(headers, httpGet);
httpGet.setHeader("Content-Type", "application/json;charset=UTF-8");
response = httpClient.execute(httpGet);
// 获取返回结果
HttpClientResult res = getHttpClientResult(response);
log.info(
"doGetJSON uri:{}, req:{}, resp:{}",
url,
"headers:" + getHeaders(httpGet) + "------params:" + params,
res);
return res;
} catch (Exception e) {
log.warn(
"doGetJSON error! uri:{}, req:{}",
url,
"headers:" + headers + "------params:" + params,
e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpGet, response);
public static HttpClientResult doGetJSON(String url, String proxyHost, Integer proxyPort,
Map<String, String> headers, Map<String, String> params) {
return RetryUtils.exec(() -> {
HttpGet httpGet = null;
CloseableHttpResponse response = null;
try {
// 创建访问的地址
URIBuilder uriBuilder = new URIBuilder(url);
if (params != null) {
Set<Map.Entry<String, String>> entrySet = params.entrySet();
for (Map.Entry<String, String> entry : entrySet) {
uriBuilder.setParameter(entry.getKey(), entry.getValue());
}
});
}
httpGet = new HttpGet(uriBuilder.build());
setProxy(httpGet, proxyHost, proxyPort);
// 设置请求头
packageHeader(headers, httpGet);
httpGet.setHeader("Content-Type", "application/json;charset=UTF-8");
response = httpClient.execute(httpGet);
// 获取返回结果
HttpClientResult res = getHttpClientResult(response);
log.info("doGetJSON uri:{}, req:{}, resp:{}", url,
"headers:" + getHeaders(httpGet) + "------params:" + params, res);
return res;
} catch (Exception e) {
log.warn("doGetJSON error! uri:{}, req:{}", url,
"headers:" + headers + "------params:" + params, e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpGet, response);
}
});
}
private static HttpClientResult getHttpClientResult(CloseableHttpResponse response)
@@ -564,82 +496,63 @@ public class HttpClientUtils {
* @param fullFilePath
* @return
*/
public static HttpClientResult doFileUploadBodyParams(
String url,
Map<String, String> headers,
Map<String, String> bodyParams,
String fullFilePath) {
public static HttpClientResult doFileUploadBodyParams(String url, Map<String, String> headers,
Map<String, String> bodyParams, String fullFilePath) {
return doFileUpload(url, null, null, headers, null, bodyParams, fullFilePath);
}
public static HttpClientResult doFileUpload(
String url,
String proxyHost,
Integer proxyPort,
Map<String, String> headers,
Map<String, String> params,
Map<String, String> bodyParams,
public static HttpClientResult doFileUpload(String url, String proxyHost, Integer proxyPort,
Map<String, String> headers, Map<String, String> params, Map<String, String> bodyParams,
String fullFilePath) {
return RetryUtils.exec(
() -> {
InputStream inputStream = null;
CloseableHttpResponse response = null;
HttpPost httpPost = null;
try {
return RetryUtils.exec(() -> {
InputStream inputStream = null;
CloseableHttpResponse response = null;
HttpPost httpPost = null;
try {
File uploadFile = new File(fullFilePath);
inputStream = new FileInputStream(uploadFile);
File uploadFile = new File(fullFilePath);
inputStream = new FileInputStream(uploadFile);
httpPost = new HttpPost(url);
setProxy(httpPost, proxyHost, proxyPort);
httpPost = new HttpPost(url);
setProxy(httpPost, proxyHost, proxyPort);
packageHeader(headers, httpPost);
packageHeader(headers, httpPost);
HttpEntity entity =
getFileUploadHttpEntity(
params, bodyParams, inputStream, uploadFile.getName());
httpPost.setEntity(entity);
HttpEntity entity = getFileUploadHttpEntity(params, bodyParams, inputStream,
uploadFile.getName());
httpPost.setEntity(entity);
response = httpClient.execute(httpPost);
// 执行请求并获得响应结果
HttpClientResult res = getHttpClientResult(response);
log.info(
"doFileUpload uri:{}, req:{}, resp:{}",
url,
"params:" + params + ", fullFilePath:" + fullFilePath,
res);
return res;
} catch (Exception e) {
log.error(
"doFileUpload error! uri:{}, req:{}",
url,
"params:" + params + ", fullFilePath:" + fullFilePath,
e);
throw new RuntimeException(e.getMessage());
} finally {
try {
if (null != inputStream) {
inputStream.close();
}
// 释放资源
close(httpPost, response);
} catch (IOException e) {
log.error("HttpClientUtils release error!", e);
}
response = httpClient.execute(httpPost);
// 执行请求并获得响应结果
HttpClientResult res = getHttpClientResult(response);
log.info("doFileUpload uri:{}, req:{}, resp:{}", url,
"params:" + params + ", fullFilePath:" + fullFilePath, res);
return res;
} catch (Exception e) {
log.error("doFileUpload error! uri:{}, req:{}", url,
"params:" + params + ", fullFilePath:" + fullFilePath, e);
throw new RuntimeException(e.getMessage());
} finally {
try {
if (null != inputStream) {
inputStream.close();
}
});
// 释放资源
close(httpPost, response);
} catch (IOException e) {
log.error("HttpClientUtils release error!", e);
}
}
});
}
private static HttpEntity getFileUploadHttpEntity(
Map<String, String> params,
Map<String, String> bodyParams,
InputStream inputStream,
String fileName)
private static HttpEntity getFileUploadHttpEntity(Map<String, String> params,
Map<String, String> bodyParams, InputStream inputStream, String fileName)
throws UnsupportedEncodingException {
MultipartEntityBuilder builder = MultipartEntityBuilder.create();
builder.setMode(HttpMultipartMode.BROWSER_COMPATIBLE);
builder.addBinaryBody(
"file", inputStream, ContentType.create("multipart/form-data"), fileName);
builder.addBinaryBody("file", inputStream, ContentType.create("multipart/form-data"),
fileName);
if (!CollectionUtils.isEmpty(bodyParams)) {
for (String bodyParamsKey : bodyParams.keySet()) {
@@ -649,8 +562,7 @@ public class HttpClientUtils {
// 构建请求参数 普通表单项
if (!CollectionUtils.isEmpty(params)) {
for (Map.Entry<String, String> entry : params.entrySet()) {
builder.addPart(
entry.getKey(),
builder.addPart(entry.getKey(),
new StringBody(entry.getValue(), ContentType.MULTIPART_FORM_DATA));
}
}
@@ -668,41 +580,34 @@ public class HttpClientUtils {
* @return
*/
public static HttpClientResult doDelete(String url, Map<String, String> headers, String req) {
return RetryUtils.exec(
() -> {
HttpDeleteWithBody httpDelete = null;
CloseableHttpResponse response = null;
try {
httpDelete = new HttpDeleteWithBody(url);
// 封装header参数
packageHeader(headers, httpDelete);
httpDelete.setHeader("Content-Type", "application/json;charset=UTF-8");
// 封装请求参数
StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题
stringEntity.setContentEncoding("UTF-8");
return RetryUtils.exec(() -> {
HttpDeleteWithBody httpDelete = null;
CloseableHttpResponse response = null;
try {
httpDelete = new HttpDeleteWithBody(url);
// 封装header参数
packageHeader(headers, httpDelete);
httpDelete.setHeader("Content-Type", "application/json;charset=UTF-8");
// 封装请求参数
StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题
stringEntity.setContentEncoding("UTF-8");
httpDelete.setEntity(stringEntity);
httpDelete.setEntity(stringEntity);
response = httpClient.execute(httpDelete);
response = httpClient.execute(httpDelete);
HttpClientResult res = getHttpClientResult(response);
log.info(
"doDeleteJSON uri:{}, req:{}, resp:{}",
url,
"headers:" + getHeaders(httpDelete) + "------req:" + req,
res);
return res;
} catch (Exception e) {
log.error(
"doDeleteJSON error! uri:{}, req:{}",
url,
"headers:" + headers + "------req:" + req,
e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpDelete, response);
}
});
HttpClientResult res = getHttpClientResult(response);
log.info("doDeleteJSON uri:{}, req:{}, resp:{}", url,
"headers:" + getHeaders(httpDelete) + "------req:" + req, res);
return res;
} catch (Exception e) {
log.error("doDeleteJSON error! uri:{}, req:{}", url,
"headers:" + headers + "------req:" + req, e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpDelete, response);
}
});
}
private static class HttpDeleteWithBody extends HttpEntityEnclosingRequestBase {
@@ -730,37 +635,30 @@ public class HttpClientUtils {
}
public static HttpClientResult doPutJson(String url, Map<String, String> headers, String req) {
return RetryUtils.exec(
() -> {
HttpPut httpPut = null;
CloseableHttpResponse response = null;
try {
httpPut = new HttpPut(url);
// 封装header参数
packageHeader(headers, httpPut);
httpPut.setHeader("Content-Type", "application/json;charset=UTF-8");
// 封装请求参数
StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题
stringEntity.setContentEncoding("UTF-8");
httpPut.setEntity(stringEntity);
response = httpClient.execute(httpPut);
HttpClientResult res = getHttpClientResult(response);
log.info(
"doPutJSON uri:{}, req:{}, resp:{}",
url,
"headers:" + getHeaders(httpPut) + "------req:" + req,
res);
return res;
} catch (Exception e) {
log.error(
"doPutJSON error! uri:{}, req:{}",
url,
"headers:" + headers + "------req:" + req,
e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpPut, response);
}
});
return RetryUtils.exec(() -> {
HttpPut httpPut = null;
CloseableHttpResponse response = null;
try {
httpPut = new HttpPut(url);
// 封装header参数
packageHeader(headers, httpPut);
httpPut.setHeader("Content-Type", "application/json;charset=UTF-8");
// 封装请求参数
StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题
stringEntity.setContentEncoding("UTF-8");
httpPut.setEntity(stringEntity);
response = httpClient.execute(httpPut);
HttpClientResult res = getHttpClientResult(response);
log.info("doPutJSON uri:{}, req:{}, resp:{}", url,
"headers:" + getHeaders(httpPut) + "------req:" + req, res);
return res;
} catch (Exception e) {
log.error("doPutJSON error! uri:{}, req:{}", url,
"headers:" + headers + "------req:" + req, e);
throw new RuntimeException(e.getMessage());
} finally {
close(httpPut, response);
}
});
}
}

View File

@@ -30,7 +30,8 @@ public class JsonUtil {
public static final JsonUtil INSTANCE = new JsonUtil();
@Getter private final ObjectMapper objectMapper = new ObjectMapper();
@Getter
private final ObjectMapper objectMapper = new ObjectMapper();
public JsonUtil() {
// 当属性为null时不参与序列化
@@ -400,10 +401,8 @@ public class JsonUtil {
try {
notNull(keyClass, "key class is null");
notNull(valueClass, "value class is null");
JavaType type =
objectMapper
.getTypeFactory()
.constructParametricType(Map.class, keyClass, valueClass);
JavaType type = objectMapper.getTypeFactory().constructParametricType(Map.class,
keyClass, valueClass);
return objectMapper.readValue(json, type);
} catch (Exception e) {
throw new JsonException(e);
@@ -503,8 +502,7 @@ public class JsonUtil {
}
try {
JsonNode jsonNode = readTree(string);
return objectMapper
.writerWithDefaultPrettyPrinter()
return objectMapper.writerWithDefaultPrettyPrinter()
.writeValueAsString(jsonNode);
} catch (Exception e) {
return string;
@@ -617,10 +615,7 @@ public class JsonUtil {
super(cause);
}
private JsonException(
String message,
Throwable cause,
boolean enableSuppression,
private JsonException(String message, Throwable cause, boolean enableSuppression,
boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}

View File

@@ -12,8 +12,8 @@ public class SignatureUtils {
return DigestUtils.sha1Hex(psw);
}
public static Pair<Boolean, String> isValidSignature(
String appKey, String appSecret, long timestamp, String signatureToCheck) {
public static Pair<Boolean, String> isValidSignature(String appKey, String appSecret,
long timestamp, String signatureToCheck) {
long currentTimeMillis = System.currentTimeMillis();
if (currentTimeMillis < timestamp) {

View File

@@ -59,13 +59,11 @@ public class SqlFilterUtils {
StringJoiner joiner = new StringJoiner(Constants.AND_UPPER);
if (!CollectionUtils.isEmpty(filters)) {
filters.stream()
.forEach(
filter -> {
if (StringUtils.isNotEmpty(dealFilter(filter, isBizName))) {
joiner.add(SPACE + dealFilter(filter, isBizName) + SPACE);
}
});
filters.stream().forEach(filter -> {
if (StringUtils.isNotEmpty(dealFilter(filter, isBizName))) {
joiner.add(SPACE + dealFilter(filter, isBizName) + SPACE);
}
});
log.debug("getWhereClause, where sql : {}", joiner);
return joiner.toString();
}
@@ -160,8 +158,8 @@ public class SqlFilterUtils {
throw new RuntimeException("criterion.getValue() can not be null");
}
StringBuilder whereClause = new StringBuilder();
whereClause.append(
criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE);
whereClause
.append(criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE);
String value = criterion.getValue().toString();
if (criterion.isNeedApostrophe() && !Pattern.matches(pattern, value)) {
// like click => 'like%'
@@ -170,10 +168,9 @@ public class SqlFilterUtils {
} else {
// like 'click' => 'like%'
whereClause.append(
Constants.APOSTROPHE
+ value.replaceAll(Constants.APOSTROPHE, Constants.PERCENT_SIGN)
+ Constants.APOSTROPHE);
whereClause.append(Constants.APOSTROPHE
+ value.replaceAll(Constants.APOSTROPHE, Constants.PERCENT_SIGN)
+ Constants.APOSTROPHE);
}
return whereClause.toString();
}
@@ -184,8 +181,8 @@ public class SqlFilterUtils {
}
StringBuilder whereClause = new StringBuilder();
whereClause.append(
criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE);
whereClause
.append(criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE);
List values = (List) criterion.getValue();
whereClause.append(PARENTHESES_START);
StringJoiner joiner = new StringJoiner(",");
@@ -209,19 +206,12 @@ public class SqlFilterUtils {
}
if (criterion.isNeedApostrophe()) {
return String.format(
"(%s >= %s and %s <= %s)",
criterion.getColumn(),
valueApostropheLogic(values.get(0).toString()),
criterion.getColumn(),
return String.format("(%s >= %s and %s <= %s)", criterion.getColumn(),
valueApostropheLogic(values.get(0).toString()), criterion.getColumn(),
valueApostropheLogic(values.get(1).toString()));
}
return String.format(
"(%s >= %s and %s <= %s)",
criterion.getColumn(),
values.get(0).toString(),
criterion.getColumn(),
values.get(1).toString());
return String.format("(%s >= %s and %s <= %s)", criterion.getColumn(),
values.get(0).toString(), criterion.getColumn(), values.get(1).toString());
}
private String singleValueLogic(Criterion criterion) {
@@ -229,8 +219,8 @@ public class SqlFilterUtils {
throw new RuntimeException("criterion.getValue() can not be null");
}
StringBuilder whereClause = new StringBuilder();
whereClause.append(
criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE);
whereClause
.append(criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE);
String value = criterion.getValue().toString();
if (criterion.isNeedApostrophe()) {
value = valueApostropheLogic(value);
@@ -258,10 +248,7 @@ public class SqlFilterUtils {
if (Objects.isNull(criterion) || Objects.isNull(criterion.getValue())) {
throw new RuntimeException("criterion.getValue() can not be null");
}
return PARENTHESES_START
+ SPACE
+ criterion.getValue().toString()
+ SPACE
return PARENTHESES_START + SPACE + criterion.getValue().toString() + SPACE
+ PARENTHESES_END;
}
}

View File

@@ -28,7 +28,7 @@ public class StringUtil {
* @param v1
* @param v2
* @return value 0 if v1 equal to v2; less than 0 if v1 is less than v2; greater than 0 if v1 is
* greater than v2
* greater than v2
*/
public static int compareVersion(String v1, String v2) {
String[] v1s = v1.split("\\.");

View File

@@ -12,8 +12,8 @@ public class ThreadMdcUtil {
}
}
public static <T> Callable<T> wrap(
final Callable<T> callable, final Map<String, String> context) {
public static <T> Callable<T> wrap(final Callable<T> callable,
final Map<String, String> context) {
return () -> {
if (context == null) {
MDC.clear();

View File

@@ -51,10 +51,8 @@ public class YamlUtils {
.disable(YAMLGenerator.Feature.LITERAL_BLOCK_STYLE);
try {
String yaml = mapper.writeValueAsString(object);
return yaml.replaceAll("\"True\"", "true")
.replaceAll("\"true\"", "true")
.replaceAll("\"false\"", "false")
.replaceAll("\"False\"", "false");
return yaml.replaceAll("\"True\"", "true").replaceAll("\"true\"", "true")
.replaceAll("\"false\"", "false").replaceAll("\"False\"", "false");
} catch (IOException e) {
log.error("", e);
}

View File

@@ -24,11 +24,8 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
@Override
public EmbeddingStore createEmbeddingStore(String collectionName) {
return ChromaEmbeddingStore.builder()
.baseUrl(storeProperties.getBaseUrl())
.collectionName(collectionName)
.timeout(storeProperties.getTimeout())
.build();
return ChromaEmbeddingStore.builder().baseUrl(storeProperties.getBaseUrl())
.collectionName(collectionName).timeout(storeProperties.getTimeout()).build();
}
private static EmbeddingStoreProperties createPropertiesFromConfig(

View File

@@ -12,5 +12,6 @@ public class Properties {
static final String PREFIX = "langchain4j.chroma";
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}

View File

@@ -20,18 +20,15 @@ public class DashscopeAutoConfig {
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
QwenChatModel qwenChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return QwenChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName())
.topP(chatModelProperties.getTopP())
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.enableSearch(chatModelProperties.getEnableSearch())
.seed(chatModelProperties.getSeed())
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
.temperature(chatModelProperties.getTemperature())
.stops(chatModelProperties.getStops())
.maxTokens(chatModelProperties.getMaxTokens())
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
.build();
}
@@ -39,18 +36,15 @@ public class DashscopeAutoConfig {
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return QwenStreamingChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName())
.topP(chatModelProperties.getTopP())
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.enableSearch(chatModelProperties.getEnableSearch())
.seed(chatModelProperties.getSeed())
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
.temperature(chatModelProperties.getTemperature())
.stops(chatModelProperties.getStops())
.maxTokens(chatModelProperties.getMaxTokens())
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
.build();
}
@@ -58,47 +52,33 @@ public class DashscopeAutoConfig {
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
QwenLanguageModel qwenLanguageModel(Properties properties) {
ChatModelProperties languageModel = properties.getLanguageModel();
return QwenLanguageModel.builder()
.baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey())
.modelName(languageModel.getModelName())
.topP(languageModel.getTopP())
.topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch())
.seed(languageModel.getSeed())
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
.repetitionPenalty(languageModel.getRepetitionPenalty())
.temperature(languageModel.getTemperature())
.stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens())
.build();
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
return QwenStreamingLanguageModel.builder()
.baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey())
.modelName(languageModel.getModelName())
.topP(languageModel.getTopP())
.topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch())
.seed(languageModel.getSeed())
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
.repetitionPenalty(languageModel.getRepetitionPenalty())
.temperature(languageModel.getTemperature())
.stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens())
.build();
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return QwenEmbeddingModel.builder()
.apiKey(embeddingModelProperties.getApiKey())
.modelName(embeddingModelProperties.getModelName())
.build();
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
.modelName(embeddingModelProperties.getModelName()).build();
}
}

View File

@@ -12,13 +12,18 @@ public class Properties {
static final String PREFIX = "langchain4j.dashscope";
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty ChatModelProperties languageModel;
@NestedConfigurationProperty
ChatModelProperties languageModel;
@NestedConfigurationProperty ChatModelProperties streamingLanguageModel;
@NestedConfigurationProperty
ChatModelProperties streamingLanguageModel;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -74,8 +74,8 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
if (MapUtils.isEmpty(super.collectionNameToStore)) {
return;
}
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry :
collectionNameToStore.entrySet()) {
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore
.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;

View File

@@ -12,7 +12,9 @@ public class Properties {
static final String PREFIX = "langchain4j.in-memory";
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -20,70 +20,58 @@ public class LocalAiAutoConfig {
@ConditionalOnProperty(PREFIX + ".chat-model.base-url")
LocalAiChatModel localAiChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return LocalAiChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
return LocalAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.modelName(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxRetries(chatModelProperties.getMaxRetries())
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url")
LocalAiStreamingChatModel localAiStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return LocalAiStreamingChatModel.builder()
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.baseUrl(chatModelProperties.getBaseUrl())
return LocalAiStreamingChatModel.builder().temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
.modelName(chatModelProperties.getModelName())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".language-model.base-url")
LocalAiLanguageModel localAiLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
return LocalAiLanguageModel.builder()
.topP(languageModelProperties.getTopP())
return LocalAiLanguageModel.builder().topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.modelName(languageModelProperties.getModelName())
.temperature(languageModelProperties.getTemperature())
.maxRetries(languageModelProperties.getMaxRetries())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses())
.build();
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url")
LocalAiStreamingLanguageModel localAiStreamingLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
return LocalAiStreamingLanguageModel.builder()
.topP(languageModelProperties.getTopP())
return LocalAiStreamingLanguageModel.builder().topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.modelName(languageModelProperties.getModelName())
.temperature(languageModelProperties.getTemperature())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses())
.build();
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.base-url")
LocalAiEmbeddingModel localAiEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return LocalAiEmbeddingModel.builder()
.baseUrl(embeddingModelProperties.getBaseUrl())
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
.modelName(embeddingModelProperties.getModelName())
.maxRetries(embeddingModelProperties.getMaxRetries())
.logRequests(embeddingModelProperties.getLogRequests())
.logResponses(embeddingModelProperties.getLogResponses())
.build();
.logResponses(embeddingModelProperties.getLogResponses()).build();
}
}

View File

@@ -12,13 +12,18 @@ public class Properties {
static final String PREFIX = "langchain4j.local-ai";
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty LanguageModelProperties languageModel;
@NestedConfigurationProperty
LanguageModelProperties languageModel;
@NestedConfigurationProperty LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty
LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -29,21 +29,15 @@ public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
@Override
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
return MilvusEmbeddingStore.builder()
.host(storeProperties.getHost())
.port(storeProperties.getPort())
.collectionName(collectionName)
.dimension(storeProperties.getDimension())
.indexType(storeProperties.getIndexType())
.metricType(storeProperties.getMetricType())
.uri(storeProperties.getUri())
.token(storeProperties.getToken())
.username(storeProperties.getUsername())
return MilvusEmbeddingStore.builder().host(storeProperties.getHost())
.port(storeProperties.getPort()).collectionName(collectionName)
.dimension(storeProperties.getDimension()).indexType(storeProperties.getIndexType())
.metricType(storeProperties.getMetricType()).uri(storeProperties.getUri())
.token(storeProperties.getToken()).username(storeProperties.getUsername())
.password(storeProperties.getPassword())
.consistencyLevel(storeProperties.getConsistencyLevel())
.retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch())
.autoFlushOnInsert(storeProperties.getAutoFlushOnInsert())
.databaseName(storeProperties.getDatabaseName())
.build();
.databaseName(storeProperties.getDatabaseName()).build();
}
}

View File

@@ -12,5 +12,6 @@ public class Properties {
static final String PREFIX = "langchain4j.milvus";
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}

View File

@@ -13,10 +13,10 @@ import java.util.Objects;
/**
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
* from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert
* models into ONNX format can be found <a
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
* Many models already converted to ONNX format are available <a
* href="https://huggingface.co/Xenova">here</a>. Copy from
* models into ONNX format can be found <a href=
* "https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>. Many
* models already converted to ONNX format are available
* <a href="https://huggingface.co/Xenova">here</a>. Copy from
* dev.langchain4j.model.embedding.OnnxEmbeddingModel.
*/
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
@@ -28,9 +28,8 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
synchronized (S2OnnxEmbeddingModel.class) {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
URL resource =
AbstractInProcessEmbeddingModel.class.getResource(
"/bert-vocabulary-en.txt");
URL resource = AbstractInProcessEmbeddingModel.class
.getResource("/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) {
try {
resource = Paths.get(vocabularyPath).toUri().toURL();
@@ -56,15 +55,14 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
}
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
return cachedModel == null
|| !Objects.equals(cachedModelPath, pathToModel)
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
}
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
try {
return new OnnxBertBiEncoder(
Files.newInputStream(pathToModel), vocabularyFile, PoolingMode.MEAN);
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile,
PoolingMode.MEAN);
} catch (IOException e) {
throw new RuntimeException(e);
}

View File

@@ -60,8 +60,8 @@ import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
* gpt-4. You can find description of parameters <a
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
* gpt-4. You can find description of parameters
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@@ -88,32 +88,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final List<ChatModelListener> listeners;
@Builder
public OpenAiChatModel(
String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
List<String> stop,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Boolean strictJsonSchema,
Integer seed,
String user,
Boolean strictTools,
Boolean parallelToolCalls,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ChatModelListener> listeners) {
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
Double temperature, Double topP, List<String> stop, Integer maxTokens,
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries,
Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer,
Map<String, String> customHeaders, List<ChatModelListener> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@@ -123,21 +104,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
timeout = getOrDefault(timeout, ofSeconds(60));
this.client =
OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders).build();
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
@@ -146,14 +117,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.logitBias = logitBias;
this.responseFormat =
responseFormat == null
? null
: ResponseFormat.builder()
.type(
ResponseFormatType.valueOf(
responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.responseFormat = responseFormat == null ? null
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
this.seed = seed;
this.user = user;
@@ -183,61 +150,44 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification,
this.responseFormat);
}
@Override
public ChatResponse chat(ChatRequest request) {
Response<AiMessage> response =
generate(
request.messages(),
request.toolSpecifications(),
null,
generate(request.messages(), request.toolSpecifications(), null,
getOrDefault(
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
this.responseFormat));
return ChatResponse.builder()
.aiMessage(response.content())
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.build();
return ChatResponse.builder().aiMessage(response.content())
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
}
private Response<AiMessage> generate(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
ResponseFormat responseFormat) {
if (responseFormat != null
&& responseFormat.type() == JSON_SCHEMA
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
&& responseFormat.jsonSchema() == null) {
responseFormat = null;
}
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user)
.parallelToolCalls(parallelToolCalls);
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
.maxTokens(maxTokens).presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
.responseFormat(responseFormat).seed(seed).user(user)
.parallelToolCalls(parallelToolCalls);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
@@ -257,40 +207,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext =
new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
try {
ChatCompletionResponse chatCompletionResponse =
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
Response<AiMessage> response =
Response.from(
aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(
chatCompletionResponse.choices().get(0).finishReason()));
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
ChatModelResponse modelListenerResponse =
createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext =
new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
ChatModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return response;
} catch (RuntimeException e) {
@@ -305,14 +248,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatModelErrorContext errorContext =
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
listeners.forEach(
listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
@@ -328,8 +270,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
public static OpenAiChatModelBuilder builder() {
for (OpenAiChatModelBuilderFactory factory :
loadFactories(OpenAiChatModelBuilderFactory.class)) {
for (OpenAiChatModelBuilderFactory factory : loadFactories(
OpenAiChatModelBuilderFactory.class)) {
return factory.get();
}
return new OpenAiChatModelBuilder();

View File

@@ -3,9 +3,8 @@ package dev.langchain4j.model.openai;
public enum OpenAiChatModelName {
GPT_3_5_TURBO("gpt-3.5-turbo"), // alias
@Deprecated
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"),
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106"),
GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"), GPT_3_5_TURBO_1106(
"gpt-3.5-turbo-1106"), GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k"), // alias
@Deprecated
@@ -13,22 +12,18 @@ public enum OpenAiChatModelName {
GPT_4("gpt-4"), // alias
@Deprecated
GPT_4_0314("gpt-4-0314"),
GPT_4_0613("gpt-4-0613"),
GPT_4_0314("gpt-4-0314"), GPT_4_0613("gpt-4-0613"),
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"), // alias
GPT_4_1106_PREVIEW("gpt-4-1106-preview"),
GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
GPT_4_1106_PREVIEW("gpt-4-1106-preview"), GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
GPT_4_32K("gpt-4-32k"), // alias
GPT_4_32K_0314("gpt-4-32k-0314"),
GPT_4_32K_0613("gpt-4-32k-0613"),
GPT_4_32K_0314("gpt-4-32k-0314"), GPT_4_32K_0613("gpt-4-32k-0613"),
@Deprecated
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
GPT_4_O("gpt-4o"),
GPT_4_O_MINI("gpt-4o-mini");
GPT_4_O("gpt-4o"), GPT_4_O_MINI("gpt-4o-mini");
private final String stringValue;

View File

@@ -1,9 +1,7 @@
package dev.langchain4j.model.zhipu;
public enum ChatCompletionModel {
GLM_4("glm-4"),
GLM_3_TURBO("glm-3-turbo"),
CHATGLM_TURBO("chatglm_turbo");
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
private final String value;

View File

@@ -27,8 +27,8 @@ import static java.util.Collections.singletonList;
/**
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
* glm-4. You can find description of parameters <a
* href="https://open.bigmodel.cn/dev/api">here</a>.
* glm-4. You can find description of parameters
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
*/
public class ZhipuAiChatModel implements ChatLanguageModel {
@@ -41,15 +41,8 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
private final ZhipuAiClient client;
@Builder
public ZhipuAiChatModel(
String baseUrl,
String apiKey,
Double temperature,
Double topP,
String model,
Integer maxRetries,
Integer maxToken,
Boolean logRequests,
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
Boolean logResponses) {
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
this.temperature = getOrDefault(temperature, 0.7);
@@ -57,18 +50,14 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512);
this.client =
ZhipuAiClient.builder()
.baseUrl(this.baseUrl)
.apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.build();
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false)).build();
}
public static ZhipuAiChatModelBuilder builder() {
for (ZhipuAiChatModelBuilderFactory factories :
loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
ZhipuAiChatModelBuilderFactory.class)) {
return factories.get();
}
return new ZhipuAiChatModelBuilder();
@@ -80,15 +69,13 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
.topP(topP)
.toolChoice(AUTO)
.messages(toZhipuAiMessages(messages));
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
@@ -96,17 +83,15 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
ChatCompletionResponse response =
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
return Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.getUsage()),
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, toolSpecification != null ? singletonList(toolSpecification) : null);
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages,
toolSpecification != null ? singletonList(toolSpecification) : null);
}
public static class ZhipuAiChatModelBuilder {

View File

@@ -20,36 +20,27 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder =
AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(
Duration.ofSeconds(
modelConfig.getTimeOut() == null
? 0L
: modelConfig.getTimeOut()))
.logRequestsAndResponses(
modelConfig.getLogRequests() != null
&& modelConfig.getLogResponses());
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(Duration.ofSeconds(
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
.logRequestsAndResponses(
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
return builder.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
AzureOpenAiEmbeddingModel.Builder builder =
AzureOpenAiEmbeddingModel.builder()
.endpoint(embeddingModelConfig.getBaseUrl())
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.deploymentName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequestsAndResponses(
embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
return builder.build();
}

View File

@@ -19,25 +19,17 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QwenChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.modelName(modelConfig.getModelName())
.temperature(
modelConfig.getTemperature() == null
? 0L
: modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP())
.enableSearch(modelConfig.getEnableSearch())
.build();
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature() == null ? 0L
: modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QwenEmbeddingModel.builder()
.apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName())
.build();
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName()).build();
}
@Override

View File

@@ -19,27 +19,20 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
return LocalAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.maxRetries(modelConfig.getMaxRetries())
.logResponses(modelConfig.getLogResponses()).maxRetries(modelConfig.getMaxRetries())
.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return LocalAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl())
.modelName(embeddingModel.getModelName())
.maxRetries(embeddingModel.getMaxRetries())
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
.modelName(embeddingModel.getModelName()).maxRetries(embeddingModel.getMaxRetries())
.logRequests(embeddingModel.getLogRequests())
.logResponses(embeddingModel.getLogResponses())
.build();
.logResponses(embeddingModel.getLogResponses()).build();
}
@Override

View File

@@ -25,8 +25,7 @@ public class ModelProvider {
}
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
if (modelConfig == null
|| StringUtils.isBlank(modelConfig.getProvider())
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
ChatModelParameterConfig parameterConfig =
ContextUtils.getBean(ChatModelParameterConfig.class);

View File

@@ -21,27 +21,20 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OllamaChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
return OllamaChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return OllamaEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
return OllamaEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override

View File

@@ -21,29 +21,22 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OpenAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return OpenAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl())
.apiKey(embeddingModel.getApiKey())
.modelName(embeddingModel.getModelName())
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
.apiKey(embeddingModel.getApiKey()).modelName(embeddingModel.getModelName())
.maxRetries(embeddingModel.getMaxRetries())
.logRequests(embeddingModel.getLogRequests())
.logResponses(embeddingModel.getLogResponses())
.build();
.logResponses(embeddingModel.getLogResponses()).build();
}
@Override

View File

@@ -21,31 +21,23 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QianfanChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.secretKey(modelConfig.getSecretKey())
.endpoint(modelConfig.getEndpoint())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QianfanEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.secretKey(embeddingModelConfig.getSecretKey())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override

View File

@@ -19,28 +19,20 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return ZhipuAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.model(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.model(embeddingModelConfig.getModelName())
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override

Some files were not shown because too many files have changed in this diff Show More