diff --git a/cmd/proxygenerator/main.go b/cmd/proxygenerator/main.go index 0a7a593b..162bb92a 100644 --- a/cmd/proxygenerator/main.go +++ b/cmd/proxygenerator/main.go @@ -31,7 +31,12 @@ func main() { log.Print(interceptorErr) } - if serviceErr != nil || interceptorErr != nil { + requestHeaderErr := generateRequestHeader(cfg) + if requestHeaderErr != nil { + log.Print(requestHeaderErr) + } + + if serviceErr != nil || interceptorErr != nil || requestHeaderErr != nil { os.Exit(1) } } diff --git a/cmd/proxygenerator/request_header.go b/cmd/proxygenerator/request_header.go new file mode 100644 index 00000000..95ef28a7 --- /dev/null +++ b/cmd/proxygenerator/request_header.go @@ -0,0 +1,318 @@ +package main + +import ( + "bytes" + "fmt" + "go/format" + "os" + "regexp" + "sort" + "strings" + "text/template" + + "golang.org/x/tools/imports" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + + protometa "go.temporal.io/api/protometa/v1" +) + +const requestHeaderFile = "../../proxy/request_header.go" + +const requestHeaderTemplateText = ` +// Code generated by proxygenerator; DO NOT EDIT. + +package proxy + +import ( + "context" + "errors" + "fmt" + + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + {{range $path, $alias := .Imports}} + {{$alias}} "{{$path}}" + {{end}} +) + +// ExtractHeadersOptions contains options for extracting Temporal request headers. +type ExtractHeadersOptions struct { + // Request is the proto message to extract headers from. Required. + Request proto.Message + + // ExistingMetadata contains existing metadata to check for duplicates. + // If provided, headers that already exist will not be added again. + // If nil, no duplicate checking is performed. + ExistingMetadata metadata.MD +} + +// ExtractTemporalRequestHeaders extracts field values from request messages and returns +// them as a slice of key-value pairs suitable for use with metadata.AppendToOutgoingContext. +// Returns nil if no headers should be set. +func ExtractTemporalRequestHeaders(ctx context.Context, opts ExtractHeadersOptions) ([]string, error) { + if opts.Request == nil { + return nil, errors.New("request cannot be nil") + } + + var headers []string + + // Set namespace header if present and not already exists + if len(opts.ExistingMetadata.Get("temporal-namespace")) == 0 { + if nsReq, ok := opts.Request.(interface{ GetNamespace() string }); ok { + if ns := nsReq.GetNamespace(); ns != "" { + headers = append(headers, "temporal-namespace", ns) + } + } + } + + // Set headers from proto annotations{{if .Methods}} + switch r := opts.Request.(type) { +{{range .Methods}} case *{{.PackageAlias}}.{{.RequestType}}: +{{range .Headers}}{{.Code}}{{end}} +{{end}} }{{end}} + + return headers, nil +} +` + +var requestHeaderTemplate = template.Must(template.New("request_header").Parse(requestHeaderTemplateText)) + +type requestHeaderTemplateInput struct { + Methods []methodHeaderInfo + Imports map[string]string // map[importPath]alias +} + +type methodHeaderInfo struct { + PackageAlias string + RequestType string + Headers []headerInfo +} + +type headerInfo struct { + Code string +} + +func generateRequestHeader(cfg config) error { + data, err := os.ReadFile(cfg.descriptorPath) + if err != nil { + return fmt.Errorf("reading descriptor set: %w", err) + } + + var fdSet descriptorpb.FileDescriptorSet + if err := proto.Unmarshal(data, &fdSet); err != nil { + return fmt.Errorf("unmarshalling descriptor set: %w", err) + } + + files, err := protodesc.NewFiles(&fdSet) + if err != nil { + return fmt.Errorf("creating file registry: %w", err) + } + + var allMethods []methodHeaderInfo + importsMap := make(map[string]string) // map[importPath]alias + + files.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + services := fd.Services() + for i := 0; i < services.Len(); i++ { + service := services.Get(i) + methods, importPath, alias, err := extractMethodHeaders(service) + if err != nil { + return false + } + + if len(methods) > 0 && importPath != "" { + importsMap[importPath] = alias + } + + allMethods = append(allMethods, methods...) + } + return true + }) + + if err != nil { + return err + } + + // Sort methods alphabetically by RequestType for consistent output + sort.Slice(allMethods, func(i, j int) bool { + return allMethods[i].RequestType < allMethods[j].RequestType + }) + + buf := &bytes.Buffer{} + err = requestHeaderTemplate.Execute(buf, requestHeaderTemplateInput{ + Methods: allMethods, + Imports: importsMap, + }) + if err != nil { + return fmt.Errorf("executing template: %w", err) + } + + src, err := imports.Process(requestHeaderFile, buf.Bytes(), nil) + if err != nil { + return fmt.Errorf("failed to process imports: %w", err) + } + + src, err = format.Source(src) + if err != nil { + return fmt.Errorf("failed to format: %w", err) + } + + if cfg.verifyOnly { + currentSrc, err := os.ReadFile(requestHeaderFile) + if err != nil { + return fmt.Errorf("failed to read existing file: %w", err) + } + + if !bytes.Equal(src, currentSrc) { + return fmt.Errorf("generated file does not match existing file: %s", requestHeaderFile) + } + + return nil + } + + if err := os.WriteFile(requestHeaderFile, src, 0666); err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + + return nil +} + +func extractMethodHeaders(service protoreflect.ServiceDescriptor) ([]methodHeaderInfo, string, string, error) { + var methods []methodHeaderInfo + + // Get the Go package info from the file descriptor + fileDesc := service.ParentFile() + goPackageOption := fileDesc.Options().(*descriptorpb.FileOptions).GetGoPackage() + + // Parse go_package option: "go.temporal.io/api/workflowservice/v1;workflowservice" + // Format is "import/path;packagename" or just "import/path" + parts := strings.Split(goPackageOption, ";") + importPath := parts[0] + var packageAlias string + if len(parts) > 1 { + packageAlias = parts[1] + } else { + // Use last part of import path as alias + pathParts := strings.Split(importPath, "/") + packageAlias = pathParts[len(pathParts)-1] + } + + for i := 0; i < service.Methods().Len(); i++ { + method := service.Methods().Get(i) + opts := method.Options() + if opts == nil { + continue + } + + if !proto.HasExtension(opts, protometa.E_RequestHeader) { + continue + } + + ext := proto.GetExtension(opts, protometa.E_RequestHeader) + annotations, ok := ext.([]*protometa.RequestHeaderAnnotation) + if !ok || len(annotations) == 0 { + continue + } + + requestTypeName := string(method.Input().Name()) + requestMsgDesc := method.Input() + var headerInfos []headerInfo + + for _, annotation := range annotations { + code, err := generateHeaderCode(annotation.GetHeader(), annotation.GetValue(), "r", requestMsgDesc) + if err != nil { + return nil, "", "", fmt.Errorf("failed to generate header code for %s.%s: %w", service.Name(), method.Name(), err) + } + headerInfos = append(headerInfos, headerInfo{Code: code}) + } + methods = append(methods, methodHeaderInfo{ + PackageAlias: packageAlias, + RequestType: requestTypeName, + Headers: headerInfos, + }) + } + + return methods, importPath, packageAlias, nil +} + +func generateHeaderCode(headerName, valueTemplate, reqVar string, msgDesc protoreflect.MessageDescriptor) (string, error) { + fieldPaths := parseValueTemplate(valueTemplate) + + if len(fieldPaths) == 0 { + return fmt.Sprintf("\t\tif %q != \"\" && len(opts.ExistingMetadata.Get(%q)) == 0 {\n\t\t\theaders = append(headers, %q, %q)\n\t\t}", valueTemplate, headerName, headerName, valueTemplate), nil + } + + if len(fieldPaths) > 1 { + return "", fmt.Errorf("only one field interpolation is supported, found %d", len(fieldPaths)) + } + + fieldPath := fieldPaths[0] + accessor, err := generateFieldAccessor(fieldPath, reqVar, msgDesc) + if err != nil { + return "", fmt.Errorf("failed to generate accessor for %s: %w", fieldPath, err) + } + + finalValue := strings.Replace(valueTemplate, "{"+fieldPath+"}", "%s", 1) + + // Generate code that checks if the field value is non-empty and header doesn't exist before formatting and appending + return fmt.Sprintf("\t\tif val := %s; val != \"\" && len(opts.ExistingMetadata.Get(%q)) == 0 {\n\t\t\theaders = append(headers, %q, fmt.Sprintf(%q, val))\n\t\t}", + accessor, headerName, headerName, finalValue), nil +} + +func generateFieldAccessor(fieldPath, varName string, msgDesc protoreflect.MessageDescriptor) (string, error) { + parts := strings.Split(fieldPath, ".") + accessor := varName + currentMsg := msgDesc + + for _, part := range parts { + field := currentMsg.Fields().ByName(protoreflect.Name(part)) + if field == nil { + return "", fmt.Errorf("field %s not found in message %s", part, currentMsg.FullName()) + } + + goName := protoFieldToGoName(part) + accessor = fmt.Sprintf("%s.Get%s()", accessor, goName) + + if field.Kind() == protoreflect.MessageKind && field.Message() != nil { + currentMsg = field.Message() + } + } + + return accessor, nil +} + +func parseValueTemplate(valueTemplate string) []string { + re := regexp.MustCompile(`\{([^}]+)\}`) + matches := re.FindAllStringSubmatch(valueTemplate, -1) + + fieldPaths := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) > 1 { + fieldPaths = append(fieldPaths, match[1]) + } + } + + return fieldPaths +} + +func protoFieldToGoName(protoName string) string { + parts := strings.Split(protoName, "_") + for i := range parts { + if len(parts[i]) > 0 { + parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] + } + } + return strings.Join(parts, "") +} + +func getPackageName(service protoreflect.ServiceDescriptor) string { + fullName := string(service.FullName()) + parts := strings.Split(fullName, ".") + if len(parts) < 2 { + return "unknown" + } + return parts[len(parts)-2] +} diff --git a/protometa/v1/annotations.go-helpers.pb.go b/protometa/v1/annotations.go-helpers.pb.go new file mode 100644 index 00000000..2647a02c --- /dev/null +++ b/protometa/v1/annotations.go-helpers.pb.go @@ -0,0 +1,43 @@ +// Code generated by protoc-gen-go-helpers. DO NOT EDIT. +package protometa + +import ( + "google.golang.org/protobuf/proto" +) + +// Marshal an object of type RequestHeaderAnnotation to the protobuf v3 wire format +func (val *RequestHeaderAnnotation) Marshal() ([]byte, error) { + return proto.Marshal(val) +} + +// Unmarshal an object of type RequestHeaderAnnotation from the protobuf v3 wire format +func (val *RequestHeaderAnnotation) Unmarshal(buf []byte) error { + return proto.Unmarshal(buf, val) +} + +// Size returns the size of the object, in bytes, once serialized +func (val *RequestHeaderAnnotation) Size() int { + return proto.Size(val) +} + +// Equal returns whether two RequestHeaderAnnotation values are equivalent by recursively +// comparing the message's fields. +// For more information see the documentation for +// https://pkg.go.dev/google.golang.org/protobuf/proto#Equal +func (this *RequestHeaderAnnotation) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + var that1 *RequestHeaderAnnotation + switch t := that.(type) { + case *RequestHeaderAnnotation: + that1 = t + case RequestHeaderAnnotation: + that1 = &t + default: + return false + } + + return proto.Equal(this, that1) +} diff --git a/protometa/v1/annotations.pb.go b/protometa/v1/annotations.pb.go new file mode 100644 index 00000000..4537a7de --- /dev/null +++ b/protometa/v1/annotations.pb.go @@ -0,0 +1,170 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// plugins: +// protoc-gen-go +// protoc +// source: temporal/api/protometa/v1/annotations.proto + +package protometa + +import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// RequestHeaderAnnotation allows specifying that field values from a request +// should be propagated as outbound headers. +// +// The value field supports template interpolation where field paths enclosed +// in braces will be replaced with the actual field values from the request. +// For example: +// +// value: "{workflow_execution.workflow_id}" +// value: "workflow-{workflow_execution.workflow_id}" +// value: "{namespace}/{workflow_execution.workflow_id}" +type RequestHeaderAnnotation struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The name of the header to set (e.g., "temporal-resource-id") + Header string `protobuf:"bytes,1,opt,name=header,proto3" json:"header,omitempty"` + // A template string that may contain field paths in braces. + // Field paths use dot notation to traverse nested messages. + // Example: "{workflow_execution.workflow_id}" + Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestHeaderAnnotation) Reset() { + *x = RequestHeaderAnnotation{} + mi := &file_temporal_api_protometa_v1_annotations_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestHeaderAnnotation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestHeaderAnnotation) ProtoMessage() {} + +func (x *RequestHeaderAnnotation) ProtoReflect() protoreflect.Message { + mi := &file_temporal_api_protometa_v1_annotations_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestHeaderAnnotation.ProtoReflect.Descriptor instead. +func (*RequestHeaderAnnotation) Descriptor() ([]byte, []int) { + return file_temporal_api_protometa_v1_annotations_proto_rawDescGZIP(), []int{0} +} + +func (x *RequestHeaderAnnotation) GetHeader() string { + if x != nil { + return x.Header + } + return "" +} + +func (x *RequestHeaderAnnotation) GetValue() string { + if x != nil { + return x.Value + } + return "" +} + +var file_temporal_api_protometa_v1_annotations_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: ([]*RequestHeaderAnnotation)(nil), + Field: 7234001, + Name: "temporal.api.protometa.v1.request_header", + Tag: "bytes,7234001,rep,name=request_header", + Filename: "temporal/api/protometa/v1/annotations.proto", + }, +} + +// Extension fields to descriptorpb.MethodOptions. +var ( + // repeated temporal.api.protometa.v1.RequestHeaderAnnotation request_header = 7234001; + E_RequestHeader = &file_temporal_api_protometa_v1_annotations_proto_extTypes[0] +) + +var File_temporal_api_protometa_v1_annotations_proto protoreflect.FileDescriptor + +const file_temporal_api_protometa_v1_annotations_proto_rawDesc = "" + + "\n" + + "+temporal/api/protometa/v1/annotations.proto\x12\x19temporal.api.protometa.v1\x1a google/protobuf/descriptor.proto\"G\n" + + "\x17RequestHeaderAnnotation\x12\x16\n" + + "\x06header\x18\x01 \x01(\tR\x06header\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:|\n" + + "\x0erequest_header\x12\x1e.google.protobuf.MethodOptions\x18\xd1รน\x03 \x03(\v22.temporal.api.protometa.v1.RequestHeaderAnnotationR\rrequestHeaderB\x9c\x01\n" + + "\x1cio.temporal.api.protometa.v1B\x10AnnotationsProtoP\x01Z)go.temporal.io/api/protometa/v1;protometa\xaa\x02\x1bTemporalio.Api.Protometa.V1\xea\x02\x1eTemporalio::Api::Protometa::V1b\x06proto3" + +var ( + file_temporal_api_protometa_v1_annotations_proto_rawDescOnce sync.Once + file_temporal_api_protometa_v1_annotations_proto_rawDescData []byte +) + +func file_temporal_api_protometa_v1_annotations_proto_rawDescGZIP() []byte { + file_temporal_api_protometa_v1_annotations_proto_rawDescOnce.Do(func() { + file_temporal_api_protometa_v1_annotations_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_temporal_api_protometa_v1_annotations_proto_rawDesc), len(file_temporal_api_protometa_v1_annotations_proto_rawDesc))) + }) + return file_temporal_api_protometa_v1_annotations_proto_rawDescData +} + +var file_temporal_api_protometa_v1_annotations_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_temporal_api_protometa_v1_annotations_proto_goTypes = []any{ + (*RequestHeaderAnnotation)(nil), // 0: temporal.api.protometa.v1.RequestHeaderAnnotation + (*descriptorpb.MethodOptions)(nil), // 1: google.protobuf.MethodOptions +} +var file_temporal_api_protometa_v1_annotations_proto_depIdxs = []int32{ + 1, // 0: temporal.api.protometa.v1.request_header:extendee -> google.protobuf.MethodOptions + 0, // 1: temporal.api.protometa.v1.request_header:type_name -> temporal.api.protometa.v1.RequestHeaderAnnotation + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 1, // [1:2] is the sub-list for extension type_name + 0, // [0:1] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_temporal_api_protometa_v1_annotations_proto_init() } +func file_temporal_api_protometa_v1_annotations_proto_init() { + if File_temporal_api_protometa_v1_annotations_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_temporal_api_protometa_v1_annotations_proto_rawDesc), len(file_temporal_api_protometa_v1_annotations_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_temporal_api_protometa_v1_annotations_proto_goTypes, + DependencyIndexes: file_temporal_api_protometa_v1_annotations_proto_depIdxs, + MessageInfos: file_temporal_api_protometa_v1_annotations_proto_msgTypes, + ExtensionInfos: file_temporal_api_protometa_v1_annotations_proto_extTypes, + }.Build() + File_temporal_api_protometa_v1_annotations_proto = out.File + file_temporal_api_protometa_v1_annotations_proto_goTypes = nil + file_temporal_api_protometa_v1_annotations_proto_depIdxs = nil +} diff --git a/proxy/request_header.go b/proxy/request_header.go new file mode 100644 index 00000000..5eefa29f --- /dev/null +++ b/proxy/request_header.go @@ -0,0 +1,275 @@ +// Code generated by proxygenerator; DO NOT EDIT. + +package proxy + +import ( + "context" + "errors" + "fmt" + + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + + workflowservice "go.temporal.io/api/workflowservice/v1" +) + +// ExtractHeadersOptions contains options for extracting Temporal request headers. +type ExtractHeadersOptions struct { + // Request is the proto message to extract headers from. Required. + Request proto.Message + + // ExistingMetadata contains existing metadata to check for duplicates. + // If provided, headers that already exist will not be added again. + // If nil, no duplicate checking is performed. + ExistingMetadata metadata.MD +} + +// ExtractTemporalRequestHeaders extracts field values from request messages and returns +// them as a slice of key-value pairs suitable for use with metadata.AppendToOutgoingContext. +// Returns nil if no headers should be set. +func ExtractTemporalRequestHeaders(ctx context.Context, opts ExtractHeadersOptions) ([]string, error) { + if opts.Request == nil { + return nil, errors.New("request cannot be nil") + } + + var headers []string + + // Set namespace header if present and not already exists + if len(opts.ExistingMetadata.Get("temporal-namespace")) == 0 { + if nsReq, ok := opts.Request.(interface{ GetNamespace() string }); ok { + if ns := nsReq.GetNamespace(); ns != "" { + headers = append(headers, "temporal-namespace", ns) + } + } + } + + // Set headers from proto annotations + switch r := opts.Request.(type) { + case *workflowservice.CreateScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.DeleteScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.DeleteWorkerDeploymentRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DeleteWorkerDeploymentVersionRequest: + if val := r.GetDeploymentVersion().GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DeleteWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.DescribeBatchOperationRequest: + if val := r.GetJobId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("batch:%s", val)) + } + case *workflowservice.DescribeScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.DescribeTaskQueueRequest: + if val := r.GetTaskQueue().GetName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("taskqueue:%s", val)) + } + case *workflowservice.DescribeWorkerDeploymentRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DescribeWorkerDeploymentVersionRequest: + if val := r.GetDeploymentVersion().GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.DescribeWorkerRequest: + if val := r.GetWorkerInstanceKey(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("worker:%s", val)) + } + case *workflowservice.DescribeWorkflowExecutionRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ExecuteMultiOperationRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.FetchWorkerConfigRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("worker:%s", val)) + } + case *workflowservice.GetWorkflowExecutionHistoryRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.GetWorkflowExecutionHistoryReverseRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ListScheduleMatchingTimesRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.ListTaskQueuePartitionsRequest: + if val := r.GetTaskQueue().GetName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("taskqueue:%s", val)) + } + case *workflowservice.PatchScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.PauseActivityRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.PauseWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.PollWorkflowExecutionUpdateRequest: + if val := r.GetUpdateRef().GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.QueryWorkflowRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.RecordActivityTaskHeartbeatByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RecordActivityTaskHeartbeatRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RecordWorkerHeartbeatRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("worker:%s", val)) + } + case *workflowservice.RequestCancelWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ResetActivityRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ResetStickyTaskQueueRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.ResetWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.RespondActivityTaskCanceledByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RespondActivityTaskCanceledRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RespondActivityTaskCompletedByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RespondActivityTaskCompletedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RespondActivityTaskFailedByIdRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RespondActivityTaskFailedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("%s", val)) + } + case *workflowservice.RespondWorkflowTaskCompletedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.RespondWorkflowTaskFailedRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.SetWorkerDeploymentCurrentVersionRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.SetWorkerDeploymentManagerRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.SetWorkerDeploymentRampingVersionRequest: + if val := r.GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.SignalWithStartWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.SignalWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.StartBatchOperationRequest: + if val := r.GetJobId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("batch:%s", val)) + } + case *workflowservice.StartWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.StopBatchOperationRequest: + if val := r.GetJobId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("batch:%s", val)) + } + case *workflowservice.TerminateWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UnpauseActivityRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UnpauseWorkflowExecutionRequest: + if val := r.GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UpdateActivityOptionsRequest: + if val := r.GetExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UpdateScheduleRequest: + if val := r.GetScheduleId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("schedule:%s", val)) + } + case *workflowservice.UpdateTaskQueueConfigRequest: + if val := r.GetTaskQueue(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("taskqueue:%s", val)) + } + case *workflowservice.UpdateWorkerConfigRequest: + if val := r.GetResourceId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("worker:%s", val)) + } + case *workflowservice.UpdateWorkerDeploymentVersionMetadataRequest: + if val := r.GetDeploymentVersion().GetDeploymentName(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("deployment:%s", val)) + } + case *workflowservice.UpdateWorkflowExecutionOptionsRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + case *workflowservice.UpdateWorkflowExecutionRequest: + if val := r.GetWorkflowExecution().GetWorkflowId(); val != "" && len(opts.ExistingMetadata.Get("temporal-resource-id")) == 0 { + headers = append(headers, "temporal-resource-id", fmt.Sprintf("workflow:%s", val)) + } + } + + return headers, nil +} diff --git a/proxy/request_header_test.go b/proxy/request_header_test.go new file mode 100644 index 00000000..4ed73975 --- /dev/null +++ b/proxy/request_header_test.go @@ -0,0 +1,74 @@ +package proxy + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + + "go.temporal.io/api/workflowservice/v1" +) + +// findHeader searches for a header key in the headers slice and returns its value +func findHeader(headers []string, key string) (string, bool) { + for i := 0; i < len(headers); i += 2 { + if i+1 < len(headers) && headers[i] == key { + return headers[i+1], true + } + } + return "", false +} + +func TestExtractTemporalRequestHeaders_NamespaceAlwaysIncluded(t *testing.T) { + req := &workflowservice.StartWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowId: "test-workflow", + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + // Namespace should always be included in the headers + nsVal, found := findHeader(headers, "temporal-namespace") + require.True(t, found, "Expected temporal-namespace header, but not found") + require.Equal(t, "test-namespace", nsVal) +} + +func TestExtractTemporalRequestHeaders_EmptyWorkflowId(t *testing.T) { + req := &workflowservice.StartWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowId: "", + } + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + }) + require.NoError(t, err) + + // Should still set namespace + nsVal, found := findHeader(headers, "temporal-namespace") + require.True(t, found, "Expected temporal-namespace header even with empty workflow_id") + require.Equal(t, "test-namespace", nsVal) +} + +func TestExtractTemporalRequestHeaders_SkipExistingHeaders(t *testing.T) { + req := &workflowservice.StartWorkflowExecutionRequest{ + Namespace: "test-namespace", + WorkflowId: "test-workflow", + } + + existingMD := metadata.MD{} + existingMD.Set("temporal-namespace", "existing-namespace") + + headers, err := ExtractTemporalRequestHeaders(context.Background(), ExtractHeadersOptions{ + Request: req, + ExistingMetadata: existingMD, + }) + require.NoError(t, err) + + // Should not add any headers since they already exist + require.Empty(t, headers, "Expected no headers to be added when they already exist") +}