1+ import base64
12import dataclasses
23from dataclasses import asdict
34from multiprocessing import shared_memory
67import json_numpy
78import numpy as np
89import rpyc
10+ import simplejpeg
911
10- from agents .policies import Act , Agent , Obs , SharedMemoryPayload
12+ from agents .policies import Act , Agent , CameraDataType , Obs , SharedMemoryPayload
1113
1214
1315def dataclass_from_dict (klass , d ):
@@ -20,7 +22,7 @@ def dataclass_from_dict(klass, d):
2022
2123
2224class RemoteAgent (Agent ):
23- def __init__ (self , host : str , port : int , model : str , on_same_machine : bool = False ):
25+ def __init__ (self , host : str , port : int , model : str , on_same_machine : bool = False , jpeg_encoding : bool = False ):
2426 """Connect to a remote agent service.
2527
2628 Args:
@@ -29,15 +31,18 @@ def __init__(self, host: str, port: int, model: str, on_same_machine: bool = Fal
2931 model (str): Name of the model to connect to.
3032 on_same_machine (bool, optional): If True, assumes the agent is running on the same machine and uses
3133 shared memory for more efficient communication. Defaults to False.
34+ jpeg_encoding (bool, optional): If True the image data is jpeg encoded for smaller transfer size.
35+ Defaults to False.
3236 """
3337 self .on_same_machine = on_same_machine
38+ self .jpeg_encoding = jpeg_encoding
3439 self ._shm : dict [str , shared_memory .SharedMemory ] = {}
3540 self .c = rpyc .connect (
3641 host , port , config = {"allow_pickle" : True , "allow_public_attrs" : True , "sync_request_timeout" : 300 }
3742 )
3843 assert model == self .c .root .name ()
3944
40- def _to_shared_memory (self , obs : Obs ) -> Obs :
45+ def _process (self , obs : Obs ) -> Obs :
4146 if self .on_same_machine :
4247 camera_dict = {}
4348 for camera_name , camera_data in obs .cameras .items ():
@@ -54,17 +59,26 @@ def _to_shared_memory(self, obs: Obs) -> Obs:
5459 dtype = camera_data .dtype .name ,
5560 )
5661 obs .cameras = camera_dict
57- obs .camera_data_in_shared_memory = True
62+ obs .camera_data_type = CameraDataType .SHARED_MEMORY
63+ elif self .jpeg_encoding :
64+ camera_dict = {}
65+ for camera_name , camera_data in obs .cameras .items ():
66+ assert isinstance (camera_data , np .ndarray )
67+ camera_dict [camera_name ] = base64 .urlsafe_b64encode (
68+ simplejpeg .encode_jpeg (np .ascontiguousarray (camera_data ))
69+ ).decode ("utf-8" )
70+ obs .cameras = camera_dict
71+ obs .camera_data_type = CameraDataType .JPEG_ENCODED
5872 return obs
5973
6074 def act (self , obs : Obs ) -> Act :
61- obs = self ._to_shared_memory (obs )
75+ obs = self ._process (obs )
6276 obs = json_numpy .dumps (asdict (obs ))
6377 # action, done, info
6478 return dataclass_from_dict (Act , json_numpy .loads (self .c .root .act (obs )))
6579
6680 def reset (self , obs : Obs , instruction : Any , ** kwargs ) -> dict [str , Any ]:
67- obs = self ._to_shared_memory (obs )
81+ obs = self ._process (obs )
6882 obs_dict = asdict (obs )
6983 # info
7084 return json_numpy .loads (self .c .root .reset (json_numpy .dumps ((obs_dict , instruction , kwargs ))))
0 commit comments