Skip to content

Commit ed22ca7

Browse files
[python] support paimon as ray datasource for distributed processing (#6686)
1 parent 76b8343 commit ed22ca7

File tree

8 files changed

+1228
-3
lines changed

8 files changed

+1228
-3
lines changed

docs/content/program-api/python-api.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ print(duckdb_con.query("SELECT * FROM duckdb_table WHERE f0 = 1").fetchdf())
355355

356356
This requires `ray` to be installed.
357357

358-
You can convert the splits into a Ray dataset and handle it by Ray API:
358+
You can convert the splits into a Ray Dataset and handle it by Ray Data API for distributed processing:
359359

360360
```python
361361
table_read = read_builder.new_read()
@@ -376,6 +376,21 @@ print(ray_dataset.to_pandas())
376376
# ...
377377
```
378378

379+
The `to_ray()` method supports a `parallelism` parameter to control distributed reading. Use `parallelism=1` for single-task read (default) or `parallelism > 1` for distributed read with multiple Ray workers:
380+
381+
```python
382+
# Simple mode (single task)
383+
ray_dataset = table_read.to_ray(splits, parallelism=1)
384+
385+
# Distributed mode with 4 parallel tasks
386+
ray_dataset = table_read.to_ray(splits, parallelism=4)
387+
388+
# Use Ray Data operations
389+
mapped_dataset = ray_dataset.map(lambda row: {'value': row['value'] * 2})
390+
filtered_dataset = ray_dataset.filter(lambda row: row['score'] > 80)
391+
df = ray_dataset.to_pandas()
392+
```
393+
379394
### Incremental Read
380395

381396
This API allows reading data committed between two snapshot timestamps. The steps are as follows.

paimon-python/dev/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ polars==1.8.0; python_version=="3.8"
3535
polars==1.32.0; python_version>"3.8"
3636
pyarrow==6.0.1; python_version < "3.8"
3737
pyarrow==16; python_version >= "3.8"
38+
ray==2.48.0
3839
readerwriterlock==1.0.9
3940
zstandard==0.19.0; python_version<"3.9"
4041
zstandard==0.24.0; python_version>="3.9"

paimon-python/pypaimon/catalog/rest/rest_token_file_io.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ def __init__(self, identifier: Identifier, path: str,
4040
self.log = logging.getLogger(__name__)
4141
super().__init__(path, catalog_options)
4242

43+
def __getstate__(self):
44+
state = self.__dict__.copy()
45+
# Remove non-serializable objects
46+
state.pop('lock', None)
47+
state.pop('api_instance', None)
48+
# token can be serialized, but we'll refresh it on deserialization
49+
return state
50+
51+
def __setstate__(self, state):
52+
self.__dict__.update(state)
53+
# Recreate lock after deserialization
54+
self.lock = threading.Lock()
55+
# api_instance will be recreated when needed
56+
self.api_instance = None
57+
4358
def _initialize_oss_fs(self, path) -> FileSystem:
4459
self.try_to_refresh_token()
4560
self.properties.update(self.token.token)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
"""
19+
Module to read a Paimon table into a Ray Dataset, by using the Ray Datasource API.
20+
"""
21+
import heapq
22+
import itertools
23+
import logging
24+
from functools import partial
25+
from typing import List, Optional, Iterable
26+
27+
import pyarrow
28+
29+
from pypaimon.read.split import Split
30+
from pypaimon.read.table_read import TableRead
31+
from pypaimon.schema.data_types import PyarrowFieldParser
32+
33+
logger = logging.getLogger(__name__)
34+
35+
from ray.data.datasource import Datasource
36+
37+
38+
class PaimonDatasource(Datasource):
39+
"""
40+
Ray Data Datasource implementation for reading Paimon tables.
41+
42+
This datasource enables distributed parallel reading of Paimon table splits,
43+
allowing Ray to read multiple splits concurrently across the cluster.
44+
"""
45+
46+
def __init__(self, table_read: TableRead, splits: List[Split]):
47+
"""
48+
Initialize PaimonDatasource.
49+
50+
Args:
51+
table_read: TableRead instance for reading data
52+
splits: List of splits to read
53+
"""
54+
self.table_read = table_read
55+
self.splits = splits
56+
self._schema = None
57+
58+
def get_name(self) -> str:
59+
identifier = self.table_read.table.identifier
60+
table_name = identifier.get_full_name() if hasattr(identifier, 'get_full_name') else str(identifier)
61+
return f"PaimonTable({table_name})"
62+
63+
def estimate_inmemory_data_size(self) -> Optional[int]:
64+
if not self.splits:
65+
return 0
66+
67+
# Sum up file sizes from all splits
68+
total_size = sum(split.file_size for split in self.splits)
69+
return total_size if total_size > 0 else None
70+
71+
@staticmethod
72+
def _distribute_splits_into_equal_chunks(
73+
splits: Iterable[Split], n_chunks: int
74+
) -> List[List[Split]]:
75+
"""
76+
Implement a greedy knapsack algorithm to distribute the splits across tasks,
77+
based on their file size, as evenly as possible.
78+
"""
79+
chunks = [list() for _ in range(n_chunks)]
80+
chunk_sizes = [(0, chunk_id) for chunk_id in range(n_chunks)]
81+
heapq.heapify(chunk_sizes)
82+
83+
# From largest to smallest, add the splits to the smallest chunk one at a time
84+
for split in sorted(
85+
splits, key=lambda s: s.file_size if hasattr(s, 'file_size') and s.file_size > 0 else 0, reverse=True
86+
):
87+
smallest_chunk = heapq.heappop(chunk_sizes)
88+
chunks[smallest_chunk[1]].append(split)
89+
split_size = split.file_size if hasattr(split, 'file_size') and split.file_size > 0 else 0
90+
heapq.heappush(
91+
chunk_sizes,
92+
(smallest_chunk[0] + split_size, smallest_chunk[1]),
93+
)
94+
95+
return chunks
96+
97+
def get_read_tasks(self, parallelism: int) -> List:
98+
"""Return a list of read tasks that can be executed in parallel."""
99+
from ray.data.datasource import ReadTask
100+
from ray.data.block import BlockMetadata
101+
102+
# Validate parallelism parameter
103+
if parallelism < 1:
104+
raise ValueError(f"parallelism must be at least 1, got {parallelism}")
105+
106+
# Get schema for metadata
107+
if self._schema is None:
108+
self._schema = PyarrowFieldParser.from_paimon_schema(self.table_read.read_type)
109+
110+
# Adjust parallelism if it exceeds the number of splits
111+
if parallelism > len(self.splits):
112+
parallelism = len(self.splits)
113+
logger.warning(
114+
f"Reducing the parallelism to {parallelism}, as that is the number of splits"
115+
)
116+
117+
# Store necessary information for creating readers in Ray workers
118+
# Extract these to avoid serializing the entire self object in closures
119+
table = self.table_read.table
120+
predicate = self.table_read.predicate
121+
read_type = self.table_read.read_type
122+
schema = self._schema
123+
124+
# Create a partial function to avoid capturing self in closure
125+
# This reduces serialization overhead (see https://github.com/ray-project/ray/issues/49107)
126+
def _get_read_task(
127+
splits: List[Split],
128+
table=table,
129+
predicate=predicate,
130+
read_type=read_type,
131+
schema=schema,
132+
) -> Iterable[pyarrow.Table]:
133+
"""Read function that will be executed by Ray workers."""
134+
from pypaimon.read.table_read import TableRead
135+
worker_table_read = TableRead(table, predicate, read_type)
136+
137+
# Read all splits in this chunk
138+
arrow_table = worker_table_read.to_arrow(splits)
139+
140+
# Return as a list to allow Ray to split into multiple blocks if needed
141+
if arrow_table is not None and arrow_table.num_rows > 0:
142+
return [arrow_table]
143+
else:
144+
# Return empty table with correct schema
145+
empty_table = pyarrow.Table.from_arrays(
146+
[pyarrow.array([], type=field.type) for field in schema],
147+
schema=schema
148+
)
149+
return [empty_table]
150+
151+
# Use partial to create read function without capturing self
152+
get_read_task = partial(
153+
_get_read_task,
154+
table=table,
155+
predicate=predicate,
156+
read_type=read_type,
157+
schema=schema,
158+
)
159+
160+
read_tasks = []
161+
162+
# Distribute splits across tasks using load balancing algorithm
163+
for chunk_splits in self._distribute_splits_into_equal_chunks(self.splits, parallelism):
164+
if not chunk_splits:
165+
continue
166+
167+
# Calculate metadata for this chunk
168+
total_rows = 0
169+
total_size = 0
170+
171+
for split in chunk_splits:
172+
if predicate is None:
173+
# Only estimate rows if no predicate (predicate filtering changes row count)
174+
if hasattr(split, 'row_count') and split.row_count > 0:
175+
total_rows += split.row_count
176+
if hasattr(split, 'file_size') and split.file_size > 0:
177+
total_size += split.file_size
178+
179+
input_files = list(itertools.chain.from_iterable(
180+
split.file_paths
181+
for split in chunk_splits
182+
if hasattr(split, 'file_paths') and split.file_paths
183+
))
184+
185+
# For PrimaryKey tables, we can't accurately estimate num_rows before merge
186+
if table and table.is_primary_key_table:
187+
num_rows = None # Let Ray calculate actual row count after merge
188+
elif predicate is not None:
189+
num_rows = None # Can't estimate with predicate filtering
190+
else:
191+
num_rows = total_rows if total_rows > 0 else None
192+
size_bytes = total_size if total_size > 0 else None
193+
194+
metadata = BlockMetadata(
195+
num_rows=num_rows,
196+
size_bytes=size_bytes,
197+
input_files=input_files if input_files else None,
198+
exec_stats=None, # Will be populated by Ray during execution
199+
)
200+
201+
# TODO: per_task_row_limit is not supported in Ray 2.48.0, will be added in future versions
202+
read_tasks.append(
203+
ReadTask(
204+
read_fn=lambda splits=chunk_splits: get_read_task(splits),
205+
metadata=metadata,
206+
schema=schema,
207+
)
208+
)
209+
210+
return read_tasks

paimon-python/pypaimon/read/table_read.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,30 @@ def to_duckdb(self, splits: List[Split], table_name: str,
128128
con.register(table_name, self.to_arrow(splits))
129129
return con
130130

131-
def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
131+
def to_ray(self, splits: List[Split], parallelism: int = 1) -> "ray.data.dataset.Dataset":
132+
"""Convert Paimon table data to Ray Dataset."""
132133
import ray
133134

134-
return ray.data.from_arrow(self.to_arrow(splits))
135+
if not splits:
136+
schema = PyarrowFieldParser.from_paimon_schema(self.read_type)
137+
empty_table = pyarrow.Table.from_arrays(
138+
[pyarrow.array([], type=field.type) for field in schema],
139+
schema=schema
140+
)
141+
return ray.data.from_arrow(empty_table)
142+
143+
# Validate parallelism parameter
144+
if parallelism < 1:
145+
raise ValueError(f"parallelism must be at least 1, got {parallelism}")
146+
147+
if parallelism == 1:
148+
# Single-task read (simple mode)
149+
return ray.data.from_arrow(self.to_arrow(splits))
150+
else:
151+
# Distributed read with specified parallelism
152+
from pypaimon.read.ray_datasource import PaimonDatasource
153+
datasource = PaimonDatasource(self, splits)
154+
return ray.data.read_datasource(datasource, parallelism=parallelism)
135155

136156
def _create_split_read(self, split: Split) -> SplitRead:
137157
if self.table.is_primary_key_table and not split.raw_convertible:

0 commit comments

Comments
 (0)