Skip to content

Commit 8fa8b0d

Browse files
authored
[python] Fix with_shard feature for blob data (#6691)
1 parent caf6cf8 commit 8fa8b0d

File tree

7 files changed

+395
-251
lines changed

7 files changed

+395
-251
lines changed

paimon-python/pypaimon/read/reader/concat_batch_reader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def read_arrow_batch(self) -> Optional[RecordBatch]:
7575
min(self.split_end_row, self.cur_end) - self.split_start_row)
7676
elif cur_begin < self.split_end_row <= self.cur_end:
7777
return batch.slice(0, self.split_end_row - cur_begin)
78+
else:
79+
# return empty RecordBatch if the batch size has not reached split_start_row
80+
return pa.RecordBatch.from_arrays([], [])
7881
else:
7982
return batch
8083

paimon-python/pypaimon/read/scanner/full_starting_scanner.py

Lines changed: 107 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,20 @@ def _append_only_filter_by_shard(self, partitioned_files: defaultdict) -> (defau
117117
for entry in file_entries:
118118
total_row += entry.file.row_count
119119

120-
# Calculate number of rows this shard should process
121-
# Last shard handles all remaining rows (handles non-divisible cases)
122-
if self.idx_of_this_subtask == self.number_of_para_subtasks - 1:
123-
num_row = total_row - total_row // self.number_of_para_subtasks * self.idx_of_this_subtask
120+
# Calculate number of rows this shard should process using balanced distribution
121+
# Distribute remainder evenly among first few shards to avoid last shard overload
122+
base_rows_per_shard = total_row // self.number_of_para_subtasks
123+
remainder = total_row % self.number_of_para_subtasks
124+
125+
# Each of the first 'remainder' shards gets one extra row
126+
if self.idx_of_this_subtask < remainder:
127+
num_row = base_rows_per_shard + 1
128+
start_row = self.idx_of_this_subtask * (base_rows_per_shard + 1)
124129
else:
125-
num_row = total_row // self.number_of_para_subtasks
126-
# Calculate start row and end row position for current shard in all data
127-
start_row = self.idx_of_this_subtask * (total_row // self.number_of_para_subtasks)
130+
num_row = base_rows_per_shard
131+
start_row = (remainder * (base_rows_per_shard + 1) +
132+
(self.idx_of_this_subtask - remainder) * base_rows_per_shard)
133+
128134
end_row = start_row + num_row
129135

130136
plan_start_row = 0
@@ -160,24 +166,25 @@ def _append_only_filter_by_shard(self, partitioned_files: defaultdict) -> (defau
160166

161167
def _data_evolution_filter_by_shard(self, partitioned_files: defaultdict) -> (defaultdict, int, int):
162168
total_row = 0
163-
first_row_id_set = set()
164-
# Sort by file creation time to ensure consistent sharding
165169
for key, file_entries in partitioned_files.items():
166170
for entry in file_entries:
167-
if entry.file.first_row_id is None:
168-
total_row += entry.file.row_count
169-
elif entry.file.first_row_id not in first_row_id_set:
170-
first_row_id_set.add(entry.file.first_row_id)
171+
if not self._is_blob_file(entry.file.file_name):
171172
total_row += entry.file.row_count
172173

173-
# Calculate number of rows this shard should process
174-
# Last shard handles all remaining rows (handles non-divisible cases)
175-
if self.idx_of_this_subtask == self.number_of_para_subtasks - 1:
176-
num_row = total_row - total_row // self.number_of_para_subtasks * self.idx_of_this_subtask
174+
# Calculate number of rows this shard should process using balanced distribution
175+
# Distribute remainder evenly among first few shards to avoid last shard overload
176+
base_rows_per_shard = total_row // self.number_of_para_subtasks
177+
remainder = total_row % self.number_of_para_subtasks
178+
179+
# Each of the first 'remainder' shards gets one extra row
180+
if self.idx_of_this_subtask < remainder:
181+
num_row = base_rows_per_shard + 1
182+
start_row = self.idx_of_this_subtask * (base_rows_per_shard + 1)
177183
else:
178-
num_row = total_row // self.number_of_para_subtasks
179-
# Calculate start row and end row position for current shard in all data
180-
start_row = self.idx_of_this_subtask * (total_row // self.number_of_para_subtasks)
184+
num_row = base_rows_per_shard
185+
start_row = (remainder * (base_rows_per_shard + 1) +
186+
(self.idx_of_this_subtask - remainder) * base_rows_per_shard)
187+
181188
end_row = start_row + num_row
182189

183190
plan_start_row = 0
@@ -188,14 +195,13 @@ def _data_evolution_filter_by_shard(self, partitioned_files: defaultdict) -> (de
188195
# Iterate through all file entries to find files that overlap with current shard range
189196
for key, file_entries in partitioned_files.items():
190197
filtered_entries = []
191-
first_row_id_set = set()
198+
blob_added = False # If it is true, all blobs corresponding to this data file are added
192199
for entry in file_entries:
193-
if entry.file.first_row_id is not None:
194-
if entry.file.first_row_id in first_row_id_set:
200+
if self._is_blob_file(entry.file.file_name):
201+
if blob_added:
195202
filtered_entries.append(entry)
196-
continue
197-
else:
198-
first_row_id_set.add(entry.file.first_row_id)
203+
continue
204+
blob_added = False
199205
entry_begin_row = entry_end_row # Starting row position of current file in all data
200206
entry_end_row += entry.file.row_count # Update to row position after current file
201207

@@ -213,18 +219,24 @@ def _data_evolution_filter_by_shard(self, partitioned_files: defaultdict) -> (de
213219
plan_end_row = end_row - splits_start_row
214220
# Add files that overlap with shard range to result
215221
filtered_entries.append(entry)
222+
blob_added = True
216223
if filtered_entries:
217224
filtered_partitioned_files[key] = filtered_entries
218225

219226
return filtered_partitioned_files, plan_start_row, plan_end_row
220227

221228
def _compute_split_start_end_row(self, splits: List[Split], plan_start_row, plan_end_row):
222229
file_end_row = 0 # end row position of current file in all data
230+
223231
for split in splits:
232+
row_cnt = 0
224233
files = split.files
225234
split_start_row = file_end_row
226235
# Iterate through all file entries to find files that overlap with current shard range
227236
for file in files:
237+
if self._is_blob_file(file.file_name):
238+
continue
239+
row_cnt += file.row_count
228240
file_begin_row = file_end_row # Starting row position of current file in all data
229241
file_end_row += file.row_count # Update to row position after current file
230242

@@ -238,7 +250,7 @@ def _compute_split_start_end_row(self, splits: List[Split], plan_start_row, plan
238250
if split.split_start_row is None:
239251
split.split_start_row = 0
240252
if split.split_end_row is None:
241-
split.split_end_row = split.row_count
253+
split.split_end_row = row_cnt
242254

243255
def _primary_key_filter_by_shard(self, file_entries: List[ManifestEntry]) -> List[ManifestEntry]:
244256
filtered_entries = []
@@ -359,61 +371,20 @@ def weight_func(fl: List[DataFileMeta]) -> int:
359371
splits += self._build_split_from_pack(flatten_packed_files, file_entries, True)
360372
return splits
361373

362-
def _build_split_from_pack(self, packed_files, file_entries, for_primary_key_split: bool) -> List['Split']:
363-
splits = []
364-
for file_group in packed_files:
365-
raw_convertible = True
366-
if for_primary_key_split:
367-
raw_convertible = len(file_group) == 1
368-
369-
file_paths = []
370-
total_file_size = 0
371-
total_record_count = 0
372-
373-
for data_file in file_group:
374-
data_file.set_file_path(self.table.table_path, file_entries[0].partition,
375-
file_entries[0].bucket)
376-
file_paths.append(data_file.file_path)
377-
total_file_size += data_file.file_size
378-
total_record_count += data_file.row_count
379-
380-
if file_paths:
381-
split = Split(
382-
files=file_group,
383-
partition=file_entries[0].partition,
384-
bucket=file_entries[0].bucket,
385-
_file_paths=file_paths,
386-
_row_count=total_record_count,
387-
_file_size=total_file_size,
388-
raw_convertible=raw_convertible
389-
)
390-
splits.append(split)
391-
return splits
392-
393-
@staticmethod
394-
def _pack_for_ordered(items: List, weight_func: Callable, target_weight: int) -> List[List]:
395-
packed = []
396-
bin_items = []
397-
bin_weight = 0
398-
399-
for item in items:
400-
weight = weight_func(item)
401-
if bin_weight + weight > target_weight and len(bin_items) > 0:
402-
packed.append(list(bin_items))
403-
bin_items.clear()
404-
bin_weight = 0
405-
406-
bin_weight += weight
407-
bin_items.append(item)
408-
409-
if len(bin_items) > 0:
410-
packed.append(bin_items)
374+
def _create_data_evolution_splits(self, file_entries: List[ManifestEntry]) -> List['Split']:
375+
def sort_key(manifest_entry: ManifestEntry) -> tuple:
376+
first_row_id = manifest_entry.file.first_row_id if manifest_entry.file.first_row_id is not None else float(
377+
'-inf')
378+
is_blob = 1 if self._is_blob_file(manifest_entry.file.file_name) else 0
379+
# For files with same firstRowId, sort by maxSequenceNumber in descending order
380+
# (larger sequence number means more recent data)
381+
max_seq = manifest_entry.file.max_sequence_number
382+
return first_row_id, is_blob, -max_seq
411383

412-
return packed
384+
sorted_entries = sorted(file_entries, key=sort_key)
413385

414-
def _create_data_evolution_splits(self, file_entries: List[ManifestEntry]) -> List['Split']:
415386
partitioned_files = defaultdict(list)
416-
for entry in file_entries:
387+
for entry in sorted_entries:
417388
partitioned_files[(tuple(entry.partition.values), entry.bucket)].append(entry)
418389

419390
if self.idx_of_this_subtask is not None:
@@ -423,11 +394,11 @@ def weight_func(file_list: List[DataFileMeta]) -> int:
423394
return max(sum(f.file_size for f in file_list), self.open_file_cost)
424395

425396
splits = []
426-
for key, file_entries in partitioned_files.items():
427-
if not file_entries:
397+
for key, sorted_entries in partitioned_files.items():
398+
if not sorted_entries:
428399
continue
429400

430-
data_files: List[DataFileMeta] = [e.file for e in file_entries]
401+
data_files: List[DataFileMeta] = [e.file for e in sorted_entries]
431402

432403
# Split files by firstRowId for data evolution
433404
split_by_row_id = self._split_by_row_id(data_files)
@@ -442,7 +413,7 @@ def weight_func(file_list: List[DataFileMeta]) -> int:
442413
for pack in packed_files
443414
]
444415

445-
splits += self._build_split_from_pack(flatten_packed_files, file_entries, False)
416+
splits += self._build_split_from_pack(flatten_packed_files, sorted_entries, False)
446417

447418
if self.idx_of_this_subtask is not None:
448419
self._compute_split_start_end_row(splits, plan_start_row, plan_end_row)
@@ -451,18 +422,8 @@ def weight_func(file_list: List[DataFileMeta]) -> int:
451422
def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta]]:
452423
split_by_row_id = []
453424

454-
def sort_key(file: DataFileMeta) -> tuple:
455-
first_row_id = file.first_row_id if file.first_row_id is not None else float('-inf')
456-
is_blob = 1 if self._is_blob_file(file.file_name) else 0
457-
# For files with same firstRowId, sort by maxSequenceNumber in descending order
458-
# (larger sequence number means more recent data)
459-
max_seq = file.max_sequence_number
460-
return (first_row_id, is_blob, -max_seq)
461-
462-
sorted_files = sorted(files, key=sort_key)
463-
464425
# Filter blob files to only include those within the row ID range of non-blob files
465-
sorted_files = self._filter_blob(sorted_files)
426+
sorted_files = self._filter_blob(files)
466427

467428
# Split files by firstRowId
468429
last_row_id = -1
@@ -499,6 +460,58 @@ def sort_key(file: DataFileMeta) -> tuple:
499460

500461
return split_by_row_id
501462

463+
def _build_split_from_pack(self, packed_files, file_entries, for_primary_key_split: bool) -> List['Split']:
464+
splits = []
465+
for file_group in packed_files:
466+
raw_convertible = True
467+
if for_primary_key_split:
468+
raw_convertible = len(file_group) == 1
469+
470+
file_paths = []
471+
total_file_size = 0
472+
total_record_count = 0
473+
474+
for data_file in file_group:
475+
data_file.set_file_path(self.table.table_path, file_entries[0].partition,
476+
file_entries[0].bucket)
477+
file_paths.append(data_file.file_path)
478+
total_file_size += data_file.file_size
479+
total_record_count += data_file.row_count
480+
481+
if file_paths:
482+
split = Split(
483+
files=file_group,
484+
partition=file_entries[0].partition,
485+
bucket=file_entries[0].bucket,
486+
_file_paths=file_paths,
487+
_row_count=total_record_count,
488+
_file_size=total_file_size,
489+
raw_convertible=raw_convertible
490+
)
491+
splits.append(split)
492+
return splits
493+
494+
@staticmethod
495+
def _pack_for_ordered(items: List, weight_func: Callable, target_weight: int) -> List[List]:
496+
packed = []
497+
bin_items = []
498+
bin_weight = 0
499+
500+
for item in items:
501+
weight = weight_func(item)
502+
if bin_weight + weight > target_weight and len(bin_items) > 0:
503+
packed.append(list(bin_items))
504+
bin_items.clear()
505+
bin_weight = 0
506+
507+
bin_weight += weight
508+
bin_items.append(item)
509+
510+
if len(bin_items) > 0:
511+
packed.append(bin_items)
512+
513+
return packed
514+
502515
@staticmethod
503516
def _is_blob_file(file_name: str) -> bool:
504517
return file_name.endswith('.blob')

paimon-python/pypaimon/read/table_read.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def to_arrow(self, splits: List[Split]) -> Optional[pyarrow.Table]:
8080
schema = PyarrowFieldParser.from_paimon_schema(self.read_type)
8181
table_list = []
8282
for batch in iter(batch_reader.read_next_batch, None):
83+
if batch.num_rows == 0:
84+
continue
8385
table_list.append(self._try_to_pad_batch_by_schema(batch, schema))
8486

8587
if not table_list:

0 commit comments

Comments
 (0)