|
| 1 | +package encoding |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "net/netip" |
| 6 | + "reflect" |
| 7 | + "strings" |
| 8 | +) |
| 9 | + |
| 10 | +const tagName = "spoe" |
| 11 | + |
| 12 | +// Unmarshal unmarshals KV entries from the scanner into the provided struct. |
| 13 | +// The struct should have fields tagged with `spoe:"keyname"` to map KV entry |
| 14 | +// names to struct fields. |
| 15 | +// |
| 16 | +// Supported field types: |
| 17 | +// - string, []byte (for DataTypeString and DataTypeBinary) |
| 18 | +// - int32, int64, uint32, uint64 (for integer types) |
| 19 | +// - bool (for DataTypeBool) |
| 20 | +// - netip.Addr (for DataTypeIPV4 and DataTypeIPV6) |
| 21 | +// - pointer types for optional fields (nil if key not found) |
| 22 | +// |
| 23 | +// Example: |
| 24 | +// |
| 25 | +// type RequestData struct { |
| 26 | +// Headers []byte `spoe:"headers"` |
| 27 | +// Status int32 `spoe:"status-code"` |
| 28 | +// IP netip.Addr `spoe:"client-ip"` |
| 29 | +// Optional *string `spoe:"optional-field"` |
| 30 | +// } |
| 31 | +func (k *KVScanner) Unmarshal(v any) error { |
| 32 | + rv := reflect.ValueOf(v) |
| 33 | + if rv.Kind() != reflect.Pointer || rv.IsNil() { |
| 34 | + return fmt.Errorf("unmarshal target must be a non-nil pointer to struct") |
| 35 | + } |
| 36 | + |
| 37 | + rv = rv.Elem() |
| 38 | + if rv.Kind() != reflect.Struct { |
| 39 | + return fmt.Errorf("unmarshal target must be a pointer to struct") |
| 40 | + } |
| 41 | + |
| 42 | + rt := rv.Type() |
| 43 | + |
| 44 | + // Build a slice of field info to avoid string allocations during lookup |
| 45 | + type fieldInfo struct { |
| 46 | + keyStr string // cached for NameEquals and error messages |
| 47 | + fieldIdx int |
| 48 | + field reflect.Value // cached to avoid repeated rv.Field() calls |
| 49 | + fieldKind reflect.Kind // cached to avoid repeated Kind() calls |
| 50 | + isPointer bool // cached to avoid repeated checks |
| 51 | + } |
| 52 | + fields := make([]fieldInfo, 0, rt.NumField()) |
| 53 | + pointerFieldIndices := make([]int, 0, rt.NumField()) // track pointer field indices for final cleanup |
| 54 | + for i := 0; i < rt.NumField(); i++ { |
| 55 | + field := rt.Field(i) |
| 56 | + tag := field.Tag.Get(tagName) |
| 57 | + if tag == "" || tag == "-" { |
| 58 | + continue |
| 59 | + } |
| 60 | + |
| 61 | + // Handle comma-separated options (e.g., "keyname,omitempty") |
| 62 | + // Use IndexByte to avoid allocation from strings.Split |
| 63 | + commaIdx := strings.IndexByte(tag, ',') |
| 64 | + var key string |
| 65 | + if commaIdx >= 0 { |
| 66 | + key = tag[:commaIdx] |
| 67 | + } else { |
| 68 | + key = tag |
| 69 | + } |
| 70 | + if key != "" { |
| 71 | + fv := rv.Field(i) |
| 72 | + fk := fv.Kind() |
| 73 | + isPtr := fk == reflect.Pointer |
| 74 | + fields = append(fields, fieldInfo{ |
| 75 | + keyStr: key, |
| 76 | + fieldIdx: i, |
| 77 | + field: fv, |
| 78 | + fieldKind: fk, |
| 79 | + isPointer: isPtr, |
| 80 | + }) |
| 81 | + if isPtr { |
| 82 | + pointerFieldIndices = append(pointerFieldIndices, i) |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + entry := AcquireKVEntry() |
| 88 | + defer ReleaseKVEntry(entry) |
| 89 | + |
| 90 | + // Track which pointer fields have been set (to clear unset ones later) |
| 91 | + setPointerFields := make(map[int]bool, len(pointerFieldIndices)) |
| 92 | + |
| 93 | + for k.Next(entry) { |
| 94 | + var fi *fieldInfo |
| 95 | + // Use NameEquals to avoid string allocation during lookup |
| 96 | + for i := range fields { |
| 97 | + if entry.NameEquals(fields[i].keyStr) { |
| 98 | + fi = &fields[i] |
| 99 | + break |
| 100 | + } |
| 101 | + } |
| 102 | + if fi == nil { |
| 103 | + // Unknown key, skip it |
| 104 | + continue |
| 105 | + } |
| 106 | + |
| 107 | + if !fi.field.CanSet() { |
| 108 | + return fmt.Errorf("field %s is not settable", rt.Field(fi.fieldIdx).Name) |
| 109 | + } |
| 110 | + |
| 111 | + if err := setFieldValue(fi.field, fi.fieldKind, entry); err != nil { |
| 112 | + return fmt.Errorf("field %s (key %q): %w", rt.Field(fi.fieldIdx).Name, fi.keyStr, err) |
| 113 | + } |
| 114 | + |
| 115 | + // Track if this is a pointer field that was set |
| 116 | + if fi.isPointer { |
| 117 | + setPointerFields[fi.fieldIdx] = true |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + if err := k.Error(); err != nil { |
| 122 | + return fmt.Errorf("scanner error: %w", err) |
| 123 | + } |
| 124 | + |
| 125 | + // Set pointer fields to nil if they weren't set (important for pooled structs) |
| 126 | + // Only iterate through known pointer fields instead of all fields |
| 127 | + for _, idx := range pointerFieldIndices { |
| 128 | + if !setPointerFields[idx] { |
| 129 | + rv.Field(idx).Set(reflect.Zero(rt.Field(idx).Type)) |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + return nil |
| 134 | +} |
| 135 | + |
| 136 | +func setFieldValue(field reflect.Value, fieldKind reflect.Kind, entry *KVEntry) error { |
| 137 | + fieldType := field.Type() |
| 138 | + |
| 139 | + // Handle pointer types |
| 140 | + if fieldKind == reflect.Pointer { |
| 141 | + if entry.dataType == DataTypeNull { |
| 142 | + field.Set(reflect.Zero(fieldType)) |
| 143 | + return nil |
| 144 | + } |
| 145 | + |
| 146 | + // Create new value of the pointed-to type |
| 147 | + elemType := fieldType.Elem() |
| 148 | + elemValue := reflect.New(elemType).Elem() |
| 149 | + if err := setValue(elemValue, elemType.Kind(), entry); err != nil { |
| 150 | + return err |
| 151 | + } |
| 152 | + field.Set(elemValue.Addr()) |
| 153 | + return nil |
| 154 | + } |
| 155 | + |
| 156 | + return setValue(field, fieldKind, entry) |
| 157 | +} |
| 158 | + |
| 159 | +var netipAddrType = reflect.TypeOf((*netip.Addr)(nil)).Elem() |
| 160 | + |
| 161 | +func setValue(field reflect.Value, fieldKind reflect.Kind, entry *KVEntry) error { |
| 162 | + fieldType := field.Type() |
| 163 | + |
| 164 | + switch fieldKind { |
| 165 | + case reflect.String: |
| 166 | + if entry.dataType != DataTypeString { |
| 167 | + return fmt.Errorf("expected string, got %d", entry.dataType) |
| 168 | + } |
| 169 | + // Value() returns string for DataTypeString |
| 170 | + field.SetString(entry.Value().(string)) |
| 171 | + |
| 172 | + case reflect.Slice: |
| 173 | + if fieldType.Elem().Kind() != reflect.Uint8 { |
| 174 | + return fmt.Errorf("unsupported slice type: %s", fieldType) |
| 175 | + } |
| 176 | + // []byte |
| 177 | + if entry.dataType != DataTypeString && entry.dataType != DataTypeBinary { |
| 178 | + return fmt.Errorf("expected string or binary, got %d", entry.dataType) |
| 179 | + } |
| 180 | + // Copy the bytes to avoid referencing the underlying buffer |
| 181 | + val := entry.ValueBytes() |
| 182 | + cp := make([]byte, len(val)) |
| 183 | + copy(cp, val) |
| 184 | + field.SetBytes(cp) |
| 185 | + |
| 186 | + case reflect.Int32: |
| 187 | + if entry.dataType != DataTypeInt32 { |
| 188 | + return fmt.Errorf("expected int32, got %d", entry.dataType) |
| 189 | + } |
| 190 | + field.SetInt(entry.ValueInt()) |
| 191 | + |
| 192 | + case reflect.Int64: |
| 193 | + if entry.dataType != DataTypeInt64 { |
| 194 | + return fmt.Errorf("expected int64, got %d", entry.dataType) |
| 195 | + } |
| 196 | + field.SetInt(entry.ValueInt()) |
| 197 | + |
| 198 | + case reflect.Uint32: |
| 199 | + if entry.dataType != DataTypeUInt32 { |
| 200 | + return fmt.Errorf("expected uint32, got %d", entry.dataType) |
| 201 | + } |
| 202 | + field.SetUint(uint64(entry.ValueInt())) |
| 203 | + |
| 204 | + case reflect.Uint64: |
| 205 | + if entry.dataType != DataTypeUInt64 { |
| 206 | + return fmt.Errorf("expected uint64, got %d", entry.dataType) |
| 207 | + } |
| 208 | + field.SetUint(uint64(entry.ValueInt())) |
| 209 | + |
| 210 | + case reflect.Bool: |
| 211 | + if entry.dataType != DataTypeBool { |
| 212 | + return fmt.Errorf("expected bool, got %d", entry.dataType) |
| 213 | + } |
| 214 | + field.SetBool(entry.ValueBool()) |
| 215 | + |
| 216 | + default: |
| 217 | + // Check for netip.Addr (using cached type) |
| 218 | + if fieldType == netipAddrType { |
| 219 | + if entry.dataType != DataTypeIPV4 && entry.dataType != DataTypeIPV6 { |
| 220 | + return fmt.Errorf("expected IP address, got %d", entry.dataType) |
| 221 | + } |
| 222 | + addr := entry.ValueAddr() |
| 223 | + field.Set(reflect.ValueOf(addr)) |
| 224 | + return nil |
| 225 | + } |
| 226 | + |
| 227 | + return fmt.Errorf("unsupported field type: %s", fieldType) |
| 228 | + } |
| 229 | + |
| 230 | + return nil |
| 231 | +} |
0 commit comments