Skip to content

Commit 726257d

Browse files
authored
add pod and service utils (#7)
* add pod and service utils
1 parent 8238750 commit 726257d

File tree

5 files changed

+587
-1
lines changed

5 files changed

+587
-1
lines changed

internal/eventhandlers/pod.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ package eventhandlers
1818

1919
import (
2020
"context"
21+
2122
"github.com/aws/amazon-network-policy-controller-k8s/pkg/k8s"
2223
"github.com/aws/amazon-network-policy-controller-k8s/pkg/resolvers"
23-
2424
"github.com/go-logr/logr"
2525
corev1 "k8s.io/api/core/v1"
2626
"k8s.io/apimachinery/pkg/api/equality"
@@ -81,6 +81,10 @@ func (h *enqueueRequestForPodEvent) Generic(_ context.Context, _ event.GenericEv
8181
}
8282

8383
func (h *enqueueRequestForPodEvent) enqueueReferredPolicies(ctx context.Context, _ workqueue.RateLimitingInterface, pod *corev1.Pod, podOld *corev1.Pod) {
84+
if len(k8s.GetPodIP(pod)) == 0 {
85+
h.logger.V(1).Info("Pod does not have an IP yet", "pod", k8s.NamespacedName(pod))
86+
return
87+
}
8488
referredPolicies, err := h.policyResolver.GetReferredPoliciesForPod(ctx, pod, podOld)
8589
if err != nil {
8690
h.logger.Error(err, "Unable to get referred policies", "pod", k8s.NamespacedName(pod))

pkg/k8s/pod_utils.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package k8s
2+
3+
import (
4+
"github.com/pkg/errors"
5+
corev1 "k8s.io/api/core/v1"
6+
"k8s.io/apimachinery/pkg/util/intstr"
7+
)
8+
9+
const (
10+
podIPAnnotation = "vpc.amazonaws.com/pod-ips"
11+
)
12+
13+
// GetPodIP returns the pod IP from the pod status or from the pod annotation.
14+
func GetPodIP(pod *corev1.Pod) string {
15+
if len(pod.Status.PodIP) > 0 {
16+
return pod.Status.PodIP
17+
} else {
18+
return pod.Annotations[podIPAnnotation]
19+
}
20+
}
21+
22+
// LookupContainerPortAndName returns numerical containerPort and portName for specific port and protocol
23+
func LookupContainerPortAndName(pod *corev1.Pod, port intstr.IntOrString, protocol corev1.Protocol) (int32, string, error) {
24+
for _, podContainer := range pod.Spec.Containers {
25+
for _, podPort := range podContainer.Ports {
26+
if podPort.Protocol != protocol {
27+
continue
28+
}
29+
switch port.Type {
30+
case intstr.String:
31+
if podPort.Name == port.StrVal {
32+
return podPort.ContainerPort, podPort.Name, nil
33+
}
34+
case intstr.Int:
35+
if podPort.ContainerPort == port.IntVal {
36+
return podPort.ContainerPort, podPort.Name, nil
37+
}
38+
}
39+
}
40+
}
41+
if port.Type == intstr.Int {
42+
return port.IntVal, "", nil
43+
}
44+
return 0, "", errors.Errorf("unable to find port %s on pod %s", port.String(), NamespacedName(pod))
45+
}

pkg/k8s/pod_utils_test.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package k8s
2+
3+
import (
4+
"github.com/stretchr/testify/assert"
5+
corev1 "k8s.io/api/core/v1"
6+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
7+
"k8s.io/apimachinery/pkg/util/intstr"
8+
"testing"
9+
)
10+
11+
func Test_GetPodIP(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
pod *corev1.Pod
15+
want string
16+
}{
17+
{
18+
name: "pod with status IP",
19+
pod: &corev1.Pod{
20+
Status: corev1.PodStatus{
21+
PodIP: "192.168.11.22",
22+
},
23+
},
24+
want: "192.168.11.22",
25+
},
26+
{
27+
name: "pod with annotation IP",
28+
pod: &corev1.Pod{
29+
ObjectMeta: metav1.ObjectMeta{
30+
Annotations: map[string]string{
31+
podIPAnnotation: "1.2.3.4",
32+
},
33+
},
34+
},
35+
want: "1.2.3.4",
36+
},
37+
{
38+
name: "pod without status IP or annotation IP",
39+
pod: &corev1.Pod{},
40+
},
41+
}
42+
for _, tt := range tests {
43+
t.Run(tt.name, func(t *testing.T) {
44+
got := GetPodIP(tt.pod)
45+
assert.Equal(t, tt.want, got)
46+
})
47+
}
48+
}
49+
50+
func Test_LookupContainerPortAndName(t *testing.T) {
51+
pod := &corev1.Pod{
52+
ObjectMeta: metav1.ObjectMeta{
53+
Namespace: "default",
54+
Name: "pod",
55+
},
56+
Spec: corev1.PodSpec{
57+
Containers: []corev1.Container{
58+
{
59+
Ports: []corev1.ContainerPort{
60+
{
61+
Name: "http",
62+
ContainerPort: 80,
63+
Protocol: corev1.ProtocolTCP,
64+
},
65+
},
66+
},
67+
{
68+
Ports: []corev1.ContainerPort{
69+
{
70+
Name: "https",
71+
ContainerPort: 443,
72+
Protocol: corev1.ProtocolTCP,
73+
},
74+
{
75+
ContainerPort: 8080,
76+
Protocol: corev1.ProtocolTCP,
77+
},
78+
},
79+
},
80+
},
81+
},
82+
}
83+
type want struct {
84+
port int32
85+
name string
86+
}
87+
type args struct {
88+
pod *corev1.Pod
89+
protocol corev1.Protocol
90+
port intstr.IntOrString
91+
}
92+
tests := []struct {
93+
name string
94+
args args
95+
want want
96+
wantErr string
97+
}{
98+
{
99+
name: "resolve numeric pod",
100+
args: args{
101+
pod: pod,
102+
port: intstr.FromInt(8080),
103+
},
104+
want: want{
105+
port: 8080,
106+
},
107+
},
108+
{
109+
name: "numeric pod not in pod spec can still be resolved",
110+
args: args{
111+
pod: pod,
112+
port: intstr.FromInt(9090),
113+
},
114+
want: want{
115+
port: 9090,
116+
},
117+
},
118+
{
119+
name: "lookup based on port name",
120+
args: args{
121+
pod: pod,
122+
port: intstr.FromString("http"),
123+
},
124+
want: want{
125+
port: 80,
126+
name: "http",
127+
},
128+
},
129+
{
130+
name: "lookup based on port name in another container",
131+
args: args{
132+
pod: pod,
133+
port: intstr.FromString("https"),
134+
},
135+
want: want{
136+
port: 443,
137+
name: "https",
138+
},
139+
},
140+
{
141+
name: "port matches, but protocol does not",
142+
args: args{
143+
pod: pod,
144+
port: intstr.FromString("https"),
145+
protocol: corev1.ProtocolUDP,
146+
},
147+
wantErr: "unable to find port https on pod default/pod",
148+
},
149+
{
150+
name: "numeric port lookup ignores the protocol",
151+
args: args{
152+
pod: pod,
153+
port: intstr.FromInt(443),
154+
protocol: corev1.ProtocolUDP,
155+
},
156+
want: want{
157+
port: 443,
158+
},
159+
},
160+
{
161+
name: "nonexistent port name",
162+
args: args{
163+
pod: pod,
164+
port: intstr.FromString("nonexistent"),
165+
},
166+
wantErr: "unable to find port nonexistent on pod default/pod",
167+
},
168+
}
169+
for _, tt := range tests {
170+
t.Run(tt.name, func(t *testing.T) {
171+
protocol := tt.args.protocol
172+
if len(protocol) == 0 {
173+
protocol = corev1.ProtocolTCP
174+
}
175+
got := want{}
176+
var err error
177+
got.port, got.name, err = LookupContainerPortAndName(tt.args.pod, tt.args.port, protocol)
178+
if len(tt.wantErr) > 0 {
179+
assert.EqualError(t, err, tt.wantErr)
180+
} else {
181+
assert.NoError(t, err)
182+
assert.Equal(t, tt.want, got)
183+
}
184+
})
185+
}
186+
}

pkg/k8s/service_utils.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package k8s
2+
3+
import (
4+
"github.com/pkg/errors"
5+
corev1 "k8s.io/api/core/v1"
6+
"k8s.io/apimachinery/pkg/util/intstr"
7+
)
8+
9+
// LookupServiceListenPort returns the numerical port for the service listen port if the target port name matches
10+
// or the port number and the protocol matches the target port. If no matching port is found, it returns a 0 and an error.
11+
func LookupServiceListenPort(svc *corev1.Service, port intstr.IntOrString, protocol corev1.Protocol) (int32, error) {
12+
for _, svcPort := range svc.Spec.Ports {
13+
if svcPort.TargetPort.Type == port.Type && svcPort.TargetPort.String() == port.String() && svcPort.Protocol == protocol {
14+
return svcPort.Port, nil
15+
}
16+
}
17+
return 0, errors.Errorf("unable to find port %s on service %s", port.String(), NamespacedName(svc))
18+
}
19+
20+
// LookupListenPortFromPodSpec returns the numerical listener port from the service spec if the input port matches the target port
21+
// in the pod spec
22+
func LookupListenPortFromPodSpec(svc *corev1.Service, pod *corev1.Pod, port intstr.IntOrString, protocol corev1.Protocol) (int32, error) {
23+
containerPort, containerPortName, err := LookupContainerPortAndName(pod, port, protocol)
24+
if err != nil {
25+
return 0, err
26+
}
27+
for _, svcPort := range svc.Spec.Ports {
28+
if svcPort.Protocol != protocol {
29+
continue
30+
}
31+
switch svcPort.TargetPort.Type {
32+
case intstr.String:
33+
if containerPortName == svcPort.TargetPort.StrVal {
34+
return svcPort.Port, nil
35+
}
36+
37+
case intstr.Int:
38+
if containerPort == svcPort.TargetPort.IntVal {
39+
return svcPort.Port, nil
40+
}
41+
}
42+
}
43+
return 0, errors.Errorf("unable to find listener port for port %s on service %s", port.String(), NamespacedName(svc))
44+
}

0 commit comments

Comments
 (0)