Skip to content

Commit 28b898a

Browse files
author
actiontech
committed
Merge branch 'main' of https://github.com/actiontech/sqle into main
2 parents 30f1a6e + f677fd6 commit 28b898a

28 files changed

+518
-357
lines changed

sqle/api/controller/v1/configuration.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ func getGitAuthMethod(url, username, password string) (transport.AuthMethod, err
419419
if err != nil {
420420
return nil, err
421421
}
422-
if systemVariable.Code != 0 {
422+
if systemVariable.Data.SystemVariableSSHPrimaryKey != "" {
423423
return nil, errors.New(errors.DataNotExist, fmt.Errorf("git ssh private key not found"))
424424
}
425425
publicKeys, err := sshTransport.NewPublicKeys("git", []byte(systemVariable.Data.SystemVariableSSHPrimaryKey), "")
@@ -550,7 +550,7 @@ func GetSSHPublicKey(c echo.Context) error {
550550
if err != nil {
551551
return controller.JSONBaseErrorReq(c, err)
552552
}
553-
if systemVariable.Code != 0 {
553+
if systemVariable.Data.SystemVariableSSHPrimaryKey == "" {
554554
return c.JSON(http.StatusOK, SSHPublicKeyInfoV1Rsp{
555555
BaseRes: controller.NewBaseReq(nil),
556556
Data: SSHPublicKeyInfo{

sqle/api/controller/v1/workflow.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,14 @@ func toGlobalWorkflowRes(workflows []*model.WorkflowListDetail, projectMap Proje
834834
InstanceName: instanceMap.InstanceName(id),
835835
})
836836
}
837+
838+
CurrentStepAssigneeUserNames := make([]string, 0)
839+
for _, currentStepAssigneeUser := range strings.Split(workflow.CurrentStepAssigneeUserIds.String, ",") {
840+
if currentStepAssigneeUser == "" {
841+
continue
842+
}
843+
CurrentStepAssigneeUserNames = append(CurrentStepAssigneeUserNames, dms.GetUserNameWithDelTag(currentStepAssigneeUser))
844+
}
837845
workflowRes := &WorkflowDetailResV1{
838846
ProjectName: projectMap.ProjectName(workflow.ProjectId),
839847
ProjectUid: workflow.ProjectId,
@@ -845,7 +853,7 @@ func toGlobalWorkflowRes(workflows []*model.WorkflowListDetail, projectMap Proje
845853
CreateUser: utils.AddDelTag(workflow.CreateUserDeletedAt, workflow.CreateUser.String),
846854
CreateTime: workflow.CreateTime,
847855
CurrentStepType: workflow.CurrentStepType.String,
848-
CurrentStepAssigneeUser: strings.Split(workflow.CurrentStepAssigneeUserIds.String, ","),
856+
CurrentStepAssigneeUser: CurrentStepAssigneeUserNames,
849857
Status: workflow.Status,
850858
}
851859
workflowsResV1 = append(workflowsResV1, workflowRes)

sqle/api/controller/v2/workflow.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -778,20 +778,6 @@ func UpdateWorkflowV2(c echo.Context) error {
778778
fmt.Errorf("you are not allow to operate the workflow")))
779779
}
780780

781-
template, exist, err := s.GetWorkflowTemplateByProjectId(workflow.ProjectId)
782-
if err != nil {
783-
return controller.JSONBaseErrorReq(c, err)
784-
}
785-
if !exist {
786-
return controller.JSONBaseErrorReq(c, errors.New(errors.DataConflict,
787-
fmt.Errorf("failed to find the corresponding workflow template based on the task id")))
788-
}
789-
790-
err = v1.CheckWorkflowCanCommit(template, tasks)
791-
if err != nil {
792-
return controller.JSONBaseErrorReq(c, err)
793-
}
794-
795781
err = s.UpdateWorkflowRecord(workflow, tasks)
796782
if err != nil {
797783
return c.JSON(http.StatusOK, controller.NewBaseReq(err))

sqle/driver/mysql/analysis.go

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/actiontech/sqle/sqle/driver"
1010
"github.com/actiontech/sqle/sqle/driver/mysql/executor"
1111
"github.com/actiontech/sqle/sqle/driver/mysql/plocale"
12+
utilV2 "github.com/actiontech/sqle/sqle/driver/mysql/rule/ai/util"
1213
"github.com/actiontech/sqle/sqle/driver/mysql/util"
1314
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
1415
"github.com/actiontech/sqle/sqle/utils"
@@ -285,48 +286,9 @@ func (i *MysqlDriverImpl) ExtractSchemaTableList(sql string) ([]SchemaTable, err
285286
}
286287
}
287288

288-
getMultiTables := func(stmt *ast.Join) {
289-
tables := util.GetTables(stmt)
290-
for _, t := range tables {
291-
addTable(t)
292-
}
293-
}
294-
295-
switch stmt := node.(type) {
296-
case *ast.SelectStmt:
297-
if stmt.From == nil {
298-
break
299-
}
300-
getMultiTables(stmt.From.TableRefs)
301-
case *ast.UnionStmt:
302-
for _, selectStmt := range stmt.SelectList.Selects {
303-
if selectStmt.From == nil {
304-
continue
305-
}
306-
getMultiTables(selectStmt.From.TableRefs)
307-
}
308-
case *ast.UpdateStmt:
309-
getMultiTables(stmt.TableRefs.TableRefs)
310-
case *ast.InsertStmt:
311-
getMultiTables(stmt.Table.TableRefs)
312-
if stmt.Select != nil {
313-
// TODO:INSERT INTO SQLE00115_t1_tmp_employee (id, cname, sex, age, salary) SELECT 4000002, '小张', 0, 25, (SELECT AVG(salary) FROM SQLE00115_t1_employee) 对于这条SQL,解析器无法解析子查询为一个Select语句,而是认为是一个文本
314-
if selectStmt, ok := stmt.Select.(*ast.SelectStmt); ok {
315-
if selectStmt.From != nil{
316-
getMultiTables(selectStmt.From.TableRefs)
317-
}
318-
}
319-
}
320-
case *ast.DeleteStmt:
321-
getMultiTables(stmt.TableRefs.TableRefs)
322-
case *ast.LoadDataStmt:
323-
addTable(stmt.Table)
324-
case *ast.ShowStmt:
325-
if stmt.Table != nil {
326-
addTable(stmt.Table)
327-
}
328-
default:
329-
return nil, fmt.Errorf("the sql is `%v`, we don't support analysing this sql", sql)
289+
tables := utilV2.GetTableNames(node)
290+
for _, t := range tables {
291+
addTable(t)
330292
}
331293

332294
return schemaTables, nil

sqle/driver/mysql/executor/executor.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type Db interface {
2323
Close()
2424
Ping() error
2525
Exec(query string) (driver.Result, error)
26-
Transact(qs ...string) ([]driver.Result, error)
26+
Transact(qs ...string) (*driverV2.TxResponse, error)
2727
Query(query string, args ...interface{}) ([]map[string]sql.NullString, error)
2828
QueryWithContext(ctx context.Context, query string, args ...interface{}) (column []string, row [][]sql.NullString, err error)
2929
Logger() *logrus.Entry
@@ -137,14 +137,13 @@ func (c *BaseConn) Exec(query string) (driver.Result, error) {
137137
return result, errors.New(errors.ConnectRemoteDatabaseError, err)
138138
}
139139

140-
func (c *BaseConn) Transact(qs ...string) ([]driver.Result, error) {
140+
func (c *BaseConn) Transact(qs ...string) (*driverV2.TxResponse, error) {
141141
var err error
142142
var tx *sql.Tx
143-
var results []driver.Result
144143
c.Logger().Infof("doing sql transact, host: %s, port: %s, user: %s", c.host, c.port, c.user)
145144
tx, err = c.conn.BeginTx(context.Background(), nil)
146145
if err != nil {
147-
return results, err
146+
return nil, err
148147
}
149148
defer func() {
150149
if p := recover(); p != nil {
@@ -168,14 +167,23 @@ func (c *BaseConn) Transact(qs ...string) ([]driver.Result, error) {
168167
c.Logger().Info("done sql transact")
169168
}
170169
}()
171-
for _, query := range qs {
170+
171+
results := &driverV2.TxResponse{
172+
ExecResult: make([]driver.Result, 0, len(qs)),
173+
}
174+
for k, query := range qs {
172175
var txResult driver.Result
173176
txResult, err = tx.Exec(query)
174177
if err != nil {
178+
// SQL执行报错记录序号和错误信息在业务结构中
179+
results.ExecErr = &driverV2.ExecErr{
180+
ErrSqlIndex: uint32(k),
181+
SqlExecErrMsg: err.Error(),
182+
}
175183
c.Logger().Errorf("exec sql failed, error: %s, query: %s", err, query)
176-
return results, err
184+
return results, nil
177185
} else {
178-
results = append(results, txResult)
186+
results.ExecResult = append(results.ExecResult, txResult)
179187
c.Logger().Infof("exec sql success, query: %s", query)
180188
}
181189
}

sqle/driver/mysql/mysql.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (i *MysqlDriverImpl) onlineddlWithGhost(query string) (bool, error) {
240240
return int64(tableSize) > i.cnf.DDLGhostMinSize, nil
241241
}
242242

243-
func (i *MysqlDriverImpl) Tx(ctx context.Context, queries ...string) ([]_driver.Result, error) {
243+
func (i *MysqlDriverImpl) Tx(ctx context.Context, queries ...string) (*driverV2.TxResponse, error) {
244244
if i.IsOfflineAudit() {
245245
return nil, nil
246246
}

sqle/driver/mysql/rule/ai/util/extractor.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,23 @@ type SubqueryExprExtractor struct {
158158
expr []*ast.SubqueryExpr
159159
}
160160

161-
func (te *SubqueryExprExtractor) Enter(in ast.Node) (node ast.Node, skipChildren bool) {
162-
e, ok := in.(*ast.SubqueryExpr)
163-
if !ok {
161+
func (te *SubqueryExprExtractor) Enter(in ast.Node) (ast.Node, bool) {
162+
// 情况 1: WHERE、IN、EXISTS 中的子查询
163+
if sub, ok := in.(*ast.SubqueryExpr); ok {
164+
te.expr = append(te.expr, sub)
164165
return in, false
165166
}
166-
te.expr = append(te.expr, e)
167+
168+
// 情况 2: FROM 中的子查询(即 SelectStmt 嵌套在 TableSource 中)
169+
if ts, ok := in.(*ast.TableSource); ok {
170+
if sel, ok := ts.Source.(*ast.SelectStmt); ok {
171+
// 包装成 SubqueryExpr 形式以便统一处理
172+
te.expr = append(te.expr, &ast.SubqueryExpr{
173+
Query: sel,
174+
})
175+
}
176+
}
177+
167178
return in, false
168179
}
169180

sqle/driver/mysql/rule_00108_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ func TestRuleSQLE00108(t *testing.T) {
7272
nil, newTestResult())
7373

7474
// 这个子查询中实际扫描表的子查询 只有1个,因此不算违规
75-
runAIRuleCase(rule, t, "case 15: SELECT语句中使用JOIN的ON条件中嵌套子查询5层, 但是实际扫描表的子查询只有1次,因此不算违规",
75+
runAIRuleCase(rule, t, "case 15: SELECT语句中使用JOIN的ON条件中嵌套子查询5层",
7676
"SELECT st1.id FROM st1 JOIN st_class ON st1.cid in (SELECT cid FROM (SELECT cid FROM (SELECT cid FROM (SELECT cid FROM (SELECT cid FROM st_class WHERE cname = 'class2') AS sub1) AS sub2) AS sub3) AS sub4);",
7777
session.NewAIMockContext().WithSQL("CREATE TABLE st1 (id INT, cid INT); CREATE TABLE st_class (cid INT, cname VARCHAR(50));"),
7878
nil, newTestResult())
7979

80-
runAIRuleCase(rule, t, "case 16: SELECT语句中使用JOIN的ON条件中嵌套子查询6层, 但是实际扫描表的子查询只有2次,因此不算违规",
80+
runAIRuleCase(rule, t, "case 16: SELECT语句中使用JOIN的ON条件中嵌套子查询6层, 违规",
8181
"SELECT st1.id FROM st1 JOIN st_class ON st1.cid in (SELECT cid FROM (SELECT cid FROM (SELECT cid FROM (SELECT cid FROM (SELECT cid FROM st_class WHERE cname in (SELECT cname FROM st_class WHERE cname = 'class2')) AS sub1) AS sub2) AS sub3) AS sub4);",
8282
session.NewAIMockContext().WithSQL("CREATE TABLE st1 (id INT, cid INT); CREATE TABLE st_class (cid INT, cname VARCHAR(50));"),
83-
nil, newTestResult())
83+
nil, newTestResult().addResult(ruleName))
8484

8585
runAIRuleCase(rule, t, "case 17: SELECT语句中 查询列中, 嵌套子查询2层",
8686
"SELECT 1, st1.id, (SELECT (SELECT id0 FROM exist_db.exist_tb_1 WHERE id1 = 'value') xx2 FROM exist_db.exist_tb_1 WHERE id1 = 'value') xxx FROM st1;",

sqle/driver/mysql/rule_00109_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,11 @@ func TestRuleSQLE00109(t *testing.T) {
7171
WithSQL("CREATE TABLE orders_archive (id INT, amount DECIMAL(10,2));"),
7272
nil, newTestResult())
7373

74-
// 特殊情况:解析器好像对from中嵌套的子查询 ,会被视作为属于非实际路由的子查询,因此没有累计(相当于只是个外壳嵌套而已,因此这里可以视作为不违规
75-
runAIRuleCase(rule, t, "case 9: SELECT ... UNION ALL SELECT语句中的子查询使用LIMIT,但是这里子查询其实是可以直接push到外层的,可以不算做子查询",
74+
runAIRuleCase(rule, t, "case 9: SELECT ... UNION ALL SELECT语句中的子查询使用LIMIT",
7675
"SELECT * FROM (SELECT id FROM products LIMIT 1) AS sub UNION ALL SELECT * FROM products;",
7776
session.NewAIMockContext().
7877
WithSQL("CREATE TABLE products (id INT, name VARCHAR(100));"),
79-
nil, newTestResult())
78+
nil, newTestResult().addResult(ruleName))
8079

8180
runAIRuleCase(rule, t, "case 10: SELECT ... UNION ALL SELECT语句中的子查询不使用LIMIT",
8281
"SELECT * FROM (SELECT id FROM products) AS sub UNION ALL SELECT * FROM products;",

sqle/driver/plugin_adapter_v1.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func (s *PluginImplV1) ExecBatch(ctx context.Context, sqls ...string) ([]sqlDriv
223223
return nil, fmt.Errorf("unimplement this method")
224224
}
225225

226-
func (p *PluginImplV1) Tx(ctx context.Context, queries ...string) ([]sqlDriver.Result, error) {
226+
func (p *PluginImplV1) Tx(ctx context.Context, queries ...string) (*driverV2.TxResponse, error) {
227227
client, err := p.DriverManager.GetAuditDriver()
228228
if err != nil {
229229
return nil, err

0 commit comments

Comments
 (0)