Skip to content

Commit eb9ac63

Browse files
committed
Implement handling of #[serde(default)] on variant in internally tagged enums
If tag will not be found in the data, the default tag will be assumed
1 parent ce8144f commit eb9ac63

File tree

3 files changed

+148
-16
lines changed

3 files changed

+148
-16
lines changed

serde/src/private/de.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -816,17 +816,18 @@ mod content {
816816
pub struct TaggedContentVisitor<T> {
817817
tag_name: &'static str,
818818
expecting: &'static str,
819-
value: PhantomData<T>,
819+
/// If set, this tag will be used if tag will not be found in data
820+
default: Option<T>,
820821
}
821822

822823
impl<T> TaggedContentVisitor<T> {
823824
/// Visitor for the content of an internally tagged enum with the given
824825
/// tag name.
825-
pub fn new(name: &'static str, expecting: &'static str) -> Self {
826+
pub fn new(name: &'static str, expecting: &'static str, default: Option<T>) -> Self {
826827
TaggedContentVisitor {
827828
tag_name: name,
828829
expecting,
829-
value: PhantomData,
830+
default,
830831
}
831832
}
832833
}
@@ -846,6 +847,8 @@ mod content {
846847
where
847848
S: SeqAccess<'de>,
848849
{
850+
// We do not support sequence representation without tags, because that may
851+
// create ambiguity during deserialization
849852
let tag = match tri!(seq.next_element()) {
850853
Some(tag) => tag,
851854
None => {
@@ -879,7 +882,7 @@ mod content {
879882
}
880883
}
881884
}
882-
match tag {
885+
match tag.or(self.default) {
883886
None => Err(de::Error::missing_field(self.tag_name)),
884887
Some(tag) => Ok((tag, Content::Map(vec))),
885888
}

serde_derive/src/de/enum_internally.rs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,42 @@ pub(super) fn deserialize(
2727
let (variants_stmt, variant_visitor) = enum_::prepare_enum_variant_enum(variants);
2828

2929
// Match arms to extract a variant from a string
30-
let variant_arms = variants
30+
let mut variants = variants
3131
.iter()
3232
.enumerate()
33-
.filter(|&(_, variant)| !variant.attrs.skip_deserializing())
34-
.map(|(i, variant)| {
35-
let variant_name = field_i(i);
33+
.filter(|&(_, variant)| !variant.attrs.skip_deserializing());
34+
let variant_arms = variants.clone().map(|(i, variant)| {
35+
let variant_name = field_i(i);
3636

37-
let block = Match(deserialize_internally_tagged_variant(
38-
params, variant, cattrs,
39-
));
37+
let block = Match(deserialize_internally_tagged_variant(
38+
params, variant, cattrs,
39+
));
4040

41-
quote! {
42-
__Field::#variant_name => #block
43-
}
44-
});
41+
quote! {
42+
__Field::#variant_name => #block
43+
}
44+
});
4545

4646
let expecting = format!("internally tagged enum {}", params.type_name());
4747
let expecting = cattrs.expecting().unwrap_or(&expecting);
4848

49+
// We checked that only one variant is marked with #[serde(default)]
50+
let default = match variants.find(|(_, variant)| variant.attrs.default()) {
51+
Some((i, _)) => {
52+
let default = field_i(i);
53+
quote! { _serde::#private::Some(__Field::#default) }
54+
}
55+
None => quote! { _serde::#private::None },
56+
};
57+
4958
quote_block! {
5059
#variant_visitor
5160

5261
#variants_stmt
5362

5463
let (__tag, __content) = _serde::Deserializer::deserialize_any(
5564
__deserializer,
56-
_serde::#private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting))?;
65+
_serde::#private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting, #default))?;
5766
let __deserializer = _serde::#private::de::ContentDeserializer::<__D::Error>::new(__content);
5867

5968
match __tag {

test_suite/tests/test_enum_internally_tagged.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,126 @@ mod struct_enum {
10371037
}
10381038
}
10391039

1040+
#[test]
1041+
fn default_variant() {
1042+
#[derive(Debug, PartialEq, Serialize, Deserialize)]
1043+
#[serde(tag = "tag")]
1044+
enum InternallyTaggedWithDefault {
1045+
Unit,
1046+
NewtypeUnit(()),
1047+
NewtypeUnitStruct(Unit),
1048+
NewtypeNewtype(Newtype),
1049+
NewtypeMap(BTreeMap<String, String>),
1050+
NewtypeStruct(Struct),
1051+
NewtypeEnum(Enum),
1052+
#[serde(default)]
1053+
Struct {
1054+
a: u8,
1055+
},
1056+
StructEnum {
1057+
enum_: Enum,
1058+
},
1059+
}
1060+
1061+
let value = InternallyTaggedWithDefault::Struct { a: 1 };
1062+
1063+
// Special case: no tag field, use enum tokens
1064+
assert_de_tokens(
1065+
&value,
1066+
&[
1067+
Token::Struct {
1068+
name: "InternallyTagged",
1069+
len: 1,
1070+
},
1071+
Token::Str("a"),
1072+
Token::U8(1),
1073+
Token::StructEnd,
1074+
],
1075+
);
1076+
assert_de_tokens(
1077+
&value,
1078+
&[
1079+
Token::Struct {
1080+
name: "InternallyTagged",
1081+
len: 1,
1082+
},
1083+
Token::BorrowedStr("a"),
1084+
Token::U8(1),
1085+
Token::StructEnd,
1086+
],
1087+
);
1088+
1089+
// Special case: no tag field, Map representation
1090+
assert_de_tokens(
1091+
&value,
1092+
&[
1093+
Token::Map { len: Some(1) },
1094+
Token::Str("a"),
1095+
Token::U8(1),
1096+
Token::MapEnd,
1097+
],
1098+
);
1099+
assert_de_tokens(
1100+
&value,
1101+
&[
1102+
Token::Map { len: Some(1) },
1103+
Token::BorrowedStr("a"),
1104+
Token::U8(1),
1105+
Token::MapEnd,
1106+
],
1107+
);
1108+
1109+
// Special case: Map representation, unknown tag
1110+
assert_de_tokens_error::<InternallyTaggedWithDefault>(
1111+
&[
1112+
Token::Map { len: Some(1) },
1113+
Token::Str("tag"),
1114+
Token::Str("Z"),
1115+
Token::MapEnd,
1116+
],
1117+
"unknown variant `Z`, expected one of \
1118+
`Unit`, \
1119+
`NewtypeUnit`, \
1120+
`NewtypeUnitStruct`, \
1121+
`NewtypeNewtype`, \
1122+
`NewtypeMap`, \
1123+
`NewtypeStruct`, \
1124+
`NewtypeEnum`, \
1125+
`Struct`, \
1126+
`StructEnum`",
1127+
);
1128+
1129+
// Special case: Seq representation, unknown tag
1130+
assert_de_tokens_error::<InternallyTaggedWithDefault>(
1131+
&[
1132+
Token::Seq { len: Some(1) },
1133+
Token::Str("Z"), // tag
1134+
Token::SeqEnd,
1135+
],
1136+
"unknown variant `Z`, expected one of \
1137+
`Unit`, \
1138+
`NewtypeUnit`, \
1139+
`NewtypeUnitStruct`, \
1140+
`NewtypeNewtype`, \
1141+
`NewtypeMap`, \
1142+
`NewtypeStruct`, \
1143+
`NewtypeEnum`, \
1144+
`Struct`, \
1145+
`StructEnum`",
1146+
);
1147+
1148+
// Special case: Seq representation cannot be used without a tag due to ambiguity
1149+
assert_de_tokens_error::<InternallyTaggedWithDefault>(
1150+
&[
1151+
Token::Seq { len: Some(1) },
1152+
Token::U8(1), // tag (== NewtypeUnit)
1153+
Token::SeqEnd,
1154+
],
1155+
// The error is not very clear, because actually we got end of sequence instead of a Unit
1156+
"invalid type: sequence, expected unit",
1157+
);
1158+
}
1159+
10401160
#[test]
10411161
fn wrong_tag() {
10421162
assert_de_tokens_error::<InternallyTagged>(

0 commit comments

Comments
 (0)