@@ -2,15 +2,19 @@ package resolvers
22
33import (
44 "context"
5+ "fmt"
6+ "strconv"
57
68 policyinfo "github.com/aws/amazon-network-policy-controller-k8s/api/v1alpha1"
79 "github.com/aws/amazon-network-policy-controller-k8s/pkg/k8s"
810 "github.com/go-logr/logr"
911 "github.com/pkg/errors"
12+ "golang.org/x/exp/maps"
1013 corev1 "k8s.io/api/core/v1"
1114 networking "k8s.io/api/networking/v1"
1215 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1316 "k8s.io/apimachinery/pkg/labels"
17+ "k8s.io/apimachinery/pkg/types"
1418 "k8s.io/apimachinery/pkg/util/intstr"
1519 "sigs.k8s.io/controller-runtime/pkg/client"
1620)
@@ -216,10 +220,14 @@ func (r *defaultEndpointsResolver) getIngressRulesPorts(ctx context.Context, pol
216220 var portList []policyinfo.Port
217221 for _ , pod := range podList .Items {
218222 portList = append (portList , r .getPortList (pod , ports )... )
219- r .logger .Info ("Got ingress port" , "port " , portList , " pod" , pod )
223+ r .logger .Info ("Got ingress port from pod " , "pod " , types. NamespacedName { Namespace : pod . Namespace , Name : pod . Name }. String () )
220224 }
221225
222- return portList
226+ // since we pull ports from dst pods, we should deduplicate them
227+ deduppedPorts := dedupPorts (portList )
228+ r .logger .Info ("Got ingress ports from dst pods" , "port" , deduppedPorts )
229+
230+ return deduppedPorts
223231}
224232
225233func (r * defaultEndpointsResolver ) getPortList (pod corev1.Pod , ports []networking.NetworkPolicyPort ) []policyinfo.Port {
@@ -455,3 +463,25 @@ func (r *defaultEndpointsResolver) getMatchingServicePort(ctx context.Context, s
455463 }
456464 return 0 , errors .Errorf ("unable to find matching service listen port %s for service %s" , port .String (), k8s .NamespacedName (svc ))
457465}
466+
467+ func dedupPorts (policyPorts []policyinfo.Port ) []policyinfo.Port {
468+ ports := make (map [string ]policyinfo.Port )
469+ for _ , port := range policyPorts {
470+ prot , p , ep := "" , "" , ""
471+ if port .Protocol != nil {
472+ prot = string (* port .Protocol )
473+ }
474+ if port .Port != nil {
475+ p = strconv .FormatInt (int64 (* port .Port ), 10 )
476+ }
477+ if port .EndPort != nil {
478+ ep = strconv .FormatInt (int64 (* port .EndPort ), 10 )
479+ }
480+
481+ ports [fmt .Sprintf ("%s@%s@%s" , prot , p , ep )] = port
482+ }
483+ if len (ports ) > 0 {
484+ return maps .Values (ports )
485+ }
486+ return nil
487+ }
0 commit comments