diff --git a/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.BodyAndValidation.cs b/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.BodyAndValidation.cs new file mode 100644 index 0000000..89ca774 --- /dev/null +++ b/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.BodyAndValidation.cs @@ -0,0 +1,104 @@ +using ErrorOr.Generators; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace ErrorOr.Analyzers; + +/// +/// Body-source counting (EOE006) and DataAnnotations reflection check (EOE039) for the +/// . Body classification routes through +/// so analyzer and generator share one source of truth +/// (FQN + inheritance, not name-substring match). +/// +public sealed partial class ErrorOrEndpointAnalyzer +{ + private static int CountBodySources(IMethodSymbol method) + { + var bodyCount = 0; + var hasFromForm = false; + var hasStream = false; + + foreach (var param in method.Parameters) + { + // Check for body-related attributes using HasAttribute + if (param.HasAttribute(WellKnownTypes.FromBodyAttribute)) + { + bodyCount++; + continue; + } + + if (param.HasAttribute(WellKnownTypes.FromFormAttribute)) + { + hasFromForm = true; + continue; + } + + // Body-related types: route through ErrorOrContext so analyzer + generator + // share one source of truth (FQN + inheritance, not name substring match). + if (ErrorOrContext.IsStream(param.Type) || ErrorOrContext.IsPipeReader(param.Type)) + hasStream = true; + else if (ErrorOrContext.IsFormFile(param.Type) || + ErrorOrContext.IsFormFileCollection(param.Type) || + ErrorOrContext.IsFormCollection(param.Type)) + hasFromForm = true; + } + + // Multiple [FromBody] is always an error + if (bodyCount > 1) return bodyCount; + + // Otherwise return number of distinct body source buckets used + return (bodyCount > 0 ? 1 : 0) + (hasFromForm ? 1 : 0) + (hasStream ? 1 : 0); + } + + private static bool HasAcceptedResponseAttribute(ISymbol method) + { + return method.HasAttribute(WellKnownTypes.AcceptedResponseAttribute); + } + + /// + /// Checks if any parameter has validation attributes from System.ComponentModel.DataAnnotations. + /// Validator.TryValidateObject uses reflection internally. + /// + private static void CheckForValidationAttributes( + in SymbolAnalysisContext context, + IMethodSymbol method) + { + var validationAttributeType = context.Compilation.GetTypeByMetadataName(WellKnownTypes.ValidationAttribute); + if (validationAttributeType is null) return; + + foreach (var param in method.Parameters) + { + foreach (var attr in param.GetAttributes()) + { + if (attr.AttributeClass is null) continue; + + // Check if the attribute inherits from ValidationAttribute + if (InheritsFrom(attr.AttributeClass, validationAttributeType)) + { + context.ReportDiagnostic(Diagnostic.Create( + Descriptors.ValidationUsesReflection, + param.Locations.FirstOrDefault() ?? method.Locations.FirstOrDefault(), + param.Name, + method.Name)); + break; // Only report once per parameter + } + } + } + } + + /// + /// Checks if a type inherits from a base type. + /// + private static bool InheritsFrom(ITypeSymbol type, ISymbol baseType) + { + var current = type.BaseType; + while (current is not null) + { + if (SymbolEqualityComparer.Default.Equals(current, baseType)) return true; + + current = current.BaseType; + } + + return false; + } +} diff --git a/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.RouteValidation.cs b/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.RouteValidation.cs new file mode 100644 index 0000000..fce2c79 --- /dev/null +++ b/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.RouteValidation.cs @@ -0,0 +1,186 @@ +using ANcpLua.Roslyn.Utilities; +using ErrorOr.Generators; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace ErrorOr.Analyzers; + +/// +/// Route pattern validation and route-constraint type checking. Hosts EOE005 +/// (pattern syntax) and EOE020 (constraint vs. CLR-type mismatch) for the +/// . +/// +public sealed partial class ErrorOrEndpointAnalyzer +{ + /// + /// Validates route constraint types match method parameter types (EOE020). + /// + private static void ValidateConstraintTypes( + in SymbolAnalysisContext context, + ImmutableArray routeParams, + IReadOnlyDictionary methodParamsByRouteName, + Location attributeLocation) + { + foreach (var rp in routeParams) + ValidateSingleRouteConstraint(in context, rp, methodParamsByRouteName, attributeLocation); + } + + /// + /// Validates a single route parameter constraint against its bound method parameter. + /// + private static void ValidateSingleRouteConstraint( + in SymbolAnalysisContext context, + RouteParameterInfo rp, + IReadOnlyDictionary methodParamsByRouteName, + Location attributeLocation) + { + // Skip if no constraint or not bound to a method parameter + if (rp.Constraint is not { } constraint || + !methodParamsByRouteName.TryGetValue(rp.Name, out var mp)) + { + return; + } + + if (mp.TypeFqn is not { } typeFqn) return; + + // Skip format-only constraints + if (IsFormatOnlyConstraint(constraint)) return; + + // Validate based on constraint type + if (rp.IsCatchAll) + ValidateCatchAllConstraint(in context, rp, mp, typeFqn, attributeLocation); + else + ValidateTypedConstraint(in context, rp, constraint, mp, typeFqn, attributeLocation); + } + + /// + /// Checks if a constraint is format-only and doesn't constrain the CLR type. + /// Delegates to shared RouteValidator to avoid duplication. + /// + private static bool IsFormatOnlyConstraint(string constraint) + { + return RouteValidator.FormatOnlyConstraints.Contains(constraint); + } + + /// + /// Validates that a catch-all parameter is bound to a string type. + /// + private static void ValidateCatchAllConstraint( + in SymbolAnalysisContext context, + RouteParameterInfo rp, + RouteMethodParameterInfo mp, + string typeFqn, + Location attributeLocation) + { + if (!IsStringType(typeFqn)) + { + context.ReportDiagnostic(Diagnostic.Create( + Descriptors.RouteConstraintTypeMismatch, + attributeLocation, + rp.Name, + "*", + "string", + mp.Name, + NormalizeTypeName(typeFqn))); + } + } + + /// + /// Validates that a typed constraint matches the bound parameter type. + /// Uses shared RouteValidator.ConstraintToTypes to avoid duplication. + /// + private static void ValidateTypedConstraint( + in SymbolAnalysisContext context, + RouteParameterInfo rp, + string constraint, + RouteMethodParameterInfo mp, + string typeFqn, + Location attributeLocation) + { + // Look up expected types for this constraint using shared RouteValidator + if (!RouteValidator.ConstraintToTypes.TryGetValue(constraint, + out var expectedTypes)) + { + return; // Unknown constraint (e.g., custom) - skip validation + } + + // Get the actual type, unwrapping Nullable for optional parameters + var actualTypeFqn = typeFqn.UnwrapNullable(rp.IsOptional || mp.IsNullable); + + // Check if actual type matches any expected type + if (!DoesTypeMatchConstraint(actualTypeFqn, expectedTypes)) + { + context.ReportDiagnostic(Diagnostic.Create( + Descriptors.RouteConstraintTypeMismatch, + attributeLocation, + rp.Name, + constraint, + expectedTypes[0], + mp.Name, + NormalizeTypeName(typeFqn))); + } + } + + /// + /// Checks if an actual type matches any of the expected types for a constraint. + /// + private static bool DoesTypeMatchConstraint(string actualTypeFqn, IEnumerable expectedTypes) + { + foreach (var expected in expectedTypes) + { + if (TypeNamesMatch(actualTypeFqn, expected)) + return true; + } + + return false; + } + + private static List ValidateRoutePattern(string pattern) + { + var issues = new List(); + + if (string.IsNullOrWhiteSpace(pattern)) + { + issues.Add("Route pattern cannot be empty"); + return issues; + } + + // Strip escaped braces before validation (matches RouteValidator behavior) + // This prevents false positives for routes like /api/{{version}}/users + var escapedStripped = pattern.Replace("{{", "").Replace("}}", ""); + + // Check for empty parameter names: {} + if (escapedStripped.Contains("{}")) + issues.Add("Route contains empty parameter '{}'. Parameter names are required"); + + // Check for unclosed braces + var openCount = escapedStripped.Count(static c => c == '{'); + var closeCount = escapedStripped.Count(static c => c == '}'); + if (openCount != closeCount) issues.Add($"Route has mismatched braces: {openCount} '{{' and {closeCount} '}}'"); + + // Check for duplicate parameter names using RouteValidator (single source of truth) + var paramNames = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var rp in RouteValidator.ExtractRouteParameters(pattern)) + { + if (!paramNames.Add(rp.Name)) + issues.Add($"Route contains duplicate parameter '{{{rp.Name}}}'"); + } + + return issues; + } + + private static bool IsStringType(string typeFqn) + { + return typeFqn.IsStringType(); + } + + private static bool TypeNamesMatch(string actualFqn, string expected) + { + return actualFqn.TypeNamesEqual(expected); + } + + private static string NormalizeTypeName(string typeFqn) + { + return typeFqn.NormalizeTypeName(); + } +} diff --git a/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.cs b/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.cs index a3f8025..cbd0b7f 100644 --- a/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.cs +++ b/src/ErrorOrX.Generators/Analyzers/ErrorOrEndpointAnalyzer.cs @@ -10,13 +10,23 @@ namespace ErrorOr.Analyzers; /// Provides immediate IDE feedback for common issues. /// /// -/// This analyzer handles single-method diagnostics that can run fast. -/// Cross-file diagnostics (EOE004, EOE007) remain in the generator. -/// Route classification (Stream/PipeReader/IFormFile/etc.) is delegated to -/// so analyzer and generator stay in lockstep. +/// +/// This analyzer handles single-method diagnostics that can run fast. +/// Cross-file diagnostics (EOE004, EOE007) remain in the generator. +/// Route classification (Stream/PipeReader/IFormFile/etc.) is delegated to +/// so analyzer and generator stay in lockstep. +/// +/// +/// Split across: +/// +/// ErrorOrEndpointAnalyzer.cs — Entry, Initialize, top-level analysis loop, return-type / attribute extraction. +/// ErrorOrEndpointAnalyzer.RouteValidation.cs — Pattern parsing + per-constraint validation. +/// ErrorOrEndpointAnalyzer.BodyAndValidation.cs — Body-source counting, DataAnnotations reflection check. +/// +/// /// [DiagnosticAnalyzer(LanguageNames.CSharp)] -public sealed class ErrorOrEndpointAnalyzer : DiagnosticAnalyzer +public sealed partial class ErrorOrEndpointAnalyzer : DiagnosticAnalyzer { /// public override ImmutableArray SupportedDiagnostics { get; } = @@ -170,176 +180,6 @@ private static void AnalyzeEndpoint( CheckForValidationAttributes(in context, method); } - /// - /// Checks if any parameter has validation attributes from System.ComponentModel.DataAnnotations. - /// Validator.TryValidateObject uses reflection internally. - /// - private static void CheckForValidationAttributes( - in SymbolAnalysisContext context, - IMethodSymbol method) - { - var validationAttributeType = context.Compilation.GetTypeByMetadataName(WellKnownTypes.ValidationAttribute); - if (validationAttributeType is null) return; - - foreach (var param in method.Parameters) - { - foreach (var attr in param.GetAttributes()) - { - if (attr.AttributeClass is null) continue; - - // Check if the attribute inherits from ValidationAttribute - if (InheritsFrom(attr.AttributeClass, validationAttributeType)) - { - context.ReportDiagnostic(Diagnostic.Create( - Descriptors.ValidationUsesReflection, - param.Locations.FirstOrDefault() ?? method.Locations.FirstOrDefault(), - param.Name, - method.Name)); - break; // Only report once per parameter - } - } - } - } - - /// - /// Checks if a type inherits from a base type. - /// - private static bool InheritsFrom(ITypeSymbol type, ISymbol baseType) - { - var current = type.BaseType; - while (current is not null) - { - if (SymbolEqualityComparer.Default.Equals(current, baseType)) return true; - - current = current.BaseType; - } - - return false; - } - - /// - /// Validates route constraint types match method parameter types (EOE020). - /// - private static void ValidateConstraintTypes( - in SymbolAnalysisContext context, - ImmutableArray routeParams, - IReadOnlyDictionary methodParamsByRouteName, - Location attributeLocation) - { - foreach (var rp in routeParams) - ValidateSingleRouteConstraint(in context, rp, methodParamsByRouteName, attributeLocation); - } - - /// - /// Validates a single route parameter constraint against its bound method parameter. - /// - private static void ValidateSingleRouteConstraint( - in SymbolAnalysisContext context, - RouteParameterInfo rp, - IReadOnlyDictionary methodParamsByRouteName, - Location attributeLocation) - { - // Skip if no constraint or not bound to a method parameter - if (rp.Constraint is not { } constraint || - !methodParamsByRouteName.TryGetValue(rp.Name, out var mp)) - { - return; - } - - if (mp.TypeFqn is not { } typeFqn) return; - - // Skip format-only constraints - if (IsFormatOnlyConstraint(constraint)) return; - - // Validate based on constraint type - if (rp.IsCatchAll) - ValidateCatchAllConstraint(in context, rp, mp, typeFqn, attributeLocation); - else - ValidateTypedConstraint(in context, rp, constraint, mp, typeFqn, attributeLocation); - } - - /// - /// Checks if a constraint is format-only and doesn't constrain the CLR type. - /// Delegates to shared RouteValidator to avoid duplication. - /// - private static bool IsFormatOnlyConstraint(string constraint) - { - return RouteValidator.FormatOnlyConstraints.Contains(constraint); - } - - /// - /// Validates that a catch-all parameter is bound to a string type. - /// - private static void ValidateCatchAllConstraint( - in SymbolAnalysisContext context, - RouteParameterInfo rp, - RouteMethodParameterInfo mp, - string typeFqn, - Location attributeLocation) - { - if (!IsStringType(typeFqn)) - { - context.ReportDiagnostic(Diagnostic.Create( - Descriptors.RouteConstraintTypeMismatch, - attributeLocation, - rp.Name, - "*", - "string", - mp.Name, - NormalizeTypeName(typeFqn))); - } - } - - /// - /// Validates that a typed constraint matches the bound parameter type. - /// Uses shared RouteValidator.ConstraintToTypes to avoid duplication. - /// - private static void ValidateTypedConstraint( - in SymbolAnalysisContext context, - RouteParameterInfo rp, - string constraint, - RouteMethodParameterInfo mp, - string typeFqn, - Location attributeLocation) - { - // Look up expected types for this constraint using shared RouteValidator - if (!RouteValidator.ConstraintToTypes.TryGetValue(constraint, - out var expectedTypes)) - { - return; // Unknown constraint (e.g., custom) - skip validation - } - - // Get the actual type, unwrapping Nullable for optional parameters - var actualTypeFqn = typeFqn.UnwrapNullable(rp.IsOptional || mp.IsNullable); - - // Check if actual type matches any expected type - if (!DoesTypeMatchConstraint(actualTypeFqn, expectedTypes)) - { - context.ReportDiagnostic(Diagnostic.Create( - Descriptors.RouteConstraintTypeMismatch, - attributeLocation, - rp.Name, - constraint, - expectedTypes[0], - mp.Name, - NormalizeTypeName(typeFqn))); - } - } - - /// - /// Checks if an actual type matches any of the expected types for a constraint. - /// - private static bool DoesTypeMatchConstraint(string actualTypeFqn, IEnumerable expectedTypes) - { - foreach (var expected in expectedTypes) - { - if (TypeNamesMatch(actualTypeFqn, expected)) - return true; - } - - return false; - } - private static bool IsErrorOr(ITypeSymbol type) { return type is INamedTypeSymbol { Name: "ErrorOr", IsGenericType: true } named && @@ -426,96 +266,4 @@ private static ImmutableArray ExtractRouteParametersWithCons { return RouteValidator.ExtractRouteParameters(pattern); } - - private static int CountBodySources(IMethodSymbol method) - { - var bodyCount = 0; - var hasFromForm = false; - var hasStream = false; - - foreach (var param in method.Parameters) - { - // Check for body-related attributes using HasAttribute - if (param.HasAttribute(WellKnownTypes.FromBodyAttribute)) - { - bodyCount++; - continue; - } - - if (param.HasAttribute(WellKnownTypes.FromFormAttribute)) - { - hasFromForm = true; - continue; - } - - // Body-related types: route through ErrorOrContext so analyzer + generator - // share one source of truth (FQN + inheritance, not name substring match). - if (ErrorOrContext.IsStream(param.Type) || ErrorOrContext.IsPipeReader(param.Type)) - hasStream = true; - else if (ErrorOrContext.IsFormFile(param.Type) || - ErrorOrContext.IsFormFileCollection(param.Type) || - ErrorOrContext.IsFormCollection(param.Type)) - hasFromForm = true; - } - - // Multiple [FromBody] is always an error - if (bodyCount > 1) return bodyCount; - - // Otherwise return number of distinct body source buckets used - return (bodyCount > 0 ? 1 : 0) + (hasFromForm ? 1 : 0) + (hasStream ? 1 : 0); - } - - private static bool HasAcceptedResponseAttribute(ISymbol method) - { - return method.HasAttribute(WellKnownTypes.AcceptedResponseAttribute); - } - - private static List ValidateRoutePattern(string pattern) - { - var issues = new List(); - - if (string.IsNullOrWhiteSpace(pattern)) - { - issues.Add("Route pattern cannot be empty"); - return issues; - } - - // Strip escaped braces before validation (matches RouteValidator behavior) - // This prevents false positives for routes like /api/{{version}}/users - var escapedStripped = pattern.Replace("{{", "").Replace("}}", ""); - - // Check for empty parameter names: {} - if (escapedStripped.Contains("{}")) - issues.Add("Route contains empty parameter '{}'. Parameter names are required"); - - // Check for unclosed braces - var openCount = escapedStripped.Count(static c => c == '{'); - var closeCount = escapedStripped.Count(static c => c == '}'); - if (openCount != closeCount) issues.Add($"Route has mismatched braces: {openCount} '{{' and {closeCount} '}}'"); - - // Check for duplicate parameter names using RouteValidator (single source of truth) - var paramNames = new HashSet(StringComparer.OrdinalIgnoreCase); - foreach (var rp in RouteValidator.ExtractRouteParameters(pattern)) - { - if (!paramNames.Add(rp.Name)) - issues.Add($"Route contains duplicate parameter '{{{rp.Name}}}'"); - } - - return issues; - } - - private static bool IsStringType(string typeFqn) - { - return typeFqn.IsStringType(); - } - - private static bool TypeNamesMatch(string actualFqn, string expected) - { - return actualFqn.TypeNamesEqual(expected); - } - - private static string NormalizeTypeName(string typeFqn) - { - return typeFqn.NormalizeTypeName(); - } } diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.ErrorHandling.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.ErrorHandling.cs new file mode 100644 index 0000000..63c8ce3 --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.ErrorHandling.cs @@ -0,0 +1,108 @@ +using ErrorOr.Generators.Emitters; + +namespace ErrorOr.Generators; + +/// +/// Emits the error-to-result dispatch inside Invoke_Ep{N}_Core: +/// +/// Union-type path that switches on ErrorType-to-Result factories. +/// Validation handling (DataAnnotations + Error.Validation aggregation). +/// ProblemDetails construction and Location-header emission for Created+Id responses. +/// +/// +public sealed partial class ErrorOrEndpointGenerator +{ + private static void EmitUnionTypeErrorHandling( + StringBuilder code, + in EndpointDescriptor ep, + in InvokerContext ctx) + { + code.AppendLine(" if (result.IsError)"); + code.AppendLine(" {"); + code.AppendLine( + $" if (result.Errors.Count is 0) return {ctx.WrapReturn($"{WellKnownTypes.Fqn.TypedResults.InternalServerError}(new {WellKnownTypes.Fqn.ProblemDetails} {{ Title = \"Error\", Detail = \"An error occurred but no details were provided.\", Status = 500 }})")};"); + code.AppendLine(" var first = result.Errors[0];"); + + EmitValidationHandling(code, in ep, in ctx); + EmitProblemDetailsBuilding(code); + EmitErrorTypeSwitch(code, in ep, in ctx); + + code.AppendLine(" }"); + code.AppendLine(); + + var successFactory = GetSuccessFactoryWithLocation(in ep, ctx.SuccessInfo); + code.AppendLine($" return {ctx.WrapReturn(successFactory)};"); + } + + private static string GetSuccessFactoryWithLocation(in EndpointDescriptor ep, SuccessResponseInfo successInfo) + { + // POST + Created(201) + body with Id property → emit Location header + if (ep.HttpVerb == HttpVerb.Post + && successInfo is { StatusCode: 201, HasBody: true } + && ep.LocationIdPropertyName is { Length: > 0 } idProp) + { + return + $"{WellKnownTypes.Fqn.TypedResults.Created}($\"{{ctx.Request.Path}}/{{result.Value.{idProp}}}\", result.Value)"; + } + + return successInfo.Factory; + } + + private static void EmitValidationHandling(StringBuilder code, in EndpointDescriptor ep, + in InvokerContext ctx) + { + var hasValidation = !ep.ErrorInference.InferredErrorTypeNames.IsDefaultOrEmpty && + ep.ErrorInference.InferredErrorTypeNames.AsImmutableArray().Contains(ErrorMapping.Validation); + + if (!hasValidation) return; + + code.AppendLine($" if (first.Type == {WellKnownTypes.Fqn.ErrorType}.Validation)"); + code.AppendLine(" {"); + BindingCodeEmitter.EmitValidationDictBuilder( + code, 20, "validationDict", "result.Errors", "e", + "e.Code", "e.Description", + $"e.Type != {WellKnownTypes.Fqn.ErrorType}.Validation"); + code.AppendLine( + $" return {ctx.WrapReturn($"{WellKnownTypes.Fqn.TypedResults.ValidationProblem}(validationDict)")};"); + code.AppendLine(" }"); + } + + private static void EmitProblemDetailsBuilding(StringBuilder code) + { + code.AppendLine($" var problem = new {WellKnownTypes.Fqn.ProblemDetails}"); + code.AppendLine(" {"); + code.AppendLine(" Title = first.Code,"); + code.AppendLine(" Detail = first.Description,"); + code.AppendLine( + $" Status = first.Type switch {{ {ErrorMapping.GenerateStatusSwitch(WellKnownTypes.Fqn.ErrorType)} }}"); + code.AppendLine(" };"); + code.AppendLine(" problem.Type = $\"https://httpstatuses.io/{problem.Status}\";"); + code.AppendLine(); + } + + private static void EmitErrorTypeSwitch(StringBuilder code, in EndpointDescriptor ep, + in InvokerContext ctx) + { + code.AppendLine(" switch (first.Type)"); + code.AppendLine(" {"); + + if (!ep.ErrorInference.InferredErrorTypeNames.IsDefaultOrEmpty) + { + foreach (var errorTypeName in ep.ErrorInference.InferredErrorTypeNames.AsImmutableArray() + .Where(static e => e != ErrorMapping.Validation) + .Distinct() + .OrderBy(static x => x, StringComparer.Ordinal)) + { + var factory = ErrorMapping.GetFactory(errorTypeName); + + code.AppendLine($" case {WellKnownTypes.Fqn.ErrorType}.{errorTypeName}:"); + code.AppendLine($" return {ctx.WrapReturn(factory)};"); + } + } + + code.AppendLine(" default:"); + code.AppendLine( + $" return {ctx.WrapReturn(ErrorMapping.GetFactory(ErrorMapping.Failure))};"); + code.AppendLine(" }"); + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.Invoker.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.Invoker.cs new file mode 100644 index 0000000..3313ed9 --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.Invoker.cs @@ -0,0 +1,251 @@ +using ErrorOr.Generators.Emitters; + +namespace ErrorOr.Generators; + +/// +/// Per-endpoint invoker emission. Generates the two-method AOT-safe pattern: +/// +/// Invoke_Ep{N} — typed-return wrapper for OpenAPI visibility. +/// Invoke_Ep{N}_Core — body emission with binding, validation, dispatch. +/// +/// Also emits the bind-failure helpers (BindFail, BindFail415) and the +/// DataAnnotations validation block before the handler call. +/// +public sealed partial class ErrorOrEndpointGenerator +{ + private static void EmitInvoker(StringBuilder code, in EndpointDescriptor ep, int index, int maxArity) + { + var ctx = ComputeInvokerContext(in ep, index, maxArity); + var (bodyCode, usesBindFail) = EmitBodyCode(in ep, in ctx); + + EmitWrapperMethod(code, in ctx); + EmitCoreMethod(code, bodyCode, in ctx, usesBindFail); + } + + private static InvokerContext ComputeInvokerContext( + in EndpointDescriptor ep, + int index, + int maxArity) + { + var successInfo = ResultsUnionTypeBuilder.GetSuccessResponseInfo( + ep.SuccessTypeFqn, ep.SuccessKind, ep.IsAcceptedResponse); + + var hasFormBinding = ep.HasFormParams; + var hasBodyBinding = ep.HasBodyOrFormBinding; + + var unionResult = ResultsUnionTypeBuilder.ComputeReturnType( + ep.SuccessTypeFqn, ep.SuccessKind, + ep.ErrorInference.InferredErrorTypeNames, ep.ErrorInference.InferredCustomErrors, + ep.ErrorInference.DeclaredProducesErrors, hasBodyBinding, maxArity, + ep.IsAcceptedResponse, ep.Middleware, ep.HasParameterValidation); + + var needsAwait = ep.IsAsync || hasBodyBinding || ep.HasBindAsyncParam; + + return new InvokerContext(successInfo, unionResult, hasFormBinding, hasBodyBinding, needsAwait, index); + } + + private static (StringBuilder Code, bool UsesBindFail) EmitBodyCode( + in EndpointDescriptor ep, + in InvokerContext ctx) + { + var bodyCode = new StringBuilder(); + var usesBindFail = ctx.HasFormBinding; + + if (ctx.HasFormBinding) EmitFormContentTypeGuard(bodyCode); + + var args = new StringBuilder(); + var validationParams = new List<(int Index, string ParamName)>(); + for (var i = 0; i < ep.HandlerParameters.Length; i++) + { + var param = ep.HandlerParameters[i]; + usesBindFail |= BindingCodeEmitter.EmitParameterBinding(bodyCode, in param, $"p{i}", "BindFail"); + if (i > 0) args.Append(", "); + + args.Append(BindingCodeEmitter.BuildArgumentExpression(in param, $"p{i}")); + + if (param.RequiresValidation) validationParams.Add((i, $"p{i}")); + } + + if (validationParams.Count > 0) + EmitBclValidation(bodyCode, validationParams, ctx.UnionResult.ReturnTypeFqn, ctx.NeedsAwait); + + var awaitKeyword = ep.IsAsync ? "await " : ""; + bodyCode.AppendLine( + $" var result = {awaitKeyword}{ep.HandlerContainingTypeFqn}.{ep.HandlerMethodName}({args});"); + + EmitErrorHandling(bodyCode, in ep, in ctx); + + return (bodyCode, usesBindFail); + } + + private static void EmitErrorHandling( + StringBuilder bodyCode, + in EndpointDescriptor ep, + in InvokerContext ctx) + { + if (ep.Sse.IsSse) + { + bodyCode.AppendLine( + $" if (result.IsError) return {ctx.WrapReturn("ToProblem(result.Errors)")};"); + bodyCode.AppendLine( + $" return {ctx.WrapReturn($"{WellKnownTypes.Fqn.TypedResults.ServerSentEvents}(result.Value)")};"); + } + else if (ctx.UnionResult.CanUseUnion) + { + EmitUnionTypeErrorHandling(bodyCode, in ep, in ctx); + } + else + { + // Use minimal interface (IsError/Errors/Value) instead of convenience Match API + var successFactory = GetSuccessFactoryWithLocation(in ep, ctx.SuccessInfo); + bodyCode.AppendLine( + $" if (result.IsError) return {ctx.WrapReturn("ToProblem(result.Errors)")};"); + bodyCode.AppendLine($" return {ctx.WrapReturn(successFactory)};"); + } + } + + private static void EmitWrapperMethod(StringBuilder code, in InvokerContext ctx) + { + var returnType = ctx.UnionResult.ReturnTypeFqn; + code.AppendLine($" private static async Task<{returnType}> {ctx.WrapperName}(HttpContext ctx)"); + code.AppendLine(" {"); + code.AppendLine($" return await {ctx.CoreName}(ctx);"); + code.AppendLine(" }"); + code.AppendLine(); + } + + private static void EmitCoreMethod( + StringBuilder code, + StringBuilder bodyCode, + in InvokerContext ctx, + bool usesBindFail) + { + var returnType = ctx.UnionResult.ReturnTypeFqn; + code.AppendLine( + ctx.NeedsAwait + ? $" private static async Task<{returnType}> {ctx.CoreName}(HttpContext ctx)" + : $" private static Task<{returnType}> {ctx.CoreName}(HttpContext ctx)"); + + code.AppendLine(" {"); + + if (usesBindFail) + EmitBindFailHelper(code, returnType, ctx.NeedsAwait, ctx.UnionResult.UsesValidationProblemFor400); + + if (ctx.HasBodyBinding) EmitBindFail415Helper(code, returnType, ctx.NeedsAwait); + + code.Append(bodyCode); + code.AppendLine(" }"); + code.AppendLine(); + } + + /// + /// Emits BCL validation calls for parameters that have ValidationAttribute or implement IValidatableObject. + /// Uses System.ComponentModel.DataAnnotations.Validator.TryValidateObject for validation. + /// + private static void EmitBclValidation(StringBuilder code, List<(int Index, string ParamName)> validationParams, + string returnTypeFqn, bool isAsync) + { + code.AppendLine(); + code.AppendLine(" // BCL Validation"); + + foreach (var (_, paramName) in validationParams) + { + code.AppendLine( + $" var {paramName}ValidationResults = new {WellKnownTypes.Fqn.List}<{WellKnownTypes.Fqn.ValidationResult}>();"); + code.AppendLine( + $" if (!{WellKnownTypes.Fqn.Validator}.TryValidateObject({paramName}!, new {WellKnownTypes.Fqn.ValidationContext}({paramName}!), {paramName}ValidationResults, validateAllProperties: true))"); + code.AppendLine(" {"); + BindingCodeEmitter.EmitValidationDictBuilder( + code, 16, "validationDict", $"{paramName}ValidationResults", "vr", + "key", "vr.ErrorMessage ?? \"\"", + keyVarDecl: "var key = vr.MemberNames.FirstOrDefault() ?? \"\";"); + + var returnExpr = isAsync + ? $"{WellKnownTypes.Fqn.TypedResults.ValidationProblem}(validationDict)" + : $"Task.FromResult<{returnTypeFqn}>({WellKnownTypes.Fqn.TypedResults.ValidationProblem}(validationDict))"; + code.AppendLine($" return {returnExpr};"); + code.AppendLine(" }"); + } + + code.AppendLine(); + } + + private static void EmitBindFailHelper(StringBuilder code, string returnTypeFqn, bool isAsync, + bool useValidationProblem) + { + var returnType = isAsync ? returnTypeFqn : $"Task<{returnTypeFqn}>"; + + if (useValidationProblem) + { + // Use ValidationProblem to match the Results<..., ValidationProblem, ...> union type + const string validationProblemExpr = + $"{WellKnownTypes.Fqn.TypedResults.ValidationProblem}(new {WellKnownTypes.Fqn.Dictionary} {{ [param] = [reason] }})"; + var returnExpr = + isAsync ? validationProblemExpr : $"Task.FromResult<{returnTypeFqn}>({validationProblemExpr})"; + + code.AppendLine($" static {returnType} BindFail(string param, string reason)"); + code.AppendLine($" => {returnExpr};"); + code.AppendLine(); + } + else + { + // Use BadRequest to match the Results<..., BadRequest, ...> union type + code.AppendLine( + $" static {WellKnownTypes.Fqn.ProblemDetails} CreateBindProblem(string param, string reason) => new()"); + code.AppendLine(" {"); + code.AppendLine(" Title = \"Bad Request\","); + code.AppendLine(" Detail = $\"Parameter '{param}' {reason}.\","); + code.AppendLine(" Status = 400,"); + code.AppendLine($" Type = \"{WellKnownTypes.Constants.HttpStatusesBaseUrl}400\","); + code.AppendLine(" };"); + code.AppendLine(); + + const string badRequestExpr = + $"{WellKnownTypes.Fqn.TypedResults.BadRequest}(CreateBindProblem(param, reason))"; + var returnExpr = isAsync ? badRequestExpr : $"Task.FromResult<{returnTypeFqn}>({badRequestExpr})"; + + code.AppendLine($" static {returnType} BindFail(string param, string reason)"); + code.AppendLine($" => {returnExpr};"); + code.AppendLine(); + } + } + + private static void EmitBindFail415Helper(StringBuilder code, string returnTypeFqn, bool isAsync) + { + const string expr = $"{WellKnownTypes.Fqn.TypedResults.StatusCode}(415)"; + var returnExpr = isAsync ? expr : $"Task.FromResult<{returnTypeFqn}>({expr})"; + var returnType = isAsync ? returnTypeFqn : $"Task<{returnTypeFqn}>"; + + code.AppendLine($" static {returnType} BindFail415()"); + code.AppendLine($" => {returnExpr};"); + code.AppendLine(); + } + + private static void EmitFormContentTypeGuard(StringBuilder code) + { + code.AppendLine( + " if (!ctx.Request.HasFormContentType) return BindFail415();"); + code.AppendLine(" var form = await ctx.Request.ReadFormAsync(ctx.RequestAborted);"); + code.AppendLine(); + } + + /// + /// Context for invoker emission, holding precomputed values and providing helper methods. + /// + private readonly record struct InvokerContext( + SuccessResponseInfo SuccessInfo, + UnionTypeResult UnionResult, + bool HasFormBinding, + bool HasBodyBinding, + bool NeedsAwait, + int Index) + { + public string WrapperName => $"Invoke_Ep{Index}"; + public string CoreName => $"Invoke_Ep{Index}_Core"; + + public string WrapReturn(string expr) + { + return NeedsAwait ? expr : $"Task.FromResult<{UnionResult.ReturnTypeFqn}>({expr})"; + } + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.Versioning.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.Versioning.cs new file mode 100644 index 0000000..136c6ae --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.Versioning.cs @@ -0,0 +1,94 @@ +using Microsoft.CodeAnalysis; + +namespace ErrorOr.Generators; + +/// +/// Emits Asp.Versioning.Http calls — global version set, per-endpoint version mapping, +/// and version-neutral marker. Triggered when at least one endpoint declares +/// [ApiVersion] or [ApiVersionNeutral]. +/// +public sealed partial class ErrorOrEndpointGenerator +{ + /// + /// Computes the global version set from all endpoints. + /// + private static VersionSetContext ComputeGlobalVersionSet(ImmutableArray endpoints) + { + var hasVersionNeutral = endpoints.Any(static ep => ep.Versioning.IsVersionNeutral); + + var sortedVersions = endpoints + .SelectMany(static ep => ep.Versioning.SupportedVersions.AsImmutableArray()) + .Distinct() + .OrderBy(static v => v.MajorVersion) + .ThenBy(static v => v.MinorVersion ?? 0) + .ToImmutableArray(); + + return new VersionSetContext(sortedVersions, hasVersionNeutral); + } + + /// + /// Emits the version set builder before endpoint mappings. + /// + private static void EmitVersionSet(StringBuilder code, VersionSetContext versionSet) + { + code.AppendLine(" // API Versioning: Build version set for all endpoints"); + code.AppendLine(" var versionSet = app.NewApiVersionSet()"); + + foreach (var v in versionSet.AllVersions) + { + var versionExpr = v.MinorVersion.HasValue + ? $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion}, {v.MinorVersion.Value})" + : $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion})"; + + code.AppendLine(v.IsDeprecated + ? $" .HasDeprecatedApiVersion({versionExpr})" + : $" .HasApiVersion({versionExpr})"); + } + + code.AppendLine(" .ReportApiVersions()"); + code.AppendLine(" .Build();"); + code.AppendLine(); + } + + /// + /// Emits API versioning fluent calls for an endpoint. + /// + private static void EmitVersioningCalls(StringBuilder code, in VersioningInfo versioning, bool hasGlobalVersionSet) + { + // If no global version set exists, don't emit anything + if (!hasGlobalVersionSet) return; + + // Version-neutral endpoints don't map to any specific version + if (versioning.IsVersionNeutral) + { + code.AppendLine(" .IsApiVersionNeutral()"); + return; + } + + // Apply the version set to the endpoint + code.AppendLine(" .WithApiVersionSet(versionSet)"); + + // If endpoint has specific versions to map to, emit MapToApiVersion calls + var effectiveVersions = versioning.EffectiveVersions; + if (!effectiveVersions.IsDefaultOrEmpty) + { + foreach (var v in effectiveVersions.AsImmutableArray()) + { + var versionExpr = v.MinorVersion.HasValue + ? $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion}, {v.MinorVersion.Value})" + : $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion})"; + code.AppendLine($" .MapToApiVersion({versionExpr})"); + } + } + } + + /// + /// Aggregated version set information with hasVersioning flag. + /// + private readonly record struct VersionSetContext( + ImmutableArray AllVersions, + bool HasVersionNeutralEndpoints) + { + public bool HasVersioning => !AllVersions.IsDefaultOrEmpty || HasVersionNeutralEndpoints; + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.cs index 6f57cd3..0dd0438 100644 --- a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.cs +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Emitter.cs @@ -6,7 +6,14 @@ namespace ErrorOr.Generators; /// -/// Code emission logic for endpoint mappings. +/// Code emission entry point for endpoint mappings. +/// Orchestrates the per-emission-phase partials: +/// +/// Emitter.Versioning.cs — API version set + per-endpoint versioning calls. +/// Emitter.Invoker.cs — Per-endpoint wrapper/core method emission and bind helpers. +/// Emitter.ErrorHandling.cs — Error-to-ProblemDetails / Results-union dispatch. +/// Emitter.JsonContext.cs, Emitter.Options.cs, Emitter.Support.cs — Existing partials. +/// /// public sealed partial class ErrorOrEndpointGenerator { @@ -180,417 +187,4 @@ private static void EmitMapCall(StringBuilder code, in EndpointDescriptor ep, in MapCallEmitter.EmitMapCallEnd(code, index, " "); code.AppendLine(); } - - /// - /// Computes the global version set from all endpoints. - /// - private static VersionSetContext ComputeGlobalVersionSet(ImmutableArray endpoints) - { - var hasVersionNeutral = endpoints.Any(static ep => ep.Versioning.IsVersionNeutral); - - var sortedVersions = endpoints - .SelectMany(static ep => ep.Versioning.SupportedVersions.AsImmutableArray()) - .Distinct() - .OrderBy(static v => v.MajorVersion) - .ThenBy(static v => v.MinorVersion ?? 0) - .ToImmutableArray(); - - return new VersionSetContext(sortedVersions, hasVersionNeutral); - } - - /// - /// Emits the version set builder before endpoint mappings. - /// - private static void EmitVersionSet(StringBuilder code, VersionSetContext versionSet) - { - code.AppendLine(" // API Versioning: Build version set for all endpoints"); - code.AppendLine(" var versionSet = app.NewApiVersionSet()"); - - foreach (var v in versionSet.AllVersions) - { - var versionExpr = v.MinorVersion.HasValue - ? $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion}, {v.MinorVersion.Value})" - : $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion})"; - - code.AppendLine(v.IsDeprecated - ? $" .HasDeprecatedApiVersion({versionExpr})" - : $" .HasApiVersion({versionExpr})"); - } - - code.AppendLine(" .ReportApiVersions()"); - code.AppendLine(" .Build();"); - code.AppendLine(); - } - - /// - /// Emits API versioning fluent calls for an endpoint. - /// - private static void EmitVersioningCalls(StringBuilder code, in VersioningInfo versioning, bool hasGlobalVersionSet) - { - // If no global version set exists, don't emit anything - if (!hasGlobalVersionSet) return; - - // Version-neutral endpoints don't map to any specific version - if (versioning.IsVersionNeutral) - { - code.AppendLine(" .IsApiVersionNeutral()"); - return; - } - - // Apply the version set to the endpoint - code.AppendLine(" .WithApiVersionSet(versionSet)"); - - // If endpoint has specific versions to map to, emit MapToApiVersion calls - var effectiveVersions = versioning.EffectiveVersions; - if (!effectiveVersions.IsDefaultOrEmpty) - { - foreach (var v in effectiveVersions.AsImmutableArray()) - { - var versionExpr = v.MinorVersion.HasValue - ? $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion}, {v.MinorVersion.Value})" - : $"new {WellKnownTypes.Fqn.ApiVersion}({v.MajorVersion})"; - code.AppendLine($" .MapToApiVersion({versionExpr})"); - } - } - } - - private static InvokerContext ComputeInvokerContext( - in EndpointDescriptor ep, - int index, - int maxArity) - { - var successInfo = ResultsUnionTypeBuilder.GetSuccessResponseInfo( - ep.SuccessTypeFqn, ep.SuccessKind, ep.IsAcceptedResponse); - - var hasFormBinding = ep.HasFormParams; - var hasBodyBinding = ep.HasBodyOrFormBinding; - - var unionResult = ResultsUnionTypeBuilder.ComputeReturnType( - ep.SuccessTypeFqn, ep.SuccessKind, - ep.ErrorInference.InferredErrorTypeNames, ep.ErrorInference.InferredCustomErrors, - ep.ErrorInference.DeclaredProducesErrors, hasBodyBinding, maxArity, - ep.IsAcceptedResponse, ep.Middleware, ep.HasParameterValidation); - - var needsAwait = ep.IsAsync || hasBodyBinding || ep.HasBindAsyncParam; - - return new InvokerContext(successInfo, unionResult, hasFormBinding, hasBodyBinding, needsAwait, index); - } - - private static (StringBuilder Code, bool UsesBindFail) EmitBodyCode( - in EndpointDescriptor ep, - in InvokerContext ctx) - { - var bodyCode = new StringBuilder(); - var usesBindFail = ctx.HasFormBinding; - - if (ctx.HasFormBinding) EmitFormContentTypeGuard(bodyCode); - - var args = new StringBuilder(); - var validationParams = new List<(int Index, string ParamName)>(); - for (var i = 0; i < ep.HandlerParameters.Length; i++) - { - var param = ep.HandlerParameters[i]; - usesBindFail |= BindingCodeEmitter.EmitParameterBinding(bodyCode, in param, $"p{i}", "BindFail"); - if (i > 0) args.Append(", "); - - args.Append(BindingCodeEmitter.BuildArgumentExpression(in param, $"p{i}")); - - if (param.RequiresValidation) validationParams.Add((i, $"p{i}")); - } - - if (validationParams.Count > 0) - EmitBclValidation(bodyCode, validationParams, ctx.UnionResult.ReturnTypeFqn, ctx.NeedsAwait); - - var awaitKeyword = ep.IsAsync ? "await " : ""; - bodyCode.AppendLine( - $" var result = {awaitKeyword}{ep.HandlerContainingTypeFqn}.{ep.HandlerMethodName}({args});"); - - EmitErrorHandling(bodyCode, in ep, in ctx); - - return (bodyCode, usesBindFail); - } - - private static void EmitErrorHandling( - StringBuilder bodyCode, - in EndpointDescriptor ep, - in InvokerContext ctx) - { - if (ep.Sse.IsSse) - { - bodyCode.AppendLine( - $" if (result.IsError) return {ctx.WrapReturn("ToProblem(result.Errors)")};"); - bodyCode.AppendLine( - $" return {ctx.WrapReturn($"{WellKnownTypes.Fqn.TypedResults.ServerSentEvents}(result.Value)")};"); - } - else if (ctx.UnionResult.CanUseUnion) - { - EmitUnionTypeErrorHandling(bodyCode, in ep, in ctx); - } - else - { - // Use minimal interface (IsError/Errors/Value) instead of convenience Match API - var successFactory = GetSuccessFactoryWithLocation(in ep, ctx.SuccessInfo); - bodyCode.AppendLine( - $" if (result.IsError) return {ctx.WrapReturn("ToProblem(result.Errors)")};"); - bodyCode.AppendLine($" return {ctx.WrapReturn(successFactory)};"); - } - } - - private static void EmitWrapperMethod(StringBuilder code, in InvokerContext ctx) - { - var returnType = ctx.UnionResult.ReturnTypeFqn; - code.AppendLine($" private static async Task<{returnType}> {ctx.WrapperName}(HttpContext ctx)"); - code.AppendLine(" {"); - code.AppendLine($" return await {ctx.CoreName}(ctx);"); - code.AppendLine(" }"); - code.AppendLine(); - } - - private static void EmitCoreMethod( - StringBuilder code, - StringBuilder bodyCode, - in InvokerContext ctx, - bool usesBindFail) - { - var returnType = ctx.UnionResult.ReturnTypeFqn; - code.AppendLine( - ctx.NeedsAwait - ? $" private static async Task<{returnType}> {ctx.CoreName}(HttpContext ctx)" - : $" private static Task<{returnType}> {ctx.CoreName}(HttpContext ctx)"); - - code.AppendLine(" {"); - - if (usesBindFail) - EmitBindFailHelper(code, returnType, ctx.NeedsAwait, ctx.UnionResult.UsesValidationProblemFor400); - - if (ctx.HasBodyBinding) EmitBindFail415Helper(code, returnType, ctx.NeedsAwait); - - code.Append(bodyCode); - code.AppendLine(" }"); - code.AppendLine(); - } - - private static void EmitInvoker(StringBuilder code, in EndpointDescriptor ep, int index, int maxArity) - { - var ctx = ComputeInvokerContext(in ep, index, maxArity); - var (bodyCode, usesBindFail) = EmitBodyCode(in ep, in ctx); - - EmitWrapperMethod(code, in ctx); - EmitCoreMethod(code, bodyCode, in ctx, usesBindFail); - } - - /// - /// Emits BCL validation calls for parameters that have ValidationAttribute or implement IValidatableObject. - /// Uses System.ComponentModel.DataAnnotations.Validator.TryValidateObject for validation. - /// - private static void EmitBclValidation(StringBuilder code, List<(int Index, string ParamName)> validationParams, - string returnTypeFqn, bool isAsync) - { - code.AppendLine(); - code.AppendLine(" // BCL Validation"); - - foreach (var (_, paramName) in validationParams) - { - code.AppendLine( - $" var {paramName}ValidationResults = new {WellKnownTypes.Fqn.List}<{WellKnownTypes.Fqn.ValidationResult}>();"); - code.AppendLine( - $" if (!{WellKnownTypes.Fqn.Validator}.TryValidateObject({paramName}!, new {WellKnownTypes.Fqn.ValidationContext}({paramName}!), {paramName}ValidationResults, validateAllProperties: true))"); - code.AppendLine(" {"); - BindingCodeEmitter.EmitValidationDictBuilder( - code, 16, "validationDict", $"{paramName}ValidationResults", "vr", - "key", "vr.ErrorMessage ?? \"\"", - keyVarDecl: "var key = vr.MemberNames.FirstOrDefault() ?? \"\";"); - - var returnExpr = isAsync - ? $"{WellKnownTypes.Fqn.TypedResults.ValidationProblem}(validationDict)" - : $"Task.FromResult<{returnTypeFqn}>({WellKnownTypes.Fqn.TypedResults.ValidationProblem}(validationDict))"; - code.AppendLine($" return {returnExpr};"); - code.AppendLine(" }"); - } - - code.AppendLine(); - } - - private static void EmitBindFailHelper(StringBuilder code, string returnTypeFqn, bool isAsync, - bool useValidationProblem) - { - var returnType = isAsync ? returnTypeFqn : $"Task<{returnTypeFqn}>"; - - if (useValidationProblem) - { - // Use ValidationProblem to match the Results<..., ValidationProblem, ...> union type - const string validationProblemExpr = - $"{WellKnownTypes.Fqn.TypedResults.ValidationProblem}(new {WellKnownTypes.Fqn.Dictionary} {{ [param] = [reason] }})"; - var returnExpr = - isAsync ? validationProblemExpr : $"Task.FromResult<{returnTypeFqn}>({validationProblemExpr})"; - - code.AppendLine($" static {returnType} BindFail(string param, string reason)"); - code.AppendLine($" => {returnExpr};"); - code.AppendLine(); - } - else - { - // Use BadRequest to match the Results<..., BadRequest, ...> union type - code.AppendLine( - $" static {WellKnownTypes.Fqn.ProblemDetails} CreateBindProblem(string param, string reason) => new()"); - code.AppendLine(" {"); - code.AppendLine(" Title = \"Bad Request\","); - code.AppendLine(" Detail = $\"Parameter '{param}' {reason}.\","); - code.AppendLine(" Status = 400,"); - code.AppendLine($" Type = \"{WellKnownTypes.Constants.HttpStatusesBaseUrl}400\","); - code.AppendLine(" };"); - code.AppendLine(); - - const string badRequestExpr = - $"{WellKnownTypes.Fqn.TypedResults.BadRequest}(CreateBindProblem(param, reason))"; - var returnExpr = isAsync ? badRequestExpr : $"Task.FromResult<{returnTypeFqn}>({badRequestExpr})"; - - code.AppendLine($" static {returnType} BindFail(string param, string reason)"); - code.AppendLine($" => {returnExpr};"); - code.AppendLine(); - } - } - - private static void EmitBindFail415Helper(StringBuilder code, string returnTypeFqn, bool isAsync) - { - const string expr = $"{WellKnownTypes.Fqn.TypedResults.StatusCode}(415)"; - var returnExpr = isAsync ? expr : $"Task.FromResult<{returnTypeFqn}>({expr})"; - var returnType = isAsync ? returnTypeFqn : $"Task<{returnTypeFqn}>"; - - code.AppendLine($" static {returnType} BindFail415()"); - code.AppendLine($" => {returnExpr};"); - code.AppendLine(); - } - - private static void EmitUnionTypeErrorHandling( - StringBuilder code, - in EndpointDescriptor ep, - in InvokerContext ctx) - { - code.AppendLine(" if (result.IsError)"); - code.AppendLine(" {"); - code.AppendLine( - $" if (result.Errors.Count is 0) return {ctx.WrapReturn($"{WellKnownTypes.Fqn.TypedResults.InternalServerError}(new {WellKnownTypes.Fqn.ProblemDetails} {{ Title = \"Error\", Detail = \"An error occurred but no details were provided.\", Status = 500 }})")};"); - code.AppendLine(" var first = result.Errors[0];"); - - EmitValidationHandling(code, in ep, in ctx); - EmitProblemDetailsBuilding(code); - EmitErrorTypeSwitch(code, in ep, in ctx); - - code.AppendLine(" }"); - code.AppendLine(); - - var successFactory = GetSuccessFactoryWithLocation(in ep, ctx.SuccessInfo); - code.AppendLine($" return {ctx.WrapReturn(successFactory)};"); - } - - private static string GetSuccessFactoryWithLocation(in EndpointDescriptor ep, SuccessResponseInfo successInfo) - { - // POST + Created(201) + body with Id property → emit Location header - if (ep.HttpVerb == HttpVerb.Post - && successInfo is { StatusCode: 201, HasBody: true } - && ep.LocationIdPropertyName is { Length: > 0 } idProp) - { - return - $"{WellKnownTypes.Fqn.TypedResults.Created}($\"{{ctx.Request.Path}}/{{result.Value.{idProp}}}\", result.Value)"; - } - - return successInfo.Factory; - } - - private static void EmitValidationHandling(StringBuilder code, in EndpointDescriptor ep, - in InvokerContext ctx) - { - var hasValidation = !ep.ErrorInference.InferredErrorTypeNames.IsDefaultOrEmpty && - ep.ErrorInference.InferredErrorTypeNames.AsImmutableArray().Contains(ErrorMapping.Validation); - - if (!hasValidation) return; - - code.AppendLine($" if (first.Type == {WellKnownTypes.Fqn.ErrorType}.Validation)"); - code.AppendLine(" {"); - BindingCodeEmitter.EmitValidationDictBuilder( - code, 20, "validationDict", "result.Errors", "e", - "e.Code", "e.Description", - $"e.Type != {WellKnownTypes.Fqn.ErrorType}.Validation"); - code.AppendLine( - $" return {ctx.WrapReturn($"{WellKnownTypes.Fqn.TypedResults.ValidationProblem}(validationDict)")};"); - code.AppendLine(" }"); - } - - private static void EmitProblemDetailsBuilding(StringBuilder code) - { - code.AppendLine($" var problem = new {WellKnownTypes.Fqn.ProblemDetails}"); - code.AppendLine(" {"); - code.AppendLine(" Title = first.Code,"); - code.AppendLine(" Detail = first.Description,"); - code.AppendLine( - $" Status = first.Type switch {{ {ErrorMapping.GenerateStatusSwitch(WellKnownTypes.Fqn.ErrorType)} }}"); - code.AppendLine(" };"); - code.AppendLine(" problem.Type = $\"https://httpstatuses.io/{problem.Status}\";"); - code.AppendLine(); - } - - private static void EmitErrorTypeSwitch(StringBuilder code, in EndpointDescriptor ep, - in InvokerContext ctx) - { - code.AppendLine(" switch (first.Type)"); - code.AppendLine(" {"); - - if (!ep.ErrorInference.InferredErrorTypeNames.IsDefaultOrEmpty) - { - foreach (var errorTypeName in ep.ErrorInference.InferredErrorTypeNames.AsImmutableArray() - .Where(static e => e != ErrorMapping.Validation) - .Distinct() - .OrderBy(static x => x, StringComparer.Ordinal)) - { - var factory = ErrorMapping.GetFactory(errorTypeName); - - code.AppendLine($" case {WellKnownTypes.Fqn.ErrorType}.{errorTypeName}:"); - code.AppendLine($" return {ctx.WrapReturn(factory)};"); - } - } - - code.AppendLine(" default:"); - code.AppendLine( - $" return {ctx.WrapReturn(ErrorMapping.GetFactory(ErrorMapping.Failure))};"); - code.AppendLine(" }"); - } - - private static void EmitFormContentTypeGuard(StringBuilder code) - { - code.AppendLine( - " if (!ctx.Request.HasFormContentType) return BindFail415();"); - code.AppendLine(" var form = await ctx.Request.ReadFormAsync(ctx.RequestAborted);"); - code.AppendLine(); - } - - /// - /// Aggregated version set information with hasVersioning flag. - /// - private readonly record struct VersionSetContext( - ImmutableArray AllVersions, - bool HasVersionNeutralEndpoints) - { - public bool HasVersioning => !AllVersions.IsDefaultOrEmpty || HasVersionNeutralEndpoints; - } - - /// - /// Context for invoker emission, holding precomputed values and providing helper methods. - /// - private readonly record struct InvokerContext( - SuccessResponseInfo SuccessInfo, - UnionTypeResult UnionResult, - bool HasFormBinding, - bool HasBodyBinding, - bool NeedsAwait, - int Index) - { - public string WrapperName => $"Invoke_Ep{Index}"; - public string CoreName => $"Invoke_Ep{Index}_Core"; - - public string WrapReturn(string expr) - { - return NeedsAwait ? expr : $"Task.FromResult<{UnionResult.ReturnTypeFqn}>({expr})"; - } - } } diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.ErrorInference.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.ErrorInference.cs new file mode 100644 index 0000000..635382f --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.ErrorInference.cs @@ -0,0 +1,274 @@ +using ANcpLua.Roslyn.Utilities.Models; +using ErrorOr.Analyzers; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace ErrorOr.Generators; + +/// +/// Body-walking error inference: descends through a handler's syntax tree +/// (and through called same-assembly symbols) to collect every Error.X() and +/// Error.Custom("code", ...) factory invocation. Drives the union-type computation +/// and the [ProducesError] documentation diagnostics. +/// +public sealed partial class ErrorOrEndpointGenerator +{ + private static (EquatableArray ErrorTypeNames, EquatableArray CustomErrors) + InferErrorTypesFromMethod( + GeneratorAttributeSyntaxContext ctx, + ISymbol method, + ErrorOrContext context, + ImmutableArray.Builder diagnostics, + bool hasExplicitProducesError) + { + if (GetMethodBody(method) is not { } body) return (default, default); + + var methodName = method.Name; + var (errorTypeNames, customErrors) = CollectErrorTypes(ctx.SemanticModel, body, context, diagnostics, + methodName, + hasExplicitProducesError); + return (ToSortedErrorArray(errorTypeNames), new EquatableArray([.. customErrors])); + } + + private static SyntaxNode? GetMethodBody(ISymbol method) + { + var refs = method.DeclaringSyntaxReferences; + if (refs.IsDefaultOrEmpty || refs.Length is 0) return null; + + var syntax = refs[0].GetSyntax(); + return syntax switch + { + MethodDeclarationSyntax m => (SyntaxNode?)m.Body ?? m.ExpressionBody, + LocalFunctionStatementSyntax f => (SyntaxNode?)f.Body ?? f.ExpressionBody, + _ => null + }; + } + + private static (HashSet ErrorTypeNames, List CustomErrors) CollectErrorTypes( + SemanticModel semanticModel, + SyntaxNode body, + ErrorOrContext context, + ImmutableArray.Builder diagnostics, + string endpointMethodName, + bool hasExplicitProducesError) + { + var set = new HashSet(StringComparer.Ordinal); + var customErrors = new List(); + var visitedSymbols = new HashSet(SymbolEqualityComparer.Default); + var seenCustomCodes = new HashSet(StringComparer.Ordinal); + CollectErrorTypesRecursive(semanticModel, body, set, customErrors, visitedSymbols, seenCustomCodes, + context, diagnostics, endpointMethodName, hasExplicitProducesError); + return (set, customErrors); + } + + private static void CollectErrorTypesRecursive( + SemanticModel semanticModel, + SyntaxNode node, + ISet errorTypeNames, + ICollection customErrors, + ISet visitedSymbols, + ISet seenCustomCodes, + ErrorOrContext context, + ImmutableArray.Builder diagnostics, + string endpointMethodName, + bool hasExplicitProducesError) + { + foreach (var child in node.DescendantNodes()) + { + ProcessNode(semanticModel, child, errorTypeNames, customErrors, visitedSymbols, seenCustomCodes, context, + diagnostics, endpointMethodName, hasExplicitProducesError); + } + } + + private static void ProcessNode( + SemanticModel semanticModel, + SyntaxNode child, + ISet errorTypeNames, + ICollection customErrors, + ISet visitedSymbols, + ISet seenCustomCodes, + ErrorOrContext context, + ImmutableArray.Builder diagnostics, + string endpointMethodName, + bool hasExplicitProducesError) + { + if (TryHandleErrorFactoryInvocation( + semanticModel, + child, + errorTypeNames, + customErrors, + seenCustomCodes, + context, + diagnostics)) + { + return; + } + + // Check for interface/abstract method calls that return ErrorOr + if (TryDetectUndocumentedInterfaceCall( + semanticModel, + child, + context, + endpointMethodName, + hasExplicitProducesError, + diagnostics, + errorTypeNames, + customErrors, + seenCustomCodes)) + { + return; + } + + if (!TryGetReferencedSymbol(semanticModel, child, visitedSymbols, out var symbol)) return; + + foreach (var reference in symbol.DeclaringSyntaxReferences) + { + var bodyToScan = GetBodyToScan(reference.GetSyntax()); + if (bodyToScan is not null) + { + CollectErrorTypesRecursive(semanticModel, bodyToScan, errorTypeNames, customErrors, + visitedSymbols, seenCustomCodes, context, diagnostics, endpointMethodName, + hasExplicitProducesError); + } + } + } + + private static bool TryHandleErrorFactoryInvocation( + SemanticModel semanticModel, + SyntaxNode node, + ISet errorTypeNames, + ICollection customErrors, + ISet seenCustomCodes, + ErrorOrContext context, + ImmutableArray.Builder diagnostics) + { + if (!IsErrorFactoryInvocation(semanticModel, node, context, out var factoryName, out var invocation)) + return false; + + // Validate and return the factory name if it's a known ErrorType + if (ErrorMapping.IsKnownErrorType(factoryName)) + { + errorTypeNames.Add(factoryName); + return true; + } + + if (factoryName == "Custom" && invocation is not null) + { + var customInfo = ExtractCustomErrorInfo(semanticModel, invocation); + if (customInfo is { } info && seenCustomCodes.Add(info.ErrorCode)) customErrors.Add(info); + + return true; + } + + // Unknown factory method - report diagnostic + // This fails loud instead of silently ignoring it or falling back to a default + diagnostics.Add(DiagnosticInfo.Create( + Descriptors.UnknownErrorFactory, + node.GetLocation(), + factoryName)); + + return true; + } + + private static bool TryGetReferencedSymbol( + SemanticModel semanticModel, + SyntaxNode node, + ISet visitedSymbols, + [NotNullWhen(true)] out ISymbol? symbol) + { + // Conditional assignment: only resolve symbol for relevant syntax nodes + symbol = node is IdentifierNameSyntax or MemberAccessExpressionSyntax + ? semanticModel.GetSymbolInfo(node).Symbol + : null; + + // Chained guards with short-circuit evaluation: + // 1. Type check (also handles null) + // 2. Same-assembly check (avoid external symbols) + // - ILocalSymbol has no ContainingAssembly but is always in scope (local to current method) + // 3. Add to visited (side-effect only if we'll use it, returns false if duplicate) + return symbol is IPropertySymbol or IFieldSymbol or ILocalSymbol or IMethodSymbol && + (symbol is ILocalSymbol || + symbol.ContainingAssembly?.IsEqualTo(semanticModel.Compilation.Assembly) == true) && + visitedSymbols.Add(symbol); + } + + private static SyntaxNode? GetBodyToScan(SyntaxNode syntax) + { + return syntax switch + { + PropertyDeclarationSyntax p => (SyntaxNode?)p.ExpressionBody ?? p.AccessorList, + MethodDeclarationSyntax m => (SyntaxNode?)m.Body ?? m.ExpressionBody, + VariableDeclaratorSyntax v => v.Initializer, + _ => syntax + }; + } + + private static CustomErrorInfo? ExtractCustomErrorInfo(SemanticModel semanticModel, + InvocationExpressionSyntax invocation) + { + // Error.Custom(int type, string code, string description, Dictionary? metadata = null) + // The 'code' parameter (second arg) is what we want for deduplication + var args = invocation.ArgumentList.Arguments; + if (args.Count < 2) return null; + + // Try to extract the 'code' (second argument) + var codeArg = args[1].Expression; + string? errorCode = null; + + // Try constant folding + var constantValue = semanticModel.GetConstantValue(codeArg); + if (constantValue is { HasValue: true, Value: string codeStr }) + errorCode = codeStr; + else if (codeArg is LiteralExpressionSyntax { Token.Value: string literalStr }) errorCode = literalStr; + + // Pattern matching establishes non-null for compiler + if (errorCode is not { Length: > 0 } code) return null; + + return new CustomErrorInfo(code); + } + + private static bool IsErrorFactoryInvocation( + SemanticModel semanticModel, + SyntaxNode node, + ErrorOrContext context, + out string factoryName, + out InvocationExpressionSyntax? invocation) + { + factoryName = string.Empty; + invocation = null; + + if (node is not InvocationExpressionSyntax inv) return false; + + invocation = inv; + + // Fast-path: Error.X(...) where Error is a simple identifier + if (inv.Expression is MemberAccessExpressionSyntax + { + Expression: IdentifierNameSyntax { Identifier.Text: "Error" }, + Name: IdentifierNameSyntax { Identifier.Text: var name } + }) + { + factoryName = name; + return true; + } + + // Semantic fallback: resolve invoked method and ensure it's actually ErrorOr.Error. + if (semanticModel.GetSymbolInfo(inv).Symbol is not IMethodSymbol symbol || + !ErrorOrContext.MatchesType(symbol.ContainingType, WellKnownTypes.ErrorStruct)) + { + return false; + } + + factoryName = symbol.Name; + return true; + } + + private static EquatableArray ToSortedErrorArray(HashSet set) + { + if (set.Count is 0) return default; + + var array = set.ToArray(); + Array.Sort(array, StringComparer.Ordinal); + return new EquatableArray([.. array]); + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.InterfaceDetection.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.InterfaceDetection.cs new file mode 100644 index 0000000..078cf8b --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.InterfaceDetection.cs @@ -0,0 +1,138 @@ +using ANcpLua.Roslyn.Utilities.Models; +using ErrorOr.Analyzers; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace ErrorOr.Generators; + +/// +/// Detects calls to interface or abstract methods that return ErrorOr<T>. +/// If the called declaration carries [ReturnsError], the errors flow into the +/// union-type computation. Otherwise, if the endpoint itself isn't documented with +/// [ProducesError], the EOE024 diagnostic fires (fail-loud rather than +/// silently producing a 500-only Results union). +/// +public sealed partial class ErrorOrEndpointGenerator +{ + /// + /// Detects calls to interface/abstract methods returning ErrorOr. + /// If the interface method has [ReturnsError] attributes, extract them. + /// If not, and endpoint has no [ProducesError], emit ERROR (FAIL LOUD). + /// + private static bool TryDetectUndocumentedInterfaceCall( + SemanticModel semanticModel, + SyntaxNode node, + ErrorOrContext context, + string endpointMethodName, + bool hasExplicitProducesError, + ImmutableArray.Builder diagnostics, + ISet errorTypeNames, + ICollection customErrors, + ISet seenCustomCodes) + { + // Only check invocation expressions + if (node is not InvocationExpressionSyntax invocation) return false; + + var symbolInfo = semanticModel.GetSymbolInfo(invocation); + if (symbolInfo.Symbol is not IMethodSymbol methodSymbol) return false; + + // Check if method returns ErrorOr + if (!ReturnsErrorOr(methodSymbol, context)) return false; + + // Check if it's an interface or abstract method (no implementation to scan) + var containingType = methodSymbol.ContainingType; + var isInterfaceOrAbstract = containingType?.TypeKind == TypeKind.Interface || + methodSymbol.IsAbstract || + methodSymbol.IsVirtual; + + if (!isInterfaceOrAbstract) return false; + + // Try to extract [ReturnsError] attributes from the interface method + var hasReturnsError = TryExtractReturnsErrorAttributes( + methodSymbol, context, errorTypeNames, customErrors, seenCustomCodes); + + if (hasReturnsError) return true; // Successfully extracted errors from interface + + // If endpoint already has [ProducesError] attributes, assume developer knows what they're doing + if (hasExplicitProducesError) return true; // No error, endpoint is explicitly documented + + // FAIL LOUD: Interface call without documentation + var methodDisplayName = $"{containingType?.Name ?? "?"}.{methodSymbol.Name}"; + diagnostics.Add(DiagnosticInfo.Create( + Descriptors.UndocumentedInterfaceCall, + node.GetLocation(), + endpointMethodName, + methodDisplayName)); + + return true; + } + + /// + /// Extracts [ReturnsError] attributes from an interface/abstract method. + /// Returns true if any [ReturnsError] attributes were found. + /// + private static bool TryExtractReturnsErrorAttributes( + ISymbol method, + ErrorOrContext context, + ISet errorTypeNames, + ICollection customErrors, + ISet seenCustomCodes) + { + var foundAny = false; + + foreach (var attr in method.GetAttributes()) + { + if (!ErrorOrContext.MatchesType(attr.AttributeClass, WellKnownTypes.ReturnsErrorAttribute)) continue; + + var args = attr.ConstructorArguments; + if (args.Length < 2) continue; + + // Distinguish constructors by the first argument's type: + // 1. ReturnsErrorAttribute(ErrorType errorType, string errorCode) — args[0].Type is enum + // 2. ReturnsErrorAttribute(int statusCode, string errorCode) — args[0].Type is int + if (args[0].Value is not int intValue || args[1].Value is not string errorCode) continue; + + if (args[0].Type is INamedTypeSymbol { TypeKind: TypeKind.Enum }) + { + // Standard ErrorType — map enum int value to string name + var errorTypeName = MapEnumValueToName(intValue); + if (errorTypeName is not null) errorTypeNames.Add(errorTypeName); + } + else + { + // Custom error with explicit HTTP status code + if (seenCustomCodes.Add(errorCode)) customErrors.Add(new CustomErrorInfo(errorCode)); + } + + foundAny = true; + } + + return foundAny; + } + + /// + /// Maps runtime ErrorType enum integer value to its name. + /// The enum values are: Failure=0, Unexpected=1, Validation=2, Conflict=3, NotFound=4, Unauthorized=5, Forbidden=6 + /// + private static string? MapEnumValueToName(int enumValue) + { + return enumValue switch + { + 0 => ErrorMapping.Failure, + 1 => ErrorMapping.Unexpected, + 2 => ErrorMapping.Validation, + 3 => ErrorMapping.Conflict, + 4 => ErrorMapping.NotFound, + 5 => ErrorMapping.Unauthorized, + 6 => ErrorMapping.Forbidden, + _ => null + }; + } + + private static bool ReturnsErrorOr(IMethodSymbol method, ErrorOrContext context) + { + // Reuse existing helpers - unwrap Task/ValueTask, then check for ErrorOr + var unwrapped = method.ReturnType.GetTaskResultType() ?? method.ReturnType; + return IsErrorOrType(unwrapped, context, out _); + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.cs index 6001a89..258ea7c 100644 --- a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.cs +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Extractor.cs @@ -1,12 +1,17 @@ -using ANcpLua.Roslyn.Utilities.Models; -using ErrorOr.Analyzers; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; namespace ErrorOr.Generators; /// -/// Partial class containing extraction and parameter binding logic. +/// Partial class containing return-type extraction logic. +/// Determines the success type, async-ness, SSE shape, and inaccessibility / open-generic +/// diagnostics for a handler method's ErrorOr<T> return type. +/// Sibling extractor partials: +/// +/// Extractor.ErrorInference.cs — body-walking to collect Error.X() calls. +/// Extractor.InterfaceDetection.cs — undocumented interface-call detection. +/// Extractor.Metadata.cs — endpoint metadata (route, attributes, middleware). +/// /// public sealed partial class ErrorOrEndpointGenerator { @@ -168,23 +173,6 @@ private static bool IsErrorOrType( return false; } - private static (EquatableArray ErrorTypeNames, EquatableArray CustomErrors) - InferErrorTypesFromMethod( - GeneratorAttributeSyntaxContext ctx, - ISymbol method, - ErrorOrContext context, - ImmutableArray.Builder diagnostics, - bool hasExplicitProducesError) - { - if (GetMethodBody(method) is not { } body) return (default, default); - - var methodName = method.Name; - var (errorTypeNames, customErrors) = CollectErrorTypes(ctx.SemanticModel, body, context, diagnostics, - methodName, - hasExplicitProducesError); - return (ToSortedErrorArray(errorTypeNames), new EquatableArray([.. customErrors])); - } - private static EquatableArray ExtractProducesErrorAttributes( ISymbol method, ErrorOrContext context) @@ -212,369 +200,4 @@ private static bool HasAcceptedResponseAttribute(ISymbol method, ErrorOrContext { return ErrorOrContext.HasAttribute(method, WellKnownTypes.AcceptedResponseAttribute); } - - private static SyntaxNode? GetMethodBody(ISymbol method) - { - var refs = method.DeclaringSyntaxReferences; - if (refs.IsDefaultOrEmpty || refs.Length is 0) return null; - - var syntax = refs[0].GetSyntax(); - return syntax switch - { - MethodDeclarationSyntax m => (SyntaxNode?)m.Body ?? m.ExpressionBody, - LocalFunctionStatementSyntax f => (SyntaxNode?)f.Body ?? f.ExpressionBody, - _ => null - }; - } - - private static (HashSet ErrorTypeNames, List CustomErrors) CollectErrorTypes( - SemanticModel semanticModel, - SyntaxNode body, - ErrorOrContext context, - ImmutableArray.Builder diagnostics, - string endpointMethodName, - bool hasExplicitProducesError) - { - var set = new HashSet(StringComparer.Ordinal); - var customErrors = new List(); - var visitedSymbols = new HashSet(SymbolEqualityComparer.Default); - var seenCustomCodes = new HashSet(StringComparer.Ordinal); - CollectErrorTypesRecursive(semanticModel, body, set, customErrors, visitedSymbols, seenCustomCodes, - context, diagnostics, endpointMethodName, hasExplicitProducesError); - return (set, customErrors); - } - - private static void CollectErrorTypesRecursive( - SemanticModel semanticModel, - SyntaxNode node, - ISet errorTypeNames, - ICollection customErrors, - ISet visitedSymbols, - ISet seenCustomCodes, - ErrorOrContext context, - ImmutableArray.Builder diagnostics, - string endpointMethodName, - bool hasExplicitProducesError) - { - foreach (var child in node.DescendantNodes()) - { - ProcessNode(semanticModel, child, errorTypeNames, customErrors, visitedSymbols, seenCustomCodes, context, - diagnostics, endpointMethodName, hasExplicitProducesError); - } - } - - private static void ProcessNode( - SemanticModel semanticModel, - SyntaxNode child, - ISet errorTypeNames, - ICollection customErrors, - ISet visitedSymbols, - ISet seenCustomCodes, - ErrorOrContext context, - ImmutableArray.Builder diagnostics, - string endpointMethodName, - bool hasExplicitProducesError) - { - if (TryHandleErrorFactoryInvocation( - semanticModel, - child, - errorTypeNames, - customErrors, - seenCustomCodes, - context, - diagnostics)) - { - return; - } - - // Check for interface/abstract method calls that return ErrorOr - if (TryDetectUndocumentedInterfaceCall( - semanticModel, - child, - context, - endpointMethodName, - hasExplicitProducesError, - diagnostics, - errorTypeNames, - customErrors, - seenCustomCodes)) - { - return; - } - - if (!TryGetReferencedSymbol(semanticModel, child, visitedSymbols, out var symbol)) return; - - foreach (var reference in symbol.DeclaringSyntaxReferences) - { - var bodyToScan = GetBodyToScan(reference.GetSyntax()); - if (bodyToScan is not null) - { - CollectErrorTypesRecursive(semanticModel, bodyToScan, errorTypeNames, customErrors, - visitedSymbols, seenCustomCodes, context, diagnostics, endpointMethodName, - hasExplicitProducesError); - } - } - } - - /// - /// Detects calls to interface/abstract methods returning ErrorOr. - /// If the interface method has [ReturnsError] attributes, extract them. - /// If not, and endpoint has no [ProducesError], emit ERROR (FAIL LOUD). - /// - private static bool TryDetectUndocumentedInterfaceCall( - SemanticModel semanticModel, - SyntaxNode node, - ErrorOrContext context, - string endpointMethodName, - bool hasExplicitProducesError, - ImmutableArray.Builder diagnostics, - ISet errorTypeNames, - ICollection customErrors, - ISet seenCustomCodes) - { - // Only check invocation expressions - if (node is not InvocationExpressionSyntax invocation) return false; - - var symbolInfo = semanticModel.GetSymbolInfo(invocation); - if (symbolInfo.Symbol is not IMethodSymbol methodSymbol) return false; - - // Check if method returns ErrorOr - if (!ReturnsErrorOr(methodSymbol, context)) return false; - - // Check if it's an interface or abstract method (no implementation to scan) - var containingType = methodSymbol.ContainingType; - var isInterfaceOrAbstract = containingType?.TypeKind == TypeKind.Interface || - methodSymbol.IsAbstract || - methodSymbol.IsVirtual; - - if (!isInterfaceOrAbstract) return false; - - // Try to extract [ReturnsError] attributes from the interface method - var hasReturnsError = TryExtractReturnsErrorAttributes( - methodSymbol, context, errorTypeNames, customErrors, seenCustomCodes); - - if (hasReturnsError) return true; // Successfully extracted errors from interface - - // If endpoint already has [ProducesError] attributes, assume developer knows what they're doing - if (hasExplicitProducesError) return true; // No error, endpoint is explicitly documented - - // FAIL LOUD: Interface call without documentation - var methodDisplayName = $"{containingType?.Name ?? "?"}.{methodSymbol.Name}"; - diagnostics.Add(DiagnosticInfo.Create( - Descriptors.UndocumentedInterfaceCall, - node.GetLocation(), - endpointMethodName, - methodDisplayName)); - - return true; - } - - /// - /// Extracts [ReturnsError] attributes from an interface/abstract method. - /// Returns true if any [ReturnsError] attributes were found. - /// - private static bool TryExtractReturnsErrorAttributes( - ISymbol method, - ErrorOrContext context, - ISet errorTypeNames, - ICollection customErrors, - ISet seenCustomCodes) - { - var foundAny = false; - - foreach (var attr in method.GetAttributes()) - { - if (!ErrorOrContext.MatchesType(attr.AttributeClass, WellKnownTypes.ReturnsErrorAttribute)) continue; - - var args = attr.ConstructorArguments; - if (args.Length < 2) continue; - - // Distinguish constructors by the first argument's type: - // 1. ReturnsErrorAttribute(ErrorType errorType, string errorCode) — args[0].Type is enum - // 2. ReturnsErrorAttribute(int statusCode, string errorCode) — args[0].Type is int - if (args[0].Value is not int intValue || args[1].Value is not string errorCode) continue; - - if (args[0].Type is INamedTypeSymbol { TypeKind: TypeKind.Enum }) - { - // Standard ErrorType — map enum int value to string name - var errorTypeName = MapEnumValueToName(intValue); - if (errorTypeName is not null) errorTypeNames.Add(errorTypeName); - } - else - { - // Custom error with explicit HTTP status code - if (seenCustomCodes.Add(errorCode)) customErrors.Add(new CustomErrorInfo(errorCode)); - } - - foundAny = true; - } - - return foundAny; - } - - /// - /// Maps runtime ErrorType enum integer value to its name. - /// The enum values are: Failure=0, Unexpected=1, Validation=2, Conflict=3, NotFound=4, Unauthorized=5, Forbidden=6 - /// - private static string? MapEnumValueToName(int enumValue) - { - return enumValue switch - { - 0 => ErrorMapping.Failure, - 1 => ErrorMapping.Unexpected, - 2 => ErrorMapping.Validation, - 3 => ErrorMapping.Conflict, - 4 => ErrorMapping.NotFound, - 5 => ErrorMapping.Unauthorized, - 6 => ErrorMapping.Forbidden, - _ => null - }; - } - - private static bool ReturnsErrorOr(IMethodSymbol method, ErrorOrContext context) - { - // Reuse existing helpers - unwrap Task/ValueTask, then check for ErrorOr - var unwrapped = method.ReturnType.GetTaskResultType() ?? method.ReturnType; - return IsErrorOrType(unwrapped, context, out _); - } - - private static bool TryHandleErrorFactoryInvocation( - SemanticModel semanticModel, - SyntaxNode node, - ISet errorTypeNames, - ICollection customErrors, - ISet seenCustomCodes, - ErrorOrContext context, - ImmutableArray.Builder diagnostics) - { - if (!IsErrorFactoryInvocation(semanticModel, node, context, out var factoryName, out var invocation)) - return false; - - // Validate and return the factory name if it's a known ErrorType - if (ErrorMapping.IsKnownErrorType(factoryName)) - { - errorTypeNames.Add(factoryName); - return true; - } - - if (factoryName == "Custom" && invocation is not null) - { - var customInfo = ExtractCustomErrorInfo(semanticModel, invocation); - if (customInfo is { } info && seenCustomCodes.Add(info.ErrorCode)) customErrors.Add(info); - - return true; - } - - // Unknown factory method - report diagnostic - // This fails loud instead of silently ignoring it or falling back to a default - diagnostics.Add(DiagnosticInfo.Create( - Descriptors.UnknownErrorFactory, - node.GetLocation(), - factoryName)); - - return true; - } - - private static bool TryGetReferencedSymbol( - SemanticModel semanticModel, - SyntaxNode node, - ISet visitedSymbols, - [NotNullWhen(true)] out ISymbol? symbol) - { - // Conditional assignment: only resolve symbol for relevant syntax nodes - symbol = node is IdentifierNameSyntax or MemberAccessExpressionSyntax - ? semanticModel.GetSymbolInfo(node).Symbol - : null; - - // Chained guards with short-circuit evaluation: - // 1. Type check (also handles null) - // 2. Same-assembly check (avoid external symbols) - // - ILocalSymbol has no ContainingAssembly but is always in scope (local to current method) - // 3. Add to visited (side-effect only if we'll use it, returns false if duplicate) - return symbol is IPropertySymbol or IFieldSymbol or ILocalSymbol or IMethodSymbol && - (symbol is ILocalSymbol || - symbol.ContainingAssembly?.IsEqualTo(semanticModel.Compilation.Assembly) == true) && - visitedSymbols.Add(symbol); - } - - private static SyntaxNode? GetBodyToScan(SyntaxNode syntax) - { - return syntax switch - { - PropertyDeclarationSyntax p => (SyntaxNode?)p.ExpressionBody ?? p.AccessorList, - MethodDeclarationSyntax m => (SyntaxNode?)m.Body ?? m.ExpressionBody, - VariableDeclaratorSyntax v => v.Initializer, - _ => syntax - }; - } - - private static CustomErrorInfo? ExtractCustomErrorInfo(SemanticModel semanticModel, - InvocationExpressionSyntax invocation) - { - // Error.Custom(int type, string code, string description, Dictionary? metadata = null) - // The 'code' parameter (second arg) is what we want for deduplication - var args = invocation.ArgumentList.Arguments; - if (args.Count < 2) return null; - - // Try to extract the 'code' (second argument) - var codeArg = args[1].Expression; - string? errorCode = null; - - // Try constant folding - var constantValue = semanticModel.GetConstantValue(codeArg); - if (constantValue is { HasValue: true, Value: string codeStr }) - errorCode = codeStr; - else if (codeArg is LiteralExpressionSyntax { Token.Value: string literalStr }) errorCode = literalStr; - - // Pattern matching establishes non-null for compiler - if (errorCode is not { Length: > 0 } code) return null; - - return new CustomErrorInfo(code); - } - - private static bool IsErrorFactoryInvocation( - SemanticModel semanticModel, - SyntaxNode node, - ErrorOrContext context, - out string factoryName, - out InvocationExpressionSyntax? invocation) - { - factoryName = string.Empty; - invocation = null; - - if (node is not InvocationExpressionSyntax inv) return false; - - invocation = inv; - - // Fast-path: Error.X(...) where Error is a simple identifier - if (inv.Expression is MemberAccessExpressionSyntax - { - Expression: IdentifierNameSyntax { Identifier.Text: "Error" }, - Name: IdentifierNameSyntax { Identifier.Text: var name } - }) - { - factoryName = name; - return true; - } - - // Semantic fallback: resolve invoked method and ensure it's actually ErrorOr.Error. - if (semanticModel.GetSymbolInfo(inv).Symbol is not IMethodSymbol symbol || - !ErrorOrContext.MatchesType(symbol.ContainingType, WellKnownTypes.ErrorStruct)) - { - return false; - } - - factoryName = symbol.Name; - return true; - } - - private static EquatableArray ToSortedErrorArray(HashSet set) - { - if (set.Count is 0) return default; - - var array = set.ToArray(); - Array.Sort(array, StringComparer.Ordinal); - return new EquatableArray([.. array]); - } - } diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.Attributes.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.Attributes.cs new file mode 100644 index 0000000..f63af9e --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.Attributes.cs @@ -0,0 +1,127 @@ +using Microsoft.CodeAnalysis; + +namespace ErrorOr.Generators; + +/// +/// Marker-attribute emission via RegisterPostInitializationOutput. +/// Defines [Get], [Post], [Put], [Delete], [Patch], +/// [ErrorOrEndpoint], [ProducesError], [AcceptedResponse], +/// [ReturnsError], and [RouteGroup] as compile-time-injected types in the +/// consumer's ErrorOr namespace. Must live with this generator (not +/// OpenApiTransformerGenerator) because ForAttributeWithMetadataName only +/// sees types injected by the same generator. +/// +public sealed partial class ErrorOrEndpointGenerator +{ + private static void EmitAttributes(IncrementalGeneratorPostInitializationContext context) + { + const string source = """ + // + #nullable enable + + namespace ErrorOr + { + /// + /// Marks a static method as an ErrorOr endpoint with explicit HTTP method and route. + /// Prefer using [Get], [Post], [Put], [Delete], or [Patch] for standard HTTP methods. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] + public sealed class ErrorOrEndpointAttribute : global::System.Attribute + { + public ErrorOrEndpointAttribute(string httpMethod, string route) + { + HttpMethod = httpMethod; + Route = route; + } + public string HttpMethod { get; } + public string Route { get; } + } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] + public sealed class GetAttribute : global::System.Attribute + { + public GetAttribute(string route) => Route = route; + public string Route { get; } + } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] + public sealed class PostAttribute : global::System.Attribute + { + public PostAttribute(string route) => Route = route; + public string Route { get; } + } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] + public sealed class PutAttribute : global::System.Attribute + { + public PutAttribute(string route) => Route = route; + public string Route { get; } + } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] + public sealed class DeleteAttribute : global::System.Attribute + { + public DeleteAttribute(string route) => Route = route; + public string Route { get; } + } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] + public sealed class PatchAttribute : global::System.Attribute + { + public PatchAttribute(string route) => Route = route; + public string Route { get; } + } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + public sealed class ProducesErrorAttribute : global::System.Attribute + { + public ProducesErrorAttribute(int statusCode, string errorType) + { + StatusCode = statusCode; + ErrorType = errorType; + } + public int StatusCode { get; } + public string ErrorType { get; } + } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] + public sealed class AcceptedResponseAttribute : global::System.Attribute { } + + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + public sealed class ReturnsErrorAttribute : global::System.Attribute + { + public ReturnsErrorAttribute(global::ErrorOr.ErrorType errorType, string errorCode) + { + ErrorType = errorType; + ErrorCode = errorCode; + } + public ReturnsErrorAttribute(int statusCode, string errorCode) + { + StatusCode = statusCode; + ErrorCode = errorCode; + ErrorType = null; + } + public global::ErrorOr.ErrorType? ErrorType { get; } + public int? StatusCode { get; } + public string ErrorCode { get; } + } + + /// + /// Marks a class as a route group for versioned API endpoints. + /// All endpoints in the class will be mapped under the specified path prefix + /// using the eShop-style NewVersionedApi() pattern when combined with [ApiVersion]. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = false)] + public sealed class RouteGroupAttribute : global::System.Attribute + { + public RouteGroupAttribute(string path) => Path = path; + public string Path { get; } + public string? ApiName { get; set; } + } + } + """; + + // Use a different file name to avoid conflicts with OpenApiTransformerGenerator + context.AddSource("ErrorOrEndpointAttributes.Mappings.g.cs", source); + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.EndpointFlow.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.EndpointFlow.cs new file mode 100644 index 0000000..e55f512 --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.EndpointFlow.cs @@ -0,0 +1,314 @@ +using ANcpLua.Roslyn.Utilities; +using ANcpLua.Roslyn.Utilities.Models; +using ErrorOr.Analyzers; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace ErrorOr.Generators; + +/// +/// Per-attribute endpoint discovery and descriptor construction. The pipeline: +/// +/// fans out one SyntaxProvider per HTTP-method attribute. +/// shape-validates the method via DiagnosticFlow railway. +/// binds parameters, validates routes/versions, and builds the descriptor. +/// +/// +public sealed partial class ErrorOrEndpointGenerator +{ + private static IncrementalValueProvider> CombineHttpMethodProviders( + IncrementalGeneratorInitializationContext context, + IncrementalValueProvider errorOrContextProvider) + { + var getProvider = CreateEndpointProvider(context, WellKnownTypes.GetAttribute, errorOrContextProvider); + var postProvider = CreateEndpointProvider(context, WellKnownTypes.PostAttribute, errorOrContextProvider); + var putProvider = CreateEndpointProvider(context, WellKnownTypes.PutAttribute, errorOrContextProvider); + var deleteProvider = CreateEndpointProvider(context, WellKnownTypes.DeleteAttribute, errorOrContextProvider); + var patchProvider = CreateEndpointProvider(context, WellKnownTypes.PatchAttribute, errorOrContextProvider); + var baseProvider = + CreateEndpointProvider(context, WellKnownTypes.ErrorOrEndpointAttribute, errorOrContextProvider); + + return IncrementalProviderExtensions.CombineAll( + getProvider, postProvider, putProvider, + deleteProvider, patchProvider, baseProvider); + } + + private static IncrementalValuesProvider CreateEndpointProvider( + IncrementalGeneratorInitializationContext context, + string attributeName, + IncrementalValueProvider errorOrContextProvider) + { + return context.SyntaxProvider + .ForAttributeWithMetadataName( + attributeName, + static (node, _) => node is MethodDeclarationSyntax, + static (ctx, _) => ctx) + .Combine(errorOrContextProvider) + .SelectFlow(static (pair, ct) => + { + var (ctx, errorOrContext) = pair; + return AnalyzeEndpointFlow(ctx, errorOrContext, ct); + }) + .WithTrackingName(TrackingNames.EndpointBindingFlow(attributeName)) + .ReportAndContinue(context) + .SelectMany(static (endpoints, _) => endpoints.AsImmutableArray()); + } + + private static DiagnosticFlow> AnalyzeEndpointFlow( + GeneratorAttributeSyntaxContext ctx, + ErrorOrContext errorOrContext, + CancellationToken ct) + { + if (ctx.TargetSymbol is not IMethodSymbol method || ctx.Attributes.IsDefaultOrEmpty) + return Helpers.EmptyEndpointFlow(); + + var location = method.Locations.FirstOrDefault() ?? Location.None; + + // 1. Validate shape using SemanticGuard + DiagnosticFlow (The Railway Pattern) + var methodAnalysisFlow = SemanticGuard.For(method) + .MustBeStatic(DiagnosticInfo.Create(Descriptors.NonStaticHandler, location, method.Name)) + .ToFlow() + .Then(m => + { + var returnInfo = ExtractErrorOrReturnType(m.ReturnType, errorOrContext); + + // EOE015: Anonymous return type + if (returnInfo.IsAnonymousType) + { + return DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( + DiagnosticInfo.Create(Descriptors.AnonymousReturnTypeNotSupported, location, m.Name)); + } + + // EOE018: Inaccessible return type + if (returnInfo.IsInaccessibleType) + { + return DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( + DiagnosticInfo.Create(Descriptors.InaccessibleTypeNotSupported, location, + returnInfo.InaccessibleTypeName ?? "unknown", + m.Name, + returnInfo.InaccessibleTypeAccessibility ?? "private")); + } + + // EOE019: Type parameter in return type + if (returnInfo.IsTypeParameter) + { + return DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( + DiagnosticInfo.Create(Descriptors.TypeParameterNotSupported, location, + m.Name, + returnInfo.TypeParameterName ?? "T")); + } + + return returnInfo.SuccessTypeFqn is not null + ? DiagnosticFlow.Ok((m, returnInfo)) + : DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( + DiagnosticInfo.Create(Descriptors.InvalidReturnType, location, m.Name)); + }) + .Then(pair => + { + var (m, returnInfo) = pair; + var builder = ImmutableArray.CreateBuilder(); + + // EOE033: Validate PascalCase naming convention + if (NamingValidator.ValidatePascalCase(m.Name, location) is { } namingDiagnostic) + builder.Add(namingDiagnostic); + + // Extract method-level attributes first (needed for interface call detection) + var producesErrors = ExtractProducesErrorAttributes(m, errorOrContext); + var isAcceptedResponse = HasAcceptedResponseAttribute(m, errorOrContext); + var hasExplicitProducesError = !producesErrors.IsDefaultOrEmpty; + + // Extract middleware attributes (BCL: Authorize, RateLimiting, OutputCache, CORS) + var middleware = ExtractMiddlewareAttributes(m, errorOrContext); + + // Infer errors once per method (now with interface call detection) + var (inferredErrors, customErrors) = + InferErrorTypesFromMethod(ctx, m, errorOrContext, builder, hasExplicitProducesError); + + var analysis = new MethodAnalysis( + returnInfo, + inferredErrors, + customErrors, + producesErrors, + isAcceptedResponse, + middleware); + + var flow = DiagnosticFlow.Ok(analysis); + foreach (var diag in builder) + flow = flow.Warn(diag); + + return flow; + }); + + // 2. Map method analysis to individual attribute descriptors + var flows = ImmutableArray.CreateBuilder>(ctx.Attributes.Length); + foreach (var attr in ctx.Attributes) + { + if (attr is null) continue; + + var flow = methodAnalysisFlow.Then(analysis => + ProcessAttributeFlow(method, in analysis, attr, errorOrContext, ct)); + flows.Add(flow); + } + + if (flows.Count is 0) return Helpers.EmptyEndpointFlow(); + + return DiagnosticFlow.Collect(flows.ToImmutable()) + .Select(static endpoints => new EquatableArray(endpoints)); + } + + private static DiagnosticFlow ProcessAttributeFlow( + IMethodSymbol method, + in MethodAnalysis analysis, + AttributeData attr, + ErrorOrContext errorOrContext, + CancellationToken ct) + { + ct.ThrowIfCancellationRequested(); + + if (attr.AttributeClass is not { } attrClass) return DiagnosticFlow.Fail(); + + var attrName = attrClass.Name; + + var (verb, pattern, customMethod) = ExtractHttpMethodAndPattern(attr, attrName); + if (verb is null) return DiagnosticFlow.Fail(); + + // Guard: SuccessTypeFqn validated in upstream .Then() but compiler doesn't know + if (analysis.ReturnInfo.SuccessTypeFqn is not { } successTypeFqn) + return DiagnosticFlow.Fail(); + + var builder = ImmutableArray.CreateBuilder(); + + // Extract route parameters as HashSet for binding + var routeParamInfos = RouteValidator.ExtractRouteParameters(pattern); + var routeParamNames = routeParamInfos + .Select(static r => r.Name) + .ToImmutableHashSet(StringComparer.OrdinalIgnoreCase); + + var bindingFlow = RouteBindingHelper.BindRouteParameters( + method, + routeParamNames, + errorOrContext, + verb.Value); + if (!bindingFlow.IsSuccess) return DiagnosticFlow.Fail(bindingFlow.Diagnostics); + + builder.AddRange(bindingFlow.Diagnostics.AsImmutableArray()); + var bindingAnalysis = bindingFlow.ValueOrDefault(); + + // Validate route pattern + builder.AddRange(RouteValidator.ValidatePattern(pattern, method)); + + // Extract method parameter info for route binding validation + var methodParams = bindingAnalysis.RouteParameters.AsImmutableArray(); + + // Validate route parameters are bound + builder.AddRange(RouteValidator.ValidateParameterBindings( + pattern, routeParamInfos, methodParams, method)); + + // Validate route constraint types + builder.AddRange(RouteValidator.ValidateConstraintTypes( + routeParamInfos, methodParams, method)); + + // Extract API versioning attributes + var versioning = ExtractVersioningAttributes(method, errorOrContext); + + // Validate API versioning configuration (EOE027-EOE031) + var rawClassVersions = ExtractRawClassVersionStrings(method, errorOrContext); + var rawMethodVersions = ExtractRawMethodVersionStrings(method, errorOrContext); + var location = method.Locations.FirstOrDefault() ?? Location.None; + builder.AddRange(ApiVersioningValidator.Validate( + method.Name, +in versioning, + rawClassVersions, + rawMethodVersions, + location, + errorOrContext.HasApiVersioningSupport, + method)); + + // Extract route group configuration for eShop-style grouping + var routeGroup = ExtractRouteGroupInfo(method, errorOrContext); + + // Extract custom endpoint metadata + var metadata = ExtractMetadata(method); + + var descriptor = new EndpointDescriptor( + verb.Value, + pattern, + successTypeFqn, + analysis.ReturnInfo.Kind, + analysis.ReturnInfo.IsAsync, + method.ContainingType?.GetFullyQualifiedName() ?? "Unknown", + method.Name, + bindingAnalysis.Parameters, + new ErrorInferenceInfo( + analysis.InferredErrorTypeNames, + analysis.InferredCustomErrors, + analysis.ProducesErrors), + new SseInfo( + analysis.ReturnInfo.IsSse, + analysis.ReturnInfo.SseItemTypeFqn), + analysis.IsAcceptedResponse, + analysis.ReturnInfo.IdPropertyName, + analysis.Middleware, + versioning, + routeGroup, + metadata, + customMethod); + + var flow = DiagnosticFlow.Ok(descriptor); + foreach (var diag in builder) + flow = flow.Warn(diag); + + return flow; + } + + private static (HttpVerb? Verb, string Pattern, string? CustomMethod) ExtractHttpMethodAndPattern( + AttributeData attr, + string attrName) + { + var verb = HttpVerbExtensions.TryParseFromAttribute(attrName, attr.ConstructorArguments); + + // For ErrorOrEndpointAttribute with unrecognized methods (e.g., "CONNECT", "PROPFIND"), + // store the raw method string so we can emit MapMethods with it + string? customMethod = null; + var isErrorOrEndpoint = attrName.Contains("ErrorOrEndpoint"); + if (verb is null && isErrorOrEndpoint && + attr.ConstructorArguments is [{ Value: string rawMethod }, ..]) + { + customMethod = rawMethod.ToUpperInvariant(); + verb = HttpVerb.Get; // placeholder — MapMethods is used when CustomHttpMethod is set + } + + if (verb is null) return (null, "/", null); + + // Extract pattern - index differs for ErrorOrEndpoint (has httpMethod arg first) + var patternIndex = isErrorOrEndpoint ? 1 : 0; + var pattern = attr.GetConstructorArgument(patternIndex) is { } p + ? p + : "/"; + + return (verb, pattern, customMethod); + } + + /// + /// Incremental pipeline tracking names for caching diagnostics. + /// + private static class TrackingNames + { + public const string ResultsUnionMaxArity = "ResultsUnionMaxArity"; + public const string ErrorOrContext = "ErrorOrContext"; + + public static string EndpointBindingFlow(string attributeName) + { + return attributeName switch + { + WellKnownTypes.GetAttribute => "EndpointBindingFlow.Get", + WellKnownTypes.PostAttribute => "EndpointBindingFlow.Post", + WellKnownTypes.PutAttribute => "EndpointBindingFlow.Put", + WellKnownTypes.DeleteAttribute => "EndpointBindingFlow.Delete", + WellKnownTypes.PatchAttribute => "EndpointBindingFlow.Patch", + WellKnownTypes.ErrorOrEndpointAttribute => "EndpointBindingFlow.Custom", + _ => "EndpointBindingFlow.Unknown" + }; + } + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.cs index 97a6f1d..096fa8f 100644 --- a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.cs +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.Initialize.cs @@ -1,8 +1,5 @@ -using ANcpLua.Roslyn.Utilities.Models; using ANcpLua.Roslyn.Utilities; -using ErrorOr.Analyzers; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; namespace ErrorOr.Generators; @@ -10,6 +7,13 @@ namespace ErrorOr.Generators; /// /// Generator entry point for ErrorOr endpoint mappings. /// Generates MapErrorOrEndpoints() and AddErrorOrEndpoints() fluent configuration extension methods. +/// +/// Pipeline wiring lives here. Sibling partials: +/// +/// Initialize.Attributes.cs — Marker attribute emission via PostInitializationOutput. +/// Initialize.EndpointFlow.cs — Per-attribute endpoint discovery, validation, and descriptor build. +/// +/// /// [Generator(LanguageNames.CSharp)] public sealed partial class ErrorOrEndpointGenerator : IIncrementalGenerator @@ -50,124 +54,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterSourceOutput(emitInput, static (spc, ctx) => EmitMappingsAndRunAnalysis(spc, in ctx)); } - /// - /// Emits the marker attributes that users apply to their endpoint handler methods. - /// This must be registered by this generator (not just OpenApiTransformerGenerator) - /// because ForAttributeWithMetadataName only sees types from PostInitializationOutput - /// within the same generator. - /// - private static void EmitAttributes(IncrementalGeneratorPostInitializationContext context) - { - const string source = """ - // - #nullable enable - - namespace ErrorOr - { - /// - /// Marks a static method as an ErrorOr endpoint with explicit HTTP method and route. - /// Prefer using [Get], [Post], [Put], [Delete], or [Patch] for standard HTTP methods. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] - public sealed class ErrorOrEndpointAttribute : global::System.Attribute - { - public ErrorOrEndpointAttribute(string httpMethod, string route) - { - HttpMethod = httpMethod; - Route = route; - } - public string HttpMethod { get; } - public string Route { get; } - } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] - public sealed class GetAttribute : global::System.Attribute - { - public GetAttribute(string route) => Route = route; - public string Route { get; } - } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] - public sealed class PostAttribute : global::System.Attribute - { - public PostAttribute(string route) => Route = route; - public string Route { get; } - } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] - public sealed class PutAttribute : global::System.Attribute - { - public PutAttribute(string route) => Route = route; - public string Route { get; } - } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] - public sealed class DeleteAttribute : global::System.Attribute - { - public DeleteAttribute(string route) => Route = route; - public string Route { get; } - } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] - public sealed class PatchAttribute : global::System.Attribute - { - public PatchAttribute(string route) => Route = route; - public string Route { get; } - } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] - public sealed class ProducesErrorAttribute : global::System.Attribute - { - public ProducesErrorAttribute(int statusCode, string errorType) - { - StatusCode = statusCode; - ErrorType = errorType; - } - public int StatusCode { get; } - public string ErrorType { get; } - } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = false)] - public sealed class AcceptedResponseAttribute : global::System.Attribute { } - - [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] - public sealed class ReturnsErrorAttribute : global::System.Attribute - { - public ReturnsErrorAttribute(global::ErrorOr.ErrorType errorType, string errorCode) - { - ErrorType = errorType; - ErrorCode = errorCode; - } - public ReturnsErrorAttribute(int statusCode, string errorCode) - { - StatusCode = statusCode; - ErrorCode = errorCode; - ErrorType = null; - } - public global::ErrorOr.ErrorType? ErrorType { get; } - public int? StatusCode { get; } - public string ErrorCode { get; } - } - - /// - /// Marks a class as a route group for versioned API endpoints. - /// All endpoints in the class will be mapped under the specified path prefix - /// using the eShop-style NewVersionedApi() pattern when combined with [ApiVersion]. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class, AllowMultiple = false)] - public sealed class RouteGroupAttribute : global::System.Attribute - { - public RouteGroupAttribute(string path) => Path = path; - public string Path { get; } - public string? ApiName { get; set; } - } - } - """; - - // Use a different file name to avoid conflicts with OpenApiTransformerGenerator - context.AddSource("ErrorOrEndpointAttributes.Mappings.g.cs", source); - } - private static bool ParseGenerateJsonContextOption(AnalyzerConfigOptionsProvider options, CancellationToken _) { options.GlobalOptions.TryGetValue("build_property.ErrorOrGenerateJsonContext", out var value); @@ -180,23 +66,6 @@ private static bool ParsePublishAotOption(AnalyzerConfigOptionsProvider options, return string.Equals(value, "true", StringComparison.OrdinalIgnoreCase); } - private static IncrementalValueProvider> CombineHttpMethodProviders( - IncrementalGeneratorInitializationContext context, - IncrementalValueProvider errorOrContextProvider) - { - var getProvider = CreateEndpointProvider(context, WellKnownTypes.GetAttribute, errorOrContextProvider); - var postProvider = CreateEndpointProvider(context, WellKnownTypes.PostAttribute, errorOrContextProvider); - var putProvider = CreateEndpointProvider(context, WellKnownTypes.PutAttribute, errorOrContextProvider); - var deleteProvider = CreateEndpointProvider(context, WellKnownTypes.DeleteAttribute, errorOrContextProvider); - var patchProvider = CreateEndpointProvider(context, WellKnownTypes.PatchAttribute, errorOrContextProvider); - var baseProvider = - CreateEndpointProvider(context, WellKnownTypes.ErrorOrEndpointAttribute, errorOrContextProvider); - - return IncrementalProviderExtensions.CombineAll( - getProvider, postProvider, putProvider, - deleteProvider, patchProvider, baseProvider); - } - private static void EmitMappingsAndRunAnalysis( SourceProductionContext spc, in EmitContext ctx) @@ -229,285 +98,6 @@ private static void ReportVersioningInconsistencies(SourceProductionContext spc, spc.ReportDiagnostic(diagnostic); } - private static IncrementalValuesProvider CreateEndpointProvider( - IncrementalGeneratorInitializationContext context, - string attributeName, - IncrementalValueProvider errorOrContextProvider) - { - return context.SyntaxProvider - .ForAttributeWithMetadataName( - attributeName, - static (node, _) => node is MethodDeclarationSyntax, - static (ctx, _) => ctx) - .Combine(errorOrContextProvider) - .SelectFlow(static (pair, ct) => - { - var (ctx, errorOrContext) = pair; - return AnalyzeEndpointFlow(ctx, errorOrContext, ct); - }) - .WithTrackingName(TrackingNames.EndpointBindingFlow(attributeName)) - .ReportAndContinue(context) - .SelectMany(static (endpoints, _) => endpoints.AsImmutableArray()); - } - - private static DiagnosticFlow> AnalyzeEndpointFlow( - GeneratorAttributeSyntaxContext ctx, - ErrorOrContext errorOrContext, - CancellationToken ct) - { - if (ctx.TargetSymbol is not IMethodSymbol method || ctx.Attributes.IsDefaultOrEmpty) - return Helpers.EmptyEndpointFlow(); - - var location = method.Locations.FirstOrDefault() ?? Location.None; - - // 1. Validate shape using SemanticGuard + DiagnosticFlow (The Railway Pattern) - var methodAnalysisFlow = SemanticGuard.For(method) - .MustBeStatic(DiagnosticInfo.Create(Descriptors.NonStaticHandler, location, method.Name)) - .ToFlow() - .Then(m => - { - var returnInfo = ExtractErrorOrReturnType(m.ReturnType, errorOrContext); - - // EOE015: Anonymous return type - if (returnInfo.IsAnonymousType) - { - return DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( - DiagnosticInfo.Create(Descriptors.AnonymousReturnTypeNotSupported, location, m.Name)); - } - - // EOE018: Inaccessible return type - if (returnInfo.IsInaccessibleType) - { - return DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( - DiagnosticInfo.Create(Descriptors.InaccessibleTypeNotSupported, location, - returnInfo.InaccessibleTypeName ?? "unknown", - m.Name, - returnInfo.InaccessibleTypeAccessibility ?? "private")); - } - - // EOE019: Type parameter in return type - if (returnInfo.IsTypeParameter) - { - return DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( - DiagnosticInfo.Create(Descriptors.TypeParameterNotSupported, location, - m.Name, - returnInfo.TypeParameterName ?? "T")); - } - - return returnInfo.SuccessTypeFqn is not null - ? DiagnosticFlow.Ok((m, returnInfo)) - : DiagnosticFlow.Fail<(IMethodSymbol, ErrorOrReturnTypeInfo)>( - DiagnosticInfo.Create(Descriptors.InvalidReturnType, location, m.Name)); - }) - .Then(pair => - { - var (m, returnInfo) = pair; - var builder = ImmutableArray.CreateBuilder(); - - // EOE033: Validate PascalCase naming convention - if (NamingValidator.ValidatePascalCase(m.Name, location) is { } namingDiagnostic) - builder.Add(namingDiagnostic); - - // Extract method-level attributes first (needed for interface call detection) - var producesErrors = ExtractProducesErrorAttributes(m, errorOrContext); - var isAcceptedResponse = HasAcceptedResponseAttribute(m, errorOrContext); - var hasExplicitProducesError = !producesErrors.IsDefaultOrEmpty; - - // Extract middleware attributes (BCL: Authorize, RateLimiting, OutputCache, CORS) - var middleware = ExtractMiddlewareAttributes(m, errorOrContext); - - // Infer errors once per method (now with interface call detection) - var (inferredErrors, customErrors) = - InferErrorTypesFromMethod(ctx, m, errorOrContext, builder, hasExplicitProducesError); - - var analysis = new MethodAnalysis( - returnInfo, - inferredErrors, - customErrors, - producesErrors, - isAcceptedResponse, - middleware); - - var flow = DiagnosticFlow.Ok(analysis); - foreach (var diag in builder) - flow = flow.Warn(diag); - - return flow; - }); - - // 2. Map method analysis to individual attribute descriptors - var flows = ImmutableArray.CreateBuilder>(ctx.Attributes.Length); - foreach (var attr in ctx.Attributes) - { - if (attr is null) continue; - - var flow = methodAnalysisFlow.Then(analysis => - ProcessAttributeFlow(method, in analysis, attr, errorOrContext, ct)); - flows.Add(flow); - } - - if (flows.Count is 0) return Helpers.EmptyEndpointFlow(); - - return DiagnosticFlow.Collect(flows.ToImmutable()) - .Select(static endpoints => new EquatableArray(endpoints)); - } - - private static DiagnosticFlow ProcessAttributeFlow( - IMethodSymbol method, - in MethodAnalysis analysis, - AttributeData attr, - ErrorOrContext errorOrContext, - CancellationToken ct) - { - ct.ThrowIfCancellationRequested(); - - if (attr.AttributeClass is not { } attrClass) return DiagnosticFlow.Fail(); - - var attrName = attrClass.Name; - - var (verb, pattern, customMethod) = ExtractHttpMethodAndPattern(attr, attrName); - if (verb is null) return DiagnosticFlow.Fail(); - - // Guard: SuccessTypeFqn validated in upstream .Then() but compiler doesn't know - if (analysis.ReturnInfo.SuccessTypeFqn is not { } successTypeFqn) - return DiagnosticFlow.Fail(); - - var builder = ImmutableArray.CreateBuilder(); - - // Extract route parameters as HashSet for binding - var routeParamInfos = RouteValidator.ExtractRouteParameters(pattern); - var routeParamNames = routeParamInfos - .Select(static r => r.Name) - .ToImmutableHashSet(StringComparer.OrdinalIgnoreCase); - - var bindingFlow = RouteBindingHelper.BindRouteParameters( - method, - routeParamNames, - errorOrContext, - verb.Value); - if (!bindingFlow.IsSuccess) return DiagnosticFlow.Fail(bindingFlow.Diagnostics); - - builder.AddRange(bindingFlow.Diagnostics.AsImmutableArray()); - var bindingAnalysis = bindingFlow.ValueOrDefault(); - - // Validate route pattern - builder.AddRange(RouteValidator.ValidatePattern(pattern, method)); - - // Extract method parameter info for route binding validation - var methodParams = bindingAnalysis.RouteParameters.AsImmutableArray(); - - // Validate route parameters are bound - builder.AddRange(RouteValidator.ValidateParameterBindings( - pattern, routeParamInfos, methodParams, method)); - - // Validate route constraint types - builder.AddRange(RouteValidator.ValidateConstraintTypes( - routeParamInfos, methodParams, method)); - - // Extract API versioning attributes - var versioning = ExtractVersioningAttributes(method, errorOrContext); - - // Validate API versioning configuration (EOE027-EOE031) - var rawClassVersions = ExtractRawClassVersionStrings(method, errorOrContext); - var rawMethodVersions = ExtractRawMethodVersionStrings(method, errorOrContext); - var location = method.Locations.FirstOrDefault() ?? Location.None; - builder.AddRange(ApiVersioningValidator.Validate( - method.Name, -in versioning, - rawClassVersions, - rawMethodVersions, - location, - errorOrContext.HasApiVersioningSupport, - method)); - - // Extract route group configuration for eShop-style grouping - var routeGroup = ExtractRouteGroupInfo(method, errorOrContext); - - // Extract custom endpoint metadata - var metadata = ExtractMetadata(method); - - var descriptor = new EndpointDescriptor( - verb.Value, - pattern, - successTypeFqn, - analysis.ReturnInfo.Kind, - analysis.ReturnInfo.IsAsync, - method.ContainingType?.GetFullyQualifiedName() ?? "Unknown", - method.Name, - bindingAnalysis.Parameters, - new ErrorInferenceInfo( - analysis.InferredErrorTypeNames, - analysis.InferredCustomErrors, - analysis.ProducesErrors), - new SseInfo( - analysis.ReturnInfo.IsSse, - analysis.ReturnInfo.SseItemTypeFqn), - analysis.IsAcceptedResponse, - analysis.ReturnInfo.IdPropertyName, - analysis.Middleware, - versioning, - routeGroup, - metadata, - customMethod); - - var flow = DiagnosticFlow.Ok(descriptor); - foreach (var diag in builder) - flow = flow.Warn(diag); - - return flow; - } - - private static (HttpVerb? Verb, string Pattern, string? CustomMethod) ExtractHttpMethodAndPattern( - AttributeData attr, - string attrName) - { - var verb = HttpVerbExtensions.TryParseFromAttribute(attrName, attr.ConstructorArguments); - - // For ErrorOrEndpointAttribute with unrecognized methods (e.g., "CONNECT", "PROPFIND"), - // store the raw method string so we can emit MapMethods with it - string? customMethod = null; - var isErrorOrEndpoint = attrName.Contains("ErrorOrEndpoint"); - if (verb is null && isErrorOrEndpoint && - attr.ConstructorArguments is [{ Value: string rawMethod }, ..]) - { - customMethod = rawMethod.ToUpperInvariant(); - verb = HttpVerb.Get; // placeholder — MapMethods is used when CustomHttpMethod is set - } - - if (verb is null) return (null, "/", null); - - // Extract pattern - index differs for ErrorOrEndpoint (has httpMethod arg first) - var patternIndex = isErrorOrEndpoint ? 1 : 0; - var pattern = attr.GetConstructorArgument(patternIndex) is { } p - ? p - : "/"; - - return (verb, pattern, customMethod); - } - - /// - /// Incremental pipeline tracking names for caching diagnostics. - /// - private static class TrackingNames - { - public const string ResultsUnionMaxArity = "ResultsUnionMaxArity"; - public const string ErrorOrContext = "ErrorOrContext"; - - public static string EndpointBindingFlow(string attributeName) - { - return attributeName switch - { - WellKnownTypes.GetAttribute => "EndpointBindingFlow.Get", - WellKnownTypes.PostAttribute => "EndpointBindingFlow.Post", - WellKnownTypes.PutAttribute => "EndpointBindingFlow.Put", - WellKnownTypes.DeleteAttribute => "EndpointBindingFlow.Delete", - WellKnownTypes.PatchAttribute => "EndpointBindingFlow.Patch", - WellKnownTypes.ErrorOrEndpointAttribute => "EndpointBindingFlow.Custom", - _ => "EndpointBindingFlow.Unknown" - }; - } - } - /// /// Flattened context for the combined Roslyn pipeline inputs to . /// diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.Classifiers.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.Classifiers.cs new file mode 100644 index 0000000..1e737da --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.Classifiers.cs @@ -0,0 +1,295 @@ +using ANcpLua.Roslyn.Utilities; +using ANcpLua.Roslyn.Utilities.Models; +using ErrorOr.Analyzers; +using Microsoft.CodeAnalysis; + +namespace ErrorOr.Generators; + +/// +/// Partial class containing per-binding-source parameter classifiers. +/// Each ClassifyFrom*Parameter validates one attribute family and emits the +/// matching EOE0xx diagnostic on failure. and +/// additionally recurse into constructor parameters +/// to build nested trees. +/// +public sealed partial class ErrorOrEndpointGenerator +{ + /// + /// Classifies [FromRoute] parameter with proper EOE010 diagnostic. + /// + private static ParameterClassificationResult ClassifyFromRouteParameter( + in ParameterMeta meta, + ImmutableHashSet routeParameters, + ISymbol method, + ImmutableArray.Builder diagnostics) + { + var hasTryParse = meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat; + + // EOE010: [FromRoute] requires primitive or TryParse + if (meta.RouteKind is null && !hasTryParse) + { + diagnostics.Add(DiagnosticInfo.Create( + Descriptors.InvalidFromRouteType, + method.Locations.FirstOrDefault() ?? Location.None, + meta.Name, + meta.TypeFqn)); + return ParameterClassificationResult.Error; + } + + return ParameterSuccess(in meta, ParameterSource.Route, meta.BoundName, + customBinding: meta.CustomBinding); + } + + /// + /// Classifies implicit route parameter with proper EOE010 diagnostic. + /// + private static ParameterClassificationResult ClassifyImplicitRouteParameter( + in ParameterMeta meta, + ISymbol method, + ImmutableArray.Builder diagnostics) + { + var hasTryParse = meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat; + + // EOE010: Route parameters must use supported primitive types or TryParse + if (meta.RouteKind is null && !hasTryParse) + { + diagnostics.Add(DiagnosticInfo.Create( + Descriptors.InvalidFromRouteType, + method.Locations.FirstOrDefault() ?? Location.None, + meta.Name, + meta.TypeFqn)); + return ParameterClassificationResult.Error; + } + + return ParameterSuccess(in meta, ParameterSource.Route, meta.Name, customBinding: meta.CustomBinding); + } + + /// + /// Classifies [FromQuery] parameter with proper EOE011 diagnostic. + /// + private static ParameterClassificationResult ClassifyFromQueryParameter( + in ParameterMeta meta, + ISymbol method, + ImmutableArray.Builder diagnostics) + { + // Valid: primitive type + if (meta.RouteKind is not null) + return ParameterSuccess(in meta, ParameterSource.Query, queryName: meta.BoundName); + + // Valid: collection of primitives + if (meta is { IsCollection: true, CollectionItemPrimitiveKind: not null }) + return ParameterSuccess(in meta, ParameterSource.Query, queryName: meta.BoundName); + + // Valid: has TryParse + if (meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat) + { + return ParameterSuccess(in meta, ParameterSource.Query, queryName: meta.BoundName, + customBinding: meta.CustomBinding); + } + + // EOE011: [FromQuery] only supports primitives or collections of primitives + diagnostics.Add(DiagnosticInfo.Create( + Descriptors.InvalidFromQueryType, + method.Locations.FirstOrDefault() ?? Location.None, + meta.Name, + meta.TypeFqn)); + return ParameterClassificationResult.Error; + } + + /// + /// Classifies [FromHeader] parameter with proper EOE014 diagnostic. + /// + private static ParameterClassificationResult ClassifyFromHeaderParameter( + in ParameterMeta meta, + ISymbol method, + ImmutableArray.Builder diagnostics) + { + // Valid: primitive type (has implicit TryParse) + if (meta.RouteKind is not null) + return ParameterSuccess(in meta, ParameterSource.Header, headerName: meta.BoundName); + + // Valid: collection of strings or primitives + if (meta is { IsCollection: true, CollectionItemPrimitiveKind: not null }) + return ParameterSuccess(in meta, ParameterSource.Header, headerName: meta.BoundName); + + // Valid: has TryParse + if (meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat) + { + return ParameterSuccess(in meta, ParameterSource.Header, headerName: meta.BoundName, + customBinding: meta.CustomBinding); + } + + // EOE014: [FromHeader] requires string, primitive with TryParse, or collection thereof + diagnostics.Add(DiagnosticInfo.Create( + Descriptors.InvalidFromHeaderType, + method.Locations.FirstOrDefault() ?? Location.None, + meta.Name, + meta.TypeFqn)); + return ParameterClassificationResult.Error; + } + + private static ParameterClassificationResult ClassifyFromFormParameter( + in ParameterMeta meta, + ITypeSymbol type, + ErrorOrContext context) + { + if (meta.IsFormFile) return ParameterSuccess(in meta, ParameterSource.FormFile, formName: meta.BoundName); + + if (meta.IsFormFileCollection) + return ParameterSuccess(in meta, ParameterSource.FormFiles, formName: meta.BoundName); + + if (meta.IsFormCollection) + return ParameterSuccess(in meta, ParameterSource.FormCollection, formName: meta.BoundName); + + if (meta.RouteKind is not null || meta is { IsCollection: true, CollectionItemPrimitiveKind: not null }) + return ParameterSuccess(in meta, ParameterSource.Form, formName: meta.BoundName); + + // Complex DTO - let BCL handle form binding + return ClassifyFormDtoParameter(in meta, type, context); + } + + private static ParameterClassificationResult ClassifyFormDtoParameter( + in ParameterMeta meta, + ITypeSymbol type, + ErrorOrContext context) + { + // For complex form DTOs, analyze the constructor to build child parameter info + // BCL handles actual binding - we just need structure for code generation + if (type is not INamedTypeSymbol typeSymbol) + { + // Non-named types get simple form binding - BCL will handle/fail at runtime + return ParameterSuccess(in meta, ParameterSource.Form, formName: meta.BoundName); + } + + var constructor = typeSymbol.Constructors + .Where(static c => c.DeclaredAccessibility == Accessibility.Public && !c.IsStatic) + .OrderByDescending(static c => c.Parameters.Length) + .FirstOrDefault(); + + if (constructor is null || constructor.Parameters.Length is 0) + { + // No suitable constructor - simple form binding + return ParameterSuccess(in meta, ParameterSource.Form, formName: meta.BoundName); + } + + // Build child parameters for DTO constructor + var children = ImmutableArray.CreateBuilder(constructor.Parameters.Length); + + foreach (var paramSymbol in constructor.Parameters) + { + var childMeta = CreateParameterMeta(paramSymbol, context); + + ParameterSource childSource; + if (childMeta.IsFormFile) + childSource = ParameterSource.FormFile; + else if (childMeta.IsFormFileCollection) + childSource = ParameterSource.FormFiles; + else + childSource = ParameterSource.Form; + + children.Add(new EndpointParameter( + childMeta.Name, + childMeta.TypeFqn, + childSource, + childMeta.BoundName, + childMeta.IsNullable, + childMeta.IsNonNullableValueType, + childMeta.IsCollection, + childMeta.CollectionItemTypeFqn, + default)); + } + + return new ParameterClassificationResult(IsError: false, new EndpointParameter( + meta.Name, + meta.TypeFqn, + ParameterSource.Form, + meta.BoundName, + meta.IsNullable, + meta.IsNonNullableValueType, +IsCollection: false, +CollectionItemTypeFqn: null, + new EquatableArray(children.ToImmutable()), + CustomBindingMethod.None, + meta.RequiresValidation, + ValidatableProperties: meta.ValidatableProperties)); + } + + /// + /// Classifies [AsParameters] with proper EOE012/EOE013/EOE016/EOE017 diagnostics. + /// + private static ParameterClassificationResult ClassifyAsParameters( + in ParameterMeta meta, + ITypeSymbol type, + ImmutableHashSet routeParameters, + ISymbol method, + ImmutableArray.Builder diagnostics, + ErrorOrContext context, + HttpVerb httpVerb) + { + // EOE017: [AsParameters] cannot be nullable + if (meta.IsNullable) + { + diagnostics.Add(DiagnosticInfo.Create(Descriptors.NullableAsParametersNotSupported, + method.Locations.FirstOrDefault() ?? Location.None, meta.Name)); + return ParameterClassificationResult.Error; + } + + // EOE012: [AsParameters] can only be used on class or struct types + if (type is not INamedTypeSymbol typeSymbol) + { + diagnostics.Add(DiagnosticInfo.Create(Descriptors.InvalidAsParametersType, method, meta.Name, + meta.TypeFqn)); + return ParameterClassificationResult.Error; + } + + var constructor = typeSymbol.Constructors + .Where(static c => c.DeclaredAccessibility == Accessibility.Public && !c.IsStatic) + .OrderByDescending(static c => c.Parameters.Length) + .FirstOrDefault(); + + // EOE013: [AsParameters] type must have an accessible constructor + if (constructor is null) + { + diagnostics.Add(DiagnosticInfo.Create(Descriptors.AsParametersNoConstructor, method, + typeSymbol.ToDisplayString())); + return ParameterClassificationResult.Error; + } + + var children = ImmutableArray.CreateBuilder(); + foreach (var paramSymbol in constructor.Parameters) + { + var childMeta = CreateParameterMeta(paramSymbol, context); + + // EOE016: Nested [AsParameters] not supported + if (childMeta.HasAsParameters) + { + diagnostics.Add(DiagnosticInfo.Create(Descriptors.NestedAsParametersNotSupported, + method.Locations.FirstOrDefault() ?? Location.None, + typeSymbol.ToDisplayString(), + paramSymbol.Name)); + return ParameterClassificationResult.Error; + } + + var result = ClassifyParameter(in childMeta, paramSymbol, routeParameters, method, diagnostics, context, + httpVerb); + + if (result.IsError) return ParameterClassificationResult.Error; + + children.Add(result.Parameter); + } + + return new ParameterClassificationResult(IsError: false, new EndpointParameter( + meta.Name, + meta.TypeFqn, + ParameterSource.AsParameters, +KeyName: null, + meta.IsNullable, + meta.IsNonNullableValueType, +IsCollection: false, +CollectionItemTypeFqn: null, + new EquatableArray(children.ToImmutable()), + CustomBindingMethod.None, + meta.RequiresValidation, + ValidatableProperties: meta.ValidatableProperties)); + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.Meta.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.Meta.cs new file mode 100644 index 0000000..f913bfa --- /dev/null +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.Meta.cs @@ -0,0 +1,163 @@ +using ANcpLua.Roslyn.Utilities; +using Microsoft.CodeAnalysis; + +namespace ErrorOr.Generators; + +/// +/// Partial class containing symbol-to-meta extraction for parameter binding. +/// Converts into for downstream classification. +/// +public sealed partial class ErrorOrEndpointGenerator +{ + private static ParameterMeta CreateParameterMeta( + IParameterSymbol parameter, + ErrorOrContext context) + { + var type = parameter.Type; + var typeFqn = type.GetFullyQualifiedName(); + + var flags = BuildFlags(parameter, type, context); + var specialKind = DetectSpecialKind(type, context); + + var (isCollection, itemType, itemPrimitiveKind) = AnalyzeCollectionType(type, context); + if (isCollection) flags |= ParameterFlags.Collection; + + // Determine bound name based on explicit attribute or default to parameter name + var boundName = DetermineBoundName(parameter, flags, context); + + var serviceKey = flags.HasFlag(ParameterFlags.FromKeyedServices) + ? ExtractKeyFromKeyedServiceAttribute(parameter) + : null; + + var validatableProperties = flags.HasFlag(ParameterFlags.RequiresValidation) + ? ErrorOrContext.CollectValidatableProperties(type) + : default; + + return new ParameterMeta( + parameter.Name, + typeFqn, + TryGetRoutePrimitiveKind(type, context), + flags, + specialKind, + serviceKey, + boundName, + itemType?.GetFullyQualifiedName(), + itemPrimitiveKind, + DetectCustomBinding(type, context), + DetectEmptyBodyBehavior(parameter), + validatableProperties); + } + + private static ParameterFlags BuildFlags(IParameterSymbol parameter, ITypeSymbol type, ErrorOrContext context) + { + var flags = ParameterFlags.None; + + if (HasParameterAttribute(parameter, WellKnownTypes.FromBodyAttribute)) + flags |= ParameterFlags.FromBody; + + if (HasParameterAttribute(parameter, WellKnownTypes.FromRouteAttribute)) + flags |= ParameterFlags.FromRoute; + + if (HasParameterAttribute(parameter, WellKnownTypes.FromQueryAttribute)) + flags |= ParameterFlags.FromQuery; + + if (HasParameterAttribute(parameter, WellKnownTypes.FromHeaderAttribute)) + flags |= ParameterFlags.FromHeader; + + if (HasParameterAttribute(parameter, WellKnownTypes.FromFormAttribute)) + flags |= ParameterFlags.FromForm; + + if (HasParameterAttribute(parameter, WellKnownTypes.FromServicesAttribute)) + flags |= ParameterFlags.FromServices; + + if (HasParameterAttribute(parameter, WellKnownTypes.FromKeyedServicesAttribute)) + flags |= ParameterFlags.FromKeyedServices; + + if (HasParameterAttribute(parameter, WellKnownTypes.AsParametersAttribute)) + flags |= ParameterFlags.AsParameters; + + var (isNullable, isNonNullableValueType) = GetParameterNullability(type, parameter.NullableAnnotation); + if (isNullable) flags |= ParameterFlags.Nullable; + + if (isNonNullableValueType) flags |= ParameterFlags.NonNullableValueType; + + if (ErrorOrContext.RequiresValidation(type)) flags |= ParameterFlags.RequiresValidation; + + return flags; + } + + private static SpecialParameterKind DetectSpecialKind(ITypeSymbol type, ErrorOrContext context) + { + if (ErrorOrContext.IsHttpContext(type)) return SpecialParameterKind.HttpContext; + + if (ErrorOrContext.IsCancellationToken(type)) return SpecialParameterKind.CancellationToken; + + if (ErrorOrContext.IsFormFile(type)) return SpecialParameterKind.FormFile; + + if (ErrorOrContext.IsFormFileCollection(type)) return SpecialParameterKind.FormFileCollection; + + if (ErrorOrContext.IsFormCollection(type)) return SpecialParameterKind.FormCollection; + + if (ErrorOrContext.IsStream(type)) return SpecialParameterKind.Stream; + + return ErrorOrContext.IsPipeReader(type) ? SpecialParameterKind.PipeReader : SpecialParameterKind.None; + } + + private static string DetermineBoundName(ISymbol parameter, ParameterFlags flags, ErrorOrContext context) + { + // Try to get explicit name from binding attribute + if (flags.HasFlag(ParameterFlags.FromRoute)) + { + return TryGetAttributeName(parameter, context, WellKnownTypes.FromRouteAttribute) ?? + parameter.Name; + } + + if (flags.HasFlag(ParameterFlags.FromQuery)) + { + return TryGetAttributeName(parameter, context, WellKnownTypes.FromQueryAttribute) ?? + parameter.Name; + } + + if (flags.HasFlag(ParameterFlags.FromHeader)) + { + return TryGetAttributeName(parameter, context, WellKnownTypes.FromHeaderAttribute) ?? + parameter.Name; + } + + if (flags.HasFlag(ParameterFlags.FromForm)) + return TryGetAttributeName(parameter, context, WellKnownTypes.FromFormAttribute) ?? parameter.Name; + + return parameter.Name; + } + + private readonly struct AttributeNameMatcher + { + private readonly string _fullName; + private readonly string _shortName; + private readonly string _shortNameWithoutAttr; + + public AttributeNameMatcher(string fullName) + { + _fullName = fullName; + var lastDot = fullName.LastIndexOf('.'); + _shortName = lastDot >= 0 ? fullName[(lastDot + 1)..] : fullName; + _shortNameWithoutAttr = + _shortName.EndsWithOrdinal("Attribute") ? _shortName[..^"Attribute".Length] : _shortName; + } + + public bool IsMatch(ISymbol? attributeClass) + { + if (attributeClass is not ITypeSymbol typeSymbol) return false; + + var display = typeSymbol.GetFullyQualifiedName(); + + if (display.StartsWithOrdinal("global::")) display = display[8..]; + + // Strict match: Must match FQN or ShortName (if FQN not available/provided) + // We drop loose EndsWith matching to avoid collisions + return display == _fullName || + display == _shortName || + display == _shortNameWithoutAttr; + } + } +} diff --git a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.cs b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.cs index f2a95b8..4baf546 100644 --- a/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.cs +++ b/src/ErrorOrX.Generators/Core/ErrorOrEndpointGenerator.ParameterBinding.cs @@ -6,8 +6,10 @@ namespace ErrorOr.Generators; /// -/// Partial class containing parameter binding logic. -/// Includes diagnostic wiring for invalid body, route, query, header, form, and AsParameters bindings. +/// Partial class containing parameter binding entry point and classification dispatcher. +/// Per-source classifiers live in ErrorOrEndpointGenerator.ParameterBinding.Classifiers.cs. +/// Symbol-to-meta extraction lives in ErrorOrEndpointGenerator.ParameterBinding.Meta.cs. +/// Type-shape helpers live in ErrorOrEndpointGenerator.ParameterBinding.TypeAnalysis.cs. /// public sealed partial class ErrorOrEndpointGenerator { @@ -49,127 +51,6 @@ private static ParameterMeta[] BuildParameterMetas( return metas; } - private static ParameterMeta CreateParameterMeta( - IParameterSymbol parameter, - ErrorOrContext context) - { - var type = parameter.Type; - var typeFqn = type.GetFullyQualifiedName(); - - var flags = BuildFlags(parameter, type, context); - var specialKind = DetectSpecialKind(type, context); - - var (isCollection, itemType, itemPrimitiveKind) = AnalyzeCollectionType(type, context); - if (isCollection) flags |= ParameterFlags.Collection; - - // Determine bound name based on explicit attribute or default to parameter name - var boundName = DetermineBoundName(parameter, flags, context); - - var serviceKey = flags.HasFlag(ParameterFlags.FromKeyedServices) - ? ExtractKeyFromKeyedServiceAttribute(parameter) - : null; - - var validatableProperties = flags.HasFlag(ParameterFlags.RequiresValidation) - ? ErrorOrContext.CollectValidatableProperties(type) - : default; - - return new ParameterMeta( - parameter.Name, - typeFqn, - TryGetRoutePrimitiveKind(type, context), - flags, - specialKind, - serviceKey, - boundName, - itemType?.GetFullyQualifiedName(), - itemPrimitiveKind, - DetectCustomBinding(type, context), - DetectEmptyBodyBehavior(parameter), - validatableProperties); - } - - private static ParameterFlags BuildFlags(IParameterSymbol parameter, ITypeSymbol type, ErrorOrContext context) - { - var flags = ParameterFlags.None; - - if (HasParameterAttribute(parameter, WellKnownTypes.FromBodyAttribute)) - flags |= ParameterFlags.FromBody; - - if (HasParameterAttribute(parameter, WellKnownTypes.FromRouteAttribute)) - flags |= ParameterFlags.FromRoute; - - if (HasParameterAttribute(parameter, WellKnownTypes.FromQueryAttribute)) - flags |= ParameterFlags.FromQuery; - - if (HasParameterAttribute(parameter, WellKnownTypes.FromHeaderAttribute)) - flags |= ParameterFlags.FromHeader; - - if (HasParameterAttribute(parameter, WellKnownTypes.FromFormAttribute)) - flags |= ParameterFlags.FromForm; - - if (HasParameterAttribute(parameter, WellKnownTypes.FromServicesAttribute)) - flags |= ParameterFlags.FromServices; - - if (HasParameterAttribute(parameter, WellKnownTypes.FromKeyedServicesAttribute)) - flags |= ParameterFlags.FromKeyedServices; - - if (HasParameterAttribute(parameter, WellKnownTypes.AsParametersAttribute)) - flags |= ParameterFlags.AsParameters; - - var (isNullable, isNonNullableValueType) = GetParameterNullability(type, parameter.NullableAnnotation); - if (isNullable) flags |= ParameterFlags.Nullable; - - if (isNonNullableValueType) flags |= ParameterFlags.NonNullableValueType; - - if (ErrorOrContext.RequiresValidation(type)) flags |= ParameterFlags.RequiresValidation; - - return flags; - } - - private static SpecialParameterKind DetectSpecialKind(ITypeSymbol type, ErrorOrContext context) - { - if (ErrorOrContext.IsHttpContext(type)) return SpecialParameterKind.HttpContext; - - if (ErrorOrContext.IsCancellationToken(type)) return SpecialParameterKind.CancellationToken; - - if (ErrorOrContext.IsFormFile(type)) return SpecialParameterKind.FormFile; - - if (ErrorOrContext.IsFormFileCollection(type)) return SpecialParameterKind.FormFileCollection; - - if (ErrorOrContext.IsFormCollection(type)) return SpecialParameterKind.FormCollection; - - if (ErrorOrContext.IsStream(type)) return SpecialParameterKind.Stream; - - return ErrorOrContext.IsPipeReader(type) ? SpecialParameterKind.PipeReader : SpecialParameterKind.None; - } - - private static string DetermineBoundName(ISymbol parameter, ParameterFlags flags, ErrorOrContext context) - { - // Try to get explicit name from binding attribute - if (flags.HasFlag(ParameterFlags.FromRoute)) - { - return TryGetAttributeName(parameter, context, WellKnownTypes.FromRouteAttribute) ?? - parameter.Name; - } - - if (flags.HasFlag(ParameterFlags.FromQuery)) - { - return TryGetAttributeName(parameter, context, WellKnownTypes.FromQueryAttribute) ?? - parameter.Name; - } - - if (flags.HasFlag(ParameterFlags.FromHeader)) - { - return TryGetAttributeName(parameter, context, WellKnownTypes.FromHeaderAttribute) ?? - parameter.Name; - } - - if (flags.HasFlag(ParameterFlags.FromForm)) - return TryGetAttributeName(parameter, context, WellKnownTypes.FromFormAttribute) ?? parameter.Name; - - return parameter.Name; - } - private static ParameterBindingResult BuildEndpointParameters( ParameterMeta[] metas, ImmutableArray parameters, @@ -327,285 +208,6 @@ private static ParameterClassificationResult InferParameterSource( return ParameterSuccess(in meta, ParameterSource.Service); } - /// - /// Classifies [FromRoute] parameter with proper EOE010 diagnostic. - /// - private static ParameterClassificationResult ClassifyFromRouteParameter( - in ParameterMeta meta, - ImmutableHashSet routeParameters, - ISymbol method, - ImmutableArray.Builder diagnostics) - { - var hasTryParse = meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat; - - // EOE010: [FromRoute] requires primitive or TryParse - if (meta.RouteKind is null && !hasTryParse) - { - diagnostics.Add(DiagnosticInfo.Create( - Descriptors.InvalidFromRouteType, - method.Locations.FirstOrDefault() ?? Location.None, - meta.Name, - meta.TypeFqn)); - return ParameterClassificationResult.Error; - } - - return ParameterSuccess(in meta, ParameterSource.Route, meta.BoundName, - customBinding: meta.CustomBinding); - } - - /// - /// Classifies implicit route parameter with proper EOE010 diagnostic. - /// - private static ParameterClassificationResult ClassifyImplicitRouteParameter( - in ParameterMeta meta, - ISymbol method, - ImmutableArray.Builder diagnostics) - { - var hasTryParse = meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat; - - // EOE010: Route parameters must use supported primitive types or TryParse - if (meta.RouteKind is null && !hasTryParse) - { - diagnostics.Add(DiagnosticInfo.Create( - Descriptors.InvalidFromRouteType, - method.Locations.FirstOrDefault() ?? Location.None, - meta.Name, - meta.TypeFqn)); - return ParameterClassificationResult.Error; - } - - return ParameterSuccess(in meta, ParameterSource.Route, meta.Name, customBinding: meta.CustomBinding); - } - - /// - /// Classifies [FromQuery] parameter with proper EOE011 diagnostic. - /// - private static ParameterClassificationResult ClassifyFromQueryParameter( - in ParameterMeta meta, - ISymbol method, - ImmutableArray.Builder diagnostics) - { - // Valid: primitive type - if (meta.RouteKind is not null) - return ParameterSuccess(in meta, ParameterSource.Query, queryName: meta.BoundName); - - // Valid: collection of primitives - if (meta is { IsCollection: true, CollectionItemPrimitiveKind: not null }) - return ParameterSuccess(in meta, ParameterSource.Query, queryName: meta.BoundName); - - // Valid: has TryParse - if (meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat) - { - return ParameterSuccess(in meta, ParameterSource.Query, queryName: meta.BoundName, - customBinding: meta.CustomBinding); - } - - // EOE011: [FromQuery] only supports primitives or collections of primitives - diagnostics.Add(DiagnosticInfo.Create( - Descriptors.InvalidFromQueryType, - method.Locations.FirstOrDefault() ?? Location.None, - meta.Name, - meta.TypeFqn)); - return ParameterClassificationResult.Error; - } - - /// - /// Classifies [FromHeader] parameter with proper EOE014 diagnostic. - /// - private static ParameterClassificationResult ClassifyFromHeaderParameter( - in ParameterMeta meta, - ISymbol method, - ImmutableArray.Builder diagnostics) - { - // Valid: primitive type (has implicit TryParse) - if (meta.RouteKind is not null) - return ParameterSuccess(in meta, ParameterSource.Header, headerName: meta.BoundName); - - // Valid: collection of strings or primitives - if (meta is { IsCollection: true, CollectionItemPrimitiveKind: not null }) - return ParameterSuccess(in meta, ParameterSource.Header, headerName: meta.BoundName); - - // Valid: has TryParse - if (meta.CustomBinding is CustomBindingMethod.TryParse or CustomBindingMethod.TryParseWithFormat) - { - return ParameterSuccess(in meta, ParameterSource.Header, headerName: meta.BoundName, - customBinding: meta.CustomBinding); - } - - // EOE014: [FromHeader] requires string, primitive with TryParse, or collection thereof - diagnostics.Add(DiagnosticInfo.Create( - Descriptors.InvalidFromHeaderType, - method.Locations.FirstOrDefault() ?? Location.None, - meta.Name, - meta.TypeFqn)); - return ParameterClassificationResult.Error; - } - - private static ParameterClassificationResult ClassifyFromFormParameter( - in ParameterMeta meta, - ITypeSymbol type, - ErrorOrContext context) - { - if (meta.IsFormFile) return ParameterSuccess(in meta, ParameterSource.FormFile, formName: meta.BoundName); - - if (meta.IsFormFileCollection) - return ParameterSuccess(in meta, ParameterSource.FormFiles, formName: meta.BoundName); - - if (meta.IsFormCollection) - return ParameterSuccess(in meta, ParameterSource.FormCollection, formName: meta.BoundName); - - if (meta.RouteKind is not null || meta is { IsCollection: true, CollectionItemPrimitiveKind: not null }) - return ParameterSuccess(in meta, ParameterSource.Form, formName: meta.BoundName); - - // Complex DTO - let BCL handle form binding - return ClassifyFormDtoParameter(in meta, type, context); - } - - private static ParameterClassificationResult ClassifyFormDtoParameter( - in ParameterMeta meta, - ITypeSymbol type, - ErrorOrContext context) - { - // For complex form DTOs, analyze the constructor to build child parameter info - // BCL handles actual binding - we just need structure for code generation - if (type is not INamedTypeSymbol typeSymbol) - { - // Non-named types get simple form binding - BCL will handle/fail at runtime - return ParameterSuccess(in meta, ParameterSource.Form, formName: meta.BoundName); - } - - var constructor = typeSymbol.Constructors - .Where(static c => c.DeclaredAccessibility == Accessibility.Public && !c.IsStatic) - .OrderByDescending(static c => c.Parameters.Length) - .FirstOrDefault(); - - if (constructor is null || constructor.Parameters.Length is 0) - { - // No suitable constructor - simple form binding - return ParameterSuccess(in meta, ParameterSource.Form, formName: meta.BoundName); - } - - // Build child parameters for DTO constructor - var children = ImmutableArray.CreateBuilder(constructor.Parameters.Length); - - foreach (var paramSymbol in constructor.Parameters) - { - var childMeta = CreateParameterMeta(paramSymbol, context); - - ParameterSource childSource; - if (childMeta.IsFormFile) - childSource = ParameterSource.FormFile; - else if (childMeta.IsFormFileCollection) - childSource = ParameterSource.FormFiles; - else - childSource = ParameterSource.Form; - - children.Add(new EndpointParameter( - childMeta.Name, - childMeta.TypeFqn, - childSource, - childMeta.BoundName, - childMeta.IsNullable, - childMeta.IsNonNullableValueType, - childMeta.IsCollection, - childMeta.CollectionItemTypeFqn, - default)); - } - - return new ParameterClassificationResult(IsError: false, new EndpointParameter( - meta.Name, - meta.TypeFqn, - ParameterSource.Form, - meta.BoundName, - meta.IsNullable, - meta.IsNonNullableValueType, -IsCollection: false, -CollectionItemTypeFqn: null, - new EquatableArray(children.ToImmutable()), - CustomBindingMethod.None, - meta.RequiresValidation, - ValidatableProperties: meta.ValidatableProperties)); - } - - /// - /// Classifies [AsParameters] with proper EOE012/EOE013/EOE016/EOE017 diagnostics. - /// - private static ParameterClassificationResult ClassifyAsParameters( - in ParameterMeta meta, - ITypeSymbol type, - ImmutableHashSet routeParameters, - ISymbol method, - ImmutableArray.Builder diagnostics, - ErrorOrContext context, - HttpVerb httpVerb) - { - // EOE017: [AsParameters] cannot be nullable - if (meta.IsNullable) - { - diagnostics.Add(DiagnosticInfo.Create(Descriptors.NullableAsParametersNotSupported, - method.Locations.FirstOrDefault() ?? Location.None, meta.Name)); - return ParameterClassificationResult.Error; - } - - // EOE012: [AsParameters] can only be used on class or struct types - if (type is not INamedTypeSymbol typeSymbol) - { - diagnostics.Add(DiagnosticInfo.Create(Descriptors.InvalidAsParametersType, method, meta.Name, - meta.TypeFqn)); - return ParameterClassificationResult.Error; - } - - var constructor = typeSymbol.Constructors - .Where(static c => c.DeclaredAccessibility == Accessibility.Public && !c.IsStatic) - .OrderByDescending(static c => c.Parameters.Length) - .FirstOrDefault(); - - // EOE013: [AsParameters] type must have an accessible constructor - if (constructor is null) - { - diagnostics.Add(DiagnosticInfo.Create(Descriptors.AsParametersNoConstructor, method, - typeSymbol.ToDisplayString())); - return ParameterClassificationResult.Error; - } - - var children = ImmutableArray.CreateBuilder(); - foreach (var paramSymbol in constructor.Parameters) - { - var childMeta = CreateParameterMeta(paramSymbol, context); - - // EOE016: Nested [AsParameters] not supported - if (childMeta.HasAsParameters) - { - diagnostics.Add(DiagnosticInfo.Create(Descriptors.NestedAsParametersNotSupported, - method.Locations.FirstOrDefault() ?? Location.None, - typeSymbol.ToDisplayString(), - paramSymbol.Name)); - return ParameterClassificationResult.Error; - } - - var result = ClassifyParameter(in childMeta, paramSymbol, routeParameters, method, diagnostics, context, - httpVerb); - - if (result.IsError) return ParameterClassificationResult.Error; - - children.Add(result.Parameter); - } - - return new ParameterClassificationResult(IsError: false, new EndpointParameter( - meta.Name, - meta.TypeFqn, - ParameterSource.AsParameters, -KeyName: null, - meta.IsNullable, - meta.IsNonNullableValueType, -IsCollection: false, -CollectionItemTypeFqn: null, - new EquatableArray(children.ToImmutable()), - CustomBindingMethod.None, - meta.RequiresValidation, - ValidatableProperties: meta.ValidatableProperties)); - } - private static ParameterClassificationResult ParameterSuccess( in ParameterMeta meta, ParameterSource source, @@ -638,35 +240,4 @@ private readonly record struct ParameterClassificationResult(bool IsError, Endpo { public static readonly ParameterClassificationResult Error = new(IsError: true, default); } - - private readonly struct AttributeNameMatcher - { - private readonly string _fullName; - private readonly string _shortName; - private readonly string _shortNameWithoutAttr; - - public AttributeNameMatcher(string fullName) - { - _fullName = fullName; - var lastDot = fullName.LastIndexOf('.'); - _shortName = lastDot >= 0 ? fullName[(lastDot + 1)..] : fullName; - _shortNameWithoutAttr = - _shortName.EndsWithOrdinal("Attribute") ? _shortName[..^"Attribute".Length] : _shortName; - } - - public bool IsMatch(ISymbol? attributeClass) - { - if (attributeClass is not ITypeSymbol typeSymbol) return false; - - var display = typeSymbol.GetFullyQualifiedName(); - - if (display.StartsWithOrdinal("global::")) display = display[8..]; - - // Strict match: Must match FQN or ShortName (if FQN not available/provided) - // We drop loose EndsWith matching to avoid collisions - return display == _fullName || - display == _shortName || - display == _shortNameWithoutAttr; - } - } } diff --git a/src/ErrorOrX.Generators/Core/OpenApiTransformerGenerator.Emitter.cs b/src/ErrorOrX.Generators/Core/OpenApiTransformerGenerator.Emitter.cs new file mode 100644 index 0000000..4bfa78c --- /dev/null +++ b/src/ErrorOrX.Generators/Core/OpenApiTransformerGenerator.Emitter.cs @@ -0,0 +1,367 @@ +using ANcpLua.Roslyn.Utilities; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; + +namespace ErrorOr.Generators; + +/// +/// Source emission for OpenAPI transformers: tag transformers, the operation transformer, +/// the schema transformer, and the AddErrorOrOpenApi registration extension. +/// +public sealed partial class OpenApiTransformerGenerator +{ + private static void Emit( + SourceProductionContext spc, + ImmutableArray endpoints, + ImmutableArray types) + { + if (endpoints.IsDefaultOrEmpty) return; + + var code = new StringBuilder(); + code.AppendLine("// "); + code.AppendLine("#nullable enable"); + code.AppendLine(); + code.AppendLine("using System;"); + code.AppendLine("using System.Collections.Frozen;"); + code.AppendLine("using System.Collections.Generic;"); + code.AppendLine("using System.Threading;"); + code.AppendLine("using System.Threading.Tasks;"); + code.AppendLine("using Microsoft.AspNetCore.OpenApi;"); + code.AppendLine("using Microsoft.AspNetCore.Routing;"); + code.AppendLine("using Microsoft.Extensions.DependencyInjection;"); + code.AppendLine("using Microsoft.OpenApi;"); + code.AppendLine(); + code.AppendLine("namespace ErrorOr.Generated;"); + code.AppendLine(); + + // Collect unique tags (1 attribute → 1 transformer) + var tags = endpoints.Select(static e => e.TagName).Distinct(StringComparer.Ordinal) + .OrderBy(static t => t, StringComparer.Ordinal).ToList(); + + // Emit tag transformers (strict 1:1 - one transformer per unique tag) + foreach (var tag in tags) EmitTagTransformer(code, tag); + + // Emit operation transformer (applies XML doc summaries) + var hasOperationDocs = EmitOperationTransformer(code, endpoints); + + // Emit schema transformer (applies type descriptions) + var hasTypeDocs = false; + if (!types.IsDefaultOrEmpty) hasTypeDocs = EmitSchemaTransformer(code, types); + + // Emit registration extension + EmitRegistrationExtension(code, tags, hasOperationDocs, hasTypeDocs); + + spc.AddSource("OpenApiTransformers.g.cs", SourceText.From(code.ToString(), Encoding.UTF8)); + } + + private static void EmitTagTransformer(StringBuilder code, string tagName) + { + var safeTagName = tagName.SanitizeIdentifier(); + code.AppendLine("/// "); + code.AppendLine($"/// Document transformer for tag: {tagName}"); + code.AppendLine($"/// Generated from: [ErrorOrEndpoint] attribute on *{tagName}Endpoints class"); + code.AppendLine("/// "); + code.AppendLine($"file sealed class Tag_{safeTagName}_Transformer : IOpenApiDocumentTransformer"); + code.AppendLine("{"); + code.AppendLine(" public Task TransformAsync("); + code.AppendLine(" OpenApiDocument document,"); + code.AppendLine(" OpenApiDocumentTransformerContext context,"); + code.AppendLine(" CancellationToken cancellationToken)"); + code.AppendLine(" {"); + // OpenApiDocument.Tags setter auto-wraps with OpenApiTagComparer.Instance + // which handles deduplication by Name - no manual .Any() check needed + code.AppendLine(" document.Tags ??= new HashSet();"); + code.AppendLine($" document.Tags.Add(new OpenApiTag {{ Name = \"{tagName}\" }});"); + code.AppendLine(" return Task.CompletedTask;"); + code.AppendLine(" }"); + code.AppendLine("}"); + code.AppendLine(); + } + + private static bool EmitOperationTransformer(StringBuilder code, ImmutableArray endpoints) + { + // Collect operations with XML docs (summary/description OR parameter docs) + var opsWithDocs = endpoints + .Where(static e => !string.IsNullOrEmpty(e.Summary) || !string.IsNullOrEmpty(e.Description) || + !e.ParameterDocs.IsDefaultOrEmpty) + .OrderBy(static e => e.Pattern, StringComparer.Ordinal) + .ThenBy(static e => e.HttpMethod, StringComparer.Ordinal).ToList(); + + // Collect operations with OpenAPI parameter definitions + var opsWithParams = endpoints + .Where(static e => !e.Parameters.IsDefaultOrEmpty) + .OrderBy(static e => e.OperationId, StringComparer.Ordinal) + .ToList(); + + if (opsWithDocs.Count is 0 && opsWithParams.Count is 0) return false; + + // Collect operations with parameter docs + var opsWithParamDocs = opsWithDocs + .Where(static e => !e.ParameterDocs.IsDefaultOrEmpty) + .ToList(); + + code.AppendLine("/// "); + code.AppendLine( + "/// Operation transformer that applies XML documentation and parameter definitions to operations."); + code.AppendLine("/// Each entry is a strict 1:1 mapping from handler signature to operation metadata."); + code.AppendLine("/// "); + code.AppendLine("file sealed class XmlDocOperationTransformer : IOpenApiOperationTransformer"); + code.AppendLine("{"); + code.AppendLine(" // Pre-computed metadata from XML docs (compile-time extraction)"); + code.AppendLine( + " private static readonly FrozenDictionary OperationDocs ="); + code.AppendLine(" new Dictionary"); + code.AppendLine(" {"); + + foreach (var op in opsWithDocs.Where(static e => + !string.IsNullOrEmpty(e.Summary) || !string.IsNullOrEmpty(e.Description))) + { + var summary = op.Summary is not null ? $"\"{op.Summary.EscapeCSharpString()}\"" : "null"; + var description = op.Description is not null ? $"\"{op.Description.EscapeCSharpString()}\"" : "null"; + code.AppendLine($" [\"{op.OperationId}\"] = ({summary}, {description}),"); + } + + code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal);"); + code.AppendLine(); + + // Emit parameter docs dictionary + code.AppendLine(" // Pre-computed parameter descriptions from XML tags"); + code.AppendLine( + " private static readonly FrozenDictionary> ParameterDocs ="); + code.AppendLine(" new Dictionary>"); + code.AppendLine(" {"); + + foreach (var op in opsWithParamDocs) + { + code.AppendLine($" [\"{op.OperationId}\"] = new Dictionary"); + code.AppendLine(" {"); + foreach (var (paramName, paramDesc) in op.ParameterDocs.AsImmutableArray()) + { + code.AppendLine( + $" [\"{paramName.EscapeCSharpString()}\"] = \"{paramDesc.EscapeCSharpString()}\","); + } + + code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal),"); + } + + code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal);"); + code.AppendLine(); + + // Emit parameter definitions dictionary + if (opsWithParams.Count > 0) + { + code.AppendLine(" // Pre-computed parameter definitions from handler signatures"); + code.AppendLine( + " private static readonly FrozenDictionary ParameterDefs ="); + code.AppendLine( + " new Dictionary"); + code.AppendLine(" {"); + + foreach (var op in opsWithParams) + { + code.Append($" [\"{op.OperationId}\"] = [("); + var first = true; + foreach (var p in op.Parameters.AsImmutableArray()) + { + if (!first) code.Append("), ("); + + var format = p.SchemaFormat is not null ? $"\"{p.SchemaFormat}\"" : "null"; + var locationEnum = p.Location switch + { + "path" => "ParameterLocation.Path", + "header" => "ParameterLocation.Header", + _ => "ParameterLocation.Query" + }; + var schemaTypeEnum = ToJsonSchemaTypeEnum(p.SchemaType); + code.Append( + $"\"{p.Name}\", {locationEnum}, {(p.Required ? "true" : "false")}, {schemaTypeEnum}, {format}"); + first = false; + } + + code.AppendLine(")],"); + } + + code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal);"); + code.AppendLine(); + } + + code.AppendLine(" public Task TransformAsync("); + code.AppendLine(" OpenApiOperation operation,"); + code.AppendLine(" OpenApiOperationTransformerContext context,"); + code.AppendLine(" CancellationToken cancellationToken)"); + code.AppendLine(" {"); + code.AppendLine(" string? operationId = null;"); + code.AppendLine(" var metadata = context.Description.ActionDescriptor?.EndpointMetadata;"); + code.AppendLine(" if (metadata is not null)"); + code.AppendLine(" {"); + code.AppendLine(" for (var i = 0; i < metadata.Count; i++)"); + code.AppendLine(" {"); + code.AppendLine(" if (metadata[i] is IEndpointNameMetadata nameMetadata)"); + code.AppendLine(" {"); + code.AppendLine(" operationId = nameMetadata.EndpointName;"); + code.AppendLine(" break;"); + code.AppendLine(" }"); + code.AppendLine(" }"); + code.AppendLine(" }"); + code.AppendLine(); + code.AppendLine(" if (operationId is null)"); + code.AppendLine(" return Task.CompletedTask;"); + code.AppendLine(); + code.AppendLine(" // Apply summary and description"); + code.AppendLine(" if (OperationDocs.TryGetValue(operationId, out var docs))"); + code.AppendLine(" {"); + code.AppendLine(" if (docs.Summary is not null)"); + code.AppendLine(" operation.Summary ??= docs.Summary;"); + code.AppendLine(" if (docs.Description is not null)"); + code.AppendLine(" operation.Description ??= docs.Description;"); + code.AppendLine(" }"); + code.AppendLine(); + + // Emit parameter definitions application code + if (opsWithParams.Count > 0) + { + code.AppendLine(" // Add parameter definitions from handler signatures"); + code.AppendLine(" if (ParameterDefs.TryGetValue(operationId, out var paramDefs))"); + code.AppendLine(" {"); + code.AppendLine(" operation.Parameters ??= [];"); + code.AppendLine( + " foreach (var (pName, pLocation, pRequired, pSchemaType, pSchemaFormat) in paramDefs)"); + code.AppendLine(" {"); + code.AppendLine(" var schema = new OpenApiSchema { Type = pSchemaType };"); + code.AppendLine(" if (pSchemaFormat is not null) schema.Format = pSchemaFormat;"); + code.AppendLine(" operation.Parameters.Add(new OpenApiParameter"); + code.AppendLine(" {"); + code.AppendLine(" Name = pName,"); + code.AppendLine(" In = pLocation,"); + code.AppendLine(" Required = pRequired,"); + code.AppendLine(" Schema = schema"); + code.AppendLine(" });"); + code.AppendLine(" }"); + code.AppendLine(" }"); + code.AppendLine(); + } + + code.AppendLine(" // Apply parameter descriptions"); + code.AppendLine( + " if (ParameterDocs.TryGetValue(operationId, out var paramDocs) && operation.Parameters is not null)"); + code.AppendLine(" {"); + code.AppendLine(" foreach (var param in operation.Parameters)"); + code.AppendLine(" {"); + code.AppendLine( + " if (param.Name is not null && paramDocs.TryGetValue(param.Name, out var paramDesc))"); + code.AppendLine(" {"); + code.AppendLine(" param.Description ??= paramDesc;"); + code.AppendLine(" }"); + code.AppendLine(" }"); + code.AppendLine(" }"); + code.AppendLine(); + code.AppendLine(" return Task.CompletedTask;"); + code.AppendLine(" }"); + code.AppendLine("}"); + code.AppendLine(); + + return true; + } + + private static bool EmitSchemaTransformer(StringBuilder code, ImmutableArray types) + { + var typesWithDocs = types.OrderBy(static t => t.TypeKey, StringComparer.Ordinal).ToList(); + + if (typesWithDocs.Count is 0) return false; + + code.AppendLine("/// "); + code.AppendLine("/// Schema transformer that applies type XML documentation to schemas."); + code.AppendLine("/// Each entry is a strict 1:1 mapping from XML doc to schema description."); + code.AppendLine("/// AOT-safe: Uses Type as dictionary key (no runtime reflection)."); + code.AppendLine("/// "); + code.AppendLine("file sealed class XmlDocSchemaTransformer : IOpenApiSchemaTransformer"); + code.AppendLine("{"); + code.AppendLine( + " // Pre-computed type descriptions from XML docs (AOT-safe: Type keys resolved at compile-time)"); + code.AppendLine(" private static readonly FrozenDictionary TypeDescriptions ="); + code.AppendLine(" new Dictionary"); + code.AppendLine(" {"); + + foreach (var type in typesWithDocs) + { + // Convert reflection-style name (Namespace.Outer+Inner) to C# typeof expression (global::Namespace.Outer.Inner) + var typeofExpr = ConvertToTypeofExpression(type.TypeKey); + code.AppendLine($" [typeof({typeofExpr})] = \"{type.Description.EscapeCSharpString()}\","); + } + + code.AppendLine(" }.ToFrozenDictionary();"); + code.AppendLine(); + code.AppendLine(" public Task TransformAsync("); + code.AppendLine(" OpenApiSchema schema,"); + code.AppendLine(" OpenApiSchemaTransformerContext context,"); + code.AppendLine(" CancellationToken cancellationToken)"); + code.AppendLine(" {"); + // AOT-safe: Direct Type lookup without reflection + code.AppendLine(" var type = context.JsonTypeInfo.Type;"); + code.AppendLine(" // For generic types, lookup the generic type definition"); + code.AppendLine(" var lookupType = type.IsGenericType ? type.GetGenericTypeDefinition() : type;"); + code.AppendLine(" if (TypeDescriptions.TryGetValue(lookupType, out var description))"); + code.AppendLine(" {"); + code.AppendLine(" schema.Description ??= description;"); + code.AppendLine(" }"); + code.AppendLine(" return Task.CompletedTask;"); + code.AppendLine(" }"); + code.AppendLine("}"); + code.AppendLine(); + + return true; + } + + private static void EmitRegistrationExtension( + StringBuilder code, + List tags, + bool hasOperationDocs, + bool hasTypeDocs) + { + code.AppendLine("/// "); + code.AppendLine("/// Extension methods for registering generated OpenAPI transformers."); + code.AppendLine("/// "); + code.AppendLine("public static class GeneratedOpenApiExtensions"); + code.AppendLine("{"); + code.AppendLine(" /// "); + code.AppendLine(" /// Adds OpenAPI with generated transformers for ErrorOr endpoints."); + code.AppendLine(" /// Each transformer is registered following the strict 1:1 mapping rule."); + code.AppendLine(" /// "); + code.AppendLine(" public static IServiceCollection AddErrorOrOpenApi("); + code.AppendLine(" this IServiceCollection services,"); + code.AppendLine(" string documentName = \"v1\")"); + code.AppendLine(" {"); + code.AppendLine(" services.AddOpenApi(documentName, options =>"); + code.AppendLine(" {"); + + // Register tag transformers (1:1 - one per tag) + foreach (var tag in tags) + { + var safeTagName = tag.SanitizeIdentifier(); + code.AppendLine($" // Tag: {tag}"); + code.AppendLine($" options.AddDocumentTransformer(new Tag_{safeTagName}_Transformer());"); + } + + // Register operation transformer if we have docs + if (hasOperationDocs) + { + code.AppendLine(); + code.AppendLine(" // XML doc summaries → operation metadata"); + code.AppendLine(" options.AddOperationTransformer(new XmlDocOperationTransformer());"); + } + + // Register schema transformer if we have type docs + if (hasTypeDocs) + { + code.AppendLine(); + code.AppendLine(" // XML doc summaries → schema descriptions"); + code.AppendLine(" options.AddSchemaTransformer(new XmlDocSchemaTransformer());"); + } + + code.AppendLine(" });"); + code.AppendLine(); + code.AppendLine(" return services;"); + code.AppendLine(" }"); + code.AppendLine("}"); + } +} diff --git a/src/ErrorOrX.Generators/Core/OpenApiTransformerGenerator.Extractor.cs b/src/ErrorOrX.Generators/Core/OpenApiTransformerGenerator.Extractor.cs new file mode 100644 index 0000000..325ccf9 --- /dev/null +++ b/src/ErrorOrX.Generators/Core/OpenApiTransformerGenerator.Extractor.cs @@ -0,0 +1,388 @@ +using ANcpLua.Roslyn.Utilities; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace ErrorOr.Generators; + +/// +/// Metadata extraction logic for the OpenAPI transformer generator: pulls XML doc, +/// parameter definitions, and type descriptions out of the compilation. +/// +public sealed partial class OpenApiTransformerGenerator +{ + private static OpenApiEndpointInfo? ExtractOpenApiMetadata( + GeneratorAttributeSyntaxContext ctx, + CancellationToken ct) + { + ct.ThrowIfCancellationRequested(); + + if (ctx.TargetSymbol is not IMethodSymbol { IsStatic: true } method) return null; + + // Extract HTTP method and pattern from attribute + // Combined null check: attr exists AND has a valid AttributeClass + if (ctx.Attributes.FirstOrDefault() is not { AttributeClass: { } attrClass } attr) return null; + + var attrClassName = attrClass.ToDisplayString(); + + var (httpMethod, pattern) = attrClassName switch + { + WellKnownTypes.GetAttribute => (WellKnownTypes.HttpMethod.Get, GetPattern(attr)), + WellKnownTypes.PostAttribute => (WellKnownTypes.HttpMethod.Post, GetPattern(attr)), + WellKnownTypes.PutAttribute => (WellKnownTypes.HttpMethod.Put, GetPattern(attr)), + WellKnownTypes.DeleteAttribute => (WellKnownTypes.HttpMethod.Delete, GetPattern(attr)), + WellKnownTypes.PatchAttribute => (WellKnownTypes.HttpMethod.Patch, GetPattern(attr)), + WellKnownTypes.HeadAttribute => (WellKnownTypes.HttpMethod.Head, GetPattern(attr)), + WellKnownTypes.OptionsAttribute => (WellKnownTypes.HttpMethod.Options, GetPattern(attr)), + WellKnownTypes.TraceAttribute => (WellKnownTypes.HttpMethod.Trace, GetPattern(attr)), + WellKnownTypes.ErrorOrEndpointAttribute => GetBaseAttributeInfo(attr), + _ => (null, null) + }; + + if (httpMethod is null || pattern is null) return null; + + // Extract XML documentation + var xmlDoc = method.GetDocumentationCommentXml(cancellationToken: ct); + var (summary, description) = ParseXmlDoc(xmlDoc); + var parameterDocs = ParseParamTags(xmlDoc); + + // Extract containing type info for tag generation + var containingType = method.ContainingType; + var containingTypeFqn = containingType.GetFullyQualifiedName(); + var (tagName, operationId) = EndpointNameHelper.GetEndpointIdentity(containingTypeFqn, method.Name); + + var parameters = ExtractParameterDefinitions(method, pattern); + + return new OpenApiEndpointInfo( + operationId, + tagName, + summary, + description, + httpMethod.ToUpperInvariant(), + pattern, + new EquatableArray<(string, string)>(parameterDocs), + new EquatableArray(parameters)); + } + + private static (string? summary, string? description) ParseXmlDoc(string? xml) + { + if (xml is null || string.IsNullOrWhiteSpace(xml)) return (null, null); + + string? summary = null; + string? description = null; + + // Simple XML parsing for summary and remarks + var summaryStart = xml.IndexOfOrdinal(""); + var summaryEnd = xml.IndexOfOrdinal(""); + if (summaryStart >= 0 && summaryEnd > summaryStart) + { + summary = xml.Substring(summaryStart + 9, summaryEnd - summaryStart - 9) + .Trim() + .Replace("\r\n", " ") + .Replace('\n', ' ') + .Trim(); + } + + var remarksStart = xml.IndexOfOrdinal(""); + var remarksEnd = xml.IndexOfOrdinal(""); + if (remarksStart >= 0 && remarksEnd > remarksStart) + { + description = xml.Substring(remarksStart + 9, remarksEnd - remarksStart - 9) + .Trim() + .Replace("\r\n", " ") + .Replace('\n', ' ') + .Trim(); + } + + return (summary, description); + } + + private static ImmutableArray<(string ParamName, string Description)> ParseParamTags(string? xml) + { + if (xml is null || string.IsNullOrWhiteSpace(xml)) return ImmutableArray<(string, string)>.Empty; + + var parameters = new List<(string, string)>(); + var searchPos = 0; + + while (true) + { + var paramStart = xml.IndexOf("", nameEnd, StringComparison.Ordinal); + if (contentStart < 0) break; + + contentStart++; + + var contentEnd = xml.IndexOf("", contentStart, StringComparison.Ordinal); + if (contentEnd < 0) break; + + var description = xml.Substring(contentStart, contentEnd - contentStart) + .Trim() + .Replace("\r\n", " ") + .Replace("\n", " ") + .Trim(); + if (!string.IsNullOrWhiteSpace(description)) parameters.Add((paramName, description)); + + searchPos = contentEnd + 8; + } + + return [.. parameters]; + } + + private static ImmutableArray ExtractParameterDefinitions( + IMethodSymbol method, string pattern) + { + var routeParams = RouteValidator.ExtractRouteParameters(pattern); + var routeParamNames = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var rp in routeParams) + routeParamNames.Add(rp.Name); + + var parameters = new List(); + + foreach (var param in method.Parameters) + { + var typeFqn = param.Type.ToDisplayString(); + + // Skip special types (services, context, etc.) + if (IsSkippedParameterType(param, typeFqn)) continue; + + // Check explicit binding attributes + var (explicitLocation, explicitName) = GetExplicitBinding(param); + + string name; + string location; + bool required; + + if (explicitLocation is not null) + { + // Explicit attribute wins + name = explicitName ?? param.Name; + location = explicitLocation; + required = location == "path" || + (param.Type.NullableAnnotation != NullableAnnotation.Annotated && + !param.HasExplicitDefaultValue); + } + else if (routeParamNames.Contains(param.Name)) + { + // Route parameter + name = param.Name; + location = "path"; + required = true; + + // Check if optional in route template + foreach (var rp in routeParams) + { + if (string.Equals(rp.Name, param.Name, StringComparison.OrdinalIgnoreCase) && rp.IsOptional) + { + required = false; + break; + } + } + } + else if (IsPrimitiveType(typeFqn)) + { + // Primitive not in route = query + name = param.Name; + location = "query"; + required = param.Type.NullableAnnotation != NullableAnnotation.Annotated && + !param.HasExplicitDefaultValue; + } + else + { + // Complex type without explicit binding - skip (it's body or service) + continue; + } + + var (schemaType, schemaFormat) = GetOpenApiSchema(typeFqn); + parameters.Add(new OpenApiParameterInfo(name, location, required, schemaType, schemaFormat)); + } + + return [.. parameters]; + } + + private static bool IsSkippedParameterType(IParameterSymbol param, string typeFqn) + { + // Skip special framework types + if (typeFqn is WellKnownTypes.HttpContext or WellKnownTypes.CancellationToken + or WellKnownTypes.FormFile or WellKnownTypes.FormFileCollection + or WellKnownTypes.Stream or WellKnownTypes.PipeReader or WellKnownTypes.FormCollection) + { + return true; + } + + // Skip interface types (services) + if (param.Type.TypeKind == TypeKind.Interface) return true; + + // Skip abstract types (services) + if (param.Type is { IsAbstract: true, TypeKind: TypeKind.Class }) return true; + + // Skip [FromServices] / [FromKeyedServices] / [FromBody] / [FromForm] + foreach (var attr in param.GetAttributes()) + { + var attrName = attr.AttributeClass?.ToDisplayString(); + if (attrName is WellKnownTypes.FromServicesAttribute or WellKnownTypes.FromBodyAttribute + or WellKnownTypes.FromFormAttribute or WellKnownTypes.FromKeyedServicesAttribute) + { + return true; + } + } + + return false; + } + + private static (string? Location, string? Name) GetExplicitBinding(ISymbol param) + { + foreach (var attr in param.GetAttributes()) + { + var attrName = attr.AttributeClass?.ToDisplayString(); + switch (attrName) + { + case WellKnownTypes.FromRouteAttribute: + { + var name = GetAttributeStringArg(attr, "Name"); + return ("path", name); + } + case WellKnownTypes.FromQueryAttribute: + { + var name = GetAttributeStringArg(attr, "Name"); + return ("query", name); + } + case WellKnownTypes.FromHeaderAttribute: + { + var name = GetAttributeStringArg(attr, "Name"); + return ("header", name); + } + } + } + + return (null, null); + } + + private static string? GetAttributeStringArg(AttributeData attr, string propName) + { + foreach (var kvp in attr.NamedArguments) + { + if (kvp.Key == propName && kvp.Value.Value is string s && !string.IsNullOrWhiteSpace(s)) + return s; + } + + return null; + } + + private static bool IsPrimitiveType(string typeFqn) + { + // Strip nullable wrapper + var type = typeFqn.EndsWithOrdinal("?") ? typeFqn.Substring(0, typeFqn.Length - 1) : typeFqn; + + return type is "int" or "System.Int32" + or "long" or "System.Int64" + or "short" or "System.Int16" + or "uint" or "System.UInt32" + or "ulong" or "System.UInt64" + or "ushort" or "System.UInt16" + or "byte" or "System.Byte" + or "sbyte" or "System.SByte" + or "bool" or "System.Boolean" + or "decimal" or "System.Decimal" + or "double" or "System.Double" + or "float" or "System.Single" + or "string" or "System.String" + or "System.Guid" + or "System.DateTime" + or "System.DateTimeOffset" + or "System.DateOnly" + or "System.TimeOnly" + or "System.TimeSpan"; + } + + private static (string SchemaType, string? SchemaFormat) GetOpenApiSchema(string typeFqn) + { + // Strip nullable wrapper + var type = typeFqn.EndsWithOrdinal("?") ? typeFqn.Substring(0, typeFqn.Length - 1) : typeFqn; + + return type switch + { + "int" or "System.Int32" => ("integer", "int32"), + "long" or "System.Int64" => ("integer", "int64"), + "short" or "System.Int16" => ("integer", "int16"), + "uint" or "System.UInt32" => ("integer", "int32"), + "ulong" or "System.UInt64" => ("integer", "int64"), + "ushort" or "System.UInt16" => ("integer", "int16"), + "byte" or "System.Byte" => ("integer", "int32"), + "sbyte" or "System.SByte" => ("integer", "int32"), + "bool" or "System.Boolean" => ("boolean", null), + "decimal" or "System.Decimal" => ("number", "double"), + "double" or "System.Double" => ("number", "double"), + "float" or "System.Single" => ("number", "float"), + "System.Guid" => ("string", "uuid"), + "System.DateTime" => ("string", "date-time"), + "System.DateTimeOffset" => ("string", "date-time"), + "System.DateOnly" => ("string", "date"), + "System.TimeOnly" => ("string", "time"), + "System.TimeSpan" => ("string", "duration"), + _ => ("string", null) + }; + } + + /// + /// Maps internal schema type string to OpenApi v2.0 JsonSchemaType enum name for emission. + /// + private static string ToJsonSchemaTypeEnum(string schemaType) + { + return schemaType switch + { + "integer" => "JsonSchemaType.Integer", + "number" => "JsonSchemaType.Number", + "boolean" => "JsonSchemaType.Boolean", + _ => "JsonSchemaType.String" + }; + } + + private static string GetReflectionFullName(ISymbol symbol) + { + var fqn = ((ITypeSymbol)symbol).GetFullyQualifiedName(); + return fqn.StartsWithOrdinal("global::") ? fqn.Substring("global::".Length) : fqn; + } + + private static TypeMetadataInfo? ExtractTypeMetadata( + GeneratorSyntaxContext ctx, + CancellationToken ct) + { + if (ctx.Node is not TypeDeclarationSyntax typeDecl) return null; + + // Skip null symbols and compiler-generated types + if (ctx.SemanticModel.GetDeclaredSymbol(typeDecl, ct) is not INamedTypeSymbol symbol || + symbol.IsImplicitlyDeclared) + { + return null; + } + + // Skip types without XML docs + var xmlDoc = symbol.GetDocumentationCommentXml(cancellationToken: ct); + if (string.IsNullOrWhiteSpace(xmlDoc)) return null; + + var (summary, _) = ParseXmlDoc(xmlDoc); + if (summary is null) return null; + + var typeKey = GetReflectionFullName(symbol); + + return new TypeMetadataInfo(typeKey, summary); + } + + /// + /// Converts a reflection-style type name to a C# typeof expression. + /// Example: "Namespace.Outer+Inner" → "global::Namespace.Outer.Inner" + /// + private static string ConvertToTypeofExpression(string reflectionName) + { + // Replace nested type separator (+) with C# dot notation + var csharpName = reflectionName.Replace('+', '.'); + return $"global::{csharpName}"; + } +} diff --git a/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Body.cs b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Body.cs new file mode 100644 index 0000000..9e1882e --- /dev/null +++ b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Body.cs @@ -0,0 +1,142 @@ +namespace ErrorOr.Generators.Emitters; + +/// +/// Request-body binding emission: JSON body (with empty-body Allow/Disallow split), +/// multipart/form-data field binding, and [AsParameters] constructor expansion. +/// All three consume something from ctx.Request and produce either a parsed DTO or a +/// BindFail short-circuit. +/// +internal static partial class BindingCodeEmitter +{ + internal static bool EmitBodyBinding(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + // Determine effective behavior: explicit > nullability-based default + var effectiveBehavior = param.EmptyBodyBehavior; + if (effectiveBehavior == EmptyBodyBehavior.Default) + effectiveBehavior = param.IsNullable ? EmptyBodyBehavior.Allow : EmptyBodyBehavior.Disallow; + + return effectiveBehavior switch + { + EmptyBodyBehavior.Allow => EmitBodyBindingAllow(code, in param, paramName, bindFailFn), + _ => EmitBodyBindingDisallow(code, in param, paramName, bindFailFn) + }; + } + + internal static bool EmitBodyBindingAllow(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + // Allow empty bodies - check ContentLength before reading + code.AppendLine($" {param.TypeFqn}? {paramName};"); + code.AppendLine(" if (ctx.Request.ContentLength is null or 0)"); + code.AppendLine(" {"); + code.AppendLine($" {paramName} = default;"); + code.AppendLine(" }"); + code.AppendLine(" else"); + code.AppendLine(" {"); + code.AppendLine(" if (!ctx.Request.HasJsonContentType()) return BindFail415();"); + code.AppendLine(" try"); + code.AppendLine(" {"); + code.AppendLine( + $" {paramName} = await ctx.Request.ReadFromJsonAsync<{param.TypeFqn}>(cancellationToken: ctx.RequestAborted);"); + code.AppendLine(" }"); + code.AppendLine($" catch ({WellKnownTypes.Fqn.JsonException})"); + code.AppendLine(" {"); + code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"has invalid JSON format\");"); + code.AppendLine(" }"); + code.AppendLine(" }"); + return true; + } + + internal static bool EmitBodyBindingDisallow(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + // Disallow empty bodies - reject with 400 if empty + code.AppendLine(" if (ctx.Request.ContentLength is null or 0)"); + code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); + code.AppendLine(" if (!ctx.Request.HasJsonContentType()) return BindFail415();"); + code.AppendLine($" {param.TypeFqn}? {paramName};"); + code.AppendLine(" try"); + code.AppendLine(" {"); + code.AppendLine( + $" {paramName} = await ctx.Request.ReadFromJsonAsync<{param.TypeFqn}>(cancellationToken: ctx.RequestAborted);"); + code.AppendLine(" }"); + code.AppendLine($" catch ({WellKnownTypes.Fqn.JsonException})"); + code.AppendLine(" {"); + code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"has invalid JSON format\");"); + code.AppendLine(" }"); + code.AppendLine( + $" if ({paramName} is null) return {bindFailFn}(\"{param.Name}\", \"is required\");"); + return true; + } + + internal static bool EmitFormBinding(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + if (!param.Children.IsDefaultOrEmpty) + { + var usesBindFail = false; + for (var i = 0; i < param.Children.Length; i++) + { + var child = param.Children[i]; + usesBindFail |= EmitParameterBinding(code, in child, $"{paramName}_f{i}", bindFailFn); + } + + var args = string.Join(", ", param.Children.AsImmutableArray().Select((_, i) => $"{paramName}_f{i}")); + code.AppendLine($" var {paramName} = new {param.TypeFqn}({args});"); + return usesBindFail; + } + + var usesBindFailScalar = false; + var fieldName = param.KeyName ?? param.Name; + var declType = param.IsNullable && !param.TypeFqn.EndsWithOrdinal("?") ? param.TypeFqn + "?" : param.TypeFqn; + code.AppendLine($" {declType} {paramName};"); + code.AppendLine( + $" if (!form.TryGetValue(\"{fieldName}\", out var {paramName}Raw) || {paramName}Raw.Count is 0)"); + code.AppendLine(" {"); + if (param.IsNullable) + { + code.AppendLine($" {paramName} = default;"); + } + else + { + usesBindFailScalar = true; + code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); + } + + code.AppendLine(" }"); + code.AppendLine(" else"); + code.AppendLine(" {"); + if (param.TypeFqn.IsStringType()) + { + code.AppendLine($" {paramName} = {paramName}Raw.ToString();"); + } + else + { + usesBindFailScalar = true; + code.AppendLine( + $" if (!{GetTryParseExpression(param.TypeFqn, paramName + "Raw.ToString()", paramName + "Temp")}) return {bindFailFn}(\"{param.Name}\", \"has invalid format\");"); + code.AppendLine($" {paramName} = {paramName}Temp;"); + } + + code.AppendLine(" }"); + return usesBindFailScalar; + } + + internal static bool EmitAsParametersBinding(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + var usesBindFail = false; + var childVars = new List(); + for (var i = 0; i < param.Children.Length; i++) + { + var child = param.Children[i]; + var childVarName = $"{paramName}_c{i}"; + usesBindFail |= EmitParameterBinding(code, in child, childVarName, bindFailFn); + childVars.Add(BuildArgumentExpression(in child, childVarName)); + } + + code.AppendLine($" var {paramName} = new {param.TypeFqn}({string.Join(", ", childVars)});"); + return usesBindFail; + } +} diff --git a/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Parsing.cs b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Parsing.cs new file mode 100644 index 0000000..6ce7b28 --- /dev/null +++ b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Parsing.cs @@ -0,0 +1,148 @@ +namespace ErrorOr.Generators.Emitters; + +/// +/// Cross-cutting helpers shared by every binding family: +/// +/// — composes the call-site expression respecting nullability and value-vs-reference type rules. +/// — table-routed BCL-aware TryParse invocation per type FQN. +/// — emits the Dictionary<string, string[]> aggregation pattern used by both DataAnnotations and ErrorOr.Validation paths. +/// +/// +internal static partial class BindingCodeEmitter +{ + internal static string BuildArgumentExpression(in EndpointParameter param, string paramName) + { + var source = param.Source; + + if (source == ParameterSource.Body && !param.IsNullable) + return paramName + "!"; + + if (source == ParameterSource.Route && param is { IsNullable: false, IsNonNullableValueType: false }) + return paramName + "!"; + + if (source is ParameterSource.Query or ParameterSource.Header + && param is { IsNullable: false, IsNonNullableValueType: true }) + { + return paramName + ".Value"; + } + + if (source is ParameterSource.Query or ParameterSource.Header + && param is { IsNullable: false, IsNonNullableValueType: false }) + { + return paramName + "!"; + } + + return paramName; + } + + internal static string GetTryParseExpression(string typeFqn, string rawName, string outputName, + CustomBindingMethod customBinding = CustomBindingMethod.None) + { + if (customBinding is CustomBindingMethod.TryParse) + { + var baseType = typeFqn.TrimEnd('?'); + return $"{baseType}.TryParse({rawName}, out var {outputName})"; + } + + if (customBinding is CustomBindingMethod.TryParseWithFormat) + { + var baseType = typeFqn.TrimEnd('?'); + return + $"{baseType}.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})"; + } + + var normalized = typeFqn.Replace("global::", "").TrimEnd('?'); + return normalized switch + { + // Integer types - no IFormatProvider overload + "System.Int32" or "int" => $"int.TryParse({rawName}, out var {outputName})", + "System.Int64" or "long" => $"long.TryParse({rawName}, out var {outputName})", + "System.Int16" or "short" => $"short.TryParse({rawName}, out var {outputName})", + "System.Byte" or "byte" => $"byte.TryParse({rawName}, out var {outputName})", + "System.SByte" or "sbyte" => $"sbyte.TryParse({rawName}, out var {outputName})", + "System.UInt16" or "ushort" => $"ushort.TryParse({rawName}, out var {outputName})", + "System.UInt32" or "uint" => $"uint.TryParse({rawName}, out var {outputName})", + "System.UInt64" or "ulong" => $"ulong.TryParse({rawName}, out var {outputName})", + "System.Int128" => $"global::System.Int128.TryParse({rawName}, out var {outputName})", + "System.UInt128" => $"global::System.UInt128.TryParse({rawName}, out var {outputName})", + + // Other types without IFormatProvider overload + "System.Boolean" or "bool" => $"bool.TryParse({rawName}, out var {outputName})", + "System.Guid" => $"global::System.Guid.TryParse({rawName}, out var {outputName})", + "System.Uri" => + $"global::System.Uri.TryCreate({rawName}, global::System.UriKind.RelativeOrAbsolute, out var {outputName})", + + // Culture-sensitive floating point types - use InvariantCulture + "System.Decimal" or "decimal" => + $"decimal.TryParse({rawName}, global::System.Globalization.NumberStyles.Number, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", + "System.Double" or "double" => + $"double.TryParse({rawName}, global::System.Globalization.NumberStyles.Float | global::System.Globalization.NumberStyles.AllowThousands, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", + "System.Single" or "float" => + $"float.TryParse({rawName}, global::System.Globalization.NumberStyles.Float | global::System.Globalization.NumberStyles.AllowThousands, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", + "System.Half" => + $"global::System.Half.TryParse({rawName}, global::System.Globalization.NumberStyles.Float, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", + + // Culture-sensitive date/time types - use InvariantCulture + "System.DateTime" => + $"global::System.DateTime.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.RoundtripKind, out var {outputName})", + "System.DateTimeOffset" => + $"global::System.DateTimeOffset.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.RoundtripKind, out var {outputName})", + "System.DateOnly" => + $"global::System.DateOnly.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.None, out var {outputName})", + "System.TimeOnly" => + $"global::System.TimeOnly.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.None, out var {outputName})", + "System.TimeSpan" => + $"global::System.TimeSpan.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", + + _ => "false" + }; + } + + /// + /// Emits the standard validation dictionary building pattern. + /// Consolidates the repeated logic for aggregating errors by key into string arrays. + /// + /// The StringBuilder to append to. + /// Base indentation (number of spaces). + /// Name of the dictionary variable. + /// The collection to iterate over. + /// Name of the loop variable. + /// Expression to get the dictionary key. + /// Expression to get the value to add. + /// Optional filter expression (items not matching are skipped). + /// Optional local variable declaration for key (emitted before TryGetValue). + internal static void EmitValidationDictBuilder( + StringBuilder code, + int indent, + string dictName, + string iteratorSource, + string iteratorVar, + string keyExpr, + string valueExpr, + string? filterExpr = null, + string? keyVarDecl = null) + { + var pad = new string(' ', indent); + var pad4 = new string(' ', indent + 4); + var pad8 = new string(' ', indent + 8); + + code.AppendLine($"{pad}var {dictName} = new {WellKnownTypes.Fqn.Dictionary}();"); + code.AppendLine($"{pad}foreach (var {iteratorVar} in {iteratorSource})"); + code.AppendLine($"{pad}{{"); + + if (filterExpr is not null) code.AppendLine($"{pad4}if ({filterExpr}) continue;"); + + if (keyVarDecl is not null) code.AppendLine($"{pad4}{keyVarDecl}"); + + code.AppendLine($"{pad4}if (!{dictName}.TryGetValue({keyExpr}, out var existing))"); + code.AppendLine($"{pad8}{dictName}[{keyExpr}] = new[] {{ {valueExpr} }};"); + code.AppendLine($"{pad4}else"); + code.AppendLine($"{pad4}{{"); + code.AppendLine($"{pad8}var arr = new string[existing.Length + 1];"); + code.AppendLine($"{pad8}existing.CopyTo(arr, 0);"); + code.AppendLine($"{pad8}arr[existing.Length] = {valueExpr};"); + code.AppendLine($"{pad8}{dictName}[{keyExpr}] = arr;"); + code.AppendLine($"{pad4}}}"); + code.AppendLine($"{pad}}}"); + } +} diff --git a/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Query.cs b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Query.cs new file mode 100644 index 0000000..aaee07c --- /dev/null +++ b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.Query.cs @@ -0,0 +1,174 @@ +namespace ErrorOr.Generators.Emitters; + +/// +/// Query and Header binding emission. Both sources share the same scalar/collection branching +/// and the same nullable-vs-required emission shape; only the source-extraction call differs +/// (TryGetQueryValue / ctx.Request.Query["..."] vs ctx.Request.Headers.TryGetValue). +/// Also hosts since it's a query-bound custom hook. +/// +internal static partial class BindingCodeEmitter +{ + internal static bool EmitQueryBinding(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + if (param.CustomBinding is CustomBindingMethod.BindAsync or CustomBindingMethod.BindAsyncWithParam) + return EmitBindAsyncBinding(code, in param, paramName, bindFailFn); + + var queryKey = param.KeyName ?? param.Name; + return param is { IsCollection: true, CollectionItemTypeFqn: { } itemType } + ? EmitCollectionQueryBinding(code, in param, paramName, queryKey, itemType, bindFailFn) + : EmitScalarQueryBinding(code, in param, paramName, queryKey, bindFailFn); + } + + internal static bool EmitBindAsyncBinding(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + var baseType = param.TypeFqn.TrimEnd('?'); + code.AppendLine($" var {paramName} = await {baseType}.BindAsync(ctx);"); + if (param.IsNullable) return false; + + code.AppendLine( + $" if ({paramName} is null) return {bindFailFn}(\"{param.Name}\", \"binding failed\");"); + return true; + } + + internal static bool EmitCollectionQueryBinding(StringBuilder code, in EndpointParameter param, string paramName, + string queryKey, string itemType, string bindFailFn) + { + code.AppendLine($" var {paramName}Raw = ctx.Request.Query[\"{queryKey}\"];"); + code.AppendLine($" var {paramName}List = new {WellKnownTypes.Fqn.List}<{itemType}>();"); + code.AppendLine($" foreach (var item in {paramName}Raw)"); + code.AppendLine(" {"); + + var usesBindFail = false; + if (itemType.IsStringType()) + { + code.AppendLine( + $" if (item is {{ Length: > 0 }} validItem) {paramName}List.Add(validItem);"); + } + else + { + usesBindFail = true; + code.AppendLine( + $" if ({GetTryParseExpression(itemType, "item", "parsedItem")}) {paramName}List.Add(parsedItem);"); + code.AppendLine( + $" else if (!string.IsNullOrEmpty(item)) return {bindFailFn}(\"{param.Name}\", \"has invalid item format\");"); + } + + code.AppendLine(" }"); + var isArray = param.TypeFqn.EndsWithOrdinal("[]"); + var assignment = isArray ? $"{paramName}List.ToArray()" : $"{paramName}List"; + code.AppendLine($" var {paramName} = {assignment};"); + return usesBindFail; + } + + internal static bool EmitScalarQueryBinding(StringBuilder code, in EndpointParameter param, string paramName, + string queryKey, string bindFailFn) + { + var usesBindFail = false; + var declType = param.TypeFqn.EndsWithOrdinal("?") ? param.TypeFqn : param.TypeFqn + "?"; + code.AppendLine($" {declType} {paramName};"); + code.AppendLine($" if (!TryGetQueryValue(ctx, \"{queryKey}\", out var {paramName}Raw))"); + code.AppendLine(" {"); + if (param.IsNullable) + { + code.AppendLine($" {paramName} = default;"); + } + else + { + usesBindFail = true; + code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); + } + + code.AppendLine(" }"); + code.AppendLine(" else"); + code.AppendLine(" {"); + if (param.TypeFqn.IsStringType()) + { + code.AppendLine($" {paramName} = {paramName}Raw;"); + } + else + { + usesBindFail = true; + code.AppendLine( + $" if (!{GetTryParseExpression(param.TypeFqn, paramName + "Raw", paramName + "Temp", param.CustomBinding)}) return {bindFailFn}(\"{param.Name}\", \"has invalid format\");"); + code.AppendLine($" {paramName} = {paramName}Temp;"); + } + + code.AppendLine(" }"); + return usesBindFail; + } + + internal static bool EmitHeaderBinding(StringBuilder code, in EndpointParameter param, string paramName, + string bindFailFn) + { + var key = param.KeyName ?? param.Name; + var usesBindFail = false; + + if (param is { IsCollection: true, CollectionItemTypeFqn: { } itemType }) + { + code.AppendLine($" {param.TypeFqn} {paramName};"); + code.AppendLine( + $" if (!ctx.Request.Headers.TryGetValue(\"{key}\", out var {paramName}Raw) || {paramName}Raw.Count is 0)"); + code.AppendLine(" {"); + if (param.IsNullable) + { + code.AppendLine($" {paramName} = default!;"); + } + else + { + usesBindFail = true; + code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); + } + + code.AppendLine(" }"); + code.AppendLine(" else"); + code.AppendLine(" {"); + code.AppendLine($" var {paramName}List = new {WellKnownTypes.Fqn.List}<{itemType}>();"); + code.AppendLine($" foreach (var item in {paramName}Raw)"); + code.AppendLine(" {"); + code.AppendLine( + itemType.IsStringType() + ? $" if (item is {{ Length: > 0 }} validItem) {paramName}List.Add(validItem);" + : $" if ({GetTryParseExpression(itemType, "item", "parsedItem")}) {paramName}List.Add(parsedItem);"); + code.AppendLine(" }"); + var isArray = param.TypeFqn.EndsWithOrdinal("[]"); + var assignment = isArray ? $"{paramName}List.ToArray()" : $"{paramName}List"; + code.AppendLine($" {paramName} = {assignment};"); + } + else + { + var declType = param.TypeFqn.EndsWithOrdinal("?") ? param.TypeFqn : param.TypeFqn + "?"; + code.AppendLine($" {declType} {paramName};"); + code.AppendLine( + $" if (!ctx.Request.Headers.TryGetValue(\"{key}\", out var {paramName}Raw) || {paramName}Raw.Count is 0)"); + code.AppendLine(" {"); + if (param.IsNullable) + { + code.AppendLine($" {paramName} = default;"); + } + else + { + usesBindFail = true; + code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); + } + + code.AppendLine(" }"); + code.AppendLine(" else"); + code.AppendLine(" {"); + if (param.TypeFqn.IsStringType()) + { + code.AppendLine($" {paramName} = {paramName}Raw.ToString();"); + } + else + { + usesBindFail = true; + code.AppendLine( + $" if (!{GetTryParseExpression(param.TypeFqn, paramName + "Raw.ToString()", paramName + "Temp")}) return {bindFailFn}(\"{param.Name}\", \"has invalid format\"); {paramName} = {paramName}Temp;"); + } + } + + code.AppendLine(" }"); + return usesBindFail; + } +} diff --git a/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.cs b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.cs index 733ab1f..1204622 100644 --- a/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.cs +++ b/src/ErrorOrX.Generators/Emitters/BindingCodeEmitter.cs @@ -1,6 +1,16 @@ namespace ErrorOr.Generators.Emitters; -internal static class BindingCodeEmitter +/// +/// Emits the per-parameter C# binding code consumed by the generated Invoke_Ep{N}_Core +/// methods. Partial across: +/// +/// BindingCodeEmitter.cs — Dispatcher, Route, and special / service / form-file bindings. +/// BindingCodeEmitter.Query.cs — Query and Header bindings (scalar + collection). +/// BindingCodeEmitter.Body.cs — Body, Form (DTO + scalar), and AsParameters expansion. +/// BindingCodeEmitter.Parsing.cs — Shared BuildArgumentExpression, GetTryParseExpression, validation dict builder. +/// +/// +internal static partial class BindingCodeEmitter { /// /// Emits parameter binding code and returns whether BindFail helper is used. @@ -100,436 +110,4 @@ internal static bool EmitRouteBinding(StringBuilder code, in EndpointParameter p : $" if (!TryGetRouteValue(ctx, \"{routeName}\", out var {paramName}Raw) || !{GetTryParseExpression(param.TypeFqn, paramName + "Raw", paramName, param.CustomBinding)}) return {bindFailFn}(\"{param.Name}\", \"has invalid format\");"); return true; } - - internal static bool EmitQueryBinding(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - if (param.CustomBinding is CustomBindingMethod.BindAsync or CustomBindingMethod.BindAsyncWithParam) - return EmitBindAsyncBinding(code, in param, paramName, bindFailFn); - - var queryKey = param.KeyName ?? param.Name; - return param is { IsCollection: true, CollectionItemTypeFqn: { } itemType } - ? EmitCollectionQueryBinding(code, in param, paramName, queryKey, itemType, bindFailFn) - : EmitScalarQueryBinding(code, in param, paramName, queryKey, bindFailFn); - } - - internal static bool EmitBindAsyncBinding(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - var baseType = param.TypeFqn.TrimEnd('?'); - code.AppendLine($" var {paramName} = await {baseType}.BindAsync(ctx);"); - if (param.IsNullable) return false; - - code.AppendLine( - $" if ({paramName} is null) return {bindFailFn}(\"{param.Name}\", \"binding failed\");"); - return true; - } - - internal static bool EmitCollectionQueryBinding(StringBuilder code, in EndpointParameter param, string paramName, - string queryKey, string itemType, string bindFailFn) - { - code.AppendLine($" var {paramName}Raw = ctx.Request.Query[\"{queryKey}\"];"); - code.AppendLine($" var {paramName}List = new {WellKnownTypes.Fqn.List}<{itemType}>();"); - code.AppendLine($" foreach (var item in {paramName}Raw)"); - code.AppendLine(" {"); - - var usesBindFail = false; - if (itemType.IsStringType()) - { - code.AppendLine( - $" if (item is {{ Length: > 0 }} validItem) {paramName}List.Add(validItem);"); - } - else - { - usesBindFail = true; - code.AppendLine( - $" if ({GetTryParseExpression(itemType, "item", "parsedItem")}) {paramName}List.Add(parsedItem);"); - code.AppendLine( - $" else if (!string.IsNullOrEmpty(item)) return {bindFailFn}(\"{param.Name}\", \"has invalid item format\");"); - } - - code.AppendLine(" }"); - var isArray = param.TypeFqn.EndsWithOrdinal("[]"); - var assignment = isArray ? $"{paramName}List.ToArray()" : $"{paramName}List"; - code.AppendLine($" var {paramName} = {assignment};"); - return usesBindFail; - } - - internal static bool EmitScalarQueryBinding(StringBuilder code, in EndpointParameter param, string paramName, - string queryKey, string bindFailFn) - { - var usesBindFail = false; - var declType = param.TypeFqn.EndsWithOrdinal("?") ? param.TypeFqn : param.TypeFqn + "?"; - code.AppendLine($" {declType} {paramName};"); - code.AppendLine($" if (!TryGetQueryValue(ctx, \"{queryKey}\", out var {paramName}Raw))"); - code.AppendLine(" {"); - if (param.IsNullable) - { - code.AppendLine($" {paramName} = default;"); - } - else - { - usesBindFail = true; - code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); - } - - code.AppendLine(" }"); - code.AppendLine(" else"); - code.AppendLine(" {"); - if (param.TypeFqn.IsStringType()) - { - code.AppendLine($" {paramName} = {paramName}Raw;"); - } - else - { - usesBindFail = true; - code.AppendLine( - $" if (!{GetTryParseExpression(param.TypeFqn, paramName + "Raw", paramName + "Temp", param.CustomBinding)}) return {bindFailFn}(\"{param.Name}\", \"has invalid format\");"); - code.AppendLine($" {paramName} = {paramName}Temp;"); - } - - code.AppendLine(" }"); - return usesBindFail; - } - - internal static bool EmitHeaderBinding(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - var key = param.KeyName ?? param.Name; - var usesBindFail = false; - - if (param is { IsCollection: true, CollectionItemTypeFqn: { } itemType }) - { - code.AppendLine($" {param.TypeFqn} {paramName};"); - code.AppendLine( - $" if (!ctx.Request.Headers.TryGetValue(\"{key}\", out var {paramName}Raw) || {paramName}Raw.Count is 0)"); - code.AppendLine(" {"); - if (param.IsNullable) - { - code.AppendLine($" {paramName} = default!;"); - } - else - { - usesBindFail = true; - code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); - } - - code.AppendLine(" }"); - code.AppendLine(" else"); - code.AppendLine(" {"); - code.AppendLine($" var {paramName}List = new {WellKnownTypes.Fqn.List}<{itemType}>();"); - code.AppendLine($" foreach (var item in {paramName}Raw)"); - code.AppendLine(" {"); - code.AppendLine( - itemType.IsStringType() - ? $" if (item is {{ Length: > 0 }} validItem) {paramName}List.Add(validItem);" - : $" if ({GetTryParseExpression(itemType, "item", "parsedItem")}) {paramName}List.Add(parsedItem);"); - code.AppendLine(" }"); - var isArray = param.TypeFqn.EndsWithOrdinal("[]"); - var assignment = isArray ? $"{paramName}List.ToArray()" : $"{paramName}List"; - code.AppendLine($" {paramName} = {assignment};"); - } - else - { - var declType = param.TypeFqn.EndsWithOrdinal("?") ? param.TypeFqn : param.TypeFqn + "?"; - code.AppendLine($" {declType} {paramName};"); - code.AppendLine( - $" if (!ctx.Request.Headers.TryGetValue(\"{key}\", out var {paramName}Raw) || {paramName}Raw.Count is 0)"); - code.AppendLine(" {"); - if (param.IsNullable) - { - code.AppendLine($" {paramName} = default;"); - } - else - { - usesBindFail = true; - code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); - } - - code.AppendLine(" }"); - code.AppendLine(" else"); - code.AppendLine(" {"); - if (param.TypeFqn.IsStringType()) - { - code.AppendLine($" {paramName} = {paramName}Raw.ToString();"); - } - else - { - usesBindFail = true; - code.AppendLine( - $" if (!{GetTryParseExpression(param.TypeFqn, paramName + "Raw.ToString()", paramName + "Temp")}) return {bindFailFn}(\"{param.Name}\", \"has invalid format\"); {paramName} = {paramName}Temp;"); - } - } - - code.AppendLine(" }"); - return usesBindFail; - } - - internal static bool EmitBodyBinding(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - // Determine effective behavior: explicit > nullability-based default - var effectiveBehavior = param.EmptyBodyBehavior; - if (effectiveBehavior == EmptyBodyBehavior.Default) - effectiveBehavior = param.IsNullable ? EmptyBodyBehavior.Allow : EmptyBodyBehavior.Disallow; - - return effectiveBehavior switch - { - EmptyBodyBehavior.Allow => EmitBodyBindingAllow(code, in param, paramName, bindFailFn), - _ => EmitBodyBindingDisallow(code, in param, paramName, bindFailFn) - }; - } - - internal static bool EmitBodyBindingAllow(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - // Allow empty bodies - check ContentLength before reading - code.AppendLine($" {param.TypeFqn}? {paramName};"); - code.AppendLine(" if (ctx.Request.ContentLength is null or 0)"); - code.AppendLine(" {"); - code.AppendLine($" {paramName} = default;"); - code.AppendLine(" }"); - code.AppendLine(" else"); - code.AppendLine(" {"); - code.AppendLine(" if (!ctx.Request.HasJsonContentType()) return BindFail415();"); - code.AppendLine(" try"); - code.AppendLine(" {"); - code.AppendLine( - $" {paramName} = await ctx.Request.ReadFromJsonAsync<{param.TypeFqn}>(cancellationToken: ctx.RequestAborted);"); - code.AppendLine(" }"); - code.AppendLine($" catch ({WellKnownTypes.Fqn.JsonException})"); - code.AppendLine(" {"); - code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"has invalid JSON format\");"); - code.AppendLine(" }"); - code.AppendLine(" }"); - return true; - } - - internal static bool EmitBodyBindingDisallow(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - // Disallow empty bodies - reject with 400 if empty - code.AppendLine(" if (ctx.Request.ContentLength is null or 0)"); - code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); - code.AppendLine(" if (!ctx.Request.HasJsonContentType()) return BindFail415();"); - code.AppendLine($" {param.TypeFqn}? {paramName};"); - code.AppendLine(" try"); - code.AppendLine(" {"); - code.AppendLine( - $" {paramName} = await ctx.Request.ReadFromJsonAsync<{param.TypeFqn}>(cancellationToken: ctx.RequestAborted);"); - code.AppendLine(" }"); - code.AppendLine($" catch ({WellKnownTypes.Fqn.JsonException})"); - code.AppendLine(" {"); - code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"has invalid JSON format\");"); - code.AppendLine(" }"); - code.AppendLine( - $" if ({paramName} is null) return {bindFailFn}(\"{param.Name}\", \"is required\");"); - return true; - } - - internal static bool EmitFormBinding(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - if (!param.Children.IsDefaultOrEmpty) - { - var usesBindFail = false; - for (var i = 0; i < param.Children.Length; i++) - { - var child = param.Children[i]; - usesBindFail |= EmitParameterBinding(code, in child, $"{paramName}_f{i}", bindFailFn); - } - - var args = string.Join(", ", param.Children.AsImmutableArray().Select((_, i) => $"{paramName}_f{i}")); - code.AppendLine($" var {paramName} = new {param.TypeFqn}({args});"); - return usesBindFail; - } - - var usesBindFailScalar = false; - var fieldName = param.KeyName ?? param.Name; - var declType = param.IsNullable && !param.TypeFqn.EndsWithOrdinal("?") ? param.TypeFqn + "?" : param.TypeFqn; - code.AppendLine($" {declType} {paramName};"); - code.AppendLine( - $" if (!form.TryGetValue(\"{fieldName}\", out var {paramName}Raw) || {paramName}Raw.Count is 0)"); - code.AppendLine(" {"); - if (param.IsNullable) - { - code.AppendLine($" {paramName} = default;"); - } - else - { - usesBindFailScalar = true; - code.AppendLine($" return {bindFailFn}(\"{param.Name}\", \"is required\");"); - } - - code.AppendLine(" }"); - code.AppendLine(" else"); - code.AppendLine(" {"); - if (param.TypeFqn.IsStringType()) - { - code.AppendLine($" {paramName} = {paramName}Raw.ToString();"); - } - else - { - usesBindFailScalar = true; - code.AppendLine( - $" if (!{GetTryParseExpression(param.TypeFqn, paramName + "Raw.ToString()", paramName + "Temp")}) return {bindFailFn}(\"{param.Name}\", \"has invalid format\");"); - code.AppendLine($" {paramName} = {paramName}Temp;"); - } - - code.AppendLine(" }"); - return usesBindFailScalar; - } - - internal static bool EmitAsParametersBinding(StringBuilder code, in EndpointParameter param, string paramName, - string bindFailFn) - { - var usesBindFail = false; - var childVars = new List(); - for (var i = 0; i < param.Children.Length; i++) - { - var child = param.Children[i]; - var childVarName = $"{paramName}_c{i}"; - usesBindFail |= EmitParameterBinding(code, in child, childVarName, bindFailFn); - childVars.Add(BuildArgumentExpression(in child, childVarName)); - } - - code.AppendLine($" var {paramName} = new {param.TypeFqn}({string.Join(", ", childVars)});"); - return usesBindFail; - } - - internal static string BuildArgumentExpression(in EndpointParameter param, string paramName) - { - var source = param.Source; - - if (source == ParameterSource.Body && !param.IsNullable) - return paramName + "!"; - - if (source == ParameterSource.Route && param is { IsNullable: false, IsNonNullableValueType: false }) - return paramName + "!"; - - if (source is ParameterSource.Query or ParameterSource.Header - && param is { IsNullable: false, IsNonNullableValueType: true }) - { - return paramName + ".Value"; - } - - if (source is ParameterSource.Query or ParameterSource.Header - && param is { IsNullable: false, IsNonNullableValueType: false }) - { - return paramName + "!"; - } - - return paramName; - } - - internal static string GetTryParseExpression(string typeFqn, string rawName, string outputName, - CustomBindingMethod customBinding = CustomBindingMethod.None) - { - if (customBinding is CustomBindingMethod.TryParse) - { - var baseType = typeFqn.TrimEnd('?'); - return $"{baseType}.TryParse({rawName}, out var {outputName})"; - } - - if (customBinding is CustomBindingMethod.TryParseWithFormat) - { - var baseType = typeFqn.TrimEnd('?'); - return - $"{baseType}.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})"; - } - - var normalized = typeFqn.Replace("global::", "").TrimEnd('?'); - return normalized switch - { - // Integer types - no IFormatProvider overload - "System.Int32" or "int" => $"int.TryParse({rawName}, out var {outputName})", - "System.Int64" or "long" => $"long.TryParse({rawName}, out var {outputName})", - "System.Int16" or "short" => $"short.TryParse({rawName}, out var {outputName})", - "System.Byte" or "byte" => $"byte.TryParse({rawName}, out var {outputName})", - "System.SByte" or "sbyte" => $"sbyte.TryParse({rawName}, out var {outputName})", - "System.UInt16" or "ushort" => $"ushort.TryParse({rawName}, out var {outputName})", - "System.UInt32" or "uint" => $"uint.TryParse({rawName}, out var {outputName})", - "System.UInt64" or "ulong" => $"ulong.TryParse({rawName}, out var {outputName})", - "System.Int128" => $"global::System.Int128.TryParse({rawName}, out var {outputName})", - "System.UInt128" => $"global::System.UInt128.TryParse({rawName}, out var {outputName})", - - // Other types without IFormatProvider overload - "System.Boolean" or "bool" => $"bool.TryParse({rawName}, out var {outputName})", - "System.Guid" => $"global::System.Guid.TryParse({rawName}, out var {outputName})", - "System.Uri" => - $"global::System.Uri.TryCreate({rawName}, global::System.UriKind.RelativeOrAbsolute, out var {outputName})", - - // Culture-sensitive floating point types - use InvariantCulture - "System.Decimal" or "decimal" => - $"decimal.TryParse({rawName}, global::System.Globalization.NumberStyles.Number, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", - "System.Double" or "double" => - $"double.TryParse({rawName}, global::System.Globalization.NumberStyles.Float | global::System.Globalization.NumberStyles.AllowThousands, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", - "System.Single" or "float" => - $"float.TryParse({rawName}, global::System.Globalization.NumberStyles.Float | global::System.Globalization.NumberStyles.AllowThousands, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", - "System.Half" => - $"global::System.Half.TryParse({rawName}, global::System.Globalization.NumberStyles.Float, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", - - // Culture-sensitive date/time types - use InvariantCulture - "System.DateTime" => - $"global::System.DateTime.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.RoundtripKind, out var {outputName})", - "System.DateTimeOffset" => - $"global::System.DateTimeOffset.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.RoundtripKind, out var {outputName})", - "System.DateOnly" => - $"global::System.DateOnly.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.None, out var {outputName})", - "System.TimeOnly" => - $"global::System.TimeOnly.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, global::System.Globalization.DateTimeStyles.None, out var {outputName})", - "System.TimeSpan" => - $"global::System.TimeSpan.TryParse({rawName}, global::System.Globalization.CultureInfo.InvariantCulture, out var {outputName})", - - _ => "false" - }; - } - - /// - /// Emits the standard validation dictionary building pattern. - /// Consolidates the repeated logic for aggregating errors by key into string arrays. - /// - /// The StringBuilder to append to. - /// Base indentation (number of spaces). - /// Name of the dictionary variable. - /// The collection to iterate over. - /// Name of the loop variable. - /// Expression to get the dictionary key. - /// Expression to get the value to add. - /// Optional filter expression (items not matching are skipped). - /// Optional local variable declaration for key (emitted before TryGetValue). - internal static void EmitValidationDictBuilder( - StringBuilder code, - int indent, - string dictName, - string iteratorSource, - string iteratorVar, - string keyExpr, - string valueExpr, - string? filterExpr = null, - string? keyVarDecl = null) - { - var pad = new string(' ', indent); - var pad4 = new string(' ', indent + 4); - var pad8 = new string(' ', indent + 8); - - code.AppendLine($"{pad}var {dictName} = new {WellKnownTypes.Fqn.Dictionary}();"); - code.AppendLine($"{pad}foreach (var {iteratorVar} in {iteratorSource})"); - code.AppendLine($"{pad}{{"); - - if (filterExpr is not null) code.AppendLine($"{pad4}if ({filterExpr}) continue;"); - - if (keyVarDecl is not null) code.AppendLine($"{pad4}{keyVarDecl}"); - - code.AppendLine($"{pad4}if (!{dictName}.TryGetValue({keyExpr}, out var existing))"); - code.AppendLine($"{pad8}{dictName}[{keyExpr}] = new[] {{ {valueExpr} }};"); - code.AppendLine($"{pad4}else"); - code.AppendLine($"{pad4}{{"); - code.AppendLine($"{pad8}var arr = new string[existing.Length + 1];"); - code.AppendLine($"{pad8}existing.CopyTo(arr, 0);"); - code.AppendLine($"{pad8}arr[existing.Length] = {valueExpr};"); - code.AppendLine($"{pad8}{dictName}[{keyExpr}] = arr;"); - code.AppendLine($"{pad4}}}"); - code.AppendLine($"{pad}}}"); - } } diff --git a/src/ErrorOrX.Generators/Models/EndpointDescriptor.cs b/src/ErrorOrX.Generators/Models/EndpointDescriptor.cs new file mode 100644 index 0000000..8d764b7 --- /dev/null +++ b/src/ErrorOrX.Generators/Models/EndpointDescriptor.cs @@ -0,0 +1,183 @@ +namespace ErrorOr.Generators; + +/// +/// Represents a metadata entry for an endpoint. +/// +internal readonly record struct MetadataEntry(string Key, string Value); + +/// +/// Classifies the success response type for HTTP status code mapping. +/// +internal enum SuccessKind +{ + Payload, + Success, + Created, + Updated, + Deleted +} + +/// +/// Success response information for OpenAPI metadata. +/// +internal readonly record struct SuccessResponseInfo( + string ResultTypeFqn, + int StatusCode, + bool HasBody, + string Factory); + +/// +/// Result of union type computation. +/// +internal readonly record struct UnionTypeResult( + bool CanUseUnion, + string ReturnTypeFqn, + EquatableArray ExplicitProduceCodes, + bool UsesValidationProblemFor400 = false); + +/// +/// Complete descriptor for an ErrorOr endpoint used for code generation. +/// +internal readonly record struct EndpointDescriptor( + HttpVerb HttpVerb, + string Pattern, + string SuccessTypeFqn, + SuccessKind SuccessKind, + bool IsAsync, + string HandlerContainingTypeFqn, + string HandlerMethodName, + EquatableArray HandlerParameters, + ErrorInferenceInfo ErrorInference, + SseInfo Sse = default, + bool IsAcceptedResponse = false, + string? LocationIdPropertyName = null, + MiddlewareInfo Middleware = default, + VersioningInfo Versioning = default, + RouteGroupInfo RouteGroup = default, + EquatableArray Metadata = default, + string? CustomHttpMethod = null) +{ + /// Gets the HTTP method string for emission (e.g., "GET", "POST", or custom like "CONNECT"). + public string HttpMethod => CustomHttpMethod ?? HttpVerb.ToHttpString(); + + /// + /// Returns true if any parameter is bound from body. + /// + public bool HasBodyParam + { + get + { + foreach (var p in HandlerParameters.AsImmutableArray()) + { + if (p.Source == ParameterSource.Body) + return true; + } + + return false; + } + } + + /// + /// Returns true if any parameter is bound from form-related sources. + /// + public bool HasFormParams + { + get + { + foreach (var p in HandlerParameters.AsImmutableArray()) + { + if (p.Source.IsFormRelated()) + return true; + } + + return false; + } + } + + /// + /// Returns true if endpoint has body or form binding (for OpenAPI and validation). + /// Uses single-pass enumeration to avoid multiple iterations. + /// + public bool HasBodyOrFormBinding + { + get + { + foreach (var p in HandlerParameters.AsImmutableArray()) + { + if (p.Source == ParameterSource.Body || p.Source.IsFormRelated()) + return true; + } + + return false; + } + } + + /// + /// Returns true if any parameter uses BindAsync custom binding. + /// + public bool HasBindAsyncParam + { + get + { + foreach (var p in HandlerParameters.AsImmutableArray()) + { + if (p.CustomBinding is CustomBindingMethod.BindAsync or CustomBindingMethod.BindAsyncWithParam) + return true; + } + + return false; + } + } + + /// + /// Returns true if any parameter requires DataAnnotations validation. + /// + public bool HasParameterValidation + { + get + { + foreach (var p in HandlerParameters.AsImmutableArray()) + { + if (p.RequiresValidation) + return true; + } + + return false; + } + } + + /// Gets metadata value by key, or null if not found. + public string? GetMetadata(string key) + { + foreach (var entry in Metadata.AsImmutableArray()) + { + if (entry.Key == key) + return entry.Value; + } + + return null; + } + + /// Returns true if metadata with the given key exists. + public bool HasMetadata(string key) + { + foreach (var entry in Metadata.AsImmutableArray()) + { + if (entry.Key == key) + return true; + } + + return false; + } +} + +/// +/// Well-known metadata key constants. +/// +internal static class MetadataKeys +{ + public const string Deprecated = "erroror:deprecated"; + public const string DeprecatedMessage = "erroror:deprecated-message"; + public const string OpenApiExtension = "openapi:x-"; + public const string CustomTag = "openapi:tag"; +} diff --git a/src/ErrorOrX.Generators/Models/EndpointModels.cs b/src/ErrorOrX.Generators/Models/EndpointModels.cs deleted file mode 100644 index 359bb26..0000000 --- a/src/ErrorOrX.Generators/Models/EndpointModels.cs +++ /dev/null @@ -1,580 +0,0 @@ -namespace ErrorOr.Generators; - -/// -/// Represents a metadata entry for an endpoint. -/// -internal readonly record struct MetadataEntry(string Key, string Value); - -/// -/// Represents the custom binding method detected on a parameter type. -/// -internal enum CustomBindingMethod -{ - None, - TryParse, - TryParseWithFormat, - BindAsync, - BindAsyncWithParam, - Bindable -} - -/// -/// Primitive types that can be bound from route templates. -/// -internal enum RoutePrimitiveKind -{ - String, - Int32, - Int64, - Int16, - UInt32, - UInt64, - UInt16, - Byte, - SByte, - Boolean, - Decimal, - Double, - Single, - Guid, - DateTime, - DateTimeOffset, - DateOnly, - TimeOnly, - TimeSpan -} - -/// -/// Classifies the success response type for HTTP status code mapping. -/// -internal enum SuccessKind -{ - Payload, - Success, - Created, - Updated, - Deleted -} - -/// -/// Flags for parameter binding characteristics. -/// -[Flags] -internal enum ParameterFlags -{ - None = 0, - FromServices = 1 << 0, - FromKeyedServices = 1 << 1, - FromBody = 1 << 2, - FromRoute = 1 << 3, - FromQuery = 1 << 4, - FromHeader = 1 << 5, - FromForm = 1 << 6, - AsParameters = 1 << 7, - Nullable = 1 << 8, - NonNullableValueType = 1 << 9, - Collection = 1 << 10, - RequiresValidation = 1 << 11 -} - -/// -/// Special parameter kinds that have dedicated binding. -/// -internal enum SpecialParameterKind -{ - None, - HttpContext, - CancellationToken, - FormFile, - FormFileCollection, - FormCollection, - Stream, - PipeReader -} - -/// -/// Specifies how empty request bodies should be handled. -/// -internal enum EmptyBodyBehavior -{ - /// Framework default: Nullable allows empty, non-nullable rejects. - Default, - - /// Empty bodies are valid (null/default assigned). - Allow, - - /// Empty bodies are invalid (400 Bad Request). - Disallow -} - -/// -/// Represents a bound endpoint parameter with its source and type information. -/// -internal readonly record struct EndpointParameter( - string Name, - string TypeFqn, - ParameterSource Source, - string? KeyName, - bool IsNullable, - bool IsNonNullableValueType, - bool IsCollection, - string? CollectionItemTypeFqn, - EquatableArray Children, - CustomBindingMethod CustomBinding = CustomBindingMethod.None, - bool RequiresValidation = false, - EmptyBodyBehavior EmptyBodyBehavior = EmptyBodyBehavior.Default, - EquatableArray ValidatableProperties = default); - -/// -/// Raw metadata extracted from a method parameter for binding classification. -/// -internal readonly struct ParameterMeta( - string name, - string typeFqn, - RoutePrimitiveKind? routeKind, - ParameterFlags flags, - SpecialParameterKind specialKind, - string? serviceKey, - string boundName, - string? collectionItemTypeFqn, - RoutePrimitiveKind? collectionItemPrimitiveKind, - CustomBindingMethod customBinding, - EmptyBodyBehavior emptyBodyBehavior = EmptyBodyBehavior.Default, - EquatableArray validatableProperties = default) -{ - public string Name { get; } = name; - public string TypeFqn { get; } = typeFqn; - public RoutePrimitiveKind? RouteKind { get; } = routeKind; - public ParameterFlags Flags { get; } = flags; - public SpecialParameterKind SpecialKind { get; } = specialKind; - public string? ServiceKey { get; } = serviceKey; - public string BoundName { get; } = boundName; - public string? CollectionItemTypeFqn { get; } = collectionItemTypeFqn; - public RoutePrimitiveKind? CollectionItemPrimitiveKind { get; } = collectionItemPrimitiveKind; - public CustomBindingMethod CustomBinding { get; } = customBinding; - public EmptyBodyBehavior EmptyBodyBehavior { get; } = emptyBodyBehavior; - public EquatableArray ValidatableProperties { get; } = validatableProperties; - - public bool HasFromBody => Flags.HasFlag(ParameterFlags.FromBody); - public bool HasFromRoute => Flags.HasFlag(ParameterFlags.FromRoute); - public bool HasFromQuery => Flags.HasFlag(ParameterFlags.FromQuery); - public bool HasFromHeader => Flags.HasFlag(ParameterFlags.FromHeader); - public bool HasFromForm => Flags.HasFlag(ParameterFlags.FromForm); - public bool HasFromServices => Flags.HasFlag(ParameterFlags.FromServices); - public bool HasFromKeyedServices => Flags.HasFlag(ParameterFlags.FromKeyedServices); - public bool HasAsParameters => Flags.HasFlag(ParameterFlags.AsParameters); - public bool IsNullable => Flags.HasFlag(ParameterFlags.Nullable); - public bool IsNonNullableValueType => Flags.HasFlag(ParameterFlags.NonNullableValueType); - public bool IsCollection => Flags.HasFlag(ParameterFlags.Collection); - public bool RequiresValidation => Flags.HasFlag(ParameterFlags.RequiresValidation); - - public bool IsHttpContext => SpecialKind == SpecialParameterKind.HttpContext; - public bool IsCancellationToken => SpecialKind == SpecialParameterKind.CancellationToken; - public bool IsFormFile => SpecialKind == SpecialParameterKind.FormFile; - public bool IsFormFileCollection => SpecialKind == SpecialParameterKind.FormFileCollection; - public bool IsFormCollection => SpecialKind == SpecialParameterKind.FormCollection; - public bool IsStream => SpecialKind == SpecialParameterKind.Stream; - public bool IsPipeReader => SpecialKind == SpecialParameterKind.PipeReader; - - public bool HasExplicitBinding => (Flags & ( - ParameterFlags.FromBody | ParameterFlags.FromRoute | ParameterFlags.FromQuery | - ParameterFlags.FromHeader | ParameterFlags.FromForm | ParameterFlags.FromServices | - ParameterFlags.FromKeyedServices | ParameterFlags.AsParameters)) != ParameterFlags.None; -} - -/// -/// Represents a custom error detected via Error.Custom() call. -/// -internal readonly record struct CustomErrorInfo( - string ErrorCode); - -/// -/// Represents a [ProducesError] attribute on an endpoint method. -/// -internal readonly record struct ProducesErrorInfo( - int StatusCode); - -/// -/// Result of extracting the ErrorOr return type, including SSE detection. -/// -internal readonly record struct ErrorOrReturnTypeInfo( - string? SuccessTypeFqn, - bool IsAsync, - bool IsSse, - string? SseItemTypeFqn, - SuccessKind Kind, - string? IdPropertyName = null, - bool IsAnonymousType = false, - bool IsInaccessibleType = false, - string? InaccessibleTypeName = null, - string? InaccessibleTypeAccessibility = null, - bool IsTypeParameter = false, - string? TypeParameterName = null); - -/// -/// Pre-computed method-level analysis shared across multiple HTTP method attributes. -/// -internal readonly struct MethodAnalysis( - ErrorOrReturnTypeInfo returnInfo, - EquatableArray inferredErrorTypeNames, - EquatableArray inferredCustomErrors, - EquatableArray producesErrors, - bool isAcceptedResponse, - MiddlewareInfo middleware) -{ - public ErrorOrReturnTypeInfo ReturnInfo { get; } = returnInfo; - public EquatableArray InferredErrorTypeNames { get; } = inferredErrorTypeNames; - public EquatableArray InferredCustomErrors { get; } = inferredCustomErrors; - public EquatableArray ProducesErrors { get; } = producesErrors; - public bool IsAcceptedResponse { get; } = isAcceptedResponse; - public MiddlewareInfo Middleware { get; } = middleware; -} - -/// -/// SSE (Server-Sent Events) configuration for an endpoint. -/// -internal readonly record struct SseInfo( - bool IsSse, - string? SseItemTypeFqn); - -/// -/// Error inference results for an endpoint. -/// -internal readonly record struct ErrorInferenceInfo( - EquatableArray InferredErrorTypeNames, - EquatableArray InferredCustomErrors, - EquatableArray DeclaredProducesErrors); - -/// -/// Complete descriptor for an ErrorOr endpoint used for code generation. -/// -internal readonly record struct EndpointDescriptor( - HttpVerb HttpVerb, - string Pattern, - string SuccessTypeFqn, - SuccessKind SuccessKind, - bool IsAsync, - string HandlerContainingTypeFqn, - string HandlerMethodName, - EquatableArray HandlerParameters, - ErrorInferenceInfo ErrorInference, - SseInfo Sse = default, - bool IsAcceptedResponse = false, - string? LocationIdPropertyName = null, - MiddlewareInfo Middleware = default, - VersioningInfo Versioning = default, - RouteGroupInfo RouteGroup = default, - EquatableArray Metadata = default, - string? CustomHttpMethod = null) -{ - /// Gets the HTTP method string for emission (e.g., "GET", "POST", or custom like "CONNECT"). - public string HttpMethod => CustomHttpMethod ?? HttpVerb.ToHttpString(); - - /// - /// Returns true if any parameter is bound from body. - /// - public bool HasBodyParam - { - get - { - foreach (var p in HandlerParameters.AsImmutableArray()) - { - if (p.Source == ParameterSource.Body) - return true; - } - - return false; - } - } - - /// - /// Returns true if any parameter is bound from form-related sources. - /// - public bool HasFormParams - { - get - { - foreach (var p in HandlerParameters.AsImmutableArray()) - { - if (p.Source.IsFormRelated()) - return true; - } - - return false; - } - } - - /// - /// Returns true if endpoint has body or form binding (for OpenAPI and validation). - /// Uses single-pass enumeration to avoid multiple iterations. - /// - public bool HasBodyOrFormBinding - { - get - { - foreach (var p in HandlerParameters.AsImmutableArray()) - { - if (p.Source == ParameterSource.Body || p.Source.IsFormRelated()) - return true; - } - - return false; - } - } - - /// - /// Returns true if any parameter uses BindAsync custom binding. - /// - public bool HasBindAsyncParam - { - get - { - foreach (var p in HandlerParameters.AsImmutableArray()) - { - if (p.CustomBinding is CustomBindingMethod.BindAsync or CustomBindingMethod.BindAsyncWithParam) - return true; - } - - return false; - } - } - - /// - /// Returns true if any parameter requires DataAnnotations validation. - /// - public bool HasParameterValidation - { - get - { - foreach (var p in HandlerParameters.AsImmutableArray()) - { - if (p.RequiresValidation) - return true; - } - - return false; - } - } - - /// Gets metadata value by key, or null if not found. - public string? GetMetadata(string key) - { - foreach (var entry in Metadata.AsImmutableArray()) - { - if (entry.Key == key) - return entry.Value; - } - - return null; - } - - /// Returns true if metadata with the given key exists. - public bool HasMetadata(string key) - { - foreach (var entry in Metadata.AsImmutableArray()) - { - if (entry.Key == key) - return true; - } - - return false; - } -} - -/// -/// Success response information for OpenAPI metadata. -/// -internal readonly record struct SuccessResponseInfo( - string ResultTypeFqn, - int StatusCode, - bool HasBody, - string Factory); - -/// -/// Result of union type computation. -/// -internal readonly record struct UnionTypeResult( - bool CanUseUnion, - string ReturnTypeFqn, - EquatableArray ExplicitProduceCodes, - bool UsesValidationProblemFor400 = false); - -/// -/// Middleware configuration extracted from BCL attributes. -/// -internal readonly record struct MiddlewareInfo( - bool RequiresAuthorization, - EquatableArray AuthorizationPolicies, - bool AllowAnonymous, - bool EnableRateLimiting, - string? RateLimitingPolicy, - bool DisableRateLimiting, - bool EnableOutputCache, - string? OutputCachePolicy, - int? OutputCacheDuration, - bool EnableCors, - string? CorsPolicy, - bool DisableCors) -{ - public bool HasAny => - RequiresAuthorization || AllowAnonymous || - EnableRateLimiting || DisableRateLimiting || - EnableOutputCache || - EnableCors || DisableCors; -} - -/// -/// Information about a route parameter extracted from the route template. -/// -internal readonly record struct RouteParameterInfo( - string Name, - string? Constraint, - bool IsOptional, - bool IsCatchAll); - -/// -/// Information about a method parameter relevant to route binding validation. -/// -internal readonly record struct RouteMethodParameterInfo( - string Name, - string? BoundRouteName, - string? TypeFqn, - bool IsNullable); - -/// -/// Result of parameter binding analysis. -/// -internal readonly record struct ParameterBindingResult(bool IsValid, EquatableArray Parameters) -{ - public static readonly ParameterBindingResult Empty = new(IsValid: true, default); - public static readonly ParameterBindingResult Invalid = new(IsValid: false, default); -} - -/// -/// Information about a user-defined JsonSerializerContext. -/// -internal readonly record struct JsonContextInfo( - string ClassName, - string? Namespace, - EquatableArray SerializableTypes, - bool HasCamelCasePolicy); - -/// -/// Represents a parameter for OpenAPI documentation. -/// -internal readonly record struct OpenApiParameterInfo( - string Name, - string Location, - bool Required, - string SchemaType, - string? SchemaFormat); - -/// -/// Immutable endpoint info for OpenAPI generation. -/// -internal readonly record struct OpenApiEndpointInfo( - string OperationId, - string TagName, - string? Summary, - string? Description, - string HttpMethod, - string Pattern, - EquatableArray<(string ParamName, string Description)> ParameterDocs, - EquatableArray Parameters); - -/// -/// Immutable type metadata for schema generation. -/// -internal readonly record struct TypeMetadataInfo( - string TypeKey, - string Description); - -/// -/// Result of route binding analysis containing bound parameters and route-specific extraction. -/// -internal readonly record struct RouteBindingAnalysis( - EquatableArray Parameters, - EquatableArray RouteParameters); - -/// -/// Represents a single API version extracted from [ApiVersion] attribute. -/// -internal readonly record struct ApiVersionInfo( - int MajorVersion, - int? MinorVersion, - string? Status, - bool IsDeprecated); - -/// -/// API versioning configuration extracted from endpoint class or method. -/// -internal readonly record struct VersioningInfo( - EquatableArray SupportedVersions, - EquatableArray MappedVersions, - bool IsVersionNeutral) -{ - /// - /// Returns true if any versioning attributes were found. - /// - public bool HasVersioning => !SupportedVersions.IsDefaultOrEmpty || IsVersionNeutral; - - /// - /// Returns the versions this endpoint should be mapped to. - /// Uses MappedVersions if specified, otherwise falls back to SupportedVersions. - /// - public EquatableArray EffectiveVersions => - MappedVersions.IsDefaultOrEmpty ? SupportedVersions : MappedVersions; -} - -/// -/// Route group configuration extracted from [RouteGroup] attribute on containing type. -/// -internal readonly record struct RouteGroupInfo( - string? GroupPath, - string? ApiName) -{ - /// - /// Returns true if route grouping is enabled for this endpoint. - /// - public bool HasRouteGroup => GroupPath is not null; -} - -/// -/// A named argument literal for a validation attribute (e.g., MinimumLength = 1). -/// -internal readonly record struct NamedArgLiteral(string Name, string Value); - -/// -/// Represents a validation attribute extracted from a property for the IValidatableInfoResolver emitter. -/// -internal readonly record struct ValidatableAttributeInfo( - string AttributeTypeFqn, - EquatableArray ConstructorArgLiterals, - EquatableArray NamedArgLiterals); - -/// -/// Represents a property on a validatable type, with its validation attribute metadata. -/// -internal readonly record struct ValidatablePropertyDescriptor( - string Name, - string TypeFqn, - string DisplayName, - EquatableArray ValidationAttributes); - -/// -/// Represents a type that requires validation, along with its validatable properties. -/// -internal readonly record struct ValidatableTypeDescriptor( - string TypeFqn, - EquatableArray Properties); - -/// -/// Well-known metadata key constants. -/// -internal static class MetadataKeys -{ - public const string Deprecated = "erroror:deprecated"; - public const string DeprecatedMessage = "erroror:deprecated-message"; - public const string OpenApiExtension = "openapi:x-"; - public const string CustomTag = "openapi:tag"; -} diff --git a/src/ErrorOrX.Generators/Models/EndpointParameters.cs b/src/ErrorOrX.Generators/Models/EndpointParameters.cs new file mode 100644 index 0000000..23fa7a2 --- /dev/null +++ b/src/ErrorOrX.Generators/Models/EndpointParameters.cs @@ -0,0 +1,175 @@ +namespace ErrorOr.Generators; + +/// +/// Represents the custom binding method detected on a parameter type. +/// +internal enum CustomBindingMethod +{ + None, + TryParse, + TryParseWithFormat, + BindAsync, + BindAsyncWithParam, + Bindable +} + +/// +/// Primitive types that can be bound from route templates. +/// +internal enum RoutePrimitiveKind +{ + String, + Int32, + Int64, + Int16, + UInt32, + UInt64, + UInt16, + Byte, + SByte, + Boolean, + Decimal, + Double, + Single, + Guid, + DateTime, + DateTimeOffset, + DateOnly, + TimeOnly, + TimeSpan +} + +/// +/// Flags for parameter binding characteristics. +/// +[Flags] +internal enum ParameterFlags +{ + None = 0, + FromServices = 1 << 0, + FromKeyedServices = 1 << 1, + FromBody = 1 << 2, + FromRoute = 1 << 3, + FromQuery = 1 << 4, + FromHeader = 1 << 5, + FromForm = 1 << 6, + AsParameters = 1 << 7, + Nullable = 1 << 8, + NonNullableValueType = 1 << 9, + Collection = 1 << 10, + RequiresValidation = 1 << 11 +} + +/// +/// Special parameter kinds that have dedicated binding. +/// +internal enum SpecialParameterKind +{ + None, + HttpContext, + CancellationToken, + FormFile, + FormFileCollection, + FormCollection, + Stream, + PipeReader +} + +/// +/// Specifies how empty request bodies should be handled. +/// +internal enum EmptyBodyBehavior +{ + /// Framework default: Nullable allows empty, non-nullable rejects. + Default, + + /// Empty bodies are valid (null/default assigned). + Allow, + + /// Empty bodies are invalid (400 Bad Request). + Disallow +} + +/// +/// Represents a bound endpoint parameter with its source and type information. +/// +internal readonly record struct EndpointParameter( + string Name, + string TypeFqn, + ParameterSource Source, + string? KeyName, + bool IsNullable, + bool IsNonNullableValueType, + bool IsCollection, + string? CollectionItemTypeFqn, + EquatableArray Children, + CustomBindingMethod CustomBinding = CustomBindingMethod.None, + bool RequiresValidation = false, + EmptyBodyBehavior EmptyBodyBehavior = EmptyBodyBehavior.Default, + EquatableArray ValidatableProperties = default); + +/// +/// Raw metadata extracted from a method parameter for binding classification. +/// +internal readonly struct ParameterMeta( + string name, + string typeFqn, + RoutePrimitiveKind? routeKind, + ParameterFlags flags, + SpecialParameterKind specialKind, + string? serviceKey, + string boundName, + string? collectionItemTypeFqn, + RoutePrimitiveKind? collectionItemPrimitiveKind, + CustomBindingMethod customBinding, + EmptyBodyBehavior emptyBodyBehavior = EmptyBodyBehavior.Default, + EquatableArray validatableProperties = default) +{ + public string Name { get; } = name; + public string TypeFqn { get; } = typeFqn; + public RoutePrimitiveKind? RouteKind { get; } = routeKind; + public ParameterFlags Flags { get; } = flags; + public SpecialParameterKind SpecialKind { get; } = specialKind; + public string? ServiceKey { get; } = serviceKey; + public string BoundName { get; } = boundName; + public string? CollectionItemTypeFqn { get; } = collectionItemTypeFqn; + public RoutePrimitiveKind? CollectionItemPrimitiveKind { get; } = collectionItemPrimitiveKind; + public CustomBindingMethod CustomBinding { get; } = customBinding; + public EmptyBodyBehavior EmptyBodyBehavior { get; } = emptyBodyBehavior; + public EquatableArray ValidatableProperties { get; } = validatableProperties; + + public bool HasFromBody => Flags.HasFlag(ParameterFlags.FromBody); + public bool HasFromRoute => Flags.HasFlag(ParameterFlags.FromRoute); + public bool HasFromQuery => Flags.HasFlag(ParameterFlags.FromQuery); + public bool HasFromHeader => Flags.HasFlag(ParameterFlags.FromHeader); + public bool HasFromForm => Flags.HasFlag(ParameterFlags.FromForm); + public bool HasFromServices => Flags.HasFlag(ParameterFlags.FromServices); + public bool HasFromKeyedServices => Flags.HasFlag(ParameterFlags.FromKeyedServices); + public bool HasAsParameters => Flags.HasFlag(ParameterFlags.AsParameters); + public bool IsNullable => Flags.HasFlag(ParameterFlags.Nullable); + public bool IsNonNullableValueType => Flags.HasFlag(ParameterFlags.NonNullableValueType); + public bool IsCollection => Flags.HasFlag(ParameterFlags.Collection); + public bool RequiresValidation => Flags.HasFlag(ParameterFlags.RequiresValidation); + + public bool IsHttpContext => SpecialKind == SpecialParameterKind.HttpContext; + public bool IsCancellationToken => SpecialKind == SpecialParameterKind.CancellationToken; + public bool IsFormFile => SpecialKind == SpecialParameterKind.FormFile; + public bool IsFormFileCollection => SpecialKind == SpecialParameterKind.FormFileCollection; + public bool IsFormCollection => SpecialKind == SpecialParameterKind.FormCollection; + public bool IsStream => SpecialKind == SpecialParameterKind.Stream; + public bool IsPipeReader => SpecialKind == SpecialParameterKind.PipeReader; + + public bool HasExplicitBinding => (Flags & ( + ParameterFlags.FromBody | ParameterFlags.FromRoute | ParameterFlags.FromQuery | + ParameterFlags.FromHeader | ParameterFlags.FromForm | ParameterFlags.FromServices | + ParameterFlags.FromKeyedServices | ParameterFlags.AsParameters)) != ParameterFlags.None; +} + +/// +/// Result of parameter binding analysis. +/// +internal readonly record struct ParameterBindingResult(bool IsValid, EquatableArray Parameters) +{ + public static readonly ParameterBindingResult Empty = new(IsValid: true, default); + public static readonly ParameterBindingResult Invalid = new(IsValid: false, default); +} diff --git a/src/ErrorOrX.Generators/Models/MethodAnalysis.cs b/src/ErrorOrX.Generators/Models/MethodAnalysis.cs new file mode 100644 index 0000000..461f244 --- /dev/null +++ b/src/ErrorOrX.Generators/Models/MethodAnalysis.cs @@ -0,0 +1,64 @@ +namespace ErrorOr.Generators; + +/// +/// Represents a custom error detected via Error.Custom() call. +/// +internal readonly record struct CustomErrorInfo( + string ErrorCode); + +/// +/// Represents a [ProducesError] attribute on an endpoint method. +/// +internal readonly record struct ProducesErrorInfo( + int StatusCode); + +/// +/// Result of extracting the ErrorOr return type, including SSE detection. +/// +internal readonly record struct ErrorOrReturnTypeInfo( + string? SuccessTypeFqn, + bool IsAsync, + bool IsSse, + string? SseItemTypeFqn, + SuccessKind Kind, + string? IdPropertyName = null, + bool IsAnonymousType = false, + bool IsInaccessibleType = false, + string? InaccessibleTypeName = null, + string? InaccessibleTypeAccessibility = null, + bool IsTypeParameter = false, + string? TypeParameterName = null); + +/// +/// Pre-computed method-level analysis shared across multiple HTTP method attributes. +/// +internal readonly struct MethodAnalysis( + ErrorOrReturnTypeInfo returnInfo, + EquatableArray inferredErrorTypeNames, + EquatableArray inferredCustomErrors, + EquatableArray producesErrors, + bool isAcceptedResponse, + MiddlewareInfo middleware) +{ + public ErrorOrReturnTypeInfo ReturnInfo { get; } = returnInfo; + public EquatableArray InferredErrorTypeNames { get; } = inferredErrorTypeNames; + public EquatableArray InferredCustomErrors { get; } = inferredCustomErrors; + public EquatableArray ProducesErrors { get; } = producesErrors; + public bool IsAcceptedResponse { get; } = isAcceptedResponse; + public MiddlewareInfo Middleware { get; } = middleware; +} + +/// +/// SSE (Server-Sent Events) configuration for an endpoint. +/// +internal readonly record struct SseInfo( + bool IsSse, + string? SseItemTypeFqn); + +/// +/// Error inference results for an endpoint. +/// +internal readonly record struct ErrorInferenceInfo( + EquatableArray InferredErrorTypeNames, + EquatableArray InferredCustomErrors, + EquatableArray DeclaredProducesErrors); diff --git a/src/ErrorOrX.Generators/Models/MiddlewareInfo.cs b/src/ErrorOrX.Generators/Models/MiddlewareInfo.cs new file mode 100644 index 0000000..b3ade9f --- /dev/null +++ b/src/ErrorOrX.Generators/Models/MiddlewareInfo.cs @@ -0,0 +1,25 @@ +namespace ErrorOr.Generators; + +/// +/// Middleware configuration extracted from BCL attributes. +/// +internal readonly record struct MiddlewareInfo( + bool RequiresAuthorization, + EquatableArray AuthorizationPolicies, + bool AllowAnonymous, + bool EnableRateLimiting, + string? RateLimitingPolicy, + bool DisableRateLimiting, + bool EnableOutputCache, + string? OutputCachePolicy, + int? OutputCacheDuration, + bool EnableCors, + string? CorsPolicy, + bool DisableCors) +{ + public bool HasAny => + RequiresAuthorization || AllowAnonymous || + EnableRateLimiting || DisableRateLimiting || + EnableOutputCache || + EnableCors || DisableCors; +} diff --git a/src/ErrorOrX.Generators/Models/OpenApiModels.cs b/src/ErrorOrX.Generators/Models/OpenApiModels.cs new file mode 100644 index 0000000..3797ecc --- /dev/null +++ b/src/ErrorOrX.Generators/Models/OpenApiModels.cs @@ -0,0 +1,40 @@ +namespace ErrorOr.Generators; + +/// +/// Information about a user-defined JsonSerializerContext. +/// +internal readonly record struct JsonContextInfo( + string ClassName, + string? Namespace, + EquatableArray SerializableTypes, + bool HasCamelCasePolicy); + +/// +/// Represents a parameter for OpenAPI documentation. +/// +internal readonly record struct OpenApiParameterInfo( + string Name, + string Location, + bool Required, + string SchemaType, + string? SchemaFormat); + +/// +/// Immutable endpoint info for OpenAPI generation. +/// +internal readonly record struct OpenApiEndpointInfo( + string OperationId, + string TagName, + string? Summary, + string? Description, + string HttpMethod, + string Pattern, + EquatableArray<(string ParamName, string Description)> ParameterDocs, + EquatableArray Parameters); + +/// +/// Immutable type metadata for schema generation. +/// +internal readonly record struct TypeMetadataInfo( + string TypeKey, + string Description); diff --git a/src/ErrorOrX.Generators/Models/RouteModels.cs b/src/ErrorOrX.Generators/Models/RouteModels.cs new file mode 100644 index 0000000..0c69e34 --- /dev/null +++ b/src/ErrorOrX.Generators/Models/RouteModels.cs @@ -0,0 +1,39 @@ +namespace ErrorOr.Generators; + +/// +/// Information about a route parameter extracted from the route template. +/// +internal readonly record struct RouteParameterInfo( + string Name, + string? Constraint, + bool IsOptional, + bool IsCatchAll); + +/// +/// Information about a method parameter relevant to route binding validation. +/// +internal readonly record struct RouteMethodParameterInfo( + string Name, + string? BoundRouteName, + string? TypeFqn, + bool IsNullable); + +/// +/// Result of route binding analysis containing bound parameters and route-specific extraction. +/// +internal readonly record struct RouteBindingAnalysis( + EquatableArray Parameters, + EquatableArray RouteParameters); + +/// +/// Route group configuration extracted from [RouteGroup] attribute on containing type. +/// +internal readonly record struct RouteGroupInfo( + string? GroupPath, + string? ApiName) +{ + /// + /// Returns true if route grouping is enabled for this endpoint. + /// + public bool HasRouteGroup => GroupPath is not null; +} diff --git a/src/ErrorOrX.Generators/Models/ValidationModels.cs b/src/ErrorOrX.Generators/Models/ValidationModels.cs new file mode 100644 index 0000000..5c531b2 --- /dev/null +++ b/src/ErrorOrX.Generators/Models/ValidationModels.cs @@ -0,0 +1,30 @@ +namespace ErrorOr.Generators; + +/// +/// A named argument literal for a validation attribute (e.g., MinimumLength = 1). +/// +internal readonly record struct NamedArgLiteral(string Name, string Value); + +/// +/// Represents a validation attribute extracted from a property for the IValidatableInfoResolver emitter. +/// +internal readonly record struct ValidatableAttributeInfo( + string AttributeTypeFqn, + EquatableArray ConstructorArgLiterals, + EquatableArray NamedArgLiterals); + +/// +/// Represents a property on a validatable type, with its validation attribute metadata. +/// +internal readonly record struct ValidatablePropertyDescriptor( + string Name, + string TypeFqn, + string DisplayName, + EquatableArray ValidationAttributes); + +/// +/// Represents a type that requires validation, along with its validatable properties. +/// +internal readonly record struct ValidatableTypeDescriptor( + string TypeFqn, + EquatableArray Properties); diff --git a/src/ErrorOrX.Generators/Models/VersioningModels.cs b/src/ErrorOrX.Generators/Models/VersioningModels.cs new file mode 100644 index 0000000..94df949 --- /dev/null +++ b/src/ErrorOrX.Generators/Models/VersioningModels.cs @@ -0,0 +1,31 @@ +namespace ErrorOr.Generators; + +/// +/// Represents a single API version extracted from [ApiVersion] attribute. +/// +internal readonly record struct ApiVersionInfo( + int MajorVersion, + int? MinorVersion, + string? Status, + bool IsDeprecated); + +/// +/// API versioning configuration extracted from endpoint class or method. +/// +internal readonly record struct VersioningInfo( + EquatableArray SupportedVersions, + EquatableArray MappedVersions, + bool IsVersionNeutral) +{ + /// + /// Returns true if any versioning attributes were found. + /// + public bool HasVersioning => !SupportedVersions.IsDefaultOrEmpty || IsVersionNeutral; + + /// + /// Returns the versions this endpoint should be mapped to. + /// Uses MappedVersions if specified, otherwise falls back to SupportedVersions. + /// + public EquatableArray EffectiveVersions => + MappedVersions.IsDefaultOrEmpty ? SupportedVersions : MappedVersions; +} diff --git a/src/ErrorOrX.Generators/OpenApiTransformerGenerator.cs b/src/ErrorOrX.Generators/OpenApiTransformerGenerator.cs index 943e05f..755380b 100644 --- a/src/ErrorOrX.Generators/OpenApiTransformerGenerator.cs +++ b/src/ErrorOrX.Generators/OpenApiTransformerGenerator.cs @@ -1,7 +1,6 @@ using ANcpLua.Roslyn.Utilities; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; namespace ErrorOr.Generators; @@ -9,7 +8,7 @@ namespace ErrorOr.Generators; /// Generates OpenAPI transformers from XML documentation on ErrorOr endpoints. /// [Generator(LanguageNames.CSharp)] -public sealed class OpenApiTransformerGenerator : IIncrementalGenerator +public sealed partial class OpenApiTransformerGenerator : IIncrementalGenerator { /// public void Initialize(IncrementalGeneratorInitializationContext context) @@ -56,59 +55,6 @@ private static IncrementalValuesProvider CreateEndpointProv .WhereNotNull(); } - private static OpenApiEndpointInfo? ExtractOpenApiMetadata( - GeneratorAttributeSyntaxContext ctx, - CancellationToken ct) - { - ct.ThrowIfCancellationRequested(); - - if (ctx.TargetSymbol is not IMethodSymbol { IsStatic: true } method) return null; - - // Extract HTTP method and pattern from attribute - // Combined null check: attr exists AND has a valid AttributeClass - if (ctx.Attributes.FirstOrDefault() is not { AttributeClass: { } attrClass } attr) return null; - - var attrClassName = attrClass.ToDisplayString(); - - var (httpMethod, pattern) = attrClassName switch - { - WellKnownTypes.GetAttribute => (WellKnownTypes.HttpMethod.Get, GetPattern(attr)), - WellKnownTypes.PostAttribute => (WellKnownTypes.HttpMethod.Post, GetPattern(attr)), - WellKnownTypes.PutAttribute => (WellKnownTypes.HttpMethod.Put, GetPattern(attr)), - WellKnownTypes.DeleteAttribute => (WellKnownTypes.HttpMethod.Delete, GetPattern(attr)), - WellKnownTypes.PatchAttribute => (WellKnownTypes.HttpMethod.Patch, GetPattern(attr)), - WellKnownTypes.HeadAttribute => (WellKnownTypes.HttpMethod.Head, GetPattern(attr)), - WellKnownTypes.OptionsAttribute => (WellKnownTypes.HttpMethod.Options, GetPattern(attr)), - WellKnownTypes.TraceAttribute => (WellKnownTypes.HttpMethod.Trace, GetPattern(attr)), - WellKnownTypes.ErrorOrEndpointAttribute => GetBaseAttributeInfo(attr), - _ => (null, null) - }; - - if (httpMethod is null || pattern is null) return null; - - // Extract XML documentation - var xmlDoc = method.GetDocumentationCommentXml(cancellationToken: ct); - var (summary, description) = ParseXmlDoc(xmlDoc); - var parameterDocs = ParseParamTags(xmlDoc); - - // Extract containing type info for tag generation - var containingType = method.ContainingType; - var containingTypeFqn = containingType.GetFullyQualifiedName(); - var (tagName, operationId) = EndpointNameHelper.GetEndpointIdentity(containingTypeFqn, method.Name); - - var parameters = ExtractParameterDefinitions(method, pattern); - - return new OpenApiEndpointInfo( - operationId, - tagName, - summary, - description, - httpMethod.ToUpperInvariant(), - pattern, - new EquatableArray<(string, string)>(parameterDocs), - new EquatableArray(parameters)); - } - private static string GetPattern(AttributeData attr) { if (attr.GetConstructorArgument(0) is { } p && @@ -129,682 +75,4 @@ private static (string? httpMethod, string? pattern) GetBaseAttributeInfo(Attrib ? (null, null) : (method, string.IsNullOrWhiteSpace(pattern) ? "/" : pattern); } - - private static (string? summary, string? description) ParseXmlDoc(string? xml) - { - if (xml is null || string.IsNullOrWhiteSpace(xml)) return (null, null); - - string? summary = null; - string? description = null; - - // Simple XML parsing for summary and remarks - var summaryStart = xml.IndexOfOrdinal(""); - var summaryEnd = xml.IndexOfOrdinal(""); - if (summaryStart >= 0 && summaryEnd > summaryStart) - { - summary = xml.Substring(summaryStart + 9, summaryEnd - summaryStart - 9) - .Trim() - .Replace("\r\n", " ") - .Replace('\n', ' ') - .Trim(); - } - - var remarksStart = xml.IndexOfOrdinal(""); - var remarksEnd = xml.IndexOfOrdinal(""); - if (remarksStart >= 0 && remarksEnd > remarksStart) - { - description = xml.Substring(remarksStart + 9, remarksEnd - remarksStart - 9) - .Trim() - .Replace("\r\n", " ") - .Replace('\n', ' ') - .Trim(); - } - - return (summary, description); - } - - private static ImmutableArray<(string ParamName, string Description)> ParseParamTags(string? xml) - { - if (xml is null || string.IsNullOrWhiteSpace(xml)) return ImmutableArray<(string, string)>.Empty; - - var parameters = new List<(string, string)>(); - var searchPos = 0; - - while (true) - { - var paramStart = xml.IndexOf("", nameEnd, StringComparison.Ordinal); - if (contentStart < 0) break; - - contentStart++; - - var contentEnd = xml.IndexOf("", contentStart, StringComparison.Ordinal); - if (contentEnd < 0) break; - - var description = xml.Substring(contentStart, contentEnd - contentStart) - .Trim() - .Replace("\r\n", " ") - .Replace("\n", " ") - .Trim(); - if (!string.IsNullOrWhiteSpace(description)) parameters.Add((paramName, description)); - - searchPos = contentEnd + 8; - } - - return [.. parameters]; - } - - private static ImmutableArray ExtractParameterDefinitions( - IMethodSymbol method, string pattern) - { - var routeParams = RouteValidator.ExtractRouteParameters(pattern); - var routeParamNames = new HashSet(StringComparer.OrdinalIgnoreCase); - foreach (var rp in routeParams) - routeParamNames.Add(rp.Name); - - var parameters = new List(); - - foreach (var param in method.Parameters) - { - var typeFqn = param.Type.ToDisplayString(); - - // Skip special types (services, context, etc.) - if (IsSkippedParameterType(param, typeFqn)) continue; - - // Check explicit binding attributes - var (explicitLocation, explicitName) = GetExplicitBinding(param); - - string name; - string location; - bool required; - - if (explicitLocation is not null) - { - // Explicit attribute wins - name = explicitName ?? param.Name; - location = explicitLocation; - required = location == "path" || - (param.Type.NullableAnnotation != NullableAnnotation.Annotated && - !param.HasExplicitDefaultValue); - } - else if (routeParamNames.Contains(param.Name)) - { - // Route parameter - name = param.Name; - location = "path"; - required = true; - - // Check if optional in route template - foreach (var rp in routeParams) - { - if (string.Equals(rp.Name, param.Name, StringComparison.OrdinalIgnoreCase) && rp.IsOptional) - { - required = false; - break; - } - } - } - else if (IsPrimitiveType(typeFqn)) - { - // Primitive not in route = query - name = param.Name; - location = "query"; - required = param.Type.NullableAnnotation != NullableAnnotation.Annotated && - !param.HasExplicitDefaultValue; - } - else - { - // Complex type without explicit binding - skip (it's body or service) - continue; - } - - var (schemaType, schemaFormat) = GetOpenApiSchema(typeFqn); - parameters.Add(new OpenApiParameterInfo(name, location, required, schemaType, schemaFormat)); - } - - return [.. parameters]; - } - - private static bool IsSkippedParameterType(IParameterSymbol param, string typeFqn) - { - // Skip special framework types - if (typeFqn is WellKnownTypes.HttpContext or WellKnownTypes.CancellationToken - or WellKnownTypes.FormFile or WellKnownTypes.FormFileCollection - or WellKnownTypes.Stream or WellKnownTypes.PipeReader or WellKnownTypes.FormCollection) - { - return true; - } - - // Skip interface types (services) - if (param.Type.TypeKind == TypeKind.Interface) return true; - - // Skip abstract types (services) - if (param.Type is { IsAbstract: true, TypeKind: TypeKind.Class }) return true; - - // Skip [FromServices] / [FromKeyedServices] / [FromBody] / [FromForm] - foreach (var attr in param.GetAttributes()) - { - var attrName = attr.AttributeClass?.ToDisplayString(); - if (attrName is WellKnownTypes.FromServicesAttribute or WellKnownTypes.FromBodyAttribute - or WellKnownTypes.FromFormAttribute or WellKnownTypes.FromKeyedServicesAttribute) - { - return true; - } - } - - return false; - } - - private static (string? Location, string? Name) GetExplicitBinding(ISymbol param) - { - foreach (var attr in param.GetAttributes()) - { - var attrName = attr.AttributeClass?.ToDisplayString(); - switch (attrName) - { - case WellKnownTypes.FromRouteAttribute: - { - var name = GetAttributeStringArg(attr, "Name"); - return ("path", name); - } - case WellKnownTypes.FromQueryAttribute: - { - var name = GetAttributeStringArg(attr, "Name"); - return ("query", name); - } - case WellKnownTypes.FromHeaderAttribute: - { - var name = GetAttributeStringArg(attr, "Name"); - return ("header", name); - } - } - } - - return (null, null); - } - - private static string? GetAttributeStringArg(AttributeData attr, string propName) - { - foreach (var kvp in attr.NamedArguments) - { - if (kvp.Key == propName && kvp.Value.Value is string s && !string.IsNullOrWhiteSpace(s)) - return s; - } - - return null; - } - - private static bool IsPrimitiveType(string typeFqn) - { - // Strip nullable wrapper - var type = typeFqn.EndsWithOrdinal("?") ? typeFqn.Substring(0, typeFqn.Length - 1) : typeFqn; - - return type is "int" or "System.Int32" - or "long" or "System.Int64" - or "short" or "System.Int16" - or "uint" or "System.UInt32" - or "ulong" or "System.UInt64" - or "ushort" or "System.UInt16" - or "byte" or "System.Byte" - or "sbyte" or "System.SByte" - or "bool" or "System.Boolean" - or "decimal" or "System.Decimal" - or "double" or "System.Double" - or "float" or "System.Single" - or "string" or "System.String" - or "System.Guid" - or "System.DateTime" - or "System.DateTimeOffset" - or "System.DateOnly" - or "System.TimeOnly" - or "System.TimeSpan"; - } - - private static (string SchemaType, string? SchemaFormat) GetOpenApiSchema(string typeFqn) - { - // Strip nullable wrapper - var type = typeFqn.EndsWithOrdinal("?") ? typeFqn.Substring(0, typeFqn.Length - 1) : typeFqn; - - return type switch - { - "int" or "System.Int32" => ("integer", "int32"), - "long" or "System.Int64" => ("integer", "int64"), - "short" or "System.Int16" => ("integer", "int16"), - "uint" or "System.UInt32" => ("integer", "int32"), - "ulong" or "System.UInt64" => ("integer", "int64"), - "ushort" or "System.UInt16" => ("integer", "int16"), - "byte" or "System.Byte" => ("integer", "int32"), - "sbyte" or "System.SByte" => ("integer", "int32"), - "bool" or "System.Boolean" => ("boolean", null), - "decimal" or "System.Decimal" => ("number", "double"), - "double" or "System.Double" => ("number", "double"), - "float" or "System.Single" => ("number", "float"), - "System.Guid" => ("string", "uuid"), - "System.DateTime" => ("string", "date-time"), - "System.DateTimeOffset" => ("string", "date-time"), - "System.DateOnly" => ("string", "date"), - "System.TimeOnly" => ("string", "time"), - "System.TimeSpan" => ("string", "duration"), - _ => ("string", null) - }; - } - - /// - /// Maps internal schema type string to OpenApi v2.0 JsonSchemaType enum name for emission. - /// - private static string ToJsonSchemaTypeEnum(string schemaType) - { - return schemaType switch - { - "integer" => "JsonSchemaType.Integer", - "number" => "JsonSchemaType.Number", - "boolean" => "JsonSchemaType.Boolean", - _ => "JsonSchemaType.String" - }; - } - - private static string GetReflectionFullName(ISymbol symbol) - { - var fqn = ((ITypeSymbol)symbol).GetFullyQualifiedName(); - return fqn.StartsWithOrdinal("global::") ? fqn.Substring("global::".Length) : fqn; - } - - private static TypeMetadataInfo? ExtractTypeMetadata( - GeneratorSyntaxContext ctx, - CancellationToken ct) - { - if (ctx.Node is not TypeDeclarationSyntax typeDecl) return null; - - // Skip null symbols and compiler-generated types - if (ctx.SemanticModel.GetDeclaredSymbol(typeDecl, ct) is not INamedTypeSymbol symbol || - symbol.IsImplicitlyDeclared) - { - return null; - } - - // Skip types without XML docs - var xmlDoc = symbol.GetDocumentationCommentXml(cancellationToken: ct); - if (string.IsNullOrWhiteSpace(xmlDoc)) return null; - - var (summary, _) = ParseXmlDoc(xmlDoc); - if (summary is null) return null; - - var typeKey = GetReflectionFullName(symbol); - - return new TypeMetadataInfo(typeKey, summary); - } - - private static void Emit( - SourceProductionContext spc, - ImmutableArray endpoints, - ImmutableArray types) - { - if (endpoints.IsDefaultOrEmpty) return; - - var code = new StringBuilder(); - code.AppendLine("// "); - code.AppendLine("#nullable enable"); - code.AppendLine(); - code.AppendLine("using System;"); - code.AppendLine("using System.Collections.Frozen;"); - code.AppendLine("using System.Collections.Generic;"); - code.AppendLine("using System.Threading;"); - code.AppendLine("using System.Threading.Tasks;"); - code.AppendLine("using Microsoft.AspNetCore.OpenApi;"); - code.AppendLine("using Microsoft.AspNetCore.Routing;"); - code.AppendLine("using Microsoft.Extensions.DependencyInjection;"); - code.AppendLine("using Microsoft.OpenApi;"); - code.AppendLine(); - code.AppendLine("namespace ErrorOr.Generated;"); - code.AppendLine(); - - // Collect unique tags (1 attribute → 1 transformer) - var tags = endpoints.Select(static e => e.TagName).Distinct(StringComparer.Ordinal) - .OrderBy(static t => t, StringComparer.Ordinal).ToList(); - - // Emit tag transformers (strict 1:1 - one transformer per unique tag) - foreach (var tag in tags) EmitTagTransformer(code, tag); - - // Emit operation transformer (applies XML doc summaries) - var hasOperationDocs = EmitOperationTransformer(code, endpoints); - - // Emit schema transformer (applies type descriptions) - var hasTypeDocs = false; - if (!types.IsDefaultOrEmpty) hasTypeDocs = EmitSchemaTransformer(code, types); - - // Emit registration extension - EmitRegistrationExtension(code, tags, hasOperationDocs, hasTypeDocs); - - spc.AddSource("OpenApiTransformers.g.cs", SourceText.From(code.ToString(), Encoding.UTF8)); - } - - private static void EmitTagTransformer(StringBuilder code, string tagName) - { - var safeTagName = tagName.SanitizeIdentifier(); - code.AppendLine("/// "); - code.AppendLine($"/// Document transformer for tag: {tagName}"); - code.AppendLine($"/// Generated from: [ErrorOrEndpoint] attribute on *{tagName}Endpoints class"); - code.AppendLine("/// "); - code.AppendLine($"file sealed class Tag_{safeTagName}_Transformer : IOpenApiDocumentTransformer"); - code.AppendLine("{"); - code.AppendLine(" public Task TransformAsync("); - code.AppendLine(" OpenApiDocument document,"); - code.AppendLine(" OpenApiDocumentTransformerContext context,"); - code.AppendLine(" CancellationToken cancellationToken)"); - code.AppendLine(" {"); - // OpenApiDocument.Tags setter auto-wraps with OpenApiTagComparer.Instance - // which handles deduplication by Name - no manual .Any() check needed - code.AppendLine(" document.Tags ??= new HashSet();"); - code.AppendLine($" document.Tags.Add(new OpenApiTag {{ Name = \"{tagName}\" }});"); - code.AppendLine(" return Task.CompletedTask;"); - code.AppendLine(" }"); - code.AppendLine("}"); - code.AppendLine(); - } - - private static bool EmitOperationTransformer(StringBuilder code, ImmutableArray endpoints) - { - // Collect operations with XML docs (summary/description OR parameter docs) - var opsWithDocs = endpoints - .Where(static e => !string.IsNullOrEmpty(e.Summary) || !string.IsNullOrEmpty(e.Description) || - !e.ParameterDocs.IsDefaultOrEmpty) - .OrderBy(static e => e.Pattern, StringComparer.Ordinal) - .ThenBy(static e => e.HttpMethod, StringComparer.Ordinal).ToList(); - - // Collect operations with OpenAPI parameter definitions - var opsWithParams = endpoints - .Where(static e => !e.Parameters.IsDefaultOrEmpty) - .OrderBy(static e => e.OperationId, StringComparer.Ordinal) - .ToList(); - - if (opsWithDocs.Count is 0 && opsWithParams.Count is 0) return false; - - // Collect operations with parameter docs - var opsWithParamDocs = opsWithDocs - .Where(static e => !e.ParameterDocs.IsDefaultOrEmpty) - .ToList(); - - code.AppendLine("/// "); - code.AppendLine( - "/// Operation transformer that applies XML documentation and parameter definitions to operations."); - code.AppendLine("/// Each entry is a strict 1:1 mapping from handler signature to operation metadata."); - code.AppendLine("/// "); - code.AppendLine("file sealed class XmlDocOperationTransformer : IOpenApiOperationTransformer"); - code.AppendLine("{"); - code.AppendLine(" // Pre-computed metadata from XML docs (compile-time extraction)"); - code.AppendLine( - " private static readonly FrozenDictionary OperationDocs ="); - code.AppendLine(" new Dictionary"); - code.AppendLine(" {"); - - foreach (var op in opsWithDocs.Where(static e => - !string.IsNullOrEmpty(e.Summary) || !string.IsNullOrEmpty(e.Description))) - { - var summary = op.Summary is not null ? $"\"{op.Summary.EscapeCSharpString()}\"" : "null"; - var description = op.Description is not null ? $"\"{op.Description.EscapeCSharpString()}\"" : "null"; - code.AppendLine($" [\"{op.OperationId}\"] = ({summary}, {description}),"); - } - - code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal);"); - code.AppendLine(); - - // Emit parameter docs dictionary - code.AppendLine(" // Pre-computed parameter descriptions from XML tags"); - code.AppendLine( - " private static readonly FrozenDictionary> ParameterDocs ="); - code.AppendLine(" new Dictionary>"); - code.AppendLine(" {"); - - foreach (var op in opsWithParamDocs) - { - code.AppendLine($" [\"{op.OperationId}\"] = new Dictionary"); - code.AppendLine(" {"); - foreach (var (paramName, paramDesc) in op.ParameterDocs.AsImmutableArray()) - { - code.AppendLine( - $" [\"{paramName.EscapeCSharpString()}\"] = \"{paramDesc.EscapeCSharpString()}\","); - } - - code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal),"); - } - - code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal);"); - code.AppendLine(); - - // Emit parameter definitions dictionary - if (opsWithParams.Count > 0) - { - code.AppendLine(" // Pre-computed parameter definitions from handler signatures"); - code.AppendLine( - " private static readonly FrozenDictionary ParameterDefs ="); - code.AppendLine( - " new Dictionary"); - code.AppendLine(" {"); - - foreach (var op in opsWithParams) - { - code.Append($" [\"{op.OperationId}\"] = [("); - var first = true; - foreach (var p in op.Parameters.AsImmutableArray()) - { - if (!first) code.Append("), ("); - - var format = p.SchemaFormat is not null ? $"\"{p.SchemaFormat}\"" : "null"; - var locationEnum = p.Location switch - { - "path" => "ParameterLocation.Path", - "header" => "ParameterLocation.Header", - _ => "ParameterLocation.Query" - }; - var schemaTypeEnum = ToJsonSchemaTypeEnum(p.SchemaType); - code.Append( - $"\"{p.Name}\", {locationEnum}, {(p.Required ? "true" : "false")}, {schemaTypeEnum}, {format}"); - first = false; - } - - code.AppendLine(")],"); - } - - code.AppendLine(" }.ToFrozenDictionary(StringComparer.Ordinal);"); - code.AppendLine(); - } - - code.AppendLine(" public Task TransformAsync("); - code.AppendLine(" OpenApiOperation operation,"); - code.AppendLine(" OpenApiOperationTransformerContext context,"); - code.AppendLine(" CancellationToken cancellationToken)"); - code.AppendLine(" {"); - code.AppendLine(" string? operationId = null;"); - code.AppendLine(" var metadata = context.Description.ActionDescriptor?.EndpointMetadata;"); - code.AppendLine(" if (metadata is not null)"); - code.AppendLine(" {"); - code.AppendLine(" for (var i = 0; i < metadata.Count; i++)"); - code.AppendLine(" {"); - code.AppendLine(" if (metadata[i] is IEndpointNameMetadata nameMetadata)"); - code.AppendLine(" {"); - code.AppendLine(" operationId = nameMetadata.EndpointName;"); - code.AppendLine(" break;"); - code.AppendLine(" }"); - code.AppendLine(" }"); - code.AppendLine(" }"); - code.AppendLine(); - code.AppendLine(" if (operationId is null)"); - code.AppendLine(" return Task.CompletedTask;"); - code.AppendLine(); - code.AppendLine(" // Apply summary and description"); - code.AppendLine(" if (OperationDocs.TryGetValue(operationId, out var docs))"); - code.AppendLine(" {"); - code.AppendLine(" if (docs.Summary is not null)"); - code.AppendLine(" operation.Summary ??= docs.Summary;"); - code.AppendLine(" if (docs.Description is not null)"); - code.AppendLine(" operation.Description ??= docs.Description;"); - code.AppendLine(" }"); - code.AppendLine(); - - // Emit parameter definitions application code - if (opsWithParams.Count > 0) - { - code.AppendLine(" // Add parameter definitions from handler signatures"); - code.AppendLine(" if (ParameterDefs.TryGetValue(operationId, out var paramDefs))"); - code.AppendLine(" {"); - code.AppendLine(" operation.Parameters ??= [];"); - code.AppendLine( - " foreach (var (pName, pLocation, pRequired, pSchemaType, pSchemaFormat) in paramDefs)"); - code.AppendLine(" {"); - code.AppendLine(" var schema = new OpenApiSchema { Type = pSchemaType };"); - code.AppendLine(" if (pSchemaFormat is not null) schema.Format = pSchemaFormat;"); - code.AppendLine(" operation.Parameters.Add(new OpenApiParameter"); - code.AppendLine(" {"); - code.AppendLine(" Name = pName,"); - code.AppendLine(" In = pLocation,"); - code.AppendLine(" Required = pRequired,"); - code.AppendLine(" Schema = schema"); - code.AppendLine(" });"); - code.AppendLine(" }"); - code.AppendLine(" }"); - code.AppendLine(); - } - - code.AppendLine(" // Apply parameter descriptions"); - code.AppendLine( - " if (ParameterDocs.TryGetValue(operationId, out var paramDocs) && operation.Parameters is not null)"); - code.AppendLine(" {"); - code.AppendLine(" foreach (var param in operation.Parameters)"); - code.AppendLine(" {"); - code.AppendLine( - " if (param.Name is not null && paramDocs.TryGetValue(param.Name, out var paramDesc))"); - code.AppendLine(" {"); - code.AppendLine(" param.Description ??= paramDesc;"); - code.AppendLine(" }"); - code.AppendLine(" }"); - code.AppendLine(" }"); - code.AppendLine(); - code.AppendLine(" return Task.CompletedTask;"); - code.AppendLine(" }"); - code.AppendLine("}"); - code.AppendLine(); - - return true; - } - - private static bool EmitSchemaTransformer(StringBuilder code, ImmutableArray types) - { - var typesWithDocs = types.OrderBy(static t => t.TypeKey, StringComparer.Ordinal).ToList(); - - if (typesWithDocs.Count is 0) return false; - - code.AppendLine("/// "); - code.AppendLine("/// Schema transformer that applies type XML documentation to schemas."); - code.AppendLine("/// Each entry is a strict 1:1 mapping from XML doc to schema description."); - code.AppendLine("/// AOT-safe: Uses Type as dictionary key (no runtime reflection)."); - code.AppendLine("/// "); - code.AppendLine("file sealed class XmlDocSchemaTransformer : IOpenApiSchemaTransformer"); - code.AppendLine("{"); - code.AppendLine( - " // Pre-computed type descriptions from XML docs (AOT-safe: Type keys resolved at compile-time)"); - code.AppendLine(" private static readonly FrozenDictionary TypeDescriptions ="); - code.AppendLine(" new Dictionary"); - code.AppendLine(" {"); - - foreach (var type in typesWithDocs) - { - // Convert reflection-style name (Namespace.Outer+Inner) to C# typeof expression (global::Namespace.Outer.Inner) - var typeofExpr = ConvertToTypeofExpression(type.TypeKey); - code.AppendLine($" [typeof({typeofExpr})] = \"{type.Description.EscapeCSharpString()}\","); - } - - code.AppendLine(" }.ToFrozenDictionary();"); - code.AppendLine(); - code.AppendLine(" public Task TransformAsync("); - code.AppendLine(" OpenApiSchema schema,"); - code.AppendLine(" OpenApiSchemaTransformerContext context,"); - code.AppendLine(" CancellationToken cancellationToken)"); - code.AppendLine(" {"); - // AOT-safe: Direct Type lookup without reflection - code.AppendLine(" var type = context.JsonTypeInfo.Type;"); - code.AppendLine(" // For generic types, lookup the generic type definition"); - code.AppendLine(" var lookupType = type.IsGenericType ? type.GetGenericTypeDefinition() : type;"); - code.AppendLine(" if (TypeDescriptions.TryGetValue(lookupType, out var description))"); - code.AppendLine(" {"); - code.AppendLine(" schema.Description ??= description;"); - code.AppendLine(" }"); - code.AppendLine(" return Task.CompletedTask;"); - code.AppendLine(" }"); - code.AppendLine("}"); - code.AppendLine(); - - return true; - } - - /// - /// Converts a reflection-style type name to a C# typeof expression. - /// Example: "Namespace.Outer+Inner" → "global::Namespace.Outer.Inner" - /// - private static string ConvertToTypeofExpression(string reflectionName) - { - // Replace nested type separator (+) with C# dot notation - var csharpName = reflectionName.Replace('+', '.'); - return $"global::{csharpName}"; - } - - private static void EmitRegistrationExtension( - StringBuilder code, - List tags, - bool hasOperationDocs, - bool hasTypeDocs) - { - code.AppendLine("/// "); - code.AppendLine("/// Extension methods for registering generated OpenAPI transformers."); - code.AppendLine("/// "); - code.AppendLine("public static class GeneratedOpenApiExtensions"); - code.AppendLine("{"); - code.AppendLine(" /// "); - code.AppendLine(" /// Adds OpenAPI with generated transformers for ErrorOr endpoints."); - code.AppendLine(" /// Each transformer is registered following the strict 1:1 mapping rule."); - code.AppendLine(" /// "); - code.AppendLine(" public static IServiceCollection AddErrorOrOpenApi("); - code.AppendLine(" this IServiceCollection services,"); - code.AppendLine(" string documentName = \"v1\")"); - code.AppendLine(" {"); - code.AppendLine(" services.AddOpenApi(documentName, options =>"); - code.AppendLine(" {"); - - // Register tag transformers (1:1 - one per tag) - foreach (var tag in tags) - { - var safeTagName = tag.SanitizeIdentifier(); - code.AppendLine($" // Tag: {tag}"); - code.AppendLine($" options.AddDocumentTransformer(new Tag_{safeTagName}_Transformer());"); - } - - // Register operation transformer if we have docs - if (hasOperationDocs) - { - code.AppendLine(); - code.AppendLine(" // XML doc summaries → operation metadata"); - code.AppendLine(" options.AddOperationTransformer(new XmlDocOperationTransformer());"); - } - - // Register schema transformer if we have type docs - if (hasTypeDocs) - { - code.AppendLine(); - code.AppendLine(" // XML doc summaries → schema descriptions"); - code.AppendLine(" options.AddSchemaTransformer(new XmlDocSchemaTransformer());"); - } - - code.AppendLine(" });"); - code.AppendLine(); - code.AppendLine(" return services;"); - code.AppendLine(" }"); - code.AppendLine("}"); - } } diff --git a/tests/ErrorOrX.Generators.Tests/BindingTypeValidationTests.cs b/tests/ErrorOrX.Generators.Tests/BindingTypeValidationTests.cs new file mode 100644 index 0000000..161d2bc --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/BindingTypeValidationTests.cs @@ -0,0 +1,199 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for parameter binding type validation diagnostics +/// (EOE010-EOE014, EOE016-EOE017). Verifies that invalid types used with +/// [FromRoute], [FromQuery], [FromHeader], and +/// [AsParameters] are detected and reported. +/// +public class BindingTypeValidationTests : GeneratorTestBase +{ + #region EOE010 - Invalid [FromRoute] type + + [Fact] + public Task EOE010_Invalid_FromRoute_Type_Complex() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + namespace DiagnosticTest; + + public class ComplexFilter { public string Name { get; set; } } + + public static class TodoApi + { + [Get("/todos/{filter}")] + public static ErrorOr GetByFilter([FromRoute] ComplexFilter filter) => "todo"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE011 - Invalid [FromQuery] type + + [Fact] + public Task EOE011_Invalid_FromQuery_Type_Complex() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + namespace DiagnosticTest; + + public class ComplexFilter { public string Name { get; set; } } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr Search([FromQuery] ComplexFilter filter) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE012 - Invalid [AsParameters] type + + [Fact] + public Task EOE012_Invalid_AsParameters_Type_Primitive() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr Search([AsParameters] int page) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE013 - [AsParameters] type has no constructor + + [Fact] + public Task EOE013_AsParameters_No_Constructor() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + + namespace DiagnosticTest; + + public class SearchParams + { + private SearchParams() { } + public string Query { get; set; } + } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr Search([AsParameters] SearchParams search) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE014 - Invalid [FromHeader] type + + [Fact] + public Task EOE014_Invalid_FromHeader_Type_Complex() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + namespace DiagnosticTest; + + public class ComplexHeader { public string Value { get; set; } } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr GetAll([FromHeader] ComplexHeader header) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE016 - Nested [AsParameters] not supported + + [Fact] + public Task EOE016_Nested_AsParameters() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + + namespace DiagnosticTest; + + public class InnerParams + { + public int Page { get; set; } + } + + public class OuterParams + { + [AsParameters] + public InnerParams Inner { get; set; } + } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr Search([AsParameters] OuterParams search) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE017 - Nullable [AsParameters] not supported + + [Fact] + public Task EOE017_Nullable_AsParameters() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + + namespace DiagnosticTest; + + public class SearchParams + { + public string Query { get; set; } + } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr Search([AsParameters] SearchParams? search) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + #endregion +} diff --git a/tests/ErrorOrX.Generators.Tests/DiagnosticTests.cs b/tests/ErrorOrX.Generators.Tests/DiagnosticTests.cs deleted file mode 100644 index 25b4b29..0000000 --- a/tests/ErrorOrX.Generators.Tests/DiagnosticTests.cs +++ /dev/null @@ -1,956 +0,0 @@ -namespace ErrorOrX.Generators.Tests; - -/// -/// Tests for generator diagnostics (EOE003-EOE038). -/// Verifies that invalid endpoint configurations are detected and reported. -/// -public class DiagnosticTests : GeneratorTestBase -{ - #region EOE010 - Invalid [FromRoute] type - - [Fact] - public Task EOE010_Invalid_FromRoute_Type_Complex() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - namespace DiagnosticTest; - - public class ComplexFilter { public string Name { get; set; } } - - public static class TodoApi - { - [Get("/todos/{filter}")] - public static ErrorOr GetByFilter([FromRoute] ComplexFilter filter) => "todo"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE011 - Invalid [FromQuery] type - - [Fact] - public Task EOE011_Invalid_FromQuery_Type_Complex() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - namespace DiagnosticTest; - - public class ComplexFilter { public string Name { get; set; } } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr Search([FromQuery] ComplexFilter filter) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE012 - Invalid [AsParameters] type - - [Fact] - public Task EOE012_Invalid_AsParameters_Type_Primitive() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr Search([AsParameters] int page) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE013 - [AsParameters] type has no constructor - - [Fact] - public Task EOE013_AsParameters_No_Constructor() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - - namespace DiagnosticTest; - - public class SearchParams - { - private SearchParams() { } - public string Query { get; set; } - } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr Search([AsParameters] SearchParams search) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE014 - Invalid [FromHeader] type - - [Fact] - public Task EOE014_Invalid_FromHeader_Type_Complex() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - namespace DiagnosticTest; - - public class ComplexHeader { public string Value { get; set; } } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr GetAll([FromHeader] ComplexHeader header) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE015 - Anonymous return type not supported - - [Fact] - public Task EOE015_Anonymous_Return_Type() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/data")] - public static ErrorOr GetData() => new { Name = "test" }; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE016 - Nested [AsParameters] not supported - - [Fact] - public Task EOE016_Nested_AsParameters() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - - namespace DiagnosticTest; - - public class InnerParams - { - public int Page { get; set; } - } - - public class OuterParams - { - [AsParameters] - public InnerParams Inner { get; set; } - } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr Search([AsParameters] OuterParams search) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE017 - Nullable [AsParameters] not supported - - [Fact] - public Task EOE017_Nullable_AsParameters() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - - namespace DiagnosticTest; - - public class SearchParams - { - public string Query { get; set; } - } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr Search([AsParameters] SearchParams? search) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE018 - Inaccessible type in endpoint - - [Fact] - public Task EOE018_Private_Return_Type() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - private class SecretData { public string Value { get; set; } } - - [Get("/secret")] - public static ErrorOr GetSecret() => new SecretData { Value = "secret" }; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE019 - Type parameter not supported - - [Fact] - public Task EOE019_Generic_Type_Parameter() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class GenericApi - { - [Get("/items")] - public static ErrorOr GetItem() where T : class => default!; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE023 - Unknown error factory - - [Fact] - public Task EOE023_Unknown_Error_Factory() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) - { - if (id < 0) - return Error.Custom(999, "custom", "description"); - return "todo"; - } - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE022 - Too many result types - - [Fact] - public Task EOE022_Too_Many_Result_Types() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) - { - if (id == 0) return Error.NotFound("Todo.NotFound", "Not found"); - if (id == 1) return Error.Validation("Todo.Invalid", "Invalid"); - if (id == 2) return Error.Conflict("Todo.Conflict", "Conflict"); - if (id == 3) return Error.Unauthorized("Todo.Unauthorized", "Unauthorized"); - if (id == 4) return Error.Forbidden("Todo.Forbidden", "Forbidden"); - if (id == 5) return Error.Failure("Todo.Failure", "Failure"); - if (id == 6) return Error.Unexpected("Todo.Unexpected", "Unexpected"); - return $"todo {id}"; - } - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE007 - Type not in JSON context - - [Fact] - public Task EOE007_Type_Not_In_Json_Context() - { - const string Source = """ - using ErrorOr; - using System.Text.Json.Serialization; - - namespace DiagnosticTest; - - public record Todo(int Id, string Title); - public record AnotherType(string Name); - - // User context that's missing Todo - [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] - [JsonSerializable(typeof(AnotherType))] - internal partial class AppJsonContext : JsonSerializerContext { } - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => new Todo(id, "Title"); - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE003 - Route parameter not bound - - [Fact] - public Task EOE003_Route_Parameter_Not_Bound() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById() => "todo"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE003_Route_Parameter_With_Constraint_Not_Bound() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id:int}")] - public static ErrorOr GetById() => "todo"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE005 - Invalid route pattern - - [Fact] - public Task EOE005_Unclosed_Brace_In_Route() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id")] - public static ErrorOr GetById(int id) => "todo"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE005_Unmatched_Close_Brace() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/id}")] - public static ErrorOr GetById(int id) => "todo"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE005_Empty_Parameter_Name() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{}")] - public static ErrorOr GetById() => "todo"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE006 - Multiple body sources - - [Fact] - public Task EOE006_Multiple_Body_Sources_FromBody_And_FromForm() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - using Microsoft.AspNetCore.Mvc; - - namespace DiagnosticTest; - - public record CreateRequest(string Name); - - public static class TodoApi - { - [Post("/todos")] - public static ErrorOr Create( - [FromBody] CreateRequest body, - [FromForm] IFormFile file) => "created"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE006_Multiple_Body_Sources_Stream_And_FromBody() - { - const string Source = """ - using ErrorOr; - using System.IO; - using Microsoft.AspNetCore.Mvc; - - namespace DiagnosticTest; - - public record CreateRequest(string Name); - - public static class TodoApi - { - [Post("/upload")] - public static ErrorOr Upload( - [FromBody] CreateRequest body, - Stream data) => "uploaded"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE020 - Route constraint type mismatch - - [Fact] - public Task EOE020_Int_Constraint_With_String_Parameter() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id:int}")] - public static ErrorOr GetById(string id) => "todo"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE020_Guid_Constraint_With_Int_Parameter() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id:guid}")] - public static ErrorOr GetById(int id) => "todo"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE021 - Ambiguous parameter binding - - [Fact] - public Task EOE021_Complex_Type_On_Get_Without_Binding() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public class SearchFilter - { - public string Query { get; set; } - public int Page { get; set; } - } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr Search(SearchFilter filter) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE021_Complex_Type_On_Delete_Without_Binding() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public class DeleteOptions - { - public bool Force { get; set; } - } - - public static class TodoApi - { - [Delete("/todos/{id}")] - public static ErrorOr Delete(int id, DeleteOptions options) => "deleted"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE024 - Undocumented interface call - - [Fact] - public Task EOE024_Undocumented_Interface_Call() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public interface ITodoService - { - ErrorOr GetById(int id); - } - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id, ITodoService svc) - => svc.GetById(id); - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE024_Interface_Call_With_ProducesError_No_Diagnostic() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public interface ITodoService - { - ErrorOr GetById(int id); - } - - public static class TodoApi - { - [Get("/todos/{id}")] - [ProducesError(404, "NotFound")] - public static ErrorOr GetById(int id, ITodoService svc) - => svc.GetById(id); - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE025 - Missing CamelCase policy - - [Fact] - public Task EOE025_Missing_CamelCase_Policy() - { - const string Source = """ - using ErrorOr; - using System.Text.Json.Serialization; - - namespace DiagnosticTest; - - public record Todo(int Id, string Title); - - // User context WITHOUT CamelCase policy - [JsonSerializable(typeof(Todo))] - internal partial class AppJsonContext : JsonSerializerContext { } - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => new Todo(id, "Title"); - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE025_With_CamelCase_Policy_No_Diagnostic() - { - const string Source = """ - using ErrorOr; - using System.Text.Json.Serialization; - - namespace DiagnosticTest; - - public record Todo(int Id, string Title); - - // User context WITH CamelCase policy - [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] - [JsonSerializable(typeof(Todo))] - internal partial class AppJsonContext : JsonSerializerContext { } - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => new Todo(id, "Title"); - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE033 - Handler method name not PascalCase - - [Fact] - public Task EOE033_Method_Name_Lowercase_Start() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr getById(int id) => $"todo {id}"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE033_Method_Name_With_Underscore() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr Get_By_Id(int id) => $"todo {id}"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE033_Method_Name_Snake_Case() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr get_by_id(int id) => $"todo {id}"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE039 - DataAnnotations validation uses reflection - - [Fact] - public Task EOE039_Validation_Attribute_On_Parameter() - { - const string Source = """ - using ErrorOr; - using System.ComponentModel.DataAnnotations; - - namespace DiagnosticTest; - - public record CreateTodoRequest([Required] string Title); - - public static class TodoApi - { - [Post("/todos")] - public static ErrorOr Create(CreateTodoRequest request) => "created"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE039_Multiple_Validation_Attributes() - { - const string Source = """ - using ErrorOr; - using System.ComponentModel.DataAnnotations; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Post("/todos")] - public static ErrorOr Create( - [Required] [StringLength(100)] string title, - [Range(1, 100)] int priority) => "created"; - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region EOE041 - JsonSerializerContext missing error types - - [Fact] - public Task EOE041_Missing_ProblemDetails_In_JsonContext() - { - const string Source = """ - using ErrorOr; - using System.Text.Json.Serialization; - - namespace DiagnosticTest; - - public record Todo(int Id, string Title); - - // User context missing ProblemDetails types - [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] - [JsonSerializable(typeof(Todo))] - internal partial class AppJsonContext : JsonSerializerContext { } - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => new Todo(id, "Title"); - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task EOE041_No_Diagnostic_When_ProblemDetails_Present() - { - const string Source = """ - using ErrorOr; - using System.Text.Json.Serialization; - using Microsoft.AspNetCore.Mvc; - using Microsoft.AspNetCore.Http; - - namespace DiagnosticTest; - - public record Todo(int Id, string Title); - - // User context WITH ProblemDetails types - [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] - [JsonSerializable(typeof(Todo))] - [JsonSerializable(typeof(ProblemDetails))] - [JsonSerializable(typeof(HttpValidationProblemDetails))] - internal partial class AppJsonContext : JsonSerializerContext { } - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => new Todo(id, "Title"); - } - """; - - return VerifyAsync(Source); - } - - #endregion - - #region Valid cases - no diagnostics - - [Fact] - public Task Valid_Route_Parameter_Bound() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => $"todo {id}"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task Valid_Complex_Type_With_AsParameters() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - - namespace DiagnosticTest; - - public class SearchFilter - { - public string Query { get; set; } - public int Page { get; set; } - } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr Search([AsParameters] SearchFilter filter) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task Valid_Complex_Type_With_FromBody_On_Post() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - namespace DiagnosticTest; - - public record CreateTodoRequest(string Title); - - public static class TodoApi - { - [Post("/todos")] - public static ErrorOr Create([FromBody] CreateTodoRequest request) => "created"; - } - """; - - return VerifyAsync(Source); - } - - [Fact] - public Task Valid_Service_Type_Inferred() - { - const string Source = """ - using ErrorOr; - - namespace DiagnosticTest; - - public interface ITodoService { } - - public static class TodoApi - { - [Get("/todos")] - public static ErrorOr GetAll(ITodoService service) => "todos"; - } - """; - - return VerifyAsync(Source); - } - - #endregion -} diff --git a/tests/ErrorOrX.Generators.Tests/JsonAotValidationTests.cs b/tests/ErrorOrX.Generators.Tests/JsonAotValidationTests.cs new file mode 100644 index 0000000..995f004 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/JsonAotValidationTests.cs @@ -0,0 +1,203 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for JSON-serialization and AOT-safety diagnostics (EOE007, +/// EOE025, EOE039, EOE041). Covers missing types in the user's +/// JsonSerializerContext, missing CamelCase policy, +/// DataAnnotations validation that relies on reflection, and the absence +/// of error-payload types (ProblemDetails) in the user's context. +/// +public class JsonAotValidationTests : GeneratorTestBase +{ + #region EOE007 - Type not in JSON context + + [Fact] + public Task EOE007_Type_Not_In_Json_Context() + { + const string Source = """ + using ErrorOr; + using System.Text.Json.Serialization; + + namespace DiagnosticTest; + + public record Todo(int Id, string Title); + public record AnotherType(string Name); + + // User context that's missing Todo + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(AnotherType))] + internal partial class AppJsonContext : JsonSerializerContext { } + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => new Todo(id, "Title"); + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE025 - Missing CamelCase policy + + [Fact] + public Task EOE025_Missing_CamelCase_Policy() + { + const string Source = """ + using ErrorOr; + using System.Text.Json.Serialization; + + namespace DiagnosticTest; + + public record Todo(int Id, string Title); + + // User context WITHOUT CamelCase policy + [JsonSerializable(typeof(Todo))] + internal partial class AppJsonContext : JsonSerializerContext { } + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => new Todo(id, "Title"); + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE025_With_CamelCase_Policy_No_Diagnostic() + { + const string Source = """ + using ErrorOr; + using System.Text.Json.Serialization; + + namespace DiagnosticTest; + + public record Todo(int Id, string Title); + + // User context WITH CamelCase policy + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(Todo))] + internal partial class AppJsonContext : JsonSerializerContext { } + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => new Todo(id, "Title"); + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE039 - DataAnnotations validation uses reflection + + [Fact] + public Task EOE039_Validation_Attribute_On_Parameter() + { + const string Source = """ + using ErrorOr; + using System.ComponentModel.DataAnnotations; + + namespace DiagnosticTest; + + public record CreateTodoRequest([Required] string Title); + + public static class TodoApi + { + [Post("/todos")] + public static ErrorOr Create(CreateTodoRequest request) => "created"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE039_Multiple_Validation_Attributes() + { + const string Source = """ + using ErrorOr; + using System.ComponentModel.DataAnnotations; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Post("/todos")] + public static ErrorOr Create( + [Required] [StringLength(100)] string title, + [Range(1, 100)] int priority) => "created"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE041 - JsonSerializerContext missing error types + + [Fact] + public Task EOE041_Missing_ProblemDetails_In_JsonContext() + { + const string Source = """ + using ErrorOr; + using System.Text.Json.Serialization; + + namespace DiagnosticTest; + + public record Todo(int Id, string Title); + + // User context missing ProblemDetails types + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(Todo))] + internal partial class AppJsonContext : JsonSerializerContext { } + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => new Todo(id, "Title"); + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE041_No_Diagnostic_When_ProblemDetails_Present() + { + const string Source = """ + using ErrorOr; + using System.Text.Json.Serialization; + using Microsoft.AspNetCore.Mvc; + using Microsoft.AspNetCore.Http; + + namespace DiagnosticTest; + + public record Todo(int Id, string Title); + + // User context WITH ProblemDetails types + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(Todo))] + [JsonSerializable(typeof(ProblemDetails))] + [JsonSerializable(typeof(HttpValidationProblemDetails))] + internal partial class AppJsonContext : JsonSerializerContext { } + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => new Todo(id, "Title"); + } + """; + + return VerifyAsync(Source); + } + + #endregion +} diff --git a/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionAuthorizationTests.cs b/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionAuthorizationTests.cs new file mode 100644 index 0000000..126703e --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionAuthorizationTests.cs @@ -0,0 +1,198 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for [Authorize]/[AllowAnonymous] emission. Security-critical because +/// the AOT wrapper delegates lose original method attributes, so RequireAuthorization MUST +/// be emitted as a fluent call. Also covers combined-middleware emission as a smoke test. +/// +public class MiddlewareEmissionAuthorizationTests : GeneratorTestBase +{ + [Fact] + public async Task Multiple_Middleware_Attributes_All_Emit() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + using Microsoft.AspNetCore.RateLimiting; + using Microsoft.AspNetCore.OutputCaching; + + public static class Api + { + [Get("/api")] + [Authorize("ApiPolicy")] + [EnableRateLimiting("standard")] + [OutputCache(PolicyName = "ApiCache")] + public static ErrorOr GetData() => "data"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".RequireAuthorization(\"ApiPolicy\")"); + generated.Should().Contain(".RequireRateLimiting(\"standard\")"); + generated.Should().Contain(".CacheOutput(\"ApiCache\")"); + } + + [Fact] + public async Task Authorize_Attribute_Emits_RequireAuthorization() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class Api + { + [Get("/admin")] + [Authorize] + public static ErrorOr AdminOnly() => "secret"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".RequireAuthorization()"); + } + + [Fact] + public async Task Authorize_With_Policy_String_Emits_Policy_Name() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class Api + { + [Get("/admin")] + [Authorize("AdminPolicy")] + public static ErrorOr AdminOnly() => "secret"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".RequireAuthorization(\"AdminPolicy\")"); + } + + [Fact] + public async Task Authorize_With_Policy_Named_Parameter_Emits_Policy_Name() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class Api + { + [Get("/admin")] + [Authorize(Policy = "AdminPolicy")] + public static ErrorOr AdminOnly() => "secret"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".RequireAuthorization(\"AdminPolicy\")"); + } + + [Fact] + public async Task Authorize_With_Roles_Emits_Roles() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class Api + { + [Get("/admin")] + [Authorize(Roles = "Admin,Manager")] + public static ErrorOr AdminOnly() => "secret"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // Roles should be handled - either via policy builder or RequireAuthorization with role + generated.Should().Match("*.RequireAuthorization*Admin*"); + } + + [Fact] + public async Task Multiple_Authorize_Attributes_Emit_All() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class Api + { + [Get("/admin")] + [Authorize("Policy1")] + [Authorize("Policy2")] + public static ErrorOr AdminOnly() => "secret"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("Policy1"); + generated.Should().Contain("Policy2"); + } + + [Fact] + public async Task AllowAnonymous_Overrides_Authorize() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class Api + { + [Get("/public")] + [Authorize] + [AllowAnonymous] + public static ErrorOr PublicEndpoint() => "public"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // AllowAnonymous should emit AllowAnonymous(), not RequireAuthorization() + generated.Should().Contain(".AllowAnonymous()"); + // The RequireAuthorization should be suppressed by AllowAnonymous + } + + [Fact] + public async Task Authorize_With_AuthenticationSchemes_Emits_Scheme() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class Api + { + [Get("/api")] + [Authorize(AuthenticationSchemes = "Bearer")] + public static ErrorOr ApiEndpoint() => "data"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // Authentication schemes should be handled + generated.Should().Match("*.RequireAuthorization*"); + } +} diff --git a/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionCorsAndMetadataTests.cs b/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionCorsAndMetadataTests.cs new file mode 100644 index 0000000..72eb5f2 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionCorsAndMetadataTests.cs @@ -0,0 +1,180 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// CORS attribute emission ([EnableCors], [DisableCors]), endpoint metadata +/// (.WithName, .WithTags), and the security-regression tests that pin the +/// "wrapper does not drop [Authorize]" contract — the single most important invariant +/// of the middleware emitter. +/// +public class MiddlewareEmissionCorsAndMetadataTests : GeneratorTestBase +{ + [Fact] + public async Task EnableCors_Emits_RequireCors() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Cors; + + public static class Api + { + [Get("/api")] + [EnableCors] + public static ErrorOr GetData() => "data"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".RequireCors()"); + } + + [Fact] + public async Task EnableCors_With_Policy_Emits_Policy_Name() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Cors; + + public static class Api + { + [Get("/api")] + [EnableCors("AllowAll")] + public static ErrorOr GetData() => "data"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".RequireCors(\"AllowAll\")"); + } + + [Fact] + public async Task DisableCors_Emits_DisableCors() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Cors; + + public static class Api + { + [Get("/internal")] + [DisableCors] + public static ErrorOr Internal() => "internal"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // DisableCors should be emitted + generated.Should().Match("*Cors*"); + } + + [Fact] + public async Task EndpointName_Uses_ClassName_And_MethodName() + { + const string Source = """ + using ErrorOr; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => $"Todo {id}"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".WithName(\"TodoApi_GetById\")"); + } + + [Fact] + public async Task EndpointTags_Uses_ClassName() + { + const string Source = """ + using ErrorOr; + + public static class UserApi + { + [Get("/users")] + public static ErrorOr GetAll() => "users"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".WithTags(\"UserApi\")"); + } + + [Fact] + public async Task Security_Authorize_NotLost_InWrapper() + { + // This is the critical test - verifying the wrapper pattern doesn't lose security + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + + public static class AdminApi + { + [Get("/admin/secrets")] + [Authorize("SuperAdmin")] + public static ErrorOr GetSecrets() => "top-secret"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + + // The wrapper method (Invoke_Ep0) doesn't have [Authorize], so we MUST emit: + generated.Should().Contain(".RequireAuthorization(\"SuperAdmin\")", + "wrapper delegates lose original method attributes, so RequireAuthorization MUST be emitted"); + + // Also verify the endpoint builder chain includes auth + var lines = generated.Split('\n'); + var endpointLine = lines.FirstOrDefault(static l => + l.Contains("MapGet", StringComparison.Ordinal) && l.Contains("/admin/secrets", StringComparison.Ordinal)); + + // The RequireAuthorization should be in the fluent chain + generated.Should().Match("*MapGet*admin/secrets*RequireAuthorization*SuperAdmin*"); + } + + [Fact] + public async Task Security_Multiple_Policies_All_Applied() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Authorization; + using Microsoft.AspNetCore.RateLimiting; + + public static class SecureApi + { + [Get("/secure/data")] + [Authorize("Policy1")] + [Authorize("Policy2")] + [EnableRateLimiting("strict")] + public static ErrorOr GetSecureData() => "secure"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + + // All security attributes must be preserved + generated.Should().Contain("Policy1"); + generated.Should().Contain("Policy2"); + generated.Should().Contain(".RequireRateLimiting(\"strict\")"); + } +} diff --git a/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionRateLimitingAndCachingTests.cs b/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionRateLimitingAndCachingTests.cs new file mode 100644 index 0000000..f6483de --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionRateLimitingAndCachingTests.cs @@ -0,0 +1,176 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for rate-limiting ([EnableRateLimiting], [DisableRateLimiting]) and +/// output-caching ([OutputCache] with PolicyName / Duration / VaryByQueryKeys) attribute +/// emission. Verifies Disable overrides Enable rather than being silently dropped. +/// +public class MiddlewareEmissionRateLimitingAndCachingTests : GeneratorTestBase +{ + [Fact] + public async Task EnableRateLimiting_Emits_RequireRateLimiting() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.RateLimiting; + + public static class Api + { + [Get("/api")] + [EnableRateLimiting("fixed")] + public static ErrorOr GetData() => "data"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".RequireRateLimiting(\"fixed\")"); + } + + [Fact] + public async Task DisableRateLimiting_Emits_DisableRateLimiting() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.RateLimiting; + + public static class Api + { + [Get("/unlimited")] + [DisableRateLimiting] + public static ErrorOr Unlimited() => "unlimited"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".DisableRateLimiting()"); + } + + [Fact] + public async Task DisableRateLimiting_Overrides_EnableRateLimiting() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.RateLimiting; + + public static class Api + { + [Get("/unlimited")] + [EnableRateLimiting("fixed")] + [DisableRateLimiting] + public static ErrorOr Unlimited() => "unlimited"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + + // Extract only the endpoint mapping lines (exclude doc comments) + var endpointLines = generated.Split('\n') + .Where(static l => !l.TrimStart().StartsWith("///", StringComparison.Ordinal)) + .Where(static l => l.Contains("MapGet", StringComparison.Ordinal) || + l.Contains(".DisableRateLimiting", StringComparison.Ordinal) || + l.Contains(".RequireRateLimiting", StringComparison.Ordinal)); + var endpointSection = string.Join("\n", endpointLines); + + // DisableRateLimiting should override EnableRateLimiting + endpointSection.Should().Contain(".DisableRateLimiting()"); + endpointSection.Should().NotContain(".RequireRateLimiting("); + } + + [Fact] + public async Task OutputCache_Emits_CacheOutput() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.OutputCaching; + + public static class Api + { + [Get("/cached")] + [OutputCache] + public static ErrorOr Cached() => "cached"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".CacheOutput()"); + } + + [Fact] + public async Task OutputCache_With_PolicyName_Emits_Policy() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.OutputCaching; + + public static class Api + { + [Get("/cached")] + [OutputCache(PolicyName = "MyPolicy")] + public static ErrorOr Cached() => "cached"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain(".CacheOutput(\"MyPolicy\")"); + } + + [Fact] + public async Task OutputCache_With_Duration_Emits_Duration() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.OutputCaching; + + public static class Api + { + [Get("/cached")] + [OutputCache(Duration = 60)] + public static ErrorOr Cached() => "cached"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // Duration should be handled - either via policy builder or inline + generated.Should().Match("*.CacheOutput*"); + } + + [Fact] + public async Task OutputCache_With_VaryByQueryKeys_Emits_VaryBy() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.OutputCaching; + + public static class Api + { + [Get("/cached")] + [OutputCache(VaryByQueryKeys = new[] { "page", "sort" })] + public static ErrorOr Cached() => "cached"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Match("*.CacheOutput*"); + } +} diff --git a/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionTests.cs b/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionTests.cs deleted file mode 100644 index 8ed5399..0000000 --- a/tests/ErrorOrX.Generators.Tests/MiddlewareEmissionTests.cs +++ /dev/null @@ -1,563 +0,0 @@ -namespace ErrorOrX.Generators.Tests; - -/// -/// Tests for middleware attribute emission in the ErrorOrEndpointGenerator. -/// Security-critical: Verifies that [Authorize], [EnableRateLimiting], [OutputCache], [EnableCors] -/// are correctly translated to fluent calls since wrapper delegates lose original attributes. -/// -public class MiddlewareEmissionTests : GeneratorTestBase -{ - #region Combined Middleware - - [Fact] - public async Task Multiple_Middleware_Attributes_All_Emit() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - using Microsoft.AspNetCore.RateLimiting; - using Microsoft.AspNetCore.OutputCaching; - - public static class Api - { - [Get("/api")] - [Authorize("ApiPolicy")] - [EnableRateLimiting("standard")] - [OutputCache(PolicyName = "ApiCache")] - public static ErrorOr GetData() => "data"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".RequireAuthorization(\"ApiPolicy\")"); - generated.Should().Contain(".RequireRateLimiting(\"standard\")"); - generated.Should().Contain(".CacheOutput(\"ApiCache\")"); - } - - #endregion - - #region Authorization Middleware - - [Fact] - public async Task Authorize_Attribute_Emits_RequireAuthorization() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class Api - { - [Get("/admin")] - [Authorize] - public static ErrorOr AdminOnly() => "secret"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".RequireAuthorization()"); - } - - [Fact] - public async Task Authorize_With_Policy_String_Emits_Policy_Name() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class Api - { - [Get("/admin")] - [Authorize("AdminPolicy")] - public static ErrorOr AdminOnly() => "secret"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".RequireAuthorization(\"AdminPolicy\")"); - } - - [Fact] - public async Task Authorize_With_Policy_Named_Parameter_Emits_Policy_Name() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class Api - { - [Get("/admin")] - [Authorize(Policy = "AdminPolicy")] - public static ErrorOr AdminOnly() => "secret"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".RequireAuthorization(\"AdminPolicy\")"); - } - - [Fact] - public async Task Authorize_With_Roles_Emits_Roles() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class Api - { - [Get("/admin")] - [Authorize(Roles = "Admin,Manager")] - public static ErrorOr AdminOnly() => "secret"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // Roles should be handled - either via policy builder or RequireAuthorization with role - generated.Should().Match("*.RequireAuthorization*Admin*"); - } - - [Fact] - public async Task Multiple_Authorize_Attributes_Emit_All() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class Api - { - [Get("/admin")] - [Authorize("Policy1")] - [Authorize("Policy2")] - public static ErrorOr AdminOnly() => "secret"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("Policy1"); - generated.Should().Contain("Policy2"); - } - - [Fact] - public async Task AllowAnonymous_Overrides_Authorize() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class Api - { - [Get("/public")] - [Authorize] - [AllowAnonymous] - public static ErrorOr PublicEndpoint() => "public"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // AllowAnonymous should emit AllowAnonymous(), not RequireAuthorization() - generated.Should().Contain(".AllowAnonymous()"); - // The RequireAuthorization should be suppressed by AllowAnonymous - } - - [Fact] - public async Task Authorize_With_AuthenticationSchemes_Emits_Scheme() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class Api - { - [Get("/api")] - [Authorize(AuthenticationSchemes = "Bearer")] - public static ErrorOr ApiEndpoint() => "data"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // Authentication schemes should be handled - generated.Should().Match("*.RequireAuthorization*"); - } - - #endregion - - #region Rate Limiting Middleware - - [Fact] - public async Task EnableRateLimiting_Emits_RequireRateLimiting() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.RateLimiting; - - public static class Api - { - [Get("/api")] - [EnableRateLimiting("fixed")] - public static ErrorOr GetData() => "data"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".RequireRateLimiting(\"fixed\")"); - } - - [Fact] - public async Task DisableRateLimiting_Emits_DisableRateLimiting() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.RateLimiting; - - public static class Api - { - [Get("/unlimited")] - [DisableRateLimiting] - public static ErrorOr Unlimited() => "unlimited"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".DisableRateLimiting()"); - } - - [Fact] - public async Task DisableRateLimiting_Overrides_EnableRateLimiting() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.RateLimiting; - - public static class Api - { - [Get("/unlimited")] - [EnableRateLimiting("fixed")] - [DisableRateLimiting] - public static ErrorOr Unlimited() => "unlimited"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - - // Extract only the endpoint mapping lines (exclude doc comments) - var endpointLines = generated.Split('\n') - .Where(static l => !l.TrimStart().StartsWith("///", StringComparison.Ordinal)) - .Where(static l => l.Contains("MapGet", StringComparison.Ordinal) || - l.Contains(".DisableRateLimiting", StringComparison.Ordinal) || - l.Contains(".RequireRateLimiting", StringComparison.Ordinal)); - var endpointSection = string.Join("\n", endpointLines); - - // DisableRateLimiting should override EnableRateLimiting - endpointSection.Should().Contain(".DisableRateLimiting()"); - endpointSection.Should().NotContain(".RequireRateLimiting("); - } - - #endregion - - #region Output Caching Middleware - - [Fact] - public async Task OutputCache_Emits_CacheOutput() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.OutputCaching; - - public static class Api - { - [Get("/cached")] - [OutputCache] - public static ErrorOr Cached() => "cached"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".CacheOutput()"); - } - - [Fact] - public async Task OutputCache_With_PolicyName_Emits_Policy() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.OutputCaching; - - public static class Api - { - [Get("/cached")] - [OutputCache(PolicyName = "MyPolicy")] - public static ErrorOr Cached() => "cached"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".CacheOutput(\"MyPolicy\")"); - } - - [Fact] - public async Task OutputCache_With_Duration_Emits_Duration() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.OutputCaching; - - public static class Api - { - [Get("/cached")] - [OutputCache(Duration = 60)] - public static ErrorOr Cached() => "cached"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // Duration should be handled - either via policy builder or inline - generated.Should().Match("*.CacheOutput*"); - } - - [Fact] - public async Task OutputCache_With_VaryByQueryKeys_Emits_VaryBy() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.OutputCaching; - - public static class Api - { - [Get("/cached")] - [OutputCache(VaryByQueryKeys = new[] { "page", "sort" })] - public static ErrorOr Cached() => "cached"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Match("*.CacheOutput*"); - } - - #endregion - - #region CORS Middleware - - [Fact] - public async Task EnableCors_Emits_RequireCors() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Cors; - - public static class Api - { - [Get("/api")] - [EnableCors] - public static ErrorOr GetData() => "data"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".RequireCors()"); - } - - [Fact] - public async Task EnableCors_With_Policy_Emits_Policy_Name() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Cors; - - public static class Api - { - [Get("/api")] - [EnableCors("AllowAll")] - public static ErrorOr GetData() => "data"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".RequireCors(\"AllowAll\")"); - } - - [Fact] - public async Task DisableCors_Emits_DisableCors() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Cors; - - public static class Api - { - [Get("/internal")] - [DisableCors] - public static ErrorOr Internal() => "internal"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // DisableCors should be emitted - generated.Should().Match("*Cors*"); - } - - #endregion - - #region Endpoint Naming and Tags - - [Fact] - public async Task EndpointName_Uses_ClassName_And_MethodName() - { - const string Source = """ - using ErrorOr; - - public static class TodoApi - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => $"Todo {id}"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".WithName(\"TodoApi_GetById\")"); - } - - [Fact] - public async Task EndpointTags_Uses_ClassName() - { - const string Source = """ - using ErrorOr; - - public static class UserApi - { - [Get("/users")] - public static ErrorOr GetAll() => "users"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain(".WithTags(\"UserApi\")"); - } - - #endregion - - #region Security - Attribute Not Lost - - [Fact] - public async Task Security_Authorize_NotLost_InWrapper() - { - // This is the critical test - verifying the wrapper pattern doesn't lose security - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - - public static class AdminApi - { - [Get("/admin/secrets")] - [Authorize("SuperAdmin")] - public static ErrorOr GetSecrets() => "top-secret"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - - // The wrapper method (Invoke_Ep0) doesn't have [Authorize], so we MUST emit: - generated.Should().Contain(".RequireAuthorization(\"SuperAdmin\")", - "wrapper delegates lose original method attributes, so RequireAuthorization MUST be emitted"); - - // Also verify the endpoint builder chain includes auth - var lines = generated.Split('\n'); - var endpointLine = lines.FirstOrDefault(static l => - l.Contains("MapGet", StringComparison.Ordinal) && l.Contains("/admin/secrets", StringComparison.Ordinal)); - - // The RequireAuthorization should be in the fluent chain - generated.Should().Match("*MapGet*admin/secrets*RequireAuthorization*SuperAdmin*"); - } - - [Fact] - public async Task Security_Multiple_Policies_All_Applied() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Authorization; - using Microsoft.AspNetCore.RateLimiting; - - public static class SecureApi - { - [Get("/secure/data")] - [Authorize("Policy1")] - [Authorize("Policy2")] - [EnableRateLimiting("strict")] - public static ErrorOr GetSecureData() => "secure"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - - // All security attributes must be preserved - generated.Should().Contain("Policy1"); - generated.Should().Contain("Policy2"); - generated.Should().Contain(".RequireRateLimiting(\"strict\")"); - } - - #endregion -} diff --git a/tests/ErrorOrX.Generators.Tests/NamingAndValidCaseTests.cs b/tests/ErrorOrX.Generators.Tests/NamingAndValidCaseTests.cs new file mode 100644 index 0000000..85dfce7 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/NamingAndValidCaseTests.cs @@ -0,0 +1,155 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for handler-method naming diagnostics (EOE033) plus baseline +/// valid-input cases that must NOT emit any diagnostics. The valid cases +/// guard against false positives in the other rule sets. +/// +public class NamingAndValidCaseTests : GeneratorTestBase +{ + #region EOE033 - Handler method name not PascalCase + + [Fact] + public Task EOE033_Method_Name_Lowercase_Start() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr getById(int id) => $"todo {id}"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE033_Method_Name_With_Underscore() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr Get_By_Id(int id) => $"todo {id}"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE033_Method_Name_Snake_Case() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr get_by_id(int id) => $"todo {id}"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region Valid cases - no diagnostics + + [Fact] + public Task Valid_Route_Parameter_Bound() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => $"todo {id}"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task Valid_Complex_Type_With_AsParameters() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + + namespace DiagnosticTest; + + public class SearchFilter + { + public string Query { get; set; } + public int Page { get; set; } + } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr Search([AsParameters] SearchFilter filter) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task Valid_Complex_Type_With_FromBody_On_Post() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + namespace DiagnosticTest; + + public record CreateTodoRequest(string Title); + + public static class TodoApi + { + [Post("/todos")] + public static ErrorOr Create([FromBody] CreateTodoRequest request) => "created"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task Valid_Service_Type_Inferred() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public interface ITodoService { } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr GetAll(ITodoService service) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + #endregion +} diff --git a/tests/ErrorOrX.Generators.Tests/ParameterBindingExplicitAttributeTests.cs b/tests/ErrorOrX.Generators.Tests/ParameterBindingExplicitAttributeTests.cs new file mode 100644 index 0000000..ef07e07 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/ParameterBindingExplicitAttributeTests.cs @@ -0,0 +1,151 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for explicit attribute bindings — [FromBody], [FromServices], +/// [FromKeyedServices], [FromRoute], [FromQuery], [FromHeader]. +/// Verifies that explicit attributes override inference and that Name = "..." +/// overrides the parameter name on the bound source. +/// +public class ParameterBindingExplicitAttributeTests : GeneratorTestBase +{ + [Fact] + public async Task FromKeyedServices_Binds_With_Key() + { + const string Source = """ + using ErrorOr; + using Microsoft.Extensions.DependencyInjection; + + public interface ICache { string Get(string key); } + + public static class Api + { + [Get("/cached")] + public static ErrorOr Handler([FromKeyedServices("redis")] ICache cache) + => cache.Get("key"); + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("GetRequiredKeyedService(\"redis\")"); + } + + [Fact] + public async Task FromBody_Attribute_Forces_Body_Binding() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + using System.Text.Json.Serialization; + using Microsoft.AspNetCore.Http; + + public record Payload(string Data); + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler([FromBody] Payload payload) => payload.Data; + } + + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(Payload))] + [JsonSerializable(typeof(ProblemDetails))] + [JsonSerializable(typeof(HttpValidationProblemDetails))] + internal partial class TestJsonContext : JsonSerializerContext { } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("ReadFromJsonAsync"); + } + + [Fact] + public async Task FromServices_Attribute_Forces_Service_Binding() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + public class MyHelper { public string Help() => "help"; } + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler([FromServices] MyHelper helper) => helper.Help(); + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("GetRequiredService"); + } + + [Fact] + public async Task FromRoute_Attribute_Forces_Route_Binding() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + public static class Api + { + [Get("/items/{itemId}")] + public static ErrorOr Handler([FromRoute(Name = "itemId")] int id) => $"Item {id}"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("TryGetRouteValue(ctx, \"itemId\""); + } + + [Fact] + public async Task FromQuery_Attribute_With_Name_Uses_Custom_Key() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + public static class Api + { + [Get("/search")] + public static ErrorOr Handler([FromQuery(Name = "q")] string searchTerm) => searchTerm; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("TryGetQueryValue(ctx, \"q\""); + } + + [Fact] + public async Task FromHeader_Binds_From_Headers() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler([FromHeader(Name = "X-Api-Key")] string apiKey) => apiKey; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("ctx.Request.Headers.TryGetValue(\"X-Api-Key\""); + } +} diff --git a/tests/ErrorOrX.Generators.Tests/ParameterBindingInferenceTests.cs b/tests/ErrorOrX.Generators.Tests/ParameterBindingInferenceTests.cs new file mode 100644 index 0000000..2c7a3a8 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/ParameterBindingInferenceTests.cs @@ -0,0 +1,249 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for smart parameter binding inference: HTTP-method + type → source mapping, +/// service detection (interface, abstract, DI naming patterns), and the EOE021 +/// ambiguous-parameter diagnostic for bodyless verbs with complex types. +/// +public class ParameterBindingInferenceTests : GeneratorTestBase +{ + [Theory] + [InlineData("Post")] + [InlineData("Put")] + [InlineData("Patch")] + public async Task Complex_Type_On_BodyMethod_Infers_Body(string httpMethod) + { + var source = $$""" + using ErrorOr; + using System.Text.Json.Serialization; + using Microsoft.AspNetCore.Mvc; + using Microsoft.AspNetCore.Http; + + public record CreateRequest(string Name); + public record Response(int Id, string Name); + + public static class Api + { + [{{httpMethod}}("/test")] + public static ErrorOr Handler(CreateRequest req) => new Response(1, req.Name); + } + + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(CreateRequest))] + [JsonSerializable(typeof(Response))] + [JsonSerializable(typeof(ProblemDetails))] + [JsonSerializable(typeof(HttpValidationProblemDetails))] + internal partial class TestJsonContext : JsonSerializerContext { } + """; + + using var result = await RunAsync(source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("ReadFromJsonAsync"); + } + + [Fact] + public async Task Mixed_Parameter_Sources_Bind_Correctly() + { + const string Source = """ + using ErrorOr; + using System.Text.Json.Serialization; + using Microsoft.AspNetCore.Mvc; + using Microsoft.AspNetCore.Http; + + public interface ITodoService { string Create(int userId, string title); } + public record CreateTodoRequest(string Title); + + public static class Api + { + [Post("/users/{userId}/todos")] + public static ErrorOr Create( + int userId, // Route (matches {userId}) + CreateTodoRequest req, // Body (POST + complex type) + ITodoService svc) // Service (interface) + => svc.Create(userId, req.Title); + } + + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(CreateTodoRequest))] + [JsonSerializable(typeof(ProblemDetails))] + [JsonSerializable(typeof(HttpValidationProblemDetails))] + internal partial class TestJsonContext : JsonSerializerContext { } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("TryGetRouteValue(ctx, \"userId\""); + generated.Should().Contain("ReadFromJsonAsync"); + generated.Should().Contain("GetRequiredService"); + } + + [Fact] + public async Task Interface_Type_Infers_Service() + { + const string Source = """ + using ErrorOr; + + public interface IMyService { string GetValue(); } + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler(IMyService svc) => svc.GetValue(); + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("GetRequiredService"); + } + + [Fact] + public async Task Abstract_Type_Infers_Service() + { + const string Source = """ + using ErrorOr; + + public abstract class BaseService { public abstract string GetValue(); } + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler(BaseService svc) => svc.GetValue(); + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("GetRequiredService"); + } + + [Theory] + [InlineData("TodoRepository")] + [InlineData("TodoHandler")] + [InlineData("TodoManager")] + [InlineData("ConfigProvider")] + [InlineData("TodoFactory")] + [InlineData("HttpClient")] + public async Task Service_Naming_Pattern_Infers_Service(string typeName) + { + var source = $$""" + using ErrorOr; + + public class {{typeName}} { public string GetValue() => "test"; } + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler({{typeName}} svc) => svc.GetValue(); + } + """; + + using var result = await RunAsync(source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain($"GetRequiredService"); + } + + [Fact] + public async Task DbContext_Pattern_Infers_Service() + { + const string Source = """ + using ErrorOr; + + public class AppDbContext { public string Query() => "data"; } + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler(AppDbContext db) => db.Query(); + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("GetRequiredService"); + } + + [Theory] + [InlineData("Get")] + [InlineData("Delete")] + public async Task Complex_Type_On_BodylessMethod_Emits_EOE021(string httpMethod) + { + var source = $$""" + using ErrorOr; + + public record SearchFilter(string Query, int Page); + + public static class Api + { + [{{httpMethod}}("/test")] + public static ErrorOr Handler(SearchFilter filter) => "result"; + } + """; + + using var result = await RunAsync(source); + + result.Diagnostics.Should().ContainSingle(static d => d.Id == "EOE021"); + var diagnostic = result.Diagnostics.First(static d => d.Id == "EOE021"); + diagnostic.GetMessage().Should().Contain("filter"); + diagnostic.GetMessage().Should().Contain(httpMethod.ToUpperInvariant()); + } + + [Fact] + public async Task Complex_Type_With_Explicit_FromQuery_NoWarning() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Mvc; + + public record SearchFilter(string Query, int Page); + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler([FromQuery] SearchFilter filter) => "result"; + } + """; + + using var result = await RunAsync(Source); + + // EOE011: [FromQuery] only supports primitives or collections of primitives + // This is expected behavior - complex types can't be query bound without [AsParameters] + result.Diagnostics.Should().ContainSingle(static d => d.Id == "EOE011"); + } + + [Fact] + public async Task Complex_Type_With_AsParameters_NoWarning() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + + public record SearchFilter(string Query, int Page); + + public static class Api + { + [Get("/test")] + public static ErrorOr Handler([AsParameters] SearchFilter filter) => "result"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // AsParameters expands to individual bindings + generated.Should().Contain("TryGetQueryValue"); + } +} diff --git a/tests/ErrorOrX.Generators.Tests/ParameterBindingRouteQueryTests.cs b/tests/ErrorOrX.Generators.Tests/ParameterBindingRouteQueryTests.cs new file mode 100644 index 0000000..a970807 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/ParameterBindingRouteQueryTests.cs @@ -0,0 +1,170 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for implicit primitive binding (route by name match, query for unmapped primitives, +/// primitive collections), and custom binding via static TryParse on user-defined types. +/// +public class ParameterBindingRouteQueryTests : GeneratorTestBase +{ + [Fact] + public async Task Type_With_TryParse_Uses_Custom_Binding() + { + const string Source = """ + using ErrorOr; + + public readonly struct CustomId + { + public int Value { get; } + private CustomId(int value) => Value = value; + public static bool TryParse(string? s, out CustomId result) + { + if (int.TryParse(s, out var value)) + { + result = new CustomId(value); + return true; + } + result = default; + return false; + } + } + + public static class Api + { + [Get("/items/{id}")] + public static ErrorOr GetById(CustomId id) => $"Item {id.Value}"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("CustomId.TryParse"); + } + + [Fact] + public async Task Route_Parameter_Name_Match_Binds_From_Route() + { + const string Source = """ + using ErrorOr; + + public static class Api + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) => $"Todo {id}"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("TryGetRouteValue(ctx, \"id\""); + } + + [Fact] + public async Task Multiple_Route_Parameters_Bind_Correctly() + { + const string Source = """ + using ErrorOr; + + public static class Api + { + [Get("/users/{userId}/posts/{postId}")] + public static ErrorOr GetPost(int userId, int postId) + => $"User {userId} Post {postId}"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("TryGetRouteValue(ctx, \"userId\""); + generated.Should().Contain("TryGetRouteValue(ctx, \"postId\""); + } + + [Fact] + public async Task Guid_Route_Parameter_Uses_TryParse() + { + const string Source = """ + using ErrorOr; + using System; + + public static class Api + { + [Get("/items/{id}")] + public static ErrorOr GetById(Guid id) => id.ToString(); + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("Guid.TryParse"); + } + + [Fact] + public async Task Primitive_NotInRoute_Infers_Query() + { + const string Source = """ + using ErrorOr; + + public static class Api + { + [Get("/search")] + public static ErrorOr Search(string query, int page) => $"Query: {query}, Page: {page}"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("TryGetQueryValue(ctx, \"query\""); + generated.Should().Contain("TryGetQueryValue(ctx, \"page\""); + } + + [Fact] + public async Task Primitive_Array_Binds_As_Query_Collection() + { + const string Source = """ + using ErrorOr; + + public static class Api + { + [Get("/filter")] + public static ErrorOr Filter(int[] ids) => $"Count: {ids.Length}"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("ctx.Request.Query[\"ids\"]"); + generated.Should().Contain("ToArray()"); + } + + [Fact] + public async Task Nullable_Query_Parameter_Allows_Missing() + { + const string Source = """ + using ErrorOr; + + public static class Api + { + [Get("/search")] + public static ErrorOr Search(string? query) => query ?? "all"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // Nullable query parameters use default when missing, not BindFail + generated.Should().Contain("= default"); + } +} diff --git a/tests/ErrorOrX.Generators.Tests/ParameterBindingSpecialTypesTests.cs b/tests/ErrorOrX.Generators.Tests/ParameterBindingSpecialTypesTests.cs new file mode 100644 index 0000000..043ee9e --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/ParameterBindingSpecialTypesTests.cs @@ -0,0 +1,73 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for special parameter types that bypass normal classification: +/// HttpContext, CancellationToken, and Stream for request-body access. +/// +public class ParameterBindingSpecialTypesTests : GeneratorTestBase +{ + [Fact] + public async Task HttpContext_Binds_Directly() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + + public static class Api + { + [Get("/info")] + public static ErrorOr GetInfo(HttpContext context) => context.Request.Path; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + // HttpContext binds directly from ctx parameter (uses p0, p1, etc. naming) + generated.Should().Contain("= ctx;"); + generated.Should().Contain("global::Api.GetInfo(p0)"); + } + + [Fact] + public async Task CancellationToken_Binds_From_RequestAborted() + { + const string Source = """ + using ErrorOr; + using System.Threading; + + public static class Api + { + [Get("/long-running")] + public static ErrorOr LongRunning(CancellationToken cancellationToken) => "done"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("ctx.RequestAborted"); + } + + [Fact] + public async Task Stream_Binds_From_RequestBody() + { + const string Source = """ + using ErrorOr; + using System.IO; + + public static class Api + { + [Post("/upload")] + public static ErrorOr Upload(Stream body) => "uploaded"; + } + """; + + using var result = await RunAsync(Source); + + result.Diagnostics.Should().BeEmpty(); + var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; + generated.Should().Contain("ctx.Request.Body"); + } +} diff --git a/tests/ErrorOrX.Generators.Tests/ParameterBindingTests.cs b/tests/ErrorOrX.Generators.Tests/ParameterBindingTests.cs deleted file mode 100644 index 17f2ce2..0000000 --- a/tests/ErrorOrX.Generators.Tests/ParameterBindingTests.cs +++ /dev/null @@ -1,660 +0,0 @@ -namespace ErrorOrX.Generators.Tests; - -/// -/// Tests for smart parameter binding inference in the ErrorOrEndpointGenerator. -/// Covers: service detection, body inference, route/query binding, special types, and diagnostics. -/// -public class ParameterBindingTests : GeneratorTestBase -{ - #region HTTP Method + Complex Type → Body Inference - - [Theory] - [InlineData("Post")] - [InlineData("Put")] - [InlineData("Patch")] - public async Task Complex_Type_On_BodyMethod_Infers_Body(string httpMethod) - { - var source = $$""" - using ErrorOr; - using System.Text.Json.Serialization; - using Microsoft.AspNetCore.Mvc; - using Microsoft.AspNetCore.Http; - - public record CreateRequest(string Name); - public record Response(int Id, string Name); - - public static class Api - { - [{{httpMethod}}("/test")] - public static ErrorOr Handler(CreateRequest req) => new Response(1, req.Name); - } - - [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] - [JsonSerializable(typeof(CreateRequest))] - [JsonSerializable(typeof(Response))] - [JsonSerializable(typeof(ProblemDetails))] - [JsonSerializable(typeof(HttpValidationProblemDetails))] - internal partial class TestJsonContext : JsonSerializerContext { } - """; - - using var result = await RunAsync(source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("ReadFromJsonAsync"); - } - - #endregion - - #region Keyed Services - - [Fact] - public async Task FromKeyedServices_Binds_With_Key() - { - const string Source = """ - using ErrorOr; - using Microsoft.Extensions.DependencyInjection; - - public interface ICache { string Get(string key); } - - public static class Api - { - [Get("/cached")] - public static ErrorOr Handler([FromKeyedServices("redis")] ICache cache) - => cache.Get("key"); - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("GetRequiredKeyedService(\"redis\")"); - } - - #endregion - - #region Combined Parameters - - [Fact] - public async Task Mixed_Parameter_Sources_Bind_Correctly() - { - const string Source = """ - using ErrorOr; - using System.Text.Json.Serialization; - using Microsoft.AspNetCore.Mvc; - using Microsoft.AspNetCore.Http; - - public interface ITodoService { string Create(int userId, string title); } - public record CreateTodoRequest(string Title); - - public static class Api - { - [Post("/users/{userId}/todos")] - public static ErrorOr Create( - int userId, // Route (matches {userId}) - CreateTodoRequest req, // Body (POST + complex type) - ITodoService svc) // Service (interface) - => svc.Create(userId, req.Title); - } - - [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] - [JsonSerializable(typeof(CreateTodoRequest))] - [JsonSerializable(typeof(ProblemDetails))] - [JsonSerializable(typeof(HttpValidationProblemDetails))] - internal partial class TestJsonContext : JsonSerializerContext { } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("TryGetRouteValue(ctx, \"userId\""); - generated.Should().Contain("ReadFromJsonAsync"); - generated.Should().Contain("GetRequiredService"); - } - - #endregion - - #region Custom Binding (TryParse) - - [Fact] - public async Task Type_With_TryParse_Uses_Custom_Binding() - { - const string Source = """ - using ErrorOr; - - public readonly struct CustomId - { - public int Value { get; } - private CustomId(int value) => Value = value; - public static bool TryParse(string? s, out CustomId result) - { - if (int.TryParse(s, out var value)) - { - result = new CustomId(value); - return true; - } - result = default; - return false; - } - } - - public static class Api - { - [Get("/items/{id}")] - public static ErrorOr GetById(CustomId id) => $"Item {id.Value}"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("CustomId.TryParse"); - } - - #endregion - - #region Interface and Abstract Type → Service Inference - - [Fact] - public async Task Interface_Type_Infers_Service() - { - const string Source = """ - using ErrorOr; - - public interface IMyService { string GetValue(); } - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler(IMyService svc) => svc.GetValue(); - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("GetRequiredService"); - } - - [Fact] - public async Task Abstract_Type_Infers_Service() - { - const string Source = """ - using ErrorOr; - - public abstract class BaseService { public abstract string GetValue(); } - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler(BaseService svc) => svc.GetValue(); - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("GetRequiredService"); - } - - #endregion - - #region Service Naming Patterns - - [Theory] - [InlineData("TodoRepository")] - [InlineData("TodoHandler")] - [InlineData("TodoManager")] - [InlineData("ConfigProvider")] - [InlineData("TodoFactory")] - [InlineData("HttpClient")] - public async Task Service_Naming_Pattern_Infers_Service(string typeName) - { - var source = $$""" - using ErrorOr; - - public class {{typeName}} { public string GetValue() => "test"; } - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler({{typeName}} svc) => svc.GetValue(); - } - """; - - using var result = await RunAsync(source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain($"GetRequiredService"); - } - - [Fact] - public async Task DbContext_Pattern_Infers_Service() - { - const string Source = """ - using ErrorOr; - - public class AppDbContext { public string Query() => "data"; } - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler(AppDbContext db) => db.Query(); - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("GetRequiredService"); - } - - #endregion - - #region GET/DELETE + Complex Type → EOE021 Warning - - [Theory] - [InlineData("Get")] - [InlineData("Delete")] - public async Task Complex_Type_On_BodylessMethod_Emits_EOE021(string httpMethod) - { - var source = $$""" - using ErrorOr; - - public record SearchFilter(string Query, int Page); - - public static class Api - { - [{{httpMethod}}("/test")] - public static ErrorOr Handler(SearchFilter filter) => "result"; - } - """; - - using var result = await RunAsync(source); - - result.Diagnostics.Should().ContainSingle(static d => d.Id == "EOE021"); - var diagnostic = result.Diagnostics.First(static d => d.Id == "EOE021"); - diagnostic.GetMessage().Should().Contain("filter"); - diagnostic.GetMessage().Should().Contain(httpMethod.ToUpperInvariant()); - } - - [Fact] - public async Task Complex_Type_With_Explicit_FromQuery_NoWarning() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - public record SearchFilter(string Query, int Page); - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler([FromQuery] SearchFilter filter) => "result"; - } - """; - - using var result = await RunAsync(Source); - - // EOE011: [FromQuery] only supports primitives or collections of primitives - // This is expected behavior - complex types can't be query bound without [AsParameters] - result.Diagnostics.Should().ContainSingle(static d => d.Id == "EOE011"); - } - - [Fact] - public async Task Complex_Type_With_AsParameters_NoWarning() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - - public record SearchFilter(string Query, int Page); - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler([AsParameters] SearchFilter filter) => "result"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // AsParameters expands to individual bindings - generated.Should().Contain("TryGetQueryValue"); - } - - #endregion - - #region Route Parameter Binding - - [Fact] - public async Task Route_Parameter_Name_Match_Binds_From_Route() - { - const string Source = """ - using ErrorOr; - - public static class Api - { - [Get("/todos/{id}")] - public static ErrorOr GetById(int id) => $"Todo {id}"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("TryGetRouteValue(ctx, \"id\""); - } - - [Fact] - public async Task Multiple_Route_Parameters_Bind_Correctly() - { - const string Source = """ - using ErrorOr; - - public static class Api - { - [Get("/users/{userId}/posts/{postId}")] - public static ErrorOr GetPost(int userId, int postId) - => $"User {userId} Post {postId}"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("TryGetRouteValue(ctx, \"userId\""); - generated.Should().Contain("TryGetRouteValue(ctx, \"postId\""); - } - - [Fact] - public async Task Guid_Route_Parameter_Uses_TryParse() - { - const string Source = """ - using ErrorOr; - using System; - - public static class Api - { - [Get("/items/{id}")] - public static ErrorOr GetById(Guid id) => id.ToString(); - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("Guid.TryParse"); - } - - #endregion - - #region Query Parameter Binding - - [Fact] - public async Task Primitive_NotInRoute_Infers_Query() - { - const string Source = """ - using ErrorOr; - - public static class Api - { - [Get("/search")] - public static ErrorOr Search(string query, int page) => $"Query: {query}, Page: {page}"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("TryGetQueryValue(ctx, \"query\""); - generated.Should().Contain("TryGetQueryValue(ctx, \"page\""); - } - - [Fact] - public async Task Primitive_Array_Binds_As_Query_Collection() - { - const string Source = """ - using ErrorOr; - - public static class Api - { - [Get("/filter")] - public static ErrorOr Filter(int[] ids) => $"Count: {ids.Length}"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("ctx.Request.Query[\"ids\"]"); - generated.Should().Contain("ToArray()"); - } - - [Fact] - public async Task Nullable_Query_Parameter_Allows_Missing() - { - const string Source = """ - using ErrorOr; - - public static class Api - { - [Get("/search")] - public static ErrorOr Search(string? query) => query ?? "all"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // Nullable query parameters use default when missing, not BindFail - generated.Should().Contain("= default"); - } - - #endregion - - #region Special Types - - [Fact] - public async Task HttpContext_Binds_Directly() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Http; - - public static class Api - { - [Get("/info")] - public static ErrorOr GetInfo(HttpContext context) => context.Request.Path; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - // HttpContext binds directly from ctx parameter (uses p0, p1, etc. naming) - generated.Should().Contain("= ctx;"); - generated.Should().Contain("global::Api.GetInfo(p0)"); - } - - [Fact] - public async Task CancellationToken_Binds_From_RequestAborted() - { - const string Source = """ - using ErrorOr; - using System.Threading; - - public static class Api - { - [Get("/long-running")] - public static ErrorOr LongRunning(CancellationToken cancellationToken) => "done"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("ctx.RequestAborted"); - } - - [Fact] - public async Task Stream_Binds_From_RequestBody() - { - const string Source = """ - using ErrorOr; - using System.IO; - - public static class Api - { - [Post("/upload")] - public static ErrorOr Upload(Stream body) => "uploaded"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("ctx.Request.Body"); - } - - #endregion - - #region Explicit Attribute Bindings - - [Fact] - public async Task FromBody_Attribute_Forces_Body_Binding() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - using System.Text.Json.Serialization; - using Microsoft.AspNetCore.Http; - - public record Payload(string Data); - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler([FromBody] Payload payload) => payload.Data; - } - - [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] - [JsonSerializable(typeof(Payload))] - [JsonSerializable(typeof(ProblemDetails))] - [JsonSerializable(typeof(HttpValidationProblemDetails))] - internal partial class TestJsonContext : JsonSerializerContext { } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("ReadFromJsonAsync"); - } - - [Fact] - public async Task FromServices_Attribute_Forces_Service_Binding() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - public class MyHelper { public string Help() => "help"; } - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler([FromServices] MyHelper helper) => helper.Help(); - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("GetRequiredService"); - } - - [Fact] - public async Task FromRoute_Attribute_Forces_Route_Binding() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - public static class Api - { - [Get("/items/{itemId}")] - public static ErrorOr Handler([FromRoute(Name = "itemId")] int id) => $"Item {id}"; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("TryGetRouteValue(ctx, \"itemId\""); - } - - [Fact] - public async Task FromQuery_Attribute_With_Name_Uses_Custom_Key() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - public static class Api - { - [Get("/search")] - public static ErrorOr Handler([FromQuery(Name = "q")] string searchTerm) => searchTerm; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("TryGetQueryValue(ctx, \"q\""); - } - - [Fact] - public async Task FromHeader_Binds_From_Headers() - { - const string Source = """ - using ErrorOr; - using Microsoft.AspNetCore.Mvc; - - public static class Api - { - [Get("/test")] - public static ErrorOr Handler([FromHeader(Name = "X-Api-Key")] string apiKey) => apiKey; - } - """; - - using var result = await RunAsync(Source); - - result.Diagnostics.Should().BeEmpty(); - var generated = result.Files.First(static f => f.HintName == "ErrorOrEndpointMappings.cs").Content; - generated.Should().Contain("ctx.Request.Headers.TryGetValue(\"X-Api-Key\""); - } - - #endregion -} diff --git a/tests/ErrorOrX.Generators.Tests/RouteBodyValidationTests.cs b/tests/ErrorOrX.Generators.Tests/RouteBodyValidationTests.cs new file mode 100644 index 0000000..5559314 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/RouteBodyValidationTests.cs @@ -0,0 +1,320 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for route shape, body-source, and high-level handler validation +/// diagnostics (EOE003, EOE005-EOE006, EOE015, EOE018-EOE021). Covers +/// unbound route parameters, malformed route patterns, multiple body +/// sources, inaccessible or unsupported return types, generic type +/// parameters, route constraint mismatches, and ambiguous bindings. +/// +public class RouteBodyValidationTests : GeneratorTestBase +{ + #region EOE003 - Route parameter not bound + + [Fact] + public Task EOE003_Route_Parameter_Not_Bound() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById() => "todo"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE003_Route_Parameter_With_Constraint_Not_Bound() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id:int}")] + public static ErrorOr GetById() => "todo"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE005 - Invalid route pattern + + [Fact] + public Task EOE005_Unclosed_Brace_In_Route() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id")] + public static ErrorOr GetById(int id) => "todo"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE005_Unmatched_Close_Brace() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/id}")] + public static ErrorOr GetById(int id) => "todo"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE005_Empty_Parameter_Name() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{}")] + public static ErrorOr GetById() => "todo"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE006 - Multiple body sources + + [Fact] + public Task EOE006_Multiple_Body_Sources_FromBody_And_FromForm() + { + const string Source = """ + using ErrorOr; + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Mvc; + + namespace DiagnosticTest; + + public record CreateRequest(string Name); + + public static class TodoApi + { + [Post("/todos")] + public static ErrorOr Create( + [FromBody] CreateRequest body, + [FromForm] IFormFile file) => "created"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE006_Multiple_Body_Sources_Stream_And_FromBody() + { + const string Source = """ + using ErrorOr; + using System.IO; + using Microsoft.AspNetCore.Mvc; + + namespace DiagnosticTest; + + public record CreateRequest(string Name); + + public static class TodoApi + { + [Post("/upload")] + public static ErrorOr Upload( + [FromBody] CreateRequest body, + Stream data) => "uploaded"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE015 - Anonymous return type not supported + + [Fact] + public Task EOE015_Anonymous_Return_Type() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/data")] + public static ErrorOr GetData() => new { Name = "test" }; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE018 - Inaccessible type in endpoint + + [Fact] + public Task EOE018_Private_Return_Type() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + private class SecretData { public string Value { get; set; } } + + [Get("/secret")] + public static ErrorOr GetSecret() => new SecretData { Value = "secret" }; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE019 - Type parameter not supported + + [Fact] + public Task EOE019_Generic_Type_Parameter() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class GenericApi + { + [Get("/items")] + public static ErrorOr GetItem() where T : class => default!; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE020 - Route constraint type mismatch + + [Fact] + public Task EOE020_Int_Constraint_With_String_Parameter() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id:int}")] + public static ErrorOr GetById(string id) => "todo"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE020_Guid_Constraint_With_Int_Parameter() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id:guid}")] + public static ErrorOr GetById(int id) => "todo"; + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE021 - Ambiguous parameter binding + + [Fact] + public Task EOE021_Complex_Type_On_Get_Without_Binding() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public class SearchFilter + { + public string Query { get; set; } + public int Page { get; set; } + } + + public static class TodoApi + { + [Get("/todos")] + public static ErrorOr Search(SearchFilter filter) => "todos"; + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE021_Complex_Type_On_Delete_Without_Binding() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public class DeleteOptions + { + public bool Force { get; set; } + } + + public static class TodoApi + { + [Delete("/todos/{id}")] + public static ErrorOr Delete(int id, DeleteOptions options) => "deleted"; + } + """; + + return VerifyAsync(Source); + } + + #endregion +} diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE010_Invalid_FromRoute_Type_Complex.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE010_Invalid_FromRoute_Type_Complex.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE010_Invalid_FromRoute_Type_Complex.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE010_Invalid_FromRoute_Type_Complex.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE011_Invalid_FromQuery_Type_Complex.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE011_Invalid_FromQuery_Type_Complex.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE011_Invalid_FromQuery_Type_Complex.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE011_Invalid_FromQuery_Type_Complex.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE012_Invalid_AsParameters_Type_Primitive.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE012_Invalid_AsParameters_Type_Primitive.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE012_Invalid_AsParameters_Type_Primitive.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE012_Invalid_AsParameters_Type_Primitive.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE013_AsParameters_No_Constructor.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE013_AsParameters_No_Constructor.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE013_AsParameters_No_Constructor.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE013_AsParameters_No_Constructor.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE014_Invalid_FromHeader_Type_Complex.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE014_Invalid_FromHeader_Type_Complex.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE014_Invalid_FromHeader_Type_Complex.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE014_Invalid_FromHeader_Type_Complex.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE016_Nested_AsParameters.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE016_Nested_AsParameters.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE016_Nested_AsParameters.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE016_Nested_AsParameters.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE017_Nullable_AsParameters.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE017_Nullable_AsParameters.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE017_Nullable_AsParameters.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/BindingTypeValidationTests.EOE017_Nullable_AsParameters.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE007_Type_Not_In_Json_Context.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE007_Type_Not_In_Json_Context.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE007_Type_Not_In_Json_Context.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE007_Type_Not_In_Json_Context.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE025_Missing_CamelCase_Policy.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE025_Missing_CamelCase_Policy.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE025_Missing_CamelCase_Policy.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE025_Missing_CamelCase_Policy.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE025_With_CamelCase_Policy_No_Diagnostic.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE025_With_CamelCase_Policy_No_Diagnostic.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE025_With_CamelCase_Policy_No_Diagnostic.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE025_With_CamelCase_Policy_No_Diagnostic.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE039_Multiple_Validation_Attributes.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE039_Multiple_Validation_Attributes.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE039_Multiple_Validation_Attributes.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE039_Multiple_Validation_Attributes.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE039_Validation_Attribute_On_Parameter.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE039_Validation_Attribute_On_Parameter.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE039_Validation_Attribute_On_Parameter.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE039_Validation_Attribute_On_Parameter.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE041_Missing_ProblemDetails_In_JsonContext.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE041_Missing_ProblemDetails_In_JsonContext.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE041_Missing_ProblemDetails_In_JsonContext.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE041_Missing_ProblemDetails_In_JsonContext.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE041_No_Diagnostic_When_ProblemDetails_Present.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE041_No_Diagnostic_When_ProblemDetails_Present.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE041_No_Diagnostic_When_ProblemDetails_Present.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/JsonAotValidationTests.EOE041_No_Diagnostic_When_ProblemDetails_Present.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE033_Method_Name_Lowercase_Start.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.EOE033_Method_Name_Lowercase_Start.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE033_Method_Name_Lowercase_Start.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.EOE033_Method_Name_Lowercase_Start.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE033_Method_Name_Snake_Case.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.EOE033_Method_Name_Snake_Case.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE033_Method_Name_Snake_Case.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.EOE033_Method_Name_Snake_Case.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE033_Method_Name_With_Underscore.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.EOE033_Method_Name_With_Underscore.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE033_Method_Name_With_Underscore.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.EOE033_Method_Name_With_Underscore.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Complex_Type_With_AsParameters.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Complex_Type_With_AsParameters.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Complex_Type_With_AsParameters.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Complex_Type_With_AsParameters.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Complex_Type_With_FromBody_On_Post.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Complex_Type_With_FromBody_On_Post.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Complex_Type_With_FromBody_On_Post.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Complex_Type_With_FromBody_On_Post.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Route_Parameter_Bound.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Route_Parameter_Bound.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Route_Parameter_Bound.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Route_Parameter_Bound.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Service_Type_Inferred.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Service_Type_Inferred.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.Valid_Service_Type_Inferred.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/NamingAndValidCaseTests.Valid_Service_Type_Inferred.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE003_Route_Parameter_Not_Bound.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE003_Route_Parameter_Not_Bound.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE003_Route_Parameter_Not_Bound.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE003_Route_Parameter_Not_Bound.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE003_Route_Parameter_With_Constraint_Not_Bound.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE003_Route_Parameter_With_Constraint_Not_Bound.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE003_Route_Parameter_With_Constraint_Not_Bound.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE003_Route_Parameter_With_Constraint_Not_Bound.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE005_Empty_Parameter_Name.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE005_Empty_Parameter_Name.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE005_Empty_Parameter_Name.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE005_Empty_Parameter_Name.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE005_Unclosed_Brace_In_Route.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE005_Unclosed_Brace_In_Route.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE005_Unclosed_Brace_In_Route.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE005_Unclosed_Brace_In_Route.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE005_Unmatched_Close_Brace.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE005_Unmatched_Close_Brace.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE005_Unmatched_Close_Brace.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE005_Unmatched_Close_Brace.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE006_Multiple_Body_Sources_FromBody_And_FromForm.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE006_Multiple_Body_Sources_FromBody_And_FromForm.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE006_Multiple_Body_Sources_FromBody_And_FromForm.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE006_Multiple_Body_Sources_FromBody_And_FromForm.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE006_Multiple_Body_Sources_Stream_And_FromBody.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE006_Multiple_Body_Sources_Stream_And_FromBody.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE006_Multiple_Body_Sources_Stream_And_FromBody.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE006_Multiple_Body_Sources_Stream_And_FromBody.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE015_Anonymous_Return_Type.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE015_Anonymous_Return_Type.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE015_Anonymous_Return_Type.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE015_Anonymous_Return_Type.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE018_Private_Return_Type.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE018_Private_Return_Type.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE018_Private_Return_Type.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE018_Private_Return_Type.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE019_Generic_Type_Parameter.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE019_Generic_Type_Parameter.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE019_Generic_Type_Parameter.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE019_Generic_Type_Parameter.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE020_Guid_Constraint_With_Int_Parameter.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE020_Guid_Constraint_With_Int_Parameter.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE020_Guid_Constraint_With_Int_Parameter.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE020_Guid_Constraint_With_Int_Parameter.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE020_Int_Constraint_With_String_Parameter.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE020_Int_Constraint_With_String_Parameter.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE020_Int_Constraint_With_String_Parameter.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE020_Int_Constraint_With_String_Parameter.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE021_Complex_Type_On_Delete_Without_Binding.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE021_Complex_Type_On_Delete_Without_Binding.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE021_Complex_Type_On_Delete_Without_Binding.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE021_Complex_Type_On_Delete_Without_Binding.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE021_Complex_Type_On_Get_Without_Binding.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE021_Complex_Type_On_Get_Without_Binding.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE021_Complex_Type_On_Get_Without_Binding.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/RouteBodyValidationTests.EOE021_Complex_Type_On_Get_Without_Binding.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE022_Too_Many_Result_Types.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE022_Too_Many_Result_Types.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE022_Too_Many_Result_Types.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE022_Too_Many_Result_Types.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE023_Unknown_Error_Factory.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE023_Unknown_Error_Factory.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE023_Unknown_Error_Factory.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE023_Unknown_Error_Factory.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE024_Interface_Call_With_ProducesError_No_Diagnostic.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE024_Interface_Call_With_ProducesError_No_Diagnostic.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE024_Interface_Call_With_ProducesError_No_Diagnostic.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE024_Interface_Call_With_ProducesError_No_Diagnostic.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE024_Undocumented_Interface_Call.verified.txt b/tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE024_Undocumented_Interface_Call.verified.txt similarity index 100% rename from tests/ErrorOrX.Generators.Tests/Snapshots/DiagnosticTests.EOE024_Undocumented_Interface_Call.verified.txt rename to tests/ErrorOrX.Generators.Tests/Snapshots/UnionTypeAndFactoryTests.EOE024_Undocumented_Interface_Call.verified.txt diff --git a/tests/ErrorOrX.Generators.Tests/UnionTypeAndFactoryTests.cs b/tests/ErrorOrX.Generators.Tests/UnionTypeAndFactoryTests.cs new file mode 100644 index 0000000..0e2c111 --- /dev/null +++ b/tests/ErrorOrX.Generators.Tests/UnionTypeAndFactoryTests.cs @@ -0,0 +1,122 @@ +namespace ErrorOrX.Generators.Tests; + +/// +/// Tests for results-union and error-factory diagnostics +/// (EOE022-EOE024). Covers unions with too many result types, unknown +/// Error.* factories, and undocumented interface calls that need an +/// explicit [ProducesError] attribute. +/// +public class UnionTypeAndFactoryTests : GeneratorTestBase +{ + #region EOE023 - Unknown error factory + + [Fact] + public Task EOE023_Unknown_Error_Factory() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) + { + if (id < 0) + return Error.Custom(999, "custom", "description"); + return "todo"; + } + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE022 - Too many result types + + [Fact] + public Task EOE022_Too_Many_Result_Types() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id) + { + if (id == 0) return Error.NotFound("Todo.NotFound", "Not found"); + if (id == 1) return Error.Validation("Todo.Invalid", "Invalid"); + if (id == 2) return Error.Conflict("Todo.Conflict", "Conflict"); + if (id == 3) return Error.Unauthorized("Todo.Unauthorized", "Unauthorized"); + if (id == 4) return Error.Forbidden("Todo.Forbidden", "Forbidden"); + if (id == 5) return Error.Failure("Todo.Failure", "Failure"); + if (id == 6) return Error.Unexpected("Todo.Unexpected", "Unexpected"); + return $"todo {id}"; + } + } + """; + + return VerifyAsync(Source); + } + + #endregion + + #region EOE024 - Undocumented interface call + + [Fact] + public Task EOE024_Undocumented_Interface_Call() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public interface ITodoService + { + ErrorOr GetById(int id); + } + + public static class TodoApi + { + [Get("/todos/{id}")] + public static ErrorOr GetById(int id, ITodoService svc) + => svc.GetById(id); + } + """; + + return VerifyAsync(Source); + } + + [Fact] + public Task EOE024_Interface_Call_With_ProducesError_No_Diagnostic() + { + const string Source = """ + using ErrorOr; + + namespace DiagnosticTest; + + public interface ITodoService + { + ErrorOr GetById(int id); + } + + public static class TodoApi + { + [Get("/todos/{id}")] + [ProducesError(404, "NotFound")] + public static ErrorOr GetById(int id, ITodoService svc) + => svc.GetById(id); + } + """; + + return VerifyAsync(Source); + } + + #endregion +} diff --git a/tests/ErrorOrX.Tests/ErrorOr/ErrorOrTests.cs b/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.AccessTests.cs similarity index 51% rename from tests/ErrorOrX.Tests/ErrorOr/ErrorOrTests.cs rename to tests/ErrorOrX.Tests/ErrorOr/ErrorOr.AccessTests.cs index eec42ad..b56512e 100644 --- a/tests/ErrorOrX.Tests/ErrorOr/ErrorOrTests.cs +++ b/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.AccessTests.cs @@ -1,12 +1,11 @@ namespace ErrorOrX.Tests.ErrorOr; /// -/// Unit tests for the ErrorOr<TValue> struct. +/// Tests for ErrorOr<TValue> state, value/error access, implicit conversions, +/// and Error factory ErrorType assignment. /// -public class ErrorOrTests +public class ErrorOrAccessTests { - #region Error with Metadata Tests - [Fact] public void Error_WithMetadata_ShouldContainMetadata() { @@ -22,10 +21,6 @@ public void Error_WithMetadata_ShouldContainMetadata() error.Metadata["Field"].Should().Be("Email"); } - #endregion - - #region IsError and IsSuccess Tests - [Fact] public void IsError_WhenCreatedWithValue_ShouldBeFalse() { @@ -66,10 +61,6 @@ public void IsSuccess_WhenCreatedWithError_ShouldBeFalse() result.IsSuccess.Should().BeFalse(); } - #endregion - - #region Value Access Tests - [Fact] public void Value_WhenCreatedWithValue_ShouldReturnValue() { @@ -98,10 +89,6 @@ public void Value_WhenCreatedWithError_ShouldThrowInvalidOperationException() .WithMessage("*cannot be accessed when errors have been recorded*"); } - #endregion - - #region Errors Access Tests - [Fact] public void Errors_WhenCreatedWithError_ShouldReturnErrors() { @@ -179,10 +166,6 @@ public void FirstError_WhenCreatedWithValue_ShouldThrowInvalidOperationException .WithMessage("*cannot be accessed when no errors have been recorded*"); } - #endregion - - #region Implicit Conversion Tests - [Fact] public void ImplicitConversion_FromValue_ShouldCreateSuccessResult() { @@ -226,10 +209,6 @@ public void ImplicitConversion_FromErrorList_ShouldCreateErrorResult() result.Errors.Should().HaveCount(2); } - #endregion - - #region Error Type Tests - [Fact] public void Error_Failure_ShouldHaveCorrectType() { @@ -301,265 +280,4 @@ public void Error_Unexpected_ShouldHaveCorrectType() // Assert error.Type.Should().Be(ErrorType.Unexpected); } - - #endregion - - #region Match Tests - - [Fact] - public void Match_WhenSuccess_ShouldInvokeOnValue() - { - // Arrange - ErrorOr result = 42; - var onValueInvoked = false; - var onErrorInvoked = false; - - // Act - var output = result.Match( - value => - { - onValueInvoked = true; - return value * 2; - }, - _ => - { - onErrorInvoked = true; - return -1; - }); - - // Assert - onValueInvoked.Should().BeTrue(); - onErrorInvoked.Should().BeFalse(); - output.Should().Be(84); - } - - [Fact] - public void Match_WhenError_ShouldInvokeOnError() - { - // Arrange - ErrorOr result = Error.Failure("Test.Error", "A test error"); - var onValueInvoked = false; - var onErrorInvoked = false; - - // Act - var output = result.Match( - value => - { - onValueInvoked = true; - return value * 2; - }, - _ => - { - onErrorInvoked = true; - return -1; - }); - - // Assert - onValueInvoked.Should().BeFalse(); - onErrorInvoked.Should().BeTrue(); - output.Should().Be(-1); - } - - #endregion - - #region Switch Tests - - [Fact] - public void Switch_WhenSuccess_ShouldInvokeOnValue() - { - // Arrange - ErrorOr result = 42; - var onValueInvoked = false; - var onErrorInvoked = false; - - // Act - result.Switch( - _ => onValueInvoked = true, - _ => onErrorInvoked = true); - - // Assert - onValueInvoked.Should().BeTrue(); - onErrorInvoked.Should().BeFalse(); - } - - [Fact] - public void Switch_WhenError_ShouldInvokeOnError() - { - // Arrange - ErrorOr result = Error.Failure("Test.Error", "A test error"); - var onValueInvoked = false; - var onErrorInvoked = false; - - // Act - result.Switch( - _ => onValueInvoked = true, - _ => onErrorInvoked = true); - - // Assert - onValueInvoked.Should().BeFalse(); - onErrorInvoked.Should().BeTrue(); - } - - #endregion - - #region Then/ThenDo Tests - - [Fact] - public void Then_WhenSuccess_ShouldTransformValue() - { - // Arrange - ErrorOr result = 42; - - // Act - var transformed = result.Then(static value => value.ToString()); - - // Assert - transformed.IsError.Should().BeFalse(); - transformed.Value.Should().Be("42"); - } - - [Fact] - public void Then_WhenError_ShouldPropagateError() - { - // Arrange - var error = Error.Failure("Test.Error", "A test error"); - ErrorOr result = error; - - // Act - var transformed = result.Then(static value => value.ToString()); - - // Assert - transformed.IsError.Should().BeTrue(); - transformed.FirstError.Should().Be(error); - } - - [Fact] - public void ThenDo_WhenSuccess_ShouldExecuteAction() - { - // Arrange - ErrorOr result = 42; - var actionExecuted = false; - - // Act - var output = result.ThenDo(_ => actionExecuted = true); - - // Assert - actionExecuted.Should().BeTrue(); - output.IsError.Should().BeFalse(); - output.Value.Should().Be(42); - } - - [Fact] - public void ThenDo_WhenError_ShouldNotExecuteAction() - { - // Arrange - var error = Error.Failure("Test.Error", "A test error"); - ErrorOr result = error; - var actionExecuted = false; - - // Act - var output = result.ThenDo(_ => actionExecuted = true); - - // Assert - actionExecuted.Should().BeFalse(); - output.IsError.Should().BeTrue(); - } - - #endregion - - #region Else Tests - - [Fact] - public void Else_WhenSuccess_ShouldReturnOriginalValue() - { - // Arrange - ErrorOr result = 42; - - // Act - Else returns ErrorOr, so extract .Value - var output = result.Else(static _ => -1); - - // Assert - output.IsError.Should().BeFalse(); - output.Value.Should().Be(42); - } - - [Fact] - public void Else_WhenError_ShouldReturnFallbackValue() - { - // Arrange - ErrorOr result = Error.Failure("Test.Error", "A test error"); - - // Act - Else returns ErrorOr with the fallback value - var output = result.Else(static _ => -1); - - // Assert - output.IsError.Should().BeFalse(); - output.Value.Should().Be(-1); - } - - [Fact] - public void ElseWithValue_WhenError_ShouldReturnFallbackValue() - { - // Arrange - ErrorOr result = Error.Failure("Test.Error", "A test error"); - - // Act - Else returns ErrorOr with the fallback value - var output = result.Else(-1); - - // Assert - output.IsError.Should().BeFalse(); - output.Value.Should().Be(-1); - } - - #endregion - - #region FailIf Tests - - [Fact] - public void FailIf_WhenPredicateIsFalse_ShouldReturnOriginalValue() - { - // Arrange - ErrorOr result = 42; - var error = Error.Validation("Test.Validation", "Value is invalid"); - - // Act - var output = result.FailIf(static value => value < 0, in error); - - // Assert - output.IsError.Should().BeFalse(); - output.Value.Should().Be(42); - } - - [Fact] - public void FailIf_WhenPredicateIsTrue_ShouldReturnError() - { - // Arrange - ErrorOr result = -5; - var error = Error.Validation("Test.Validation", "Value must be positive"); - - // Act - var output = result.FailIf(static value => value < 0, in error); - - // Assert - output.IsError.Should().BeTrue(); - output.FirstError.Should().Be(error); - } - - [Fact] - public void FailIf_WhenAlreadyError_ShouldReturnOriginalError() - { - // Arrange - var originalError = Error.Failure("Original.Error", "Original error"); - ErrorOr result = originalError; - var newError = Error.Validation("Test.Validation", "Value is invalid"); - - // Act - var output = result.FailIf(static _ => true, in newError); - - // Assert - output.IsError.Should().BeTrue(); - output.FirstError.Should().Be(originalError); - } - - #endregion } diff --git a/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.ChainingTests.cs b/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.ChainingTests.cs new file mode 100644 index 0000000..ebddb77 --- /dev/null +++ b/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.ChainingTests.cs @@ -0,0 +1,159 @@ +namespace ErrorOrX.Tests.ErrorOr; + +/// +/// Tests for the railway-oriented chaining operators on ErrorOr<TValue>: +/// Then / ThenDo (continue on success), Else (fallback on error), +/// and FailIf (conditional failure). The deeper async variants live in +/// ErrorOr.ElseAsyncTests.cs; the basic synchronous core lives here. +/// +public class ErrorOrChainingTests +{ + [Fact] + public void Then_WhenSuccess_ShouldTransformValue() + { + // Arrange + ErrorOr result = 42; + + // Act + var transformed = result.Then(static value => value.ToString()); + + // Assert + transformed.IsError.Should().BeFalse(); + transformed.Value.Should().Be("42"); + } + + [Fact] + public void Then_WhenError_ShouldPropagateError() + { + // Arrange + var error = Error.Failure("Test.Error", "A test error"); + ErrorOr result = error; + + // Act + var transformed = result.Then(static value => value.ToString()); + + // Assert + transformed.IsError.Should().BeTrue(); + transformed.FirstError.Should().Be(error); + } + + [Fact] + public void ThenDo_WhenSuccess_ShouldExecuteAction() + { + // Arrange + ErrorOr result = 42; + var actionExecuted = false; + + // Act + var output = result.ThenDo(_ => actionExecuted = true); + + // Assert + actionExecuted.Should().BeTrue(); + output.IsError.Should().BeFalse(); + output.Value.Should().Be(42); + } + + [Fact] + public void ThenDo_WhenError_ShouldNotExecuteAction() + { + // Arrange + var error = Error.Failure("Test.Error", "A test error"); + ErrorOr result = error; + var actionExecuted = false; + + // Act + var output = result.ThenDo(_ => actionExecuted = true); + + // Assert + actionExecuted.Should().BeFalse(); + output.IsError.Should().BeTrue(); + } + + [Fact] + public void Else_WhenSuccess_ShouldReturnOriginalValue() + { + // Arrange + ErrorOr result = 42; + + // Act - Else returns ErrorOr, so extract .Value + var output = result.Else(static _ => -1); + + // Assert + output.IsError.Should().BeFalse(); + output.Value.Should().Be(42); + } + + [Fact] + public void Else_WhenError_ShouldReturnFallbackValue() + { + // Arrange + ErrorOr result = Error.Failure("Test.Error", "A test error"); + + // Act - Else returns ErrorOr with the fallback value + var output = result.Else(static _ => -1); + + // Assert + output.IsError.Should().BeFalse(); + output.Value.Should().Be(-1); + } + + [Fact] + public void ElseWithValue_WhenError_ShouldReturnFallbackValue() + { + // Arrange + ErrorOr result = Error.Failure("Test.Error", "A test error"); + + // Act - Else returns ErrorOr with the fallback value + var output = result.Else(-1); + + // Assert + output.IsError.Should().BeFalse(); + output.Value.Should().Be(-1); + } + + [Fact] + public void FailIf_WhenPredicateIsFalse_ShouldReturnOriginalValue() + { + // Arrange + ErrorOr result = 42; + var error = Error.Validation("Test.Validation", "Value is invalid"); + + // Act + var output = result.FailIf(static value => value < 0, in error); + + // Assert + output.IsError.Should().BeFalse(); + output.Value.Should().Be(42); + } + + [Fact] + public void FailIf_WhenPredicateIsTrue_ShouldReturnError() + { + // Arrange + ErrorOr result = -5; + var error = Error.Validation("Test.Validation", "Value must be positive"); + + // Act + var output = result.FailIf(static value => value < 0, in error); + + // Assert + output.IsError.Should().BeTrue(); + output.FirstError.Should().Be(error); + } + + [Fact] + public void FailIf_WhenAlreadyError_ShouldReturnOriginalError() + { + // Arrange + var originalError = Error.Failure("Original.Error", "Original error"); + ErrorOr result = originalError; + var newError = Error.Validation("Test.Validation", "Value is invalid"); + + // Act + var output = result.FailIf(static _ => true, in newError); + + // Assert + output.IsError.Should().BeTrue(); + output.FirstError.Should().Be(originalError); + } +} diff --git a/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.MatchAndSwitchTests.cs b/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.MatchAndSwitchTests.cs new file mode 100644 index 0000000..81077d8 --- /dev/null +++ b/tests/ErrorOrX.Tests/ErrorOr/ErrorOr.MatchAndSwitchTests.cs @@ -0,0 +1,98 @@ +namespace ErrorOrX.Tests.ErrorOr; + +/// +/// Tests for the Match (transformation) and Switch (side-effect) handlers +/// on ErrorOr<TValue>. +/// +public class ErrorOrMatchAndSwitchTests +{ + [Fact] + public void Match_WhenSuccess_ShouldInvokeOnValue() + { + // Arrange + ErrorOr result = 42; + var onValueInvoked = false; + var onErrorInvoked = false; + + // Act + var output = result.Match( + value => + { + onValueInvoked = true; + return value * 2; + }, + _ => + { + onErrorInvoked = true; + return -1; + }); + + // Assert + onValueInvoked.Should().BeTrue(); + onErrorInvoked.Should().BeFalse(); + output.Should().Be(84); + } + + [Fact] + public void Match_WhenError_ShouldInvokeOnError() + { + // Arrange + ErrorOr result = Error.Failure("Test.Error", "A test error"); + var onValueInvoked = false; + var onErrorInvoked = false; + + // Act + var output = result.Match( + value => + { + onValueInvoked = true; + return value * 2; + }, + _ => + { + onErrorInvoked = true; + return -1; + }); + + // Assert + onValueInvoked.Should().BeFalse(); + onErrorInvoked.Should().BeTrue(); + output.Should().Be(-1); + } + + [Fact] + public void Switch_WhenSuccess_ShouldInvokeOnValue() + { + // Arrange + ErrorOr result = 42; + var onValueInvoked = false; + var onErrorInvoked = false; + + // Act + result.Switch( + _ => onValueInvoked = true, + _ => onErrorInvoked = true); + + // Assert + onValueInvoked.Should().BeTrue(); + onErrorInvoked.Should().BeFalse(); + } + + [Fact] + public void Switch_WhenError_ShouldInvokeOnError() + { + // Arrange + ErrorOr result = Error.Failure("Test.Error", "A test error"); + var onValueInvoked = false; + var onErrorInvoked = false; + + // Act + result.Switch( + _ => onValueInvoked = true, + _ => onErrorInvoked = true); + + // Assert + onValueInvoked.Should().BeFalse(); + onErrorInvoked.Should().BeTrue(); + } +}