Skip to content

Commit 3f98836

Browse files
authored
Merge pull request #40 from HRogge/master
Allow defaults library to use UnmarshalText() and UnmarshalJSON()
2 parents 325dfb4 + d1d4e26 commit 3f98836

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

defaults.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package defaults
22

33
import (
4+
"encoding"
45
"encoding/json"
56
"errors"
67
"reflect"
@@ -61,6 +62,10 @@ func setField(field reflect.Value, defaultVal string) error {
6162

6263
isInitial := isInitialValue(field)
6364
if isInitial {
65+
if unmarshalByInterface(field, defaultVal) {
66+
return nil
67+
}
68+
6469
switch field.Kind() {
6570
case reflect.Bool:
6671
if val, err := strconv.ParseBool(defaultVal); err == nil {
@@ -194,6 +199,24 @@ func setField(field reflect.Value, defaultVal string) error {
194199
return nil
195200
}
196201

202+
func unmarshalByInterface(field reflect.Value, defaultVal string) bool {
203+
asText, ok := field.Addr().Interface().(encoding.TextUnmarshaler)
204+
if ok && defaultVal != "" {
205+
// if field implements encode.TextUnmarshaler, try to use it before decode by kind
206+
if err := asText.UnmarshalText([]byte(defaultVal)); err == nil {
207+
return true
208+
}
209+
}
210+
asJSON, ok := field.Addr().Interface().(json.Unmarshaler)
211+
if ok && defaultVal != "" && defaultVal != "{}" && defaultVal != "[]" {
212+
// if field implements json.Unmarshaler, try to use it before decode by kind
213+
if err := asJSON.UnmarshalJSON([]byte(defaultVal)); err == nil {
214+
return true
215+
}
216+
}
217+
return false
218+
}
219+
197220
func isInitialValue(field reflect.Value) bool {
198221
return reflect.DeepEqual(reflect.Zero(field.Type()).Interface(), field.Interface())
199222
}

defaults_test.go

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package defaults
22

33
import (
4+
"encoding/json"
5+
"errors"
6+
"net"
47
"reflect"
8+
"strconv"
59
"testing"
610
"time"
711

@@ -112,9 +116,12 @@ type Sample struct {
112116
MyMap MyMap `default:"{}"`
113117
MySlice MySlice `default:"[]"`
114118

115-
StructWithJSON Struct `default:"{\"Foo\": 123}"`
116-
StructPtrWithJSON *Struct `default:"{\"Foo\": 123}"`
117-
MapWithJSON map[string]int `default:"{\"foo\": 123}"`
119+
StructWithText net.IP `default:"10.0.0.1"`
120+
StructPtrWithText *net.IP `default:"10.0.0.1"`
121+
StructWithJSON Struct `default:"{\"Foo\": 123}"`
122+
StructPtrWithJSON *Struct `default:"{\"Foo\": 123}"`
123+
MapWithJSON map[string]int `default:"{\"foo\": 123}"`
124+
TypeWithUnmarshalJSON JSONOnlyType `default:"\"one\""`
118125

119126
MapOfPtrStruct map[string]*Struct
120127
MapOfStruct map[string]Struct
@@ -155,6 +162,24 @@ type Embedded struct {
155162
Int int `default:"1"`
156163
}
157164

165+
type JSONOnlyType int
166+
167+
func (j *JSONOnlyType) UnmarshalJSON(b []byte) error {
168+
var tmp string
169+
if err := json.Unmarshal(b, &tmp); err != nil {
170+
return err
171+
}
172+
if i, err := strconv.Atoi(tmp); err == nil {
173+
*j = JSONOnlyType(i)
174+
return nil
175+
}
176+
if tmp == "one" {
177+
*j = 1
178+
return nil
179+
}
180+
return errors.New("cannot unmarshal")
181+
}
182+
158183
func TestMustSet(t *testing.T) {
159184

160185
t.Run("right way", func(t *testing.T) {
@@ -485,6 +510,14 @@ func TestInit(t *testing.T) {
485510
}
486511
})
487512

513+
t.Run("complex types with text unmarshal", func(t *testing.T) {
514+
if !sample.StructWithText.Equal(net.ParseIP("10.0.0.1")) {
515+
t.Errorf("it should initialize struct with text")
516+
}
517+
if !sample.StructPtrWithText.Equal(net.ParseIP("10.0.0.1")) {
518+
t.Errorf("it should initialize struct with text")
519+
}
520+
})
488521
t.Run("complex types with json", func(t *testing.T) {
489522
if sample.StructWithJSON.Foo != 123 {
490523
t.Errorf("it should initialize struct with json")
@@ -499,6 +532,10 @@ func TestInit(t *testing.T) {
499532
t.Errorf("it should initialize slice with json")
500533
}
501534

535+
if int(sample.TypeWithUnmarshalJSON) != 1 {
536+
t.Errorf("it should initialize json unmarshaled value")
537+
}
538+
502539
t.Run("invalid json", func(t *testing.T) {
503540
if err := Set(&struct {
504541
I []int `default:"[!]"`

0 commit comments

Comments
 (0)