@@ -21,10 +21,10 @@ import (
2121 "context"
2222 "database/sql"
2323 "errors"
24+ "strings"
2425 "testing"
2526
2627 "github.com/DATA-DOG/go-sqlmock"
27- "github.com/agiledragon/gomonkey/v2"
2828 "github.com/golang/mock/gomock"
2929 "github.com/stretchr/testify/assert"
3030
@@ -36,6 +36,17 @@ import (
3636 "seata.apache.org/seata-go/pkg/rm"
3737)
3838
39+ type fnPatch struct {
40+ restore func ()
41+ }
42+
43+ func (p * fnPatch ) Reset () {
44+ if p != nil && p .restore != nil {
45+ p .restore ()
46+ p .restore = nil
47+ }
48+ }
49+
3950func initMockIndexMeta () []types.IndexMeta {
4051 return []types.IndexMeta {
4152 {
@@ -62,20 +73,20 @@ func initMockColumnMeta() []types.ColumnMeta {
6273 }
6374}
6475
65- func initGetIndexesStub (m * mysqlTrigger , indexMeta []types.IndexMeta ) * gomonkey. Patches {
66- getIndexesStub := gomonkey . ApplyPrivateMethod ( m , "getIndexes" ,
67- func (_ * mysqlTrigger , ctx context.Context , dbName string , tableName string , conn * sql.Conn ) ([]types.IndexMeta , error ) {
68- return indexMeta , nil
69- })
70- return getIndexesStub
76+ func initGetIndexesStub (_ * mysqlTrigger , indexMeta []types.IndexMeta ) * fnPatch {
77+ old := getIndexesFn
78+ getIndexesFn = func (_ * mysqlTrigger , ctx context.Context , dbName string , tableName string , conn * sql.Conn ) ([]types.IndexMeta , error ) {
79+ return indexMeta , nil
80+ }
81+ return & fnPatch { restore : func () { getIndexesFn = old }}
7182}
7283
73- func initGetColumnMetasStub (m * mysqlTrigger , columnMeta []types.ColumnMeta ) * gomonkey. Patches {
74- getColumnMetasStub := gomonkey . ApplyPrivateMethod ( m , "getColumnMetas" ,
75- func (_ * mysqlTrigger , ctx context.Context , dbName string , table string , conn * sql.Conn ) ([]types.ColumnMeta , error ) {
76- return columnMeta , nil
77- })
78- return getColumnMetasStub
84+ func initGetColumnMetasStub (_ * mysqlTrigger , columnMeta []types.ColumnMeta ) * fnPatch {
85+ old := getColumnMetasFn
86+ getColumnMetasFn = func (_ * mysqlTrigger , ctx context.Context , dbName string , table string , conn * sql.Conn ) ([]types.ColumnMeta , error ) {
87+ return columnMeta , nil
88+ }
89+ return & fnPatch { restore : func () { getColumnMetasFn = old }}
7990}
8091
8192func Test_mysqlTrigger_LoadOne (t * testing.T ) {
@@ -234,22 +245,22 @@ func Test_mysqlTrigger_LoadOne_ErrorCases(t *testing.T) {
234245 m := & mysqlTrigger {}
235246
236247 if tt .columnMetaErr != nil {
237- getColumnMetasStub := gomonkey . ApplyPrivateMethod ( m , "getColumnMetas" ,
238- func (_ * mysqlTrigger , ctx context.Context , dbName string , table string , conn * sql.Conn ) ([]types.ColumnMeta , error ) {
239- return nil , tt .columnMetaErr
240- })
241- defer getColumnMetasStub . Reset ()
248+ old := getColumnMetasFn
249+ getColumnMetasFn = func (_ * mysqlTrigger , ctx context.Context , dbName string , table string , conn * sql.Conn ) ([]types.ColumnMeta , error ) {
250+ return nil , tt .columnMetaErr
251+ }
252+ defer func () { getColumnMetasFn = old } ()
242253 } else {
243254 getColumnMetasStub := initGetColumnMetasStub (m , tt .columnMeta )
244255 defer getColumnMetasStub .Reset ()
245256 }
246257
247258 if tt .indexMetaErr != nil {
248- getIndexesStub := gomonkey . ApplyPrivateMethod ( m , "getIndexes" ,
249- func (_ * mysqlTrigger , ctx context.Context , dbName string , tableName string , conn * sql.Conn ) ([]types.IndexMeta , error ) {
250- return nil , tt .indexMetaErr
251- })
252- defer getIndexesStub . Reset ()
259+ old := getIndexesFn
260+ getIndexesFn = func (_ * mysqlTrigger , ctx context.Context , dbName string , tableName string , conn * sql.Conn ) ([]types.IndexMeta , error ) {
261+ return nil , tt .indexMetaErr
262+ }
263+ defer func () { getIndexesFn = old } ()
253264 } else {
254265 getIndexesStub := initGetIndexesStub (m , tt .indexMeta )
255266 defer getIndexesStub .Reset ()
@@ -258,8 +269,10 @@ func Test_mysqlTrigger_LoadOne_ErrorCases(t *testing.T) {
258269 _ , err := m .LoadOne (context .Background (), "testdb" , "testtable" , nil )
259270
260271 if tt .expectError {
261- assert .Error (t , err )
262- assert .Contains (t , err .Error (), tt .errorContains )
272+ if ! assert .Error (t , err ) {
273+ return
274+ }
275+ assert .Contains (t , strings .ToLower (err .Error ()), strings .ToLower (tt .errorContains ))
263276 } else {
264277 assert .NoError (t , err )
265278 }
@@ -604,15 +617,15 @@ func Test_mysqlTrigger_LoadAll_ErrorHandling(t *testing.T) {
604617 indexMeta := initMockIndexMeta ()
605618
606619 callCount := 0
607- getColumnMetasStub := gomonkey . ApplyPrivateMethod ( m , "getColumnMetas" ,
608- func (_ * mysqlTrigger , ctx context.Context , dbName string , table string , conn * sql.Conn ) ([]types.ColumnMeta , error ) {
609- callCount ++
610- if callCount == 2 {
611- return nil , errors .New ("column error" )
612- }
613- return columnMeta , nil
614- })
615- defer getColumnMetasStub . Reset ()
620+ old := getColumnMetasFn
621+ getColumnMetasFn = func (_ * mysqlTrigger , ctx context.Context , dbName string , table string , conn * sql.Conn ) ([]types.ColumnMeta , error ) {
622+ callCount ++
623+ if callCount == 2 {
624+ return nil , errors .New ("column error" )
625+ }
626+ return columnMeta , nil
627+ }
628+ defer func () { getColumnMetasFn = old } ()
616629
617630 getIndexesStub := initGetIndexesStub (m , indexMeta )
618631 defer getIndexesStub .Reset ()
0 commit comments