Skip to content

Commit d213e4a

Browse files
committed
optimize code
1 parent 9f3d3e2 commit d213e4a

File tree

11 files changed

+42
-48
lines changed

11 files changed

+42
-48
lines changed

sqlrec-core/src/main/java/com/sqlrec/model/ModelManager.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ public static List<CheckpointInfo> trainModel(SqlTrainModel sqlTrainModel, Strin
112112

113113
String k8sYaml = modelController.genModelTrainK8sYaml(modelConfig, modelTrainConf);
114114
k8sYaml = injectPodConfig(k8sYaml, modelConfig, modelTrainConf.getParams());
115-
K8sManager.applyYaml(k8sYaml);
116115

117116
Checkpoint checkpoint = new Checkpoint();
118117
checkpoint.setModelName(modelTrainConf.getModelName());
@@ -125,7 +124,8 @@ public static List<CheckpointInfo> trainModel(SqlTrainModel sqlTrainModel, Strin
125124
checkpoint.setCreatedAt(System.currentTimeMillis());
126125
checkpoint.setUpdatedAt(System.currentTimeMillis());
127126

128-
DbUtils.upsertCheckpoint(checkpoint);
127+
DbUtils.insertCheckpoint(checkpoint);
128+
K8sManager.applyYaml(k8sYaml);
129129

130130
List<CheckpointInfo> checkpointInfos = new ArrayList<>();
131131
checkpointInfos.add(new CheckpointInfo(modelTrainConf.getModelName(), modelTrainConf.getCheckpointName()));
@@ -181,7 +181,6 @@ public static List<CheckpointInfo> exportModel(SqlExportModel sqlExportModel, St
181181

182182
String k8sYaml = modelController.genModelExportK8sYaml(modelConfig, modelExportConf);
183183
k8sYaml = injectPodConfig(k8sYaml, modelConfig, modelExportConf.getParams());
184-
K8sManager.applyYaml(k8sYaml);
185184

186185
List<CheckpointInfo> checkpointInfos = new ArrayList<>();
187186
for (String exportCheckpointName : exportCheckpointNames) {
@@ -196,10 +195,12 @@ public static List<CheckpointInfo> exportModel(SqlExportModel sqlExportModel, St
196195
checkpoint.setCreatedAt(System.currentTimeMillis());
197196
checkpoint.setUpdatedAt(System.currentTimeMillis());
198197

199-
DbUtils.upsertCheckpoint(checkpoint);
198+
DbUtils.insertCheckpoint(checkpoint);
200199
checkpointInfos.add(new CheckpointInfo(modelExportConf.getModelName(), exportCheckpointName));
201200
}
202201

202+
K8sManager.applyYaml(k8sYaml);
203+
203204
return checkpointInfos;
204205
}
205206

sqlrec-core/src/main/java/com/sqlrec/utils/DbMapper.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ public interface DbMapper {
8484
"ON CONFLICT (model_name, checkpoint_name) DO UPDATE SET model_ddl = #{modelDdl}, ddl = #{ddl}, yaml = #{yaml}, checkpoint_type = #{checkpointType}, status = #{status}, updated_at = #{updatedAt}")
8585
void upsertCheckpoint(Checkpoint checkpoint);
8686

87+
@Insert("INSERT INTO checkpoint " +
88+
"(model_name, checkpoint_name, model_ddl, ddl, yaml, checkpoint_type, status, created_at, updated_at) " +
89+
"VALUES (#{modelName}, #{checkpointName}, #{modelDdl}, #{ddl}, #{yaml}, #{checkpointType}, #{status}, #{createdAt}, #{updatedAt})")
90+
void insertCheckpoint(Checkpoint checkpoint);
91+
8792
@Delete("DELETE FROM checkpoint WHERE model_name = #{modelName} AND checkpoint_name = #{checkpointName}")
8893
void deleteCheckpoint(@Param("modelName") String modelName, @Param("checkpointName") String checkpointName);
8994

sqlrec-core/src/main/java/com/sqlrec/utils/DbUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ public static void upsertCheckpoint(Checkpoint checkpoint) {
141141
executeVoid(dbMapper -> dbMapper.upsertCheckpoint(checkpoint));
142142
}
143143

144+
public static void insertCheckpoint(Checkpoint checkpoint) {
145+
executeVoid(dbMapper -> dbMapper.insertCheckpoint(checkpoint));
146+
}
147+
144148
public static void deleteCheckpoint(String modelName, String checkpointName) {
145149
executeVoid(dbMapper -> dbMapper.deleteCheckpoint(modelName, checkpointName));
146150
}

sqlrec-core/src/main/java/com/sqlrec/utils/JavaFunctionUtils.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@
1515

1616
public class JavaFunctionUtils {
1717
private static final Logger log = LoggerFactory.getLogger(JavaFunctionUtils.class);
18+
private static volatile boolean skipHmsQuery = false;
1819
private static Map<String, Class<?>> javaFunctionClassMap = new ConcurrentHashMap<>();
1920
private static Map<String, Long> functionUpdateTime = new ConcurrentHashMap<>();
2021
private static Cache<String, String> notExistCache = Caffeine.newBuilder()
2122
.expireAfterWrite(SqlRecConfigs.FUNCTION_UPDATE_INTERVAL.getValue(), TimeUnit.SECONDS)
2223
.build();
2324

25+
public static void setSkipHmsQuery(boolean skip) {
26+
skipHmsQuery = skip;
27+
}
28+
29+
public static boolean isSkipHmsQuery() {
30+
return skipHmsQuery;
31+
}
32+
2433
public static Object getTableFunction(String db, String funName) throws Exception {
2534
String mapKey = getMapKey(db, funName);
2635
Class<?> clazz = javaFunctionClassMap.get(mapKey);
@@ -30,11 +39,6 @@ public static Object getTableFunction(String db, String funName) throws Exceptio
3039
if (clazz == null) {
3140
return null;
3241
}
33-
// for test
34-
if (clazz.isPrimitive()) {
35-
return null;
36-
}
37-
3842
return clazz.getDeclaredConstructor().newInstance();
3943
}
4044

@@ -71,6 +75,9 @@ public static String getJavaFunctionClassName(String db, String funName) throws
7175
if (FunctionConfigs.DEFAULT_JAVA_FUNCTION_CONFIGS.containsKey(funName)) {
7276
return FunctionConfigs.DEFAULT_JAVA_FUNCTION_CONFIGS.get(funName);
7377
}
78+
if (skipHmsQuery) {
79+
return null;
80+
}
7481
org.apache.hadoop.hive.metastore.api.Function functionObj = HmsClient.getFunctionObj(db, funName);
7582
if (functionObj == null) {
7683
throw new Exception("Function not found: " + funName);

sqlrec-core/src/main/java/com/sqlrec/utils/KvJoinUtils.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
66
import org.apache.calcite.linq4j.Enumerable;
77
import org.apache.calcite.linq4j.Linq4j;
8-
import org.apache.calcite.rel.core.Join;
98
import org.apache.calcite.rel.core.JoinRelType;
109
import org.apache.calcite.rel.type.RelDataType;
1110
import org.apache.calcite.rex.RexCall;
@@ -86,7 +85,7 @@ private static Enumerable joinByPrimaryKey(
8685
return Linq4j.asEnumerable(merged);
8786
}
8887

89-
private static Object[] copyValues(Object[] leftValue, Object[] rightValue, int leftSize, int rightSize) {
88+
public static Object[] copyValues(Object[] leftValue, Object[] rightValue, int leftSize, int rightSize) {
9089
Object[] copy = new Object[leftSize + rightSize];
9190
System.arraycopy(leftValue, 0, copy, 0, leftSize);
9291
if (rightValue != null) {
@@ -95,11 +94,6 @@ private static Object[] copyValues(Object[] leftValue, Object[] rightValue, int
9594
return copy;
9695
}
9796

98-
public static Map.Entry<Integer, Integer> getJoinKeyColIndex(Join join) {
99-
RexNode condition = join.getCondition();
100-
return getJoinKeyColIndex(condition);
101-
}
102-
10397
public static Map.Entry<Integer, Integer> getJoinKeyColIndex(RexNode condition) {
10498
List<Integer> indexList = new ArrayList<>();
10599
if (condition instanceof RexCall) {

sqlrec-core/src/main/java/com/sqlrec/utils/KvTableUtils.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ public static CacheTable getScanCacheTable(RelNode relNode) {
3030
return table.unwrap(CacheTable.class);
3131
}
3232

33+
public static boolean isKvTable(RelOptTable table) {
34+
if (table == null) {
35+
return false;
36+
}
37+
return table.unwrap(SqlRecKvTable.class) != null;
38+
}
39+
3340
public static RelOptTable getScanTable(RelNode aNode) {
3441
if (aNode instanceof RelSubset) {
3542
RelSubset relNode = ((RelSubset) aNode);
@@ -49,12 +56,4 @@ public static RelOptTable getScanTable(RelNode aNode) {
4956

5057
return null;
5158
}
52-
53-
public static boolean isKvTable(RelOptTable table) {
54-
if (table == null) {
55-
return false;
56-
}
57-
return table.unwrap(SqlRecKvTable.class) != null;
58-
}
59-
6059
}

sqlrec-core/src/main/java/com/sqlrec/utils/VectorJoinUtils.java

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ private static Enumerable doVectorJoin(
9595
for (Object[] leftValue : leftValues) {
9696
Object leftEmbedding = leftValue[leftEmbeddingColIndex];
9797
List<Float> embedding = DataTransformUtils.convertToFloatVec(leftEmbedding);
98-
9998
List<Object[]> rightValues = rightTable.searchByEmbeddingWithScore(
10099
leftValue,
101100
embedding,
@@ -104,11 +103,9 @@ private static Enumerable doVectorJoin(
104103
limit,
105104
vectorProjectColumns
106105
);
107-
108-
if (rightValues == null || rightValues.isEmpty()) {
109-
} else {
106+
if (rightValues != null) {
110107
for (Object[] rightValue : rightValues) {
111-
Object[] joinRow = copyValues(leftValue, rightValue, leftSize, rightSize);
108+
Object[] joinRow = KvJoinUtils.copyValues(leftValue, rightValue, leftSize, rightSize);
112109
Object[] projectRow = buildProjectRow(joinRow, projectColumns, projectSize);
113110
result.add(projectRow);
114111
}
@@ -131,21 +128,12 @@ private static Object[] buildProjectRow(Object[] joinRow, List<Integer> projectC
131128
return projectRow;
132129
}
133130

134-
private static Object[] copyValues(Object[] leftValue, Object[] rightValue, int leftSize, int rightSize) {
135-
Object[] copy = new Object[leftSize + rightSize];
136-
System.arraycopy(leftValue, 0, copy, 0, leftSize);
137-
if (rightValue != null) {
138-
System.arraycopy(rightValue, 0, copy, leftSize, rightSize);
139-
}
140-
return copy;
141-
}
142-
143131
public static VectorJoinConfig extractVectorJoinConfig(
144132
LogicalSort sort,
145133
LogicalProject project,
146134
LogicalFilter filter,
147-
LogicalJoin join) {
148-
135+
LogicalJoin join
136+
) {
149137
VectorJoinConfig config = new VectorJoinConfig();
150138

151139
if (sort != null && sort.fetch != null && sort.fetch instanceof RexLiteral) {

sqlrec-core/src/test/java/com/sqlrec/TestExceptionIgnore.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ public void testExceptionIgnore() throws Exception {
1818
HmsSchema.setGlobalSchema(schema);
1919

2020
JavaFunctionUtils.registerTableFunction("default", "test_fun", TestExceptionIgnore.TestFunction.class);
21-
JavaFunctionUtils.registerTableFunction("default", "sql_fun1", Integer.TYPE); // avoid find function in hms
22-
JavaFunctionUtils.registerTableFunction("default", "sql_fun2", Integer.TYPE); // avoid find function in hms
23-
JavaFunctionUtils.registerTableFunction("default", "sql_fun3", Integer.TYPE); // avoid find function in hms
21+
JavaFunctionUtils.setSkipHmsQuery(true);
2422

2523
List<String> sqlFun1 = Arrays.asList(
2624
"create sql function sql_fun1",

sqlrec-core/src/test/java/com/sqlrec/TestSqlFunction.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ protected Map<String, Table> getTableMap() {
2828
});
2929

3030
HmsSchema.setGlobalSchema(schema);
31-
JavaFunctionUtils.registerTableFunction("default", "fun1", Integer.TYPE); // avoid find function in hms
32-
JavaFunctionUtils.registerTableFunction("default", "fun2", Integer.TYPE); // avoid find function in hms
33-
JavaFunctionUtils.registerTableFunction("default", "fun3", Integer.TYPE); // avoid find function in hms
31+
JavaFunctionUtils.setSkipHmsQuery(true);
3432

3533
testSqlFunctionCompile(schema);
3634

sqlrec-frontend/src/test/java/com/sqlrec/connectors/TestJoin.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ protected Map<String, Table> getTableMap() {
100100
});
101101
HmsSchema.setGlobalSchema(schema);
102102

103-
JavaFunctionUtils.registerTableFunction("default", "test_join", Integer.TYPE); // avoid find function in hms
103+
JavaFunctionUtils.setSkipHmsQuery(true);
104104
List<String> joinFuncSqlList = Arrays.asList(
105105
"create or replace sql function test_join",
106106
"define input table user_info(id integer)",

0 commit comments

Comments
 (0)