Skip to content

Commit 2e439f0

Browse files
authored
Merge pull request #9 from RobotControlStack/feat/jpeg-encoding
feat: add transparent jpeg support
2 parents 78621e5 + 8398849 commit 2e439f0

5 files changed

Lines changed: 66 additions & 15 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"wandb",
1717
"pillow",
1818
"tqdm",
19+
"simplejpeg",
1920
]
2021
readme = "README.md"
2122
maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }]

src/agents/client.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import dataclasses
23
from dataclasses import asdict
34
from multiprocessing import shared_memory
@@ -6,8 +7,9 @@
67
import json_numpy
78
import numpy as np
89
import rpyc
10+
import simplejpeg
911

10-
from agents.policies import Act, Agent, Obs, SharedMemoryPayload
12+
from agents.policies import Act, Agent, CameraDataType, Obs, SharedMemoryPayload
1113

1214

1315
def dataclass_from_dict(klass, d):
@@ -20,7 +22,7 @@ def dataclass_from_dict(klass, d):
2022

2123

2224
class RemoteAgent(Agent):
23-
def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False):
25+
def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False, jpeg_encoding: bool = False):
2426
"""Connect to a remote agent service.
2527
2628
Args:
@@ -29,15 +31,18 @@ def __init__(self, host: str, port: int, model: str, on_same_machine: bool = Fal
2931
model (str): Name of the model to connect to.
3032
on_same_machine (bool, optional): If True, assumes the agent is running on the same machine and uses
3133
shared memory for more efficient communication. Defaults to False.
34+
jpeg_encoding (bool, optional): If True the image data is jpeg encoded for smaller transfer size.
35+
Defaults to False.
3236
"""
3337
self.on_same_machine = on_same_machine
38+
self.jpeg_encoding = jpeg_encoding
3439
self._shm: dict[str, shared_memory.SharedMemory] = {}
3540
self.c = rpyc.connect(
3641
host, port, config={"allow_pickle": True, "allow_public_attrs": True, "sync_request_timeout": 300}
3742
)
3843
assert model == self.c.root.name()
3944

40-
def _to_shared_memory(self, obs: Obs) -> Obs:
45+
def _process(self, obs: Obs) -> Obs:
4146
if self.on_same_machine:
4247
camera_dict = {}
4348
for camera_name, camera_data in obs.cameras.items():
@@ -54,17 +59,26 @@ def _to_shared_memory(self, obs: Obs) -> Obs:
5459
dtype=camera_data.dtype.name,
5560
)
5661
obs.cameras = camera_dict
57-
obs.camera_data_in_shared_memory = True
62+
obs.camera_data_type = CameraDataType.SHARED_MEMORY
63+
elif self.jpeg_encoding:
64+
camera_dict = {}
65+
for camera_name, camera_data in obs.cameras.items():
66+
assert isinstance(camera_data, np.ndarray)
67+
camera_dict[camera_name] = base64.urlsafe_b64encode(
68+
simplejpeg.encode_jpeg(np.ascontiguousarray(camera_data))
69+
).decode("utf-8")
70+
obs.cameras = camera_dict
71+
obs.camera_data_type = CameraDataType.JPEG_ENCODED
5872
return obs
5973

6074
def act(self, obs: Obs) -> Act:
61-
obs = self._to_shared_memory(obs)
75+
obs = self._process(obs)
6276
obs = json_numpy.dumps(asdict(obs))
6377
# action, done, info
6478
return dataclass_from_dict(Act, json_numpy.loads(self.c.root.act(obs)))
6579

6680
def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
67-
obs = self._to_shared_memory(obs)
81+
obs = self._process(obs)
6882
obs_dict = asdict(obs)
6983
# info
7084
return json_numpy.loads(self.c.root.reset(json_numpy.dumps((obs_dict, instruction, kwargs))))

src/agents/policies.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Any, Union
1313

1414
import numpy as np
15+
import simplejpeg
1516
from PIL import Image
1617

1718

@@ -22,12 +23,18 @@ class SharedMemoryPayload:
2223
dtype: str = "uint8"
2324

2425

26+
class CameraDataType:
27+
SHARED_MEMORY = "shared_memory"
28+
JPEG_ENCODED = "jpeg_encoded"
29+
RAW = "raw"
30+
31+
2532
@dataclass(kw_only=True)
2633
class Obs:
27-
cameras: dict[str, np.ndarray | SharedMemoryPayload] = field(default_factory=dict)
34+
cameras: dict[str, np.ndarray | SharedMemoryPayload | str] = field(default_factory=dict)
35+
camera_data_type: str = CameraDataType.RAW
2836
gripper: float | None = None
2937
info: dict[str, Any] = field(default_factory=dict)
30-
camera_data_in_shared_memory: bool = False
3138

3239

3340
@dataclass(kw_only=True)
@@ -53,9 +60,9 @@ def initialize(self):
5360
# heavy initialization, e.g. loading models
5461
pass
5562

56-
def _from_shared_memory(self, obs: Obs) -> Obs:
63+
def _to_numpy(self, obs: Obs) -> Obs:
5764
"""transparently uses shared memory if configured and modifies obs in place"""
58-
if obs.camera_data_in_shared_memory:
65+
if obs.camera_data_type == CameraDataType.SHARED_MEMORY:
5966
camera_dict = {}
6067
for camera_name, camera_data in obs.cameras.items():
6168
assert isinstance(camera_data, SharedMemoryPayload)
@@ -65,12 +72,19 @@ def _from_shared_memory(self, obs: Obs) -> Obs:
6572
camera_data.shape, dtype=camera_data.dtype, buffer=self._shm[camera_data.shm_name].buf
6673
)
6774
obs.cameras = camera_dict
75+
elif obs.camera_data_type == CameraDataType.JPEG_ENCODED:
76+
camera_dict = {}
77+
for camera_name, camera_data in obs.cameras.items():
78+
assert isinstance(camera_data, str)
79+
camera_dict[camera_name] = simplejpeg.decode_jpeg(base64.urlsafe_b64decode(camera_data))
80+
obs.cameras = camera_dict
81+
obs.camera_data_type = CameraDataType.RAW
6882
return obs
6983

7084
def act(self, obs: Obs) -> Act:
7185
assert self.instruction is not None, "forgot reset?"
7286
self.step += 1
73-
self._from_shared_memory(obs)
87+
self._to_numpy(obs)
7488

7589
return Act(action=np.zeros(7, dtype=np.float32), done=False, info={})
7690

@@ -79,7 +93,7 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
7993
self.step = 0
8094
self.episode += 1
8195
self.instruction = instruction
82-
self._from_shared_memory(obs)
96+
self._to_numpy(obs)
8397
# info
8498
return {}
8599

src/agents/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import rpyc
1111

1212
from agents.client import dataclass_from_dict
13-
from agents.policies import Agent, Obs, SharedMemoryPayload
13+
from agents.policies import Agent, CameraDataType, Obs, SharedMemoryPayload
1414

1515
logging.basicConfig(
1616
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
@@ -51,7 +51,7 @@ def act(self, obs_bytes: bytes) -> str:
5151
assert self._is_initialized, "AgentService not initialized, wait until is_initialized is True"
5252
# action, done, info
5353
obs = typing.cast(Obs, dataclass_from_dict(Obs, json_numpy.loads(obs_bytes)))
54-
if obs.camera_data_in_shared_memory:
54+
if obs.camera_data_type == CameraDataType.SHARED_MEMORY:
5555
obs.cameras = {
5656
camera_name: dataclass_from_dict(SharedMemoryPayload, camera_data)
5757
for camera_name, camera_data in obs.cameras.items()
@@ -64,7 +64,7 @@ def reset(self, args: bytes) -> str:
6464
# info
6565
obs, instruction, kwargs = json_numpy.loads(args)
6666
obs_dclass = typing.cast(Obs, dataclass_from_dict(Obs, obs))
67-
if obs_dclass.camera_data_in_shared_memory:
67+
if obs_dclass.camera_data_type == CameraDataType.SHARED_MEMORY:
6868
obs_dclass.cameras = {
6969
camera_name: dataclass_from_dict(SharedMemoryPayload, camera_data)
7070
for camera_name, camera_data in obs_dclass.cameras.items()

src/tests/test_connection.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ def _test_connection(agent: RemoteAgent):
3636
assert not a1.done
3737

3838

39+
def _test_connection_jpeg(agent: RemoteAgent):
40+
data = np.zeros((256, 256, 3), dtype=np.uint8)
41+
obs = Obs(cameras=dict(rgb_side=data))
42+
instruction = "do something"
43+
reset_info = agent.reset(obs, instruction)
44+
assert reset_info["instruction"] == instruction
45+
assert reset_info["shapes"] == {"rgb_side": [256, 256, 3]}
46+
assert reset_info["dtype"] == {"rgb_side": "uint8"}
47+
assert (reset_info["data"]["rgb_side"] == data).all()
48+
49+
3950
def test_connection_numpy_serialization():
4051
with start_server("test", {}, 8080, "localhost") as p:
4152
sleep(2)
@@ -56,3 +67,14 @@ def test_connection_numpy_shm():
5667
sleep(0.1)
5768
_test_connection(agent)
5869
p.send_signal(subprocess.signal.SIGINT)
70+
71+
72+
def test_connection_numpy_jpeg():
73+
with start_server("test", {}, 8080, "localhost") as p:
74+
sleep(2)
75+
agent = RemoteAgent("localhost", 8080, "test", jpeg_encoding=True)
76+
with agent:
77+
while not agent.is_initialized():
78+
sleep(0.1)
79+
_test_connection_jpeg(agent)
80+
p.send_signal(subprocess.signal.SIGINT)

0 commit comments

Comments
 (0)