Skip to content

Commit 51961ae

Browse files
committed
Fix #165, support for multi-byte characters in StringFileWrapper
1 parent 72de7f5 commit 51961ae

File tree

2 files changed

+154
-27
lines changed

2 files changed

+154
-27
lines changed

src/json_repair/utils/string_file_wrapper.py

Lines changed: 95 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@ def __init__(self, fd: TextIO, chunk_length: int) -> None:
1919
buffer_length (int): The length of each buffer chunk.
2020
"""
2121
self.fd = fd
22-
self.length: int = 0
23-
# Buffers are 1MB strings that are read from the file
24-
# and kept in memory to keep reads low
22+
# Buffers are chunks of text read from the file and cached to reduce disk access.
2523
self.buffers: dict[int, str] = {}
26-
# chunk_length is in bytes
2724
if not chunk_length or chunk_length < 2:
2825
chunk_length = 1_000_000
26+
# chunk_length now refers to the number of characters per chunk.
2927
self.buffer_length = chunk_length
28+
# Keep track of the starting file position ("cookie") for each chunk so we can
29+
# seek safely without landing in the middle of a multibyte code point.
30+
self._chunk_positions: list[int] = [0]
31+
self.length: int | None = None
3032

3133
def get_buffer(self, index: int) -> str:
3234
"""
@@ -38,15 +40,33 @@ def get_buffer(self, index: int) -> str:
3840
Returns:
3941
str: The buffer chunk at the specified index.
4042
"""
41-
if self.buffers.get(index) is None:
42-
self.fd.seek(index * self.buffer_length)
43-
self.buffers[index] = self.fd.read(self.buffer_length)
44-
# Save memory by keeping max 2MB buffer chunks and min 2 chunks
45-
if len(self.buffers) > max(2, 2_000_000 / self.buffer_length):
46-
oldest_key = next(iter(self.buffers))
47-
if oldest_key != index:
48-
self.buffers.pop(oldest_key)
49-
return self.buffers[index]
43+
if index < 0:
44+
raise IndexError("Negative indexing is not supported")
45+
46+
cached = self.buffers.get(index)
47+
if cached is not None:
48+
return cached
49+
50+
self._ensure_chunk_position(index)
51+
start_pos = self._chunk_positions[index]
52+
self.fd.seek(start_pos)
53+
chunk = self.fd.read(self.buffer_length)
54+
if not chunk:
55+
raise IndexError("Chunk index out of range")
56+
end_pos = self.fd.tell()
57+
if len(self._chunk_positions) <= index + 1:
58+
self._chunk_positions.append(end_pos)
59+
if len(chunk) < self.buffer_length:
60+
self.length = index * self.buffer_length + len(chunk)
61+
62+
self.buffers[index] = chunk
63+
# Save memory by keeping max 2MB buffer chunks and min 2 chunks
64+
max_buffers = max(2, int(2_000_000 / self.buffer_length))
65+
if len(self.buffers) > max_buffers:
66+
oldest_key = next(iter(self.buffers))
67+
if oldest_key != index:
68+
self.buffers.pop(oldest_key)
69+
return chunk
5070

5171
def __getitem__(self, index: int | slice) -> str:
5272
"""
@@ -62,18 +82,49 @@ def __getitem__(self, index: int | slice) -> str:
6282
# self.buffers[index]: the row in the array of length 1MB, index is `i` modulo CHUNK_LENGTH
6383
# self.buffures[index][j]: the column of the row that is `i` remainder CHUNK_LENGTH
6484
if isinstance(index, slice):
65-
buffer_index = index.start // self.buffer_length
66-
buffer_end = index.stop // self.buffer_length
85+
total_len = len(self)
86+
start = 0 if index.start is None else index.start
87+
stop = total_len if index.stop is None else index.stop
88+
step = 1 if index.step is None else index.step
89+
90+
if start < 0:
91+
start += total_len
92+
if stop < 0:
93+
stop += total_len
94+
95+
start = max(start, 0)
96+
stop = min(stop, total_len)
97+
98+
if step == 0:
99+
raise ValueError("slice step cannot be zero")
100+
if step != 1:
101+
return "".join(self[i] for i in range(start, stop, step))
102+
103+
if start >= stop:
104+
return ""
105+
106+
buffer_index = start // self.buffer_length
107+
buffer_end = (stop - 1) // self.buffer_length
108+
start_mod = start % self.buffer_length
109+
stop_mod = stop % self.buffer_length
110+
if stop_mod == 0 and stop > start:
111+
stop_mod = self.buffer_length
67112
if buffer_index == buffer_end:
68-
return self.get_buffer(buffer_index)[index.start % self.buffer_length : index.stop % self.buffer_length]
69-
else:
70-
start_slice = self.get_buffer(buffer_index)[index.start % self.buffer_length :]
71-
end_slice = self.get_buffer(buffer_end)[: index.stop % self.buffer_length]
72-
middle_slices = [self.get_buffer(i) for i in range(buffer_index + 1, buffer_end)]
73-
return start_slice + "".join(middle_slices) + end_slice
113+
buffer = self.get_buffer(buffer_index)
114+
return buffer[start_mod:stop_mod]
115+
116+
start_slice = self.get_buffer(buffer_index)[start_mod:]
117+
end_slice = self.get_buffer(buffer_end)[:stop_mod]
118+
middle_slices = [self.get_buffer(i) for i in range(buffer_index + 1, buffer_end)]
119+
return start_slice + "".join(middle_slices) + end_slice
74120
else:
121+
if index < 0:
122+
index += len(self)
123+
if index < 0:
124+
raise IndexError("string index out of range")
75125
buffer_index = index // self.buffer_length
76-
return self.get_buffer(buffer_index)[index % self.buffer_length]
126+
buffer = self.get_buffer(buffer_index)
127+
return buffer[index % self.buffer_length]
77128

78129
def __len__(self) -> int:
79130
"""
@@ -82,11 +133,10 @@ def __len__(self) -> int:
82133
Returns:
83134
int: The total number of characters in the file.
84135
"""
85-
if self.length < 1:
86-
current_position = self.fd.tell()
87-
self.fd.seek(0, os.SEEK_END)
88-
self.length = self.fd.tell()
89-
self.fd.seek(current_position)
136+
if self.length is None:
137+
while self.length is None:
138+
chunk_index = len(self._chunk_positions)
139+
self._ensure_chunk_position(chunk_index)
90140
return self.length
91141

92142
def __setitem__(self, index: int | slice, value: str) -> None: # pragma: no cover
@@ -106,3 +156,21 @@ def __setitem__(self, index: int | slice, value: str) -> None: # pragma: no cov
106156
self.fd.seek(start)
107157
self.fd.write(value)
108158
self.fd.seek(current_position)
159+
160+
def _ensure_chunk_position(self, chunk_index: int) -> None:
161+
"""
162+
Ensure that we know the starting file position for the given chunk index.
163+
"""
164+
while len(self._chunk_positions) <= chunk_index:
165+
prev_index = len(self._chunk_positions) - 1
166+
start_pos = self._chunk_positions[-1]
167+
self.fd.seek(start_pos, os.SEEK_SET)
168+
chunk = self.fd.read(self.buffer_length)
169+
end_pos = self.fd.tell()
170+
if len(chunk) < self.buffer_length:
171+
self.length = prev_index * self.buffer_length + len(chunk)
172+
self._chunk_positions.append(end_pos)
173+
if not chunk:
174+
break
175+
if len(self._chunk_positions) <= chunk_index:
176+
raise IndexError("Chunk index out of range")
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
3+
from src.json_repair.utils.string_file_wrapper import StringFileWrapper
4+
5+
6+
def test_string_file_wrapper_handles_multibyte(tmp_path):
7+
text = "\u0800"
8+
file_path = tmp_path / "multibyte.json"
9+
file_path.write_text(text, encoding="utf-8")
10+
with file_path.open("r", encoding="utf-8") as handle:
11+
wrapper = StringFileWrapper(handle, chunk_length=2)
12+
assert wrapper[0:1] == text
13+
assert wrapper[0:2] == text
14+
assert wrapper[0] == text
15+
16+
17+
def test_string_file_wrapper_invalid_buffer_access(tmp_path):
18+
file_path = tmp_path / "buffer.json"
19+
file_path.write_text("ab", encoding="utf-8")
20+
with file_path.open("r", encoding="utf-8") as handle:
21+
wrapper = StringFileWrapper(handle, chunk_length=1)
22+
with pytest.raises(IndexError):
23+
wrapper.get_buffer(-1)
24+
# Build chunk metadata and then request the chunk that resides past EOF.
25+
len(wrapper)
26+
with pytest.raises(IndexError):
27+
wrapper.get_buffer(2)
28+
29+
30+
def test_string_file_wrapper_slice_variations(tmp_path):
31+
file_path = tmp_path / "slice.json"
32+
file_path.write_text("abcd", encoding="utf-8")
33+
with file_path.open("r", encoding="utf-8") as handle:
34+
wrapper = StringFileWrapper(handle, chunk_length=2)
35+
assert wrapper[-2:4] == "cd"
36+
assert wrapper[0:-1] == "abc"
37+
assert wrapper[3:1] == ""
38+
assert wrapper[0:4:2] == "ac"
39+
with pytest.raises(ValueError, match="slice step cannot be zero"):
40+
_ = wrapper[::0]
41+
42+
43+
def test_string_file_wrapper_negative_indices(tmp_path):
44+
file_path = tmp_path / "index.json"
45+
file_path.write_text("xyz", encoding="utf-8")
46+
with file_path.open("r", encoding="utf-8") as handle:
47+
wrapper = StringFileWrapper(handle, chunk_length=2)
48+
assert wrapper[-1] == "z"
49+
with pytest.raises(IndexError):
50+
_ = wrapper[-10]
51+
52+
53+
def test_string_file_wrapper_ensure_chunk_position_raises(tmp_path):
54+
file_path = tmp_path / "ensure.json"
55+
file_path.write_text("foo", encoding="utf-8")
56+
with file_path.open("r", encoding="utf-8") as handle:
57+
wrapper = StringFileWrapper(handle, chunk_length=1)
58+
with pytest.raises(IndexError):
59+
wrapper._ensure_chunk_position(10)

0 commit comments

Comments
 (0)