|
| 1 | +################################################################################ |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 3 | +# or more contributor license agreements. See the NOTICE file |
| 4 | +# distributed with this work for additional information |
| 5 | +# regarding copyright ownership. The ASF licenses this file |
| 6 | +# to you under the Apache License, Version 2.0 (the |
| 7 | +# "License"); you may not use this file except in compliance |
| 8 | +# with the License. You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | +################################################################################ |
| 18 | +""" |
| 19 | +Module to read a Paimon table into a Ray Dataset, by using the Ray Datasource API. |
| 20 | +""" |
| 21 | +import heapq |
| 22 | +import itertools |
| 23 | +import logging |
| 24 | +from functools import partial |
| 25 | +from typing import List, Optional, Iterable |
| 26 | + |
| 27 | +import pyarrow |
| 28 | + |
| 29 | +from pypaimon.read.split import Split |
| 30 | +from pypaimon.read.table_read import TableRead |
| 31 | +from pypaimon.schema.data_types import PyarrowFieldParser |
| 32 | + |
| 33 | +logger = logging.getLogger(__name__) |
| 34 | + |
| 35 | +from ray.data.datasource import Datasource |
| 36 | + |
| 37 | + |
| 38 | +class PaimonDatasource(Datasource): |
| 39 | + """ |
| 40 | + Ray Data Datasource implementation for reading Paimon tables. |
| 41 | +
|
| 42 | + This datasource enables distributed parallel reading of Paimon table splits, |
| 43 | + allowing Ray to read multiple splits concurrently across the cluster. |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__(self, table_read: TableRead, splits: List[Split]): |
| 47 | + """ |
| 48 | + Initialize PaimonDatasource. |
| 49 | +
|
| 50 | + Args: |
| 51 | + table_read: TableRead instance for reading data |
| 52 | + splits: List of splits to read |
| 53 | + """ |
| 54 | + self.table_read = table_read |
| 55 | + self.splits = splits |
| 56 | + self._schema = None |
| 57 | + |
| 58 | + def get_name(self) -> str: |
| 59 | + identifier = self.table_read.table.identifier |
| 60 | + table_name = identifier.get_full_name() if hasattr(identifier, 'get_full_name') else str(identifier) |
| 61 | + return f"PaimonTable({table_name})" |
| 62 | + |
| 63 | + def estimate_inmemory_data_size(self) -> Optional[int]: |
| 64 | + if not self.splits: |
| 65 | + return 0 |
| 66 | + |
| 67 | + # Sum up file sizes from all splits |
| 68 | + total_size = sum(split.file_size for split in self.splits) |
| 69 | + return total_size if total_size > 0 else None |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def _distribute_splits_into_equal_chunks( |
| 73 | + splits: Iterable[Split], n_chunks: int |
| 74 | + ) -> List[List[Split]]: |
| 75 | + """ |
| 76 | + Implement a greedy knapsack algorithm to distribute the splits across tasks, |
| 77 | + based on their file size, as evenly as possible. |
| 78 | + """ |
| 79 | + chunks = [list() for _ in range(n_chunks)] |
| 80 | + chunk_sizes = [(0, chunk_id) for chunk_id in range(n_chunks)] |
| 81 | + heapq.heapify(chunk_sizes) |
| 82 | + |
| 83 | + # From largest to smallest, add the splits to the smallest chunk one at a time |
| 84 | + for split in sorted( |
| 85 | + splits, key=lambda s: s.file_size if hasattr(s, 'file_size') and s.file_size > 0 else 0, reverse=True |
| 86 | + ): |
| 87 | + smallest_chunk = heapq.heappop(chunk_sizes) |
| 88 | + chunks[smallest_chunk[1]].append(split) |
| 89 | + split_size = split.file_size if hasattr(split, 'file_size') and split.file_size > 0 else 0 |
| 90 | + heapq.heappush( |
| 91 | + chunk_sizes, |
| 92 | + (smallest_chunk[0] + split_size, smallest_chunk[1]), |
| 93 | + ) |
| 94 | + |
| 95 | + return chunks |
| 96 | + |
| 97 | + def get_read_tasks(self, parallelism: int) -> List: |
| 98 | + """Return a list of read tasks that can be executed in parallel.""" |
| 99 | + from ray.data.datasource import ReadTask |
| 100 | + from ray.data.block import BlockMetadata |
| 101 | + |
| 102 | + # Validate parallelism parameter |
| 103 | + if parallelism < 1: |
| 104 | + raise ValueError(f"parallelism must be at least 1, got {parallelism}") |
| 105 | + |
| 106 | + # Get schema for metadata |
| 107 | + if self._schema is None: |
| 108 | + self._schema = PyarrowFieldParser.from_paimon_schema(self.table_read.read_type) |
| 109 | + |
| 110 | + # Adjust parallelism if it exceeds the number of splits |
| 111 | + if parallelism > len(self.splits): |
| 112 | + parallelism = len(self.splits) |
| 113 | + logger.warning( |
| 114 | + f"Reducing the parallelism to {parallelism}, as that is the number of splits" |
| 115 | + ) |
| 116 | + |
| 117 | + # Store necessary information for creating readers in Ray workers |
| 118 | + # Extract these to avoid serializing the entire self object in closures |
| 119 | + table = self.table_read.table |
| 120 | + predicate = self.table_read.predicate |
| 121 | + read_type = self.table_read.read_type |
| 122 | + schema = self._schema |
| 123 | + |
| 124 | + # Create a partial function to avoid capturing self in closure |
| 125 | + # This reduces serialization overhead (see https://github.com/ray-project/ray/issues/49107) |
| 126 | + def _get_read_task( |
| 127 | + splits: List[Split], |
| 128 | + table=table, |
| 129 | + predicate=predicate, |
| 130 | + read_type=read_type, |
| 131 | + schema=schema, |
| 132 | + ) -> Iterable[pyarrow.Table]: |
| 133 | + """Read function that will be executed by Ray workers.""" |
| 134 | + from pypaimon.read.table_read import TableRead |
| 135 | + worker_table_read = TableRead(table, predicate, read_type) |
| 136 | + |
| 137 | + # Read all splits in this chunk |
| 138 | + arrow_table = worker_table_read.to_arrow(splits) |
| 139 | + |
| 140 | + # Return as a list to allow Ray to split into multiple blocks if needed |
| 141 | + if arrow_table is not None and arrow_table.num_rows > 0: |
| 142 | + return [arrow_table] |
| 143 | + else: |
| 144 | + # Return empty table with correct schema |
| 145 | + empty_table = pyarrow.Table.from_arrays( |
| 146 | + [pyarrow.array([], type=field.type) for field in schema], |
| 147 | + schema=schema |
| 148 | + ) |
| 149 | + return [empty_table] |
| 150 | + |
| 151 | + # Use partial to create read function without capturing self |
| 152 | + get_read_task = partial( |
| 153 | + _get_read_task, |
| 154 | + table=table, |
| 155 | + predicate=predicate, |
| 156 | + read_type=read_type, |
| 157 | + schema=schema, |
| 158 | + ) |
| 159 | + |
| 160 | + read_tasks = [] |
| 161 | + |
| 162 | + # Distribute splits across tasks using load balancing algorithm |
| 163 | + for chunk_splits in self._distribute_splits_into_equal_chunks(self.splits, parallelism): |
| 164 | + if not chunk_splits: |
| 165 | + continue |
| 166 | + |
| 167 | + # Calculate metadata for this chunk |
| 168 | + total_rows = 0 |
| 169 | + total_size = 0 |
| 170 | + |
| 171 | + for split in chunk_splits: |
| 172 | + if predicate is None: |
| 173 | + # Only estimate rows if no predicate (predicate filtering changes row count) |
| 174 | + if hasattr(split, 'row_count') and split.row_count > 0: |
| 175 | + total_rows += split.row_count |
| 176 | + if hasattr(split, 'file_size') and split.file_size > 0: |
| 177 | + total_size += split.file_size |
| 178 | + |
| 179 | + input_files = list(itertools.chain.from_iterable( |
| 180 | + split.file_paths |
| 181 | + for split in chunk_splits |
| 182 | + if hasattr(split, 'file_paths') and split.file_paths |
| 183 | + )) |
| 184 | + |
| 185 | + # For PrimaryKey tables, we can't accurately estimate num_rows before merge |
| 186 | + if table and table.is_primary_key_table: |
| 187 | + num_rows = None # Let Ray calculate actual row count after merge |
| 188 | + elif predicate is not None: |
| 189 | + num_rows = None # Can't estimate with predicate filtering |
| 190 | + else: |
| 191 | + num_rows = total_rows if total_rows > 0 else None |
| 192 | + size_bytes = total_size if total_size > 0 else None |
| 193 | + |
| 194 | + metadata = BlockMetadata( |
| 195 | + num_rows=num_rows, |
| 196 | + size_bytes=size_bytes, |
| 197 | + input_files=input_files if input_files else None, |
| 198 | + exec_stats=None, # Will be populated by Ray during execution |
| 199 | + ) |
| 200 | + |
| 201 | + # TODO: per_task_row_limit is not supported in Ray 2.48.0, will be added in future versions |
| 202 | + read_tasks.append( |
| 203 | + ReadTask( |
| 204 | + read_fn=lambda splits=chunk_splits: get_read_task(splits), |
| 205 | + metadata=metadata, |
| 206 | + schema=schema, |
| 207 | + ) |
| 208 | + ) |
| 209 | + |
| 210 | + return read_tasks |
0 commit comments