Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nodes/server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def start(self, port=None, host=None):
self.host = host

# Get the path to the ComfyStream server directory and script
server_dir = Path(__file__).parent.parent / "server"
server_dir = Path(__file__).parent.parent / "src" / "comfystream" / "server"
server_script = server_dir / "app.py"
logging.info(f"Server script: {server_script}")

Expand Down
17 changes: 17 additions & 0 deletions src/comfystream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .client import ComfyStreamClient
from .pipeline import Pipeline
from .server.utils import set_temporary_log_level
from .server.app import VideoStreamTrack, AudioStreamTrack
from .server.utils import FPSMeter
from .server.metrics import MetricsManager, StreamStatsManager

__all__ = [
'ComfyStreamClient',
'Pipeline',
'temporary_log_level',
'VideoStreamTrack',
'AudioStreamTrack',
'FPSMeter',
'MetricsManager',
'StreamStatsManager'
]
109 changes: 94 additions & 15 deletions server/pipeline.py → src/comfystream/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,38 @@
import numpy as np
import asyncio
import logging
from typing import Any, Dict, Union, List, Optional

from typing import Any, Dict, Union, List
from comfystream.client import ComfyStreamClient
from utils import temporary_log_level
from comfystream.server.utils import set_temporary_log_level

WARMUP_RUNS = 5

logger = logging.getLogger(__name__)


class Pipeline:
def __init__(self, width=512, height=512, comfyui_inference_log_level: int = None, **kwargs):
"""A pipeline for processing video and audio frames using ComfyUI.

This class provides a high-level interface for processing video and audio frames
through a ComfyUI-based processing pipeline. It handles frame preprocessing,
postprocessing, and queue management.
"""

def __init__(self, width: int = 512, height: int = 512,
comfyui_inference_log_level: Optional[int] = None, **kwargs):
"""Initialize the pipeline with the given configuration.

Args:
width: Width of the video frames (default: 512)
height: Height of the video frames (default: 512)
comfyui_inference_log_level: The logging level for ComfyUI inference.
Defaults to None, using the global ComfyUI log level.
**kwargs: Additional arguments to pass to the ComfyStreamClient
"""
self.client = ComfyStreamClient(**kwargs)
self.width = kwargs.get("width", 512)
self.height = kwargs.get("height", 512)
self.width = width
self.height = height

self.video_incoming_frames = asyncio.Queue()
self.audio_incoming_frames = asyncio.Queue()
Expand All @@ -33,7 +44,8 @@ def __init__(self, width=512, height=512, comfyui_inference_log_level: int = Non
self._comfyui_inference_log_level = comfyui_inference_log_level

async def warm_video(self):
# Create dummy frame with the CURRENT resolution settings (which might have been updated via control channel)
"""Warm up the video processing pipeline with dummy frames."""
# Create dummy frame with the CURRENT resolution settings
dummy_frame = av.VideoFrame()
dummy_frame.side_data.input = torch.randn(1, self.height, self.width, 3)

Expand All @@ -44,6 +56,7 @@ async def warm_video(self):
await self.client.get_video_output()

async def warm_audio(self):
"""Warm up the audio processing pipeline with dummy frames."""
dummy_frame = av.AudioFrame()
dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed?
dummy_frame.sample_rate = 48000
Expand All @@ -53,63 +66,124 @@ async def warm_audio(self):
await self.client.get_audio_output()

async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
"""Set the processing prompts for the pipeline.

Args:
prompts: Either a single prompt dictionary or a list of prompt dictionaries
"""
if isinstance(prompts, list):
await self.client.set_prompts(prompts)
else:
await self.client.set_prompts([prompts])

async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
"""Update the existing processing prompts.

Args:
prompts: Either a single prompt dictionary or a list of prompt dictionaries
"""
if isinstance(prompts, list):
await self.client.update_prompts(prompts)
else:
await self.client.update_prompts([prompts])

async def put_video_frame(self, frame: av.VideoFrame):
"""Queue a video frame for processing.

Args:
frame: The video frame to process
"""
frame.side_data.input = self.video_preprocess(frame)
frame.side_data.skipped = True
self.client.put_video_input(frame)
await self.video_incoming_frames.put(frame)

async def put_audio_frame(self, frame: av.AudioFrame):
"""Queue an audio frame for processing.

Args:
frame: The audio frame to process
"""
frame.side_data.input = self.audio_preprocess(frame)
frame.side_data.skipped = True
self.client.put_audio_input(frame)
await self.audio_incoming_frames.put(frame)

def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]:
"""Preprocess a video frame before processing.

Args:
frame: The video frame to preprocess

Returns:
The preprocessed frame as a tensor or numpy array
"""
frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0
return torch.from_numpy(frame_np).unsqueeze(0)

def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]:
"""Preprocess an audio frame before processing.

Args:
frame: The audio frame to preprocess

Returns:
The preprocessed frame as a tensor or numpy array
"""
return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16)

def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame:
"""Postprocess a video frame after processing.

Args:
output: The processed output tensor or numpy array

Returns:
The postprocessed video frame
"""
return av.VideoFrame.from_ndarray(
(output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy()
)

def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame:
"""Postprocess an audio frame after processing.

Args:
output: The processed output tensor or numpy array

Returns:
The postprocessed audio frame
"""
return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1))

async def get_processed_video_frame(self):
# TODO: make it generic to support purely generative video cases
async with temporary_log_level("comfy", self._comfyui_inference_log_level):
# TODO: make it generic to support purely generative video cases
async def get_processed_video_frame(self) -> av.VideoFrame:
"""Get the next processed video frame.

Returns:
The processed video frame
"""
async with set_temporary_log_level("comfy", self._comfyui_inference_log_level):
out_tensor = await self.client.get_video_output()
frame = await self.video_incoming_frames.get()
while frame.side_data.skipped:
frame = await self.video_incoming_frames.get()

processed_frame = self.video_postprocess(out_tensor)
processed_frame = self.video_postprocess(out_tensor)
processed_frame.pts = frame.pts
processed_frame.time_base = frame.time_base

return processed_frame

async def get_processed_audio_frame(self):
# TODO: make it generic to support purely generative audio cases and also add frame skipping
async def get_processed_audio_frame(self) -> av.AudioFrame:
"""Get the next processed audio frame.

Returns:
The processed audio frame
"""
frame = await self.audio_incoming_frames.get()
if frame.samples > len(self.processed_audio_buffer):
async with temporary_log_level("comfy", self._comfyui_inference_log_level):
async with set_temporary_log_level("comfy", self._comfyui_inference_log_level):
out_tensor = await self.client.get_audio_output()
self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor])
out_data = self.processed_audio_buffer[:frame.samples]
Expand All @@ -123,9 +197,14 @@ async def get_processed_audio_frame(self):
return processed_frame

async def get_nodes_info(self) -> Dict[str, Any]:
"""Get information about all nodes in the current prompt including metadata."""
"""Get information about all nodes in the current prompt including metadata.

Returns:
Dictionary containing node information
"""
nodes_info = await self.client.get_available_nodes()
return nodes_info

async def cleanup(self):
await self.client.cleanup()
"""Clean up resources used by the pipeline."""
await self.client.cleanup()
File renamed without changes.
13 changes: 3 additions & 10 deletions server/app.py → src/comfystream/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@
import os
import sys

import torch

# Initialize CUDA before any other imports to prevent core dump.
if torch.cuda.is_available():
torch.cuda.init()


from aiohttp import web
from aiortc import (
MediaStreamTrack,
Expand All @@ -22,10 +15,10 @@
)
from aiortc.codecs import h264
from aiortc.rtcrtpsender import RTCRtpSender
from pipeline import Pipeline
from comfystream.pipeline import Pipeline
from twilio.rest import Client
from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter
from metrics import MetricsManager, StreamStatsManager
from comfystream.server.utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter
from comfystream.server.metrics import MetricsManager, StreamStatsManager
import time

logger = logging.getLogger(__name__)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .utils import patch_loop_datagram, add_prefix_to_app_routes, temporary_log_level
from .utils import patch_loop_datagram, add_prefix_to_app_routes, set_temporary_log_level
from .fps_meter import FPSMeter
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import time
from collections import deque
from metrics import MetricsManager
from comfystream.server.metrics import MetricsManager

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def add_prefix_to_app_routes(app: web.Application, prefix: str):


@asynccontextmanager
async def temporary_log_level(logger_name: str, level: int):
async def set_temporary_log_level(logger_name: str, level: int):
"""Temporarily set the log level of a logger.

Args:
Expand Down