diff --git a/ingestify/application/dataset_store.py b/ingestify/application/dataset_store.py index 66a8fdf..ffc6b46 100644 --- a/ingestify/application/dataset_store.py +++ b/ingestify/application/dataset_store.py @@ -5,6 +5,7 @@ from contextlib import contextmanager import threading from io import BytesIO +from ingestify.utils import BufferedStream from typing import ( Dict, @@ -298,30 +299,36 @@ def iter_dataset_collection_batches( # dataset = self.dataset_repository. # self.dataset_repository.destroy_dataset(dataset_id) - def _prepare_write_stream(self, file_: DraftFile) -> tuple[BytesIO, int, str]: + def _prepare_write_stream( + self, file_: DraftFile + ) -> tuple[BinaryIO, int, str, Optional[str]]: + if file_.content_compression_method == "gzip": + # Already gzip - store as-is, no CPU cost + stream = file_.stream + stream.seek(0, os.SEEK_END) + storage_size = stream.tell() + stream.seek(0) + return stream, storage_size, ".gz", "gzip" + if self.storage_compression_method == "gzip": - stream = BytesIO() + stream = BufferedStream() with gzip.GzipFile(fileobj=stream, compresslevel=9, mode="wb") as fp: shutil.copyfileobj(file_.stream, fp) stream.seek(0, os.SEEK_END) storage_size = stream.tell() stream.seek(0) - suffix = ".gz" - else: - stream = file_.stream - storage_size = file_.size - suffix = "" + return stream, storage_size, ".gz", "gzip" - return stream, storage_size, suffix + return file_.stream, file_.size, "", None def _prepare_read_stream( self, - ) -> tuple[Callable[[BinaryIO], Awaitable[BytesIO]], str]: + ) -> tuple[Callable[[BinaryIO], Awaitable[BinaryIO]], str]: if self.storage_compression_method == "gzip": - def reader(fh: BinaryIO) -> BytesIO: - stream = BytesIO() + def reader(fh: BinaryIO) -> BinaryIO: + stream = BufferedStream() with gzip.GzipFile(fileobj=fh, compresslevel=9, mode="rb") as fp: shutil.copyfileobj(fp, stream) stream.seek(0) @@ -355,7 +362,12 @@ def _persist_files( # File didn't change. Ignore it. continue - stream, storage_size, suffix = self._prepare_write_stream(file_) + ( + stream, + storage_size, + suffix, + compression_method, + ) = self._prepare_write_stream(file_) # TODO: check if this is a very clean way to go from DraftFile to File full_path = self.file_repository.save_content( @@ -369,7 +381,7 @@ def _persist_files( file_, file_id, storage_size=storage_size, - storage_compression_method=self.storage_compression_method, + storage_compression_method=compression_method, path=self.file_repository.get_relative_path(full_path), ) diff --git a/ingestify/domain/models/dataset/file.py b/ingestify/domain/models/dataset/file.py index 6dfa3d4..e2e32f8 100644 --- a/ingestify/domain/models/dataset/file.py +++ b/ingestify/domain/models/dataset/file.py @@ -4,8 +4,10 @@ from io import BytesIO, StringIO import hashlib +from pydantic import field_validator + from ingestify.domain.models.base import BaseModel -from ingestify.utils import utcnow +from ingestify.utils import utcnow, BufferedStream class DraftFile(BaseModel): @@ -17,7 +19,20 @@ class DraftFile(BaseModel): data_feed_key: str # Example: 'events' data_spec_version: str # Example: 'v3' data_serialization_format: str # Example: 'json' - stream: BytesIO + content_compression_method: Optional[str] = None # Example: 'gzip' + stream: BufferedStream + + @field_validator("stream", mode="before") + @classmethod + def coerce_to_buffered_stream(cls, v): + if isinstance(v, BufferedStream): + return v + if isinstance(v, (BytesIO, bytes)): + data = v if isinstance(v, bytes) else v.getvalue() + return BufferedStream.from_stream(BytesIO(data)) + if hasattr(v, "read"): + return BufferedStream.from_stream(v) + raise ValueError(f"Cannot coerce {type(v)} to BufferedStream") @classmethod def from_input( @@ -32,26 +47,20 @@ def from_input( if isinstance(file_, (DraftFile, NotModifiedFile)): return file_ elif isinstance(file_, str): - stream = BytesIO(file_.encode("utf-8")) + data = file_.encode("utf-8") elif isinstance(file_, bytes): - stream = BytesIO(file_) + data = file_ elif isinstance(file_, StringIO): - stream = BytesIO(file_.read().encode("utf-8")) - elif isinstance(file_, BytesIO): - stream = file_ + data = file_.read().encode("utf-8") elif hasattr(file_, "read"): - data = file_.read() - if isinstance(data, bytes): - stream = BytesIO(data) - else: - stream = BytesIO(data.encode("utf-8")) + raw = file_.read() + data = raw if isinstance(raw, bytes) else raw.encode("utf-8") else: raise Exception(f"Not possible to create DraftFile from {type(file_)}") - data = stream.read() size = len(data) tag = hashlib.sha1(data).hexdigest() - stream.seek(0) + stream = BufferedStream.from_stream(BytesIO(data)) now = utcnow() @@ -127,7 +136,12 @@ class LoadedFile(BaseModel): data_serialization_format: Optional[str] # Example: 'json' storage_compression_method: Optional[str] # Example: 'gzip' storage_path: Path - stream_: Union[BinaryIO, BytesIO, Callable[[], Awaitable[Union[BinaryIO, BytesIO]]]] + stream_: Union[ + BinaryIO, + BytesIO, + BufferedStream, + Callable[[], Awaitable[Union[BinaryIO, BytesIO, BufferedStream]]], + ] revision_id: Optional[int] = None # This can be used when a Revision is squashed def load_stream(self): diff --git a/ingestify/infra/fetch/http.py b/ingestify/infra/fetch/http.py index 45bb651..0ef3493 100644 --- a/ingestify/infra/fetch/http.py +++ b/ingestify/infra/fetch/http.py @@ -3,7 +3,7 @@ from email.utils import format_datetime, parsedate from hashlib import sha1 from io import BytesIO -from typing import Optional, Callable, Tuple, Union +from typing import BinaryIO, Optional, Callable, Tuple, Union import requests from requests.adapters import HTTPAdapter @@ -11,7 +11,12 @@ from ingestify.domain.models import DraftFile, File from ingestify.domain.models.dataset.file import NotModifiedFile -from ingestify.utils import utcnow +from ingestify.utils import ( + utcnow, + BufferedStream, + detect_compression, + gzip_uncompressed_size, +) _session = None @@ -75,7 +80,7 @@ def retrieve_http( ignore_not_found = http_kwargs.pop("ignore_not_found", False) - response = get_session().get(url, headers=headers, **http_kwargs) + response = get_session().get(url, headers=headers, stream=True, **http_kwargs) if response.status_code == 404 and ignore_not_found: return NotModifiedFile( modified_at=last_modified, reason="404 http code and ignore-not-found" @@ -96,12 +101,9 @@ def retrieve_http( modified_at = utcnow() tag = response.headers.get("etag") - # content_length = int(response.headers.get("content-length", 0)) if pager: - """ - A pager helps with responses that return the data in pages. - """ + # Pager assembles multiple small JSON responses — load fully into memory data_path, pager_fn = pager data = [] while True: @@ -111,24 +113,44 @@ def retrieve_http( if not next_url: break else: - response = requests.get(next_url, headers=headers, **http_kwargs) - - content = json.dumps({data_path: data}).encode("utf-8") + response = requests.get( + next_url, headers=headers, stream=True, **http_kwargs + ) + + content_bytes = json.dumps({data_path: data}).encode("utf-8") + if not tag: + tag = sha1(content_bytes).hexdigest() + if current_file and current_file.tag == tag: + return NotModifiedFile( + modified_at=last_modified, reason="tag matched current_file" + ) + stream = BufferedStream.from_stream(BytesIO(content_bytes)) + content_length = len(content_bytes) else: - content = response.content + # Stream response body directly into BufferedStream, hashing on the fly + raw_stream = BufferedStream() + hasher = sha1() + for chunk in response.iter_content(chunk_size=1024 * 1024): + hasher.update(chunk) + raw_stream.write(chunk) - if not tag: - tag = sha1(content).hexdigest() + if not tag: + tag = hasher.hexdigest() - # if not content_length: - Don't use http header as it might be wrong - # for example in case of compressed data - content_length = len(content) + if current_file and current_file.tag == tag: + return NotModifiedFile( + modified_at=last_modified, reason="tag matched current_file" + ) - if current_file and current_file.tag == tag: - # Not changed. Don't keep it - return NotModifiedFile( - modified_at=last_modified, reason="tag matched current_file" - ) + raw_stream.seek(0) + content_compression_method = detect_compression(raw_stream) + if content_compression_method == "gzip": + content_length = gzip_uncompressed_size(raw_stream) + else: + raw_stream.seek(0, 2) + content_length = raw_stream.tell() + raw_stream.seek(0) + stream = raw_stream return DraftFile( created_at=utcnow(), @@ -136,6 +158,7 @@ def retrieve_http( tag=tag, size=content_length, content_type=response.headers.get("content-type"), - stream=BytesIO(content), + content_compression_method=content_compression_method, + stream=stream, **file_attributes, ) diff --git a/ingestify/tests/test_http_fetch.py b/ingestify/tests/test_http_fetch.py new file mode 100644 index 0000000..cdafafc --- /dev/null +++ b/ingestify/tests/test_http_fetch.py @@ -0,0 +1,59 @@ +import gzip +from unittest.mock import MagicMock, patch + +import pytest + +from ingestify.infra.fetch.http import retrieve_http +from ingestify.utils import BufferedStream + + +def make_mock_response(content, status_code=200, headers=None): + headers = headers or {} + mock = MagicMock() + mock.status_code = status_code + mock.headers = MagicMock() + mock.headers.get = lambda key, default=None: headers.get(key, default) + mock.headers.__contains__ = lambda self, key: key in headers + mock.raise_for_status = MagicMock() + mock.iter_content = lambda chunk_size=1: [content] + return mock + + +FILE_KWARGS = dict( + file_data_feed_key="test", + file_data_spec_version="v1", + file_data_serialization_format="json", +) + +PLAIN_JSON = b'{"key": "value"}' * 100 + + +def test_plain_content_size_and_stream(): + with patch("ingestify.infra.fetch.http.get_session") as mock_session: + mock_session.return_value.get.return_value = make_mock_response(PLAIN_JSON) + result = retrieve_http("https://example.com/data.json", **FILE_KWARGS) + + assert isinstance(result.stream, BufferedStream) + assert result.size == len(PLAIN_JSON) + assert result.stream.read() == PLAIN_JSON + + +def test_gzip_content_stored_as_is_with_uncompressed_size(): + compressed = gzip.compress(PLAIN_JSON) + + with patch("ingestify.infra.fetch.http.get_session") as mock_session: + mock_session.return_value.get.return_value = make_mock_response(compressed) + result = retrieve_http("https://example.com/data.json.gz", **FILE_KWARGS) + + assert isinstance(result.stream, BufferedStream) + assert result.content_compression_method == "gzip" + assert result.size == len(PLAIN_JSON) # uncompressed size from gzip trailer + assert result.stream.read() == compressed # stored as-is + + +def test_plain_content_has_no_compression_method(): + with patch("ingestify.infra.fetch.http.get_session") as mock_session: + mock_session.return_value.get.return_value = make_mock_response(PLAIN_JSON) + result = retrieve_http("https://example.com/data.json", **FILE_KWARGS) + + assert result.content_compression_method is None diff --git a/ingestify/tests/test_utils.py b/ingestify/tests/test_utils.py new file mode 100644 index 0000000..6007d75 --- /dev/null +++ b/ingestify/tests/test_utils.py @@ -0,0 +1,29 @@ +import gzip +from io import BytesIO + +from ingestify.utils import BufferedStream, detect_compression, gzip_uncompressed_size + +PLAIN = b'{"key": "value"}' * 100 + + +def to_stream(data: bytes) -> BufferedStream: + return BufferedStream.from_stream(BytesIO(data)) + + +def test_detect_compression_gzip(): + assert detect_compression(to_stream(gzip.compress(PLAIN))) == "gzip" + + +def test_detect_compression_plain(): + assert detect_compression(to_stream(PLAIN)) is None + + +def test_detect_compression_resets_position(): + stream = to_stream(gzip.compress(PLAIN)) + detect_compression(stream) + assert stream.tell() == 0 + + +def test_gzip_uncompressed_size(): + compressed = gzip.compress(PLAIN) + assert gzip_uncompressed_size(to_stream(compressed)) == len(PLAIN) diff --git a/ingestify/utils.py b/ingestify/utils.py index 4d40b84..75b04ab 100644 --- a/ingestify/utils.py +++ b/ingestify/utils.py @@ -1,5 +1,7 @@ import logging import os +import shutil +import tempfile import time import re import traceback @@ -8,7 +10,7 @@ from datetime import datetime, timezone from string import Template -from typing import Dict, Tuple, Optional, Any, List +from typing import BinaryIO, Dict, Tuple, Optional, Any, List from pydantic import Field from typing_extensions import Self @@ -20,6 +22,47 @@ logger = logging.getLogger(__name__) +_DEFAULT_BUFFER_SIZE = 5 * 1024 * 1024 # 5MB before spilling to disk + + +class BufferedStream(tempfile.SpooledTemporaryFile): + """Stays in memory up to max_size, then spills to disk. Drop-in for BytesIO for large streams.""" + + def __init__(self, max_size: int = _DEFAULT_BUFFER_SIZE): + super().__init__(max_size=max_size, mode="w+b") + + def write(self, data: bytes) -> int: + return super().write(data) + + def read(self, n: int = -1) -> bytes: + return super().read(n) + + @classmethod + def from_stream( + cls, source: BinaryIO, max_size: int = _DEFAULT_BUFFER_SIZE + ) -> "BufferedStream": + buffer = cls(max_size=max_size) + shutil.copyfileobj(source, buffer) + buffer.seek(0) + return buffer + + +def gzip_uncompressed_size(stream: BinaryIO) -> int: + """Read uncompressed size from the gzip trailer (last 4 bytes, mod 2^32).""" + stream.seek(-4, 2) + size = int.from_bytes(stream.read(4), "little") + stream.seek(0) + return size + + +def detect_compression(stream: BinaryIO) -> Optional[str]: + """Detect compression method by reading magic bytes. Resets stream position afterwards.""" + header = stream.read(2) + stream.seek(0) + if header == b"\x1f\x8b": + return "gzip" + return None + def chunker(it, size): iterator = iter(it)