Skip to content

Commit a7bae2a

Browse files
committed
Java version
1 parent 1d98933 commit a7bae2a

File tree

4 files changed

+367
-1
lines changed

4 files changed

+367
-1
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.util;
19+
20+
import com.google.api.gax.rpc.AlreadyExistsException;
21+
import com.google.api.gax.rpc.NotFoundException;
22+
import com.google.cloud.kms.v1.CryptoKeyName;
23+
import com.google.cloud.kms.v1.EncryptResponse;
24+
import com.google.cloud.kms.v1.KeyManagementServiceClient;
25+
import com.google.cloud.secretmanager.v1.AccessSecretVersionResponse;
26+
import com.google.cloud.secretmanager.v1.ProjectName;
27+
import com.google.cloud.secretmanager.v1.Replication;
28+
import com.google.cloud.secretmanager.v1.SecretPayload;
29+
import com.google.cloud.secretmanager.v1.SecretManagerServiceClient;
30+
import com.google.cloud.secretmanager.v1.SecretName;
31+
import com.google.cloud.secretmanager.v1.SecretVersionName;
32+
import com.google.crypto.tink.subtle.Hkdf;
33+
import com.google.protobuf.ByteString;
34+
import java.io.IOException;
35+
import java.security.GeneralSecurityException;
36+
import java.security.SecureRandom;
37+
import java.util.Base64;
38+
import org.slf4j.Logger;
39+
import org.slf4j.LoggerFactory;
40+
41+
/**
42+
* A {@link Secret} manager implementation that retrieves secrets from Google Cloud Secret Manager.
43+
*/
44+
public class GcpHsmGeneratedSecret implements Secret {
45+
private static final Logger LOG = LoggerFactory.getLogger(GcpHsmGeneratedSecret.class);
46+
private final String projectId;
47+
private final String locationId;
48+
private final String keyRingId;
49+
private final String keyId;
50+
private final String secretId;
51+
52+
public GcpHsmGeneratedSecret(
53+
String projectId, String locationId, String keyRingId, String keyId, String jobName) {
54+
this.projectId = projectId;
55+
this.locationId = locationId;
56+
this.keyRingId = keyRingId;
57+
this.keyId = keyId;
58+
this.secretId = "HsmGeneratedSecret_" + jobName;
59+
}
60+
61+
/**
62+
* Returns the secret as a byte array. Assumes that the current active service account has
63+
* permissions to read the secret.
64+
*
65+
* @return The secret as a byte array.
66+
*/
67+
@Override
68+
public byte[] getSecretBytes() {
69+
try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) {
70+
SecretVersionName secretVersionName = SecretVersionName.of(projectId, secretId, "1");
71+
72+
try {
73+
AccessSecretVersionResponse response = client.accessSecretVersion(secretVersionName);
74+
return response.getPayload().getData().toByteArray();
75+
} catch (NotFoundException e) {
76+
LOG.info(
77+
"Secret version {} not found. Creating new secret and version.",
78+
secretVersionName.toString());
79+
}
80+
81+
ProjectName projectName = ProjectName.of(projectId);
82+
SecretName secretName = SecretName.of(projectId, secretId);
83+
try {
84+
com.google.cloud.secretmanager.v1.Secret secret =
85+
com.google.cloud.secretmanager.v1.Secret.newBuilder()
86+
.setReplication(
87+
Replication.newBuilder()
88+
.setAutomatic(Replication.Automatic.newBuilder().build()))
89+
.build();
90+
client.createSecret(projectName, secretId, secret);
91+
} catch (AlreadyExistsException e) {
92+
LOG.info("Secret {} already exists. Adding new version.", secretName.toString());
93+
}
94+
95+
byte[] newKey = generateDek();
96+
97+
try {
98+
// Try to access again in case another thread created it.
99+
AccessSecretVersionResponse response = client.accessSecretVersion(secretVersionName);
100+
return response.getPayload().getData().toByteArray();
101+
} catch (NotFoundException e) {
102+
LOG.info(
103+
"Secret version {} not found after re-check. Creating new secret and version.",
104+
secretVersionName.toString());
105+
}
106+
107+
SecretPayload payload = SecretPayload.newBuilder().setData(ByteString.copyFrom(newKey)).build();
108+
client.addSecretVersion(secretName, payload);
109+
AccessSecretVersionResponse response = client.accessSecretVersion(secretVersionName);
110+
return response.getPayload().getData().toByteArray();
111+
112+
} catch (IOException | GeneralSecurityException e) {
113+
throw new RuntimeException("Failed to retrieve or create secret bytes", e);
114+
}
115+
}
116+
117+
private byte[] generateDek() throws IOException, GeneralSecurityException {
118+
int dekSize = 32;
119+
try (KeyManagementServiceClient client = KeyManagementServiceClient.create()) {
120+
// 1. Generate nonce_one
121+
SecureRandom random = new SecureRandom();
122+
byte[] nonceOne = new byte[dekSize];
123+
random.nextBytes(nonceOne);
124+
125+
// 2. Encrypt to get nonce_two
126+
CryptoKeyName keyName = CryptoKeyName.of(projectId, locationId, keyRingId, keyId);
127+
EncryptResponse response = client.encrypt(keyName, ByteString.copyFrom(nonceOne));
128+
byte[] nonceTwo = response.getCiphertext().toByteArray();
129+
130+
// 3. Generate DK
131+
byte[] dk = new byte[dekSize];
132+
random.nextBytes(dk);
133+
134+
// 4. Derive DEK using HKDF
135+
byte[] dek = Hkdf.computeHkdf("HmacSha256", dk, nonceTwo, new byte[0], dekSize);
136+
137+
// 5. Base64 encode
138+
return Base64.getUrlEncoder().encode(dek);
139+
}
140+
}
141+
}

sdks/java/core/src/main/java/org/apache/beam/sdk/util/Secret.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,29 @@ static Secret parseSecretOption(String secretOption) {
7676
"version_name must contain a valid value for versionName parameter");
7777
}
7878
return new GcpSecret(versionName);
79+
case "gcphsmgeneratedsecret":
80+
Set<String> gcpHsmGeneratedSecretParams =
81+
new HashSet<>(
82+
Arrays.asList("project_id", "location_id", "key_ring_id", "key_id", "job_name"));
83+
for (String paramName : paramMap.keySet()) {
84+
if (!gcpHsmGeneratedSecretParams.contains(paramName)) {
85+
throw new RuntimeException(
86+
String.format(
87+
"Invalid secret parameter %s, GcpHsmGeneratedSecret only supports the following parameters: %s",
88+
paramName, gcpHsmGeneratedSecretParams));
89+
}
90+
}
91+
return new GcpHsmGeneratedSecret(
92+
paramMap.get("project_id"),
93+
paramMap.get("location_id"),
94+
paramMap.get("key_ring_id"),
95+
paramMap.get("key_id"),
96+
paramMap.get("job_name"));
7997
default:
8098
throw new RuntimeException(
8199
String.format(
82-
"Invalid secret type %s, currently only GcpSecret is supported", secretType));
100+
"Invalid secret type %s, currently only GcpSecret and GcpHsmGeneratedSecret are supported",
101+
secretType));
83102
}
84103
}
85104
}

sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.beam.sdk.testing.NeedsRunner;
4040
import org.apache.beam.sdk.testing.PAssert;
4141
import org.apache.beam.sdk.testing.TestPipeline;
42+
import org.apache.beam.sdk.util.GcpHsmGeneratedSecret;
4243
import org.apache.beam.sdk.util.GcpSecret;
4344
import org.apache.beam.sdk.util.Secret;
4445
import org.apache.beam.sdk.values.KV;
@@ -102,6 +103,9 @@ public void testGroupByKeyFakeSecret() {
102103
private static final String PROJECT_ID = "apache-beam-testing";
103104
private static final String SECRET_ID = "gbek-test";
104105
private static Secret gcpSecret;
106+
private static Secret gcpHsmGeneratedSecret;
107+
private static String keyRingId;
108+
private static String keyId;
105109

106110
@BeforeClass
107111
public static void setup() throws IOException {
@@ -131,13 +135,59 @@ public static void setup() throws IOException {
131135
.build());
132136
}
133137
gcpSecret = new GcpSecret(secretName.toString() + "/versions/latest");
138+
139+
try {
140+
com.google.cloud.kms.v1.KeyManagementServiceClient kmsClient =
141+
com.google.cloud.kms.v1.KeyManagementServiceClient.create();
142+
String locationId = "global";
143+
keyRingId = "gbek-test-key-ring-" + System.currentTimeMillis();
144+
com.google.cloud.kms.v1.KeyRingName keyRingName =
145+
com.google.cloud.kms.v1.KeyRingName.of(PROJECT_ID, locationId, keyRingId);
146+
kmsClient.createKeyRing(
147+
keyRingName.getProject(),
148+
keyRingName.getLocation(),
149+
keyRingId,
150+
com.google.cloud.kms.v1.KeyRing.newBuilder().build());
151+
152+
keyId = "gbek-test-key-" + System.currentTimeMillis();
153+
com.google.cloud.kms.v1.CryptoKey key =
154+
com.google.cloud.kms.v1.CryptoKey.newBuilder()
155+
.setPurpose(
156+
com.google.cloud.kms.v1.CryptoKey.CryptoKeyPurpose.ENCRYPT_DECRYPT)
157+
.build();
158+
kmsClient.createCryptoKey(keyRingName, keyId, key);
159+
gcpHsmGeneratedSecret =
160+
new GcpHsmGeneratedSecret(
161+
PROJECT_ID,
162+
locationId,
163+
keyRingId,
164+
keyId,
165+
String.format("gbek-test-job-%d", new SecureRandom().nextInt(10000)));
166+
} catch (Exception e) {
167+
gcpHsmGeneratedSecret = null;
168+
}
134169
}
135170

136171
@AfterClass
137172
public static void tearDown() throws IOException {
138173
SecretManagerServiceClient client = SecretManagerServiceClient.create();
139174
SecretName secretName = SecretName.of(PROJECT_ID, SECRET_ID);
140175
client.deleteSecret(secretName);
176+
if (gcpHsmGeneratedSecret != null) {
177+
com.google.cloud.kms.v1.KeyManagementServiceClient kmsClient =
178+
com.google.cloud.kms.v1.KeyManagementServiceClient.create();
179+
com.google.cloud.kms.v1.CryptoKeyName keyName =
180+
com.google.cloud.kms.v1.CryptoKeyName.of(PROJECT_ID, "global", keyRingId, keyId);
181+
for (com.google.cloud.kms.v1.CryptoKeyVersion version :
182+
kmsClient.listCryptoKeyVersions(keyName).iterateAll()) {
183+
if (version.getState()
184+
== com.google.cloud.kms.v1.CryptoKeyVersion.CryptoKeyVersionState.ENABLED
185+
|| version.getState()
186+
== com.google.cloud.kms.v1.CryptoKeyVersion.CryptoKeyVersionState.DISABLED) {
187+
kmsClient.destroyCryptoKeyVersion(version.getName());
188+
}
189+
}
190+
}
141191
}
142192

143193
@Test
@@ -183,6 +233,43 @@ public void testGroupByKeyGcpSecretThrows() {
183233
assertThrows(RuntimeException.class, () -> p.run());
184234
}
185235

236+
@Test
237+
@Category(NeedsRunner.class)
238+
public void testGroupByKeyGcpHsmGeneratedSecret() {
239+
if (gcpHsmGeneratedSecret == null) {
240+
return;
241+
}
242+
List<KV<@Nullable String, Integer>> ungroupedPairs =
243+
Arrays.asList(
244+
KV.of(null, 3),
245+
KV.of("k1", 3),
246+
KV.of("k5", Integer.MAX_VALUE),
247+
KV.of("k5", Integer.MIN_VALUE),
248+
KV.of("k2", 66),
249+
KV.of("k1", 4),
250+
KV.of(null, 5),
251+
KV.of("k2", -33),
252+
KV.of("k3", 0));
253+
254+
PCollection<KV<String, Integer>> input =
255+
p.apply(
256+
Create.of(ungroupedPairs)
257+
.withCoder(KvCoder.of(NullableCoder.of(StringUtf8Coder.of()), VarIntCoder.of())));
258+
259+
PCollection<KV<String, Iterable<Integer>>> output =
260+
input.apply(GroupByEncryptedKey.<String, Integer>create(gcpHsmGeneratedSecret));
261+
262+
PAssert.that(output.apply("Sort", MapElements.via(new SortValues())))
263+
.containsInAnyOrder(
264+
KV.of("k1", Arrays.asList(3, 4)),
265+
KV.of(null, Arrays.asList(3, 5)),
266+
KV.of("k5", Arrays.asList(Integer.MIN_VALUE, Integer.MAX_VALUE)),
267+
KV.of("k2", Arrays.asList(-33, 66)),
268+
KV.of("k3", Arrays.asList(0)));
269+
270+
p.run();
271+
}
272+
186273
private static class SortValues
187274
extends SimpleFunction<KV<String, Iterable<Integer>>, KV<String, List<Integer>>> {
188275
@Override

0 commit comments

Comments
 (0)