@@ -117,14 +117,20 @@ def _append_only_filter_by_shard(self, partitioned_files: defaultdict) -> (defau
117117 for entry in file_entries :
118118 total_row += entry .file .row_count
119119
120- # Calculate number of rows this shard should process
121- # Last shard handles all remaining rows (handles non-divisible cases)
122- if self .idx_of_this_subtask == self .number_of_para_subtasks - 1 :
123- num_row = total_row - total_row // self .number_of_para_subtasks * self .idx_of_this_subtask
120+ # Calculate number of rows this shard should process using balanced distribution
121+ # Distribute remainder evenly among first few shards to avoid last shard overload
122+ base_rows_per_shard = total_row // self .number_of_para_subtasks
123+ remainder = total_row % self .number_of_para_subtasks
124+
125+ # Each of the first 'remainder' shards gets one extra row
126+ if self .idx_of_this_subtask < remainder :
127+ num_row = base_rows_per_shard + 1
128+ start_row = self .idx_of_this_subtask * (base_rows_per_shard + 1 )
124129 else :
125- num_row = total_row // self .number_of_para_subtasks
126- # Calculate start row and end row position for current shard in all data
127- start_row = self .idx_of_this_subtask * (total_row // self .number_of_para_subtasks )
130+ num_row = base_rows_per_shard
131+ start_row = (remainder * (base_rows_per_shard + 1 ) +
132+ (self .idx_of_this_subtask - remainder ) * base_rows_per_shard )
133+
128134 end_row = start_row + num_row
129135
130136 plan_start_row = 0
@@ -160,24 +166,25 @@ def _append_only_filter_by_shard(self, partitioned_files: defaultdict) -> (defau
160166
161167 def _data_evolution_filter_by_shard (self , partitioned_files : defaultdict ) -> (defaultdict , int , int ):
162168 total_row = 0
163- first_row_id_set = set ()
164- # Sort by file creation time to ensure consistent sharding
165169 for key , file_entries in partitioned_files .items ():
166170 for entry in file_entries :
167- if entry .file .first_row_id is None :
168- total_row += entry .file .row_count
169- elif entry .file .first_row_id not in first_row_id_set :
170- first_row_id_set .add (entry .file .first_row_id )
171+ if not self ._is_blob_file (entry .file .file_name ):
171172 total_row += entry .file .row_count
172173
173- # Calculate number of rows this shard should process
174- # Last shard handles all remaining rows (handles non-divisible cases)
175- if self .idx_of_this_subtask == self .number_of_para_subtasks - 1 :
176- num_row = total_row - total_row // self .number_of_para_subtasks * self .idx_of_this_subtask
174+ # Calculate number of rows this shard should process using balanced distribution
175+ # Distribute remainder evenly among first few shards to avoid last shard overload
176+ base_rows_per_shard = total_row // self .number_of_para_subtasks
177+ remainder = total_row % self .number_of_para_subtasks
178+
179+ # Each of the first 'remainder' shards gets one extra row
180+ if self .idx_of_this_subtask < remainder :
181+ num_row = base_rows_per_shard + 1
182+ start_row = self .idx_of_this_subtask * (base_rows_per_shard + 1 )
177183 else :
178- num_row = total_row // self .number_of_para_subtasks
179- # Calculate start row and end row position for current shard in all data
180- start_row = self .idx_of_this_subtask * (total_row // self .number_of_para_subtasks )
184+ num_row = base_rows_per_shard
185+ start_row = (remainder * (base_rows_per_shard + 1 ) +
186+ (self .idx_of_this_subtask - remainder ) * base_rows_per_shard )
187+
181188 end_row = start_row + num_row
182189
183190 plan_start_row = 0
@@ -188,14 +195,13 @@ def _data_evolution_filter_by_shard(self, partitioned_files: defaultdict) -> (de
188195 # Iterate through all file entries to find files that overlap with current shard range
189196 for key , file_entries in partitioned_files .items ():
190197 filtered_entries = []
191- first_row_id_set = set ()
198+ blob_added = False # If it is true, all blobs corresponding to this data file are added
192199 for entry in file_entries :
193- if entry .file .first_row_id is not None :
194- if entry . file . first_row_id in first_row_id_set :
200+ if self . _is_blob_file ( entry .file .file_name ) :
201+ if blob_added :
195202 filtered_entries .append (entry )
196- continue
197- else :
198- first_row_id_set .add (entry .file .first_row_id )
203+ continue
204+ blob_added = False
199205 entry_begin_row = entry_end_row # Starting row position of current file in all data
200206 entry_end_row += entry .file .row_count # Update to row position after current file
201207
@@ -213,18 +219,24 @@ def _data_evolution_filter_by_shard(self, partitioned_files: defaultdict) -> (de
213219 plan_end_row = end_row - splits_start_row
214220 # Add files that overlap with shard range to result
215221 filtered_entries .append (entry )
222+ blob_added = True
216223 if filtered_entries :
217224 filtered_partitioned_files [key ] = filtered_entries
218225
219226 return filtered_partitioned_files , plan_start_row , plan_end_row
220227
221228 def _compute_split_start_end_row (self , splits : List [Split ], plan_start_row , plan_end_row ):
222229 file_end_row = 0 # end row position of current file in all data
230+
223231 for split in splits :
232+ row_cnt = 0
224233 files = split .files
225234 split_start_row = file_end_row
226235 # Iterate through all file entries to find files that overlap with current shard range
227236 for file in files :
237+ if self ._is_blob_file (file .file_name ):
238+ continue
239+ row_cnt += file .row_count
228240 file_begin_row = file_end_row # Starting row position of current file in all data
229241 file_end_row += file .row_count # Update to row position after current file
230242
@@ -238,7 +250,7 @@ def _compute_split_start_end_row(self, splits: List[Split], plan_start_row, plan
238250 if split .split_start_row is None :
239251 split .split_start_row = 0
240252 if split .split_end_row is None :
241- split .split_end_row = split . row_count
253+ split .split_end_row = row_cnt
242254
243255 def _primary_key_filter_by_shard (self , file_entries : List [ManifestEntry ]) -> List [ManifestEntry ]:
244256 filtered_entries = []
@@ -359,61 +371,20 @@ def weight_func(fl: List[DataFileMeta]) -> int:
359371 splits += self ._build_split_from_pack (flatten_packed_files , file_entries , True )
360372 return splits
361373
362- def _build_split_from_pack (self , packed_files , file_entries , for_primary_key_split : bool ) -> List ['Split' ]:
363- splits = []
364- for file_group in packed_files :
365- raw_convertible = True
366- if for_primary_key_split :
367- raw_convertible = len (file_group ) == 1
368-
369- file_paths = []
370- total_file_size = 0
371- total_record_count = 0
372-
373- for data_file in file_group :
374- data_file .set_file_path (self .table .table_path , file_entries [0 ].partition ,
375- file_entries [0 ].bucket )
376- file_paths .append (data_file .file_path )
377- total_file_size += data_file .file_size
378- total_record_count += data_file .row_count
379-
380- if file_paths :
381- split = Split (
382- files = file_group ,
383- partition = file_entries [0 ].partition ,
384- bucket = file_entries [0 ].bucket ,
385- _file_paths = file_paths ,
386- _row_count = total_record_count ,
387- _file_size = total_file_size ,
388- raw_convertible = raw_convertible
389- )
390- splits .append (split )
391- return splits
392-
393- @staticmethod
394- def _pack_for_ordered (items : List , weight_func : Callable , target_weight : int ) -> List [List ]:
395- packed = []
396- bin_items = []
397- bin_weight = 0
398-
399- for item in items :
400- weight = weight_func (item )
401- if bin_weight + weight > target_weight and len (bin_items ) > 0 :
402- packed .append (list (bin_items ))
403- bin_items .clear ()
404- bin_weight = 0
405-
406- bin_weight += weight
407- bin_items .append (item )
408-
409- if len (bin_items ) > 0 :
410- packed .append (bin_items )
374+ def _create_data_evolution_splits (self , file_entries : List [ManifestEntry ]) -> List ['Split' ]:
375+ def sort_key (manifest_entry : ManifestEntry ) -> tuple :
376+ first_row_id = manifest_entry .file .first_row_id if manifest_entry .file .first_row_id is not None else float (
377+ '-inf' )
378+ is_blob = 1 if self ._is_blob_file (manifest_entry .file .file_name ) else 0
379+ # For files with same firstRowId, sort by maxSequenceNumber in descending order
380+ # (larger sequence number means more recent data)
381+ max_seq = manifest_entry .file .max_sequence_number
382+ return first_row_id , is_blob , - max_seq
411383
412- return packed
384+ sorted_entries = sorted ( file_entries , key = sort_key )
413385
414- def _create_data_evolution_splits (self , file_entries : List [ManifestEntry ]) -> List ['Split' ]:
415386 partitioned_files = defaultdict (list )
416- for entry in file_entries :
387+ for entry in sorted_entries :
417388 partitioned_files [(tuple (entry .partition .values ), entry .bucket )].append (entry )
418389
419390 if self .idx_of_this_subtask is not None :
@@ -423,11 +394,11 @@ def weight_func(file_list: List[DataFileMeta]) -> int:
423394 return max (sum (f .file_size for f in file_list ), self .open_file_cost )
424395
425396 splits = []
426- for key , file_entries in partitioned_files .items ():
427- if not file_entries :
397+ for key , sorted_entries in partitioned_files .items ():
398+ if not sorted_entries :
428399 continue
429400
430- data_files : List [DataFileMeta ] = [e .file for e in file_entries ]
401+ data_files : List [DataFileMeta ] = [e .file for e in sorted_entries ]
431402
432403 # Split files by firstRowId for data evolution
433404 split_by_row_id = self ._split_by_row_id (data_files )
@@ -442,7 +413,7 @@ def weight_func(file_list: List[DataFileMeta]) -> int:
442413 for pack in packed_files
443414 ]
444415
445- splits += self ._build_split_from_pack (flatten_packed_files , file_entries , False )
416+ splits += self ._build_split_from_pack (flatten_packed_files , sorted_entries , False )
446417
447418 if self .idx_of_this_subtask is not None :
448419 self ._compute_split_start_end_row (splits , plan_start_row , plan_end_row )
@@ -451,18 +422,8 @@ def weight_func(file_list: List[DataFileMeta]) -> int:
451422 def _split_by_row_id (self , files : List [DataFileMeta ]) -> List [List [DataFileMeta ]]:
452423 split_by_row_id = []
453424
454- def sort_key (file : DataFileMeta ) -> tuple :
455- first_row_id = file .first_row_id if file .first_row_id is not None else float ('-inf' )
456- is_blob = 1 if self ._is_blob_file (file .file_name ) else 0
457- # For files with same firstRowId, sort by maxSequenceNumber in descending order
458- # (larger sequence number means more recent data)
459- max_seq = file .max_sequence_number
460- return (first_row_id , is_blob , - max_seq )
461-
462- sorted_files = sorted (files , key = sort_key )
463-
464425 # Filter blob files to only include those within the row ID range of non-blob files
465- sorted_files = self ._filter_blob (sorted_files )
426+ sorted_files = self ._filter_blob (files )
466427
467428 # Split files by firstRowId
468429 last_row_id = - 1
@@ -499,6 +460,58 @@ def sort_key(file: DataFileMeta) -> tuple:
499460
500461 return split_by_row_id
501462
463+ def _build_split_from_pack (self , packed_files , file_entries , for_primary_key_split : bool ) -> List ['Split' ]:
464+ splits = []
465+ for file_group in packed_files :
466+ raw_convertible = True
467+ if for_primary_key_split :
468+ raw_convertible = len (file_group ) == 1
469+
470+ file_paths = []
471+ total_file_size = 0
472+ total_record_count = 0
473+
474+ for data_file in file_group :
475+ data_file .set_file_path (self .table .table_path , file_entries [0 ].partition ,
476+ file_entries [0 ].bucket )
477+ file_paths .append (data_file .file_path )
478+ total_file_size += data_file .file_size
479+ total_record_count += data_file .row_count
480+
481+ if file_paths :
482+ split = Split (
483+ files = file_group ,
484+ partition = file_entries [0 ].partition ,
485+ bucket = file_entries [0 ].bucket ,
486+ _file_paths = file_paths ,
487+ _row_count = total_record_count ,
488+ _file_size = total_file_size ,
489+ raw_convertible = raw_convertible
490+ )
491+ splits .append (split )
492+ return splits
493+
494+ @staticmethod
495+ def _pack_for_ordered (items : List , weight_func : Callable , target_weight : int ) -> List [List ]:
496+ packed = []
497+ bin_items = []
498+ bin_weight = 0
499+
500+ for item in items :
501+ weight = weight_func (item )
502+ if bin_weight + weight > target_weight and len (bin_items ) > 0 :
503+ packed .append (list (bin_items ))
504+ bin_items .clear ()
505+ bin_weight = 0
506+
507+ bin_weight += weight
508+ bin_items .append (item )
509+
510+ if len (bin_items ) > 0 :
511+ packed .append (bin_items )
512+
513+ return packed
514+
502515 @staticmethod
503516 def _is_blob_file (file_name : str ) -> bool :
504517 return file_name .endswith ('.blob' )
0 commit comments