Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;

/** Facade for a {@link List<T>} that keeps track of weight, for cache limit reasons. */
public class WeightedList<T> implements Weighted {
Expand Down Expand Up @@ -71,14 +72,6 @@ public void addAll(List<T> values, long weight) {
}

public void accumulateWeight(long weight) {
this.weight.accumulateAndGet(
weight,
(first, second) -> {
try {
return Math.addExact(first, second);
} catch (ArithmeticException e) {
return Long.MAX_VALUE;
}
});
this.weight.accumulateAndGet(weight, LongMath::saturatedAdd);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@
import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;

/**
* Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn
Expand Down Expand Up @@ -249,15 +251,11 @@ static class BlocksPrefix<T> extends Blocks<T> implements Shrinkable<BlocksPrefi

@Override
public long getWeight() {
try {
long sum = 8 + blocks.size() * 8L;
for (Block<T> block : blocks) {
sum = Math.addExact(sum, block.getWeight());
}
return sum;
} catch (ArithmeticException e) {
return Long.MAX_VALUE;
long sum = 8 + blocks.size() * 8L;
for (Block<T> block : blocks) {
sum = LongMath.saturatedAdd(sum, block.getWeight());
}
return sum;
}

BlocksPrefix(List<Block<T>> blocks) {
Expand All @@ -282,8 +280,7 @@ public List<Block<T>> getBlocks() {

@AutoValue
abstract static class Block<T> implements Weighted {
private static final Block<Void> EMPTY =
fromValues(WeightedList.of(Collections.emptyList(), 0), null);
private static final Block<Void> EMPTY = fromValues(ImmutableList.of(), 0, null);

@SuppressWarnings("unchecked") // Based upon as Collections.emptyList()
public static <T> Block<T> emptyBlock() {
Expand All @@ -299,21 +296,37 @@ public static <T> Block<T> mutatedBlock(WeightedList<T> values) {
}

public static <T> Block<T> fromValues(List<T> values, @Nullable ByteString nextToken) {
return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken);
if (values.isEmpty() && nextToken == null) {
return emptyBlock();
}
ImmutableList<T> immutableValues = ImmutableList.copyOf(values);
long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE;
for (T value : immutableValues) {
listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value));
}
return fromValues(immutableValues, listWeight, nextToken);
}

public static <T> Block<T> fromValues(
WeightedList<T> values, @Nullable ByteString nextToken) {
long weight = values.getWeight() + 24;
if (values.isEmpty() && nextToken == null) {
return emptyBlock();
}
return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken);
}

private static <T> Block<T> fromValues(
ImmutableList<T> values, long listWeight, @Nullable ByteString nextToken) {
long weight = LongMath.saturatedAdd(listWeight, 24);
if (nextToken != null) {
if (nextToken.isEmpty()) {
nextToken = ByteString.EMPTY;
} else {
weight += Caches.weigh(nextToken);
weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken));
}
}
return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>(
values.getBacking(), nextToken, weight);
values, nextToken, weight);
}

abstract List<T> getValues();
Expand Down Expand Up @@ -372,10 +385,12 @@ public void remove(Set<Object> toRemoveStructuralValues) {
totalSize += tBlock.getValues().size();
}

WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
long weight = 0;
List<T> blockValuesToKeep = new ArrayList<>();
for (Block<T> block : blocks) {
blockValuesToKeep.clear();
boolean valueRemovedFromBlock = false;
List<T> blockValuesToKeep = new ArrayList<>();
for (T value : block.getValues()) {
if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) {
blockValuesToKeep.add(value);
Expand All @@ -387,13 +402,19 @@ public void remove(Set<Object> toRemoveStructuralValues) {
// If any value was removed from this block, need to estimate the weight again.
// Otherwise, just reuse the block's weight.
if (valueRemovedFromBlock) {
allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues()));
allValues.addAll(blockValuesToKeep);
for (T value : blockValuesToKeep) {
weight = LongMath.saturatedAdd(weight, Caches.weigh(value));
}
} else {
allValues.addAll(block.getValues(), block.getWeight());
allValues.addAll(block.getValues());
weight = LongMath.saturatedAdd(weight, block.getWeight());
}
}

cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
cache.put(
IterableCacheKey.INSTANCE,
new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
}

/**
Expand Down Expand Up @@ -484,21 +505,22 @@ private void appendHelper(List<T> newValues, long newWeight) {
for (Block<T> block : blocks) {
totalSize += block.getValues().size();
}
WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 0L);
ImmutableList.Builder<T> allValues = ImmutableList.builderWithExpectedSize(totalSize);
long weight = 0;
for (Block<T> block : blocks) {
allValues.addAll(block.getValues(), block.getWeight());
allValues.addAll(block.getValues());
weight = LongMath.saturatedAdd(weight, block.getWeight());
}
if (newWeight < 0) {
if (newValues.size() == 1) {
// Optimize weighing of the common value state as single single-element bag state.
newWeight = Caches.weigh(newValues.get(0));
} else {
newWeight = Caches.weigh(newValues);
newWeight = 0;
for (T value : newValues) {
newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value));
}
}
allValues.addAll(newValues, newWeight);
allValues.addAll(newValues);
weight = LongMath.saturatedAdd(weight, newWeight);

cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues)));
cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null)));
}

class CachingStateIterator implements PrefetchableIterator<T> {
Expand Down Expand Up @@ -580,8 +602,7 @@ public boolean hasNext() {
return false;
}
// Release the block while we are loading the next one.
currentBlock =
Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY);
currentBlock = Block.emptyBlock();

@Nullable Blocks<T> existing = cache.peek(IterableCacheKey.INSTANCE);
boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);
Expand Down
Loading