Skip to content

Commit d76c1dc

Browse files
committed
Fix helper
1 parent b0dc683 commit d76c1dc

File tree

4 files changed

+44
-16
lines changed

4 files changed

+44
-16
lines changed

configurator.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ func New[T any](
1414
configPtr: new(T),
1515
providers: providers,
1616
registeredProviders: map[string]struct{}{},
17+
registeredTags: map[string]struct{}{},
1718
}
1819

1920
return cfg.initValues()
@@ -22,6 +23,7 @@ func New[T any](
2223
type Configurator[T any] struct {
2324
configPtr *T
2425
providers []Provider
26+
registeredTags map[string]struct{}
2527
registeredProviders map[string]struct{}
2628
}
2729

@@ -42,6 +44,11 @@ func (c *Configurator[T]) initValues() (*T, error) {
4244
}
4345
c.registeredProviders[p.Name()] = struct{}{}
4446

47+
if _, ok := c.registeredTags[p.Tag()]; ok {
48+
return nil, ErrProviderTagCollision
49+
}
50+
c.registeredTags[p.Tag()] = struct{}{}
51+
4552
if err := p.Init(c.configPtr); err != nil {
4653
return nil, fmt.Errorf("cannot init [%s] provider: %w", p.Name(), err)
4754
}
@@ -100,7 +107,7 @@ func (c *Configurator[T]) applyProviders(field reflect.StructField, v reflect.Va
100107
}
101108

102109
for _, provider := range c.providers {
103-
if _, found := fetchTagKey(field.Tag)[provider.Tag()]; !found {
110+
if _, found := fetchTagKey(field.Tag, c.registeredTags)[provider.Tag()]; !found {
104111
// skip provider if it's not specified in tags
105112
continue
106113
}

errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ var (
1212
ErrInvalidInput = errors.New("invalid input")
1313
ErrNoProviders = errors.New("no providers")
1414
ErrProviderNameCollision = errors.New("provider name collision")
15+
ErrProviderTagCollision = errors.New("provider tag collision")
1516
)

helper.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,15 @@ package configuration
22

33
import (
44
"reflect"
5-
"strings"
65
)
76

8-
func fetchTagKey(t reflect.StructTag) map[string]struct{} {
7+
func fetchTagKey(t reflect.StructTag, registered map[string]struct{}) map[string]struct{} {
98
keys := map[string]struct{}{}
109

11-
pairs := strings.Split(string(t), " ")
12-
if len(pairs) == 1 && pairs[0] == "" {
13-
return keys
14-
}
15-
16-
for _, pair := range pairs {
17-
kv := strings.Split(pair, `:"`)
18-
if len(kv) < 1 {
19-
return keys
10+
for rt := range registered {
11+
if _, ok := t.Lookup(rt); ok {
12+
keys[rt] = struct{}{}
2013
}
21-
keys[kv[0]] = struct{}{}
2214
}
2315

2416
return keys

helper_test.go

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ import (
88
func Test_fetchTagKey(t *testing.T) {
99
t.Parallel()
1010

11+
registredTags := map[string]struct{}{
12+
"json": {},
13+
"xml": {},
14+
"flag": {},
15+
"default": {},
16+
}
17+
1118
tests := []struct {
1219
name string
1320
in reflect.StructTag
@@ -27,9 +34,9 @@ func Test_fetchTagKey(t *testing.T) {
2734
},
2835
{
2936
name: "non-empty tag value",
30-
in: reflect.StructTag(`json:"id"`),
37+
in: reflect.StructTag(`default:"one;two"`),
3138
want: map[string]struct{}{
32-
"json": {},
39+
"default": {},
3340
},
3441
},
3542
{
@@ -40,14 +47,35 @@ func Test_fetchTagKey(t *testing.T) {
4047
"xml": {},
4148
},
4249
},
50+
{
51+
name: "malformed tag",
52+
in: reflect.StructTag(`json`),
53+
want: map[string]struct{}{},
54+
},
55+
{
56+
name: "tag with spaces in value",
57+
in: reflect.StructTag(`flag:"name_flag||Some description" json:"id" xml:"ID"`),
58+
want: map[string]struct{}{
59+
"flag": {},
60+
"json": {},
61+
"xml": {},
62+
},
63+
},
64+
{
65+
name: "special characters in tag value",
66+
in: reflect.StructTag(`default:"one;two-three_1,/ and this!=2*7&^3:,. $"`),
67+
want: map[string]struct{}{
68+
"default": {},
69+
},
70+
},
4371
}
4472
for _, tt := range tests {
4573
tt := tt
4674

4775
t.Run(tt.name, func(t *testing.T) {
4876
t.Parallel()
4977

50-
assert(t, tt.want, fetchTagKey(tt.in))
78+
assert(t, tt.want, fetchTagKey(tt.in, registredTags))
5179
})
5280
}
5381
}

0 commit comments

Comments
 (0)