Skip to content

Commit 8e4c976

Browse files
committed
Skip flattening for mutations across imports
1 parent ffa6d64 commit 8e4c976

4 files changed

Lines changed: 58 additions & 67 deletions

File tree

pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/GraphIngestionPipeline.java

Lines changed: 46 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
import com.google.gson.JsonArray;
55
import com.google.gson.JsonElement;
66
import com.google.gson.JsonParser;
7-
import java.util.ArrayList;
7+
import java.util.Arrays;
88
import java.util.List;
99
import org.apache.beam.sdk.Pipeline;
1010
import org.apache.beam.sdk.io.gcp.spanner.SpannerWriteResult;
1111
import org.apache.beam.sdk.metrics.Counter;
1212
import org.apache.beam.sdk.metrics.Metrics;
1313
import org.apache.beam.sdk.options.PipelineOptionsFactory;
14+
import org.apache.beam.sdk.transforms.Create;
1415
import org.apache.beam.sdk.transforms.Flatten;
1516
import org.apache.beam.sdk.transforms.Values;
1617
import org.apache.beam.sdk.values.PCollection;
1718
import org.apache.beam.sdk.values.PCollectionList;
1819
import org.apache.beam.sdk.values.PCollectionTuple;
20+
import org.apache.beam.sdk.values.TypeDescriptor;
1921
import org.datacommons.ingestion.spanner.SpannerClient;
2022
import org.datacommons.ingestion.util.GraphReader;
2123
import org.datacommons.ingestion.util.PipelineUtils;
@@ -70,12 +72,6 @@ public static void buildPipeline(
7072
Pipeline pipeline, IngestionPipelineOptions options, SpannerClient spannerClient) {
7173
LOGGER.info("Running import pipeline for imports: {}", options.getImportList());
7274

73-
// Initialize lists to hold mutations from all imports.
74-
List<PCollection<Void>> deleteOpsList = new ArrayList<>();
75-
List<PCollection<Mutation>> obsMutationList = new ArrayList<>();
76-
List<PCollection<Mutation>> edgeMutationList = new ArrayList<>();
77-
List<PCollection<Mutation>> nodeMutationList = new ArrayList<>();
78-
7975
// Parse the input import list JSON.
8076
JsonElement jsonElement = JsonParser.parseString(options.getImportList());
8177
JsonArray jsonArray = jsonElement.getAsJsonArray();
@@ -97,37 +93,8 @@ public static void buildPipeline(
9793
String graphPath = pathElement.getAsString();
9894

9995
// Process the individual import.
100-
processImport(
101-
pipeline,
102-
spannerClient,
103-
importName,
104-
graphPath,
105-
options.getSkipDelete(),
106-
deleteOpsList,
107-
nodeMutationList,
108-
edgeMutationList,
109-
obsMutationList);
110-
}
111-
// Finally, aggregate all collected mutations and write them to Spanner.
112-
// 1. Process Deletes:
113-
// First, execute all delete mutations to clear old data for the imports.
114-
PCollection<Void> deleted =
115-
PCollectionList.of(deleteOpsList).apply("DeleteOps", Flatten.pCollections());
116-
117-
// 2. Process Observations:
118-
// Write observation mutations after deletes are complete.
119-
if (options.getWriteObsGraph()) {
120-
spannerClient.writeMutations(pipeline, "Observations", obsMutationList, deleted);
96+
processImport(pipeline, spannerClient, importName, graphPath, options.getSkipDelete());
12197
}
122-
123-
// 3. Process Nodes:
124-
// Write node mutations after deletes are complete.
125-
SpannerWriteResult writtenNodes =
126-
spannerClient.writeMutations(pipeline, "Nodes", nodeMutationList, deleted);
127-
128-
// 4. Process Edges:
129-
// Write edge mutations only after node mutations are complete to ensure referential integrity.
130-
spannerClient.writeMutations(pipeline, "Edges", edgeMutationList, writtenNodes.getOutput());
13198
}
13299

133100
/**
@@ -138,31 +105,33 @@ public static void buildPipeline(
138105
* @param importName The name of the import.
139106
* @param graphPath The full path to the graph data.
140107
* @param skipDelete Whether to skip delete operations.
141-
* @param deleteOpsList List to collect delete Ops.
142-
* @param nodeMutationList List to collect node mutations.
143-
* @param edgeMutationList List to collect edge mutations.
144-
* @param obsMutationList List to collect observation mutations.
145108
*/
146109
private static void processImport(
147110
Pipeline pipeline,
148111
SpannerClient spannerClient,
149112
String importName,
150113
String graphPath,
151-
boolean skipDelete,
152-
List<PCollection<Void>> deleteOpsList,
153-
List<PCollection<Mutation>> nodeMutationList,
154-
List<PCollection<Mutation>> edgeMutationList,
155-
List<PCollection<Mutation>> obsMutationList) {
114+
boolean skipDelete) {
156115
LOGGER.info("Import: {} Graph path: {}", importName, graphPath);
157116

158117
String provenance = "dc/base/" + importName;
159118

160119
// 1. Prepare Deletes:
161120
// Generate mutations to delete existing data for this import/provenance.
121+
// Create a dummy signal if deletes are skipped, so downstream dependencies are satisfied
122+
// immediately.
123+
PCollection<Void> deleteObsWait;
124+
PCollection<Void> deleteEdgesWait;
162125
if (!skipDelete) {
163-
List<PCollection<Void>> deleteOps =
164-
GraphReader.deleteExistingDataForImport(importName, provenance, pipeline, spannerClient);
165-
deleteOpsList.addAll(deleteOps);
126+
deleteObsWait = spannerClient.deleteObservationsForImport(importName, pipeline);
127+
deleteEdgesWait = spannerClient.deleteEdgesForImport(provenance, pipeline);
128+
} else {
129+
deleteObsWait =
130+
pipeline.apply(
131+
"CreateEmptyObsWait-" + importName, Create.empty(TypeDescriptor.of(Void.class)));
132+
deleteEdgesWait =
133+
pipeline.apply(
134+
"CreateEmptyEdgesWait-" + importName, Create.empty(TypeDescriptor.of(Void.class)));
166135
}
167136

168137
// 2. Read and Split Graph:
@@ -176,29 +145,50 @@ private static void processImport(
176145
PCollection<McfGraph> schemaNodes = graphNodes.get(PipelineUtils.SCHEMA_NODES_TAG);
177146

178147
// 3. Process Schema Nodes:
179-
// Combine schema nodes if required, then convert to Node and Edge mutations.
148+
// Combine nodes if required.
180149
PCollection<McfGraph> combinedGraph = schemaNodes;
181150
if (IMPORTS_TO_COMBINE.contains(importName)) {
182151
combinedGraph = PipelineUtils.combineGraphNodes(schemaNodes);
183152
}
153+
154+
// Convert all nodes to mutations
184155
PCollection<Mutation> nodeMutations =
185156
GraphReader.graphToNodes(
186-
importName, combinedGraph, spannerClient, nodeCounter, nodeInvalidTypeCounter)
157+
"NodeMutations-" + importName,
158+
combinedGraph,
159+
spannerClient,
160+
nodeCounter,
161+
nodeInvalidTypeCounter)
187162
.apply("ExtractNodeMutations-" + importName, Values.create());
188163
PCollection<Mutation> edgeMutations =
189-
GraphReader.graphToEdges(importName, combinedGraph, provenance, spannerClient, edgeCounter)
164+
GraphReader.graphToEdges(
165+
"EdgeMutations-" + importName,
166+
combinedGraph,
167+
provenance,
168+
spannerClient,
169+
edgeCounter)
190170
.apply("ExtractEdgeMutations-" + importName, Values.create());
191171

192-
nodeMutationList.add(nodeMutations);
193-
edgeMutationList.add(edgeMutations);
172+
// Write Nodes
173+
SpannerWriteResult writtenNodes =
174+
spannerClient.writeMutations(pipeline, "Nodes-" + importName, List.of(nodeMutations), null);
175+
176+
PCollection<Void> writeEdgesWait =
177+
PCollectionList.of(Arrays.asList(writtenNodes.getOutput(), deleteEdgesWait))
178+
.apply("FlattenDeleteOps-" + importName, Flatten.pCollections());
179+
// Write Edges (wait for Nodes and deletes)
180+
spannerClient.writeMutations(
181+
pipeline, "Edges-" + importName, List.of(edgeMutations), writeEdgesWait);
194182

195183
// 4. Process Observation Nodes:
196184
// Build an optimized graph from observation nodes and convert to Observation mutations.
197185
PCollection<McfOptimizedGraph> optimizedGraph =
198186
PipelineUtils.buildOptimizedMcfGraph(observationNodes);
199187
PCollection<Mutation> observationMutations =
200188
GraphReader.graphToObservations(optimizedGraph, importName, spannerClient, obsCounter)
201-
.apply("ExtractObsMutations", Values.create());
202-
obsMutationList.add(observationMutations);
189+
.apply("ExtractObsMutations-" + importName, Values.create());
190+
// Write Observations (wait for delete)
191+
spannerClient.writeMutations(
192+
pipeline, "Observations-" + importName, List.of(observationMutations), deleteObsWait);
203193
}
204194
}

pipeline/ingestion/src/test/java/org/datacommons/ingestion/pipeline/GraphIngestionPipelineIntegrationTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import java.io.IOException;
2424
import java.nio.charset.StandardCharsets;
2525
import java.nio.file.Files;
26-
import java.util.UUID;
2726
import org.apache.beam.runners.dataflow.DataflowRunner;
2827
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
2928
import org.apache.beam.sdk.Pipeline;
@@ -86,7 +85,7 @@ public class GraphIngestionPipelineIntegrationTest {
8685
private String region;
8786
private String emulatorHost;
8887
private boolean isLocal;
89-
private String importName = "TestImport-" + UUID.randomUUID().toString();
88+
private String importName = "TestImport";
9089
private String nodeNameValue = "Test Node Name";
9190
private SpannerClient spannerClient;
9291

pipeline/util/src/main/java/org/datacommons/ingestion/util/GraphReader.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
import java.io.Serializable;
66
import java.nio.charset.StandardCharsets;
77
import java.util.ArrayList;
8-
import java.util.Arrays;
98
import java.util.Collections;
109
import java.util.List;
1110
import java.util.Map;
12-
import org.apache.beam.sdk.Pipeline;
1311
import org.apache.beam.sdk.metrics.Counter;
1412
import org.apache.beam.sdk.transforms.DoFn;
1513
import org.apache.beam.sdk.transforms.ParDo;
@@ -97,6 +95,9 @@ public static List<Edge> graphToEdges(McfGraph graph, String provenance) {
9795
String dcid = GraphUtils.getPropertyValue(pv, "dcid");
9896
String subjectId = !dcid.isEmpty() ? dcid : McfUtil.stripNamespace(nodeEntry.getKey());
9997
for (Map.Entry<String, McfGraph.Values> entry : pv.entrySet()) {
98+
if (entry.getKey().equals("dcid")) {
99+
continue;
100+
}
100101
for (TypedValue val : entry.getValue().getTypedValuesList()) {
101102
if (val.getType() != ValueType.RESOLVED_REF) {
102103
int valSize = val.getValue().getBytes(StandardCharsets.UTF_8).length;
@@ -155,13 +156,6 @@ public static Observation graphToObservations(McfOptimizedGraph graph, String im
155156
return obs.build();
156157
}
157158

158-
public static List<PCollection<Void>> deleteExistingDataForImport(
159-
String importName, String provenance, Pipeline pipeline, SpannerClient spannerClient) {
160-
return Arrays.asList(
161-
spannerClient.deleteObservationsForImport(importName, pipeline),
162-
spannerClient.deleteEdgesForImport(provenance, pipeline));
163-
}
164-
165159
public static PCollection<KV<String, Mutation>> graphToObservations(
166160
PCollection<McfOptimizedGraph> graph,
167161
String importName,

pipeline/util/src/test/java/org/datacommons/ingestion/util/GraphReaderTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,14 @@ public void testGraphToEdges() {
173173
.setType(ValueType.TEXT)
174174
.setValue("Subject Node"))
175175
.build())
176+
.putPvs(
177+
"dcid",
178+
McfGraph.Values.newBuilder()
179+
.addTypedValues(
180+
TypedValue.newBuilder()
181+
.setType(ValueType.TEXT)
182+
.setValue("dcid_subject"))
183+
.build())
176184
.putPvs(
177185
"typeOf",
178186
McfGraph.Values.newBuilder()

0 commit comments

Comments
 (0)