From efd597289105d2d18776e73f4d59cc3b60e0a3f0 Mon Sep 17 00:00:00 2001 From: main Date: Mon, 17 Mar 2025 02:32:22 -0700 Subject: [PATCH 1/2] Rudimentary SOCKS5 support --- pyproject.toml | 2 +- rnsh/args.py | 28 ++- rnsh/rnsh.py | 41 +++- {tests => rnsh/socksext}/__init__.py | 0 rnsh/socksext/counterpart.py | 114 +++++++++ rnsh/socksext/protocol.py | 23 ++ rnsh/socksext/socksproxy.py | 339 +++++++++++++++++++++++++++ rnsh/testlogging.py | 36 --- tests/helpers.py | 164 ------------- tests/reticulum_test_config | 15 -- tests/test_args.py | 106 --------- tests/test_exception.py | 9 - tests/test_process.py | 241 ------------------- tests/test_protocol.py | 109 --------- tests/test_retry.py | 162 ------------- tests/test_rnsh.py | 223 ------------------ 16 files changed, 536 insertions(+), 1076 deletions(-) rename {tests => rnsh/socksext}/__init__.py (100%) create mode 100644 rnsh/socksext/counterpart.py create mode 100644 rnsh/socksext/protocol.py create mode 100644 rnsh/socksext/socksproxy.py delete mode 100644 rnsh/testlogging.py delete mode 100644 tests/helpers.py delete mode 100644 tests/reticulum_test_config delete mode 100644 tests/test_args.py delete mode 100644 tests/test_exception.py delete mode 100644 tests/test_process.py delete mode 100644 tests/test_protocol.py delete mode 100644 tests/test_retry.py delete mode 100644 tests/test_rnsh.py diff --git a/pyproject.toml b/pyproject.toml index de8c8e6..a9fa004 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ license = "MIT" readme = "README.md" [tool.poetry.dependencies] -python = "^3.7" +python = ">3.12" rns = ">=0.9.0" [tool.poetry.scripts] diff --git a/rnsh/args.py b/rnsh/args.py index 8544fd6..52dd27b 100644 --- a/rnsh/args.py +++ b/rnsh/args.py @@ -18,17 +18,23 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]): usage = \ ''' Usage: - rnsh -l [-c ] [-i | -s ] [-v... | -q...] -p - rnsh -l [-c ] [-i | -s ] [-v... | -q...] - [-b ] [-n] [-a ] ([-a ] ...) [-A | -C] - [[--] [ ...]] + rnsh [--socks5] -l [-c ] [-i | -s ] [-v... | -q...] -p + rnsh [--socks5] -l [-c ] [-i | -s ] [-v... | -q...] + [-b ] [-n] [-a ] ([-a ] ...) [-A | -C] + [[--] [ ...]] rnsh [-c ] [-i ] [-v... | -q...] -p + rnsh [-c ] [-i ] [-v... | -q...] [-N] [-m] [-w ] + [--socks5 [--socks5-host ] [--socks5-port ]] + rnsh [-c ] [-i ] [-v... | -q...] [-N] [-m] [-w ] [[--] [ ...]] rnsh -h rnsh --version Options: + --socks5 Enable socks5 proxy mode. + --socks5-host HOST SOCKS5 proxy host + --socks5-port PORT SOCKS5 proxy port -c DIR --config DIR Alternate Reticulum config directory to use -i FILE --identity FILE Specific identity file to use -s NAME --service NAME Service name for identity file if not default @@ -79,6 +85,20 @@ def __init__(self, argv: [str]): args = docopt.docopt(usage, argv=self.docopts_argv[1:], version=f"rnsh {rnsh.__version__}") # json.dump(args, sys.stdout) + self.socks5 = "--socks5" in args + self.socks5_host = args.get("--socks5-host") or "127.0.0.1" + try: + if "--socks5-port" in args: + port_string = args.get("--socks5-port") + if port_string is None: + self.socks5_port = 1080 + else: + self.socks5_port = int(port_string) + else: + self.socks5_port = 1080 + except ValueError: + print("Invalid value for --socks5-port") + sys.exit(1) self.listen = args.get("--listen", None) or False self.service_name = args.get("--service", None) if self.listen and (self.service_name is None or len(self.service_name) > 0): diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index c1e8486..e973d79 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -48,12 +48,11 @@ import re import contextlib import rnsh.args -import pwd -import rnsh.protocol as protocol -import rnsh.helpers as helpers import rnsh.loop import rnsh.listener as listener import rnsh.initiator as initiator +from rnsh.socksext.socksproxy import SOCKS5Proxy +from rnsh.socksext.counterpart import SOCKS5CounterPart module_logger = __logging.getLogger(__name__) @@ -104,12 +103,11 @@ def print_identity(configdir, identitypath, service_name, include_destination: b verbose_set = False -async def _rnsh_cli_main(): +async def _rnsh_cli_main(args): global verbose_set log = _get_logger("main") _loop = asyncio.get_running_loop() rnslogging.set_main_loop(_loop) - args = rnsh.args.Args(sys.argv) verbose_set = args.verbose > 0 if args.print_identity: @@ -156,12 +154,43 @@ async def _rnsh_cli_main(): return 1 + +async def _rnsocks_cli_main(args): + global verbose_set + log = _get_logger("main") + _loop = asyncio.get_running_loop() + rnslogging.set_main_loop(_loop) + args = rnsh.args.Args(sys.argv) + verbose_set = args.verbose > 0 + + if args.listen: + cpart = SOCKS5CounterPart() + cpart.run() + else: + if args.destination is None: + print("No destination specified for socks5 client mode, exiting") + return 1 + host = args.socks5_host or "127.0.0.1" + port = 1080 + if args.socks5_port is not None: + try: + port = int(args.socks5_port) + except Exception: + print("Invalid socks5 port specified, exiting") + return 1 + proxy = SOCKS5Proxy(host=host, port=port, destination_hash=args.destination) + proxy.start() + def rnsh_cli(): global verbose_set return_code = 1 exc = None try: - return_code = asyncio.run(_rnsh_cli_main()) + args = rnsh.args.Args(sys.argv) + if args.socks5: + return_code = asyncio.run(_rnsocks_cli_main(args)) + else: + return_code = asyncio.run(_rnsh_cli_main(args)) except SystemExit: pass except KeyboardInterrupt: diff --git a/tests/__init__.py b/rnsh/socksext/__init__.py similarity index 100% rename from tests/__init__.py rename to rnsh/socksext/__init__.py diff --git a/rnsh/socksext/counterpart.py b/rnsh/socksext/counterpart.py new file mode 100644 index 0000000..6069064 --- /dev/null +++ b/rnsh/socksext/counterpart.py @@ -0,0 +1,114 @@ +import os +import sys +import threading +import socket +import RNS + +from rnsh.socksext.socksproxy import SOCKS_APP_NAME +from rnsh.socksext.protocol import RequestMessage + +COUNTERPART_IDENTITY_FILE = "socks5_identity" + + +class SOCKS5CounterPart: + def __init__(self): + self.reticulum = RNS.Reticulum(configdir=None, loglevel=RNS.LOG_INFO) + self.identity = self.load_or_create_identity() + self.destination = RNS.Destination( + self.identity, RNS.Destination.IN, RNS.Destination.SINGLE, SOCKS_APP_NAME + ) + self.channels = {} + self.connections = {} + self.lock = threading.Lock() + self.next_link_id = 0 + + def load_or_create_identity(self): + if os.path.exists(COUNTERPART_IDENTITY_FILE): + identity = RNS.Identity.from_file(COUNTERPART_IDENTITY_FILE) + print("Loaded existing identity") + else: + identity = RNS.Identity() + identity.to_file(COUNTERPART_IDENTITY_FILE) + print("Created and saved new identity") + return identity + + def handle_message(self, message: RequestMessage, link_id: int): + try: + # Parse bytes directly + data = message.data + parts = data.split(b":", 2) + if len(parts) < 2: + print(f"Invalid message format on link {link_id}") + return + command = parts[0].decode('utf-8') # Command is ASCII + handler_id = int(parts[1].decode('utf-8')) # Handler ID is ASCII + payload = parts[2] if len(parts) > 2 else b"" + + if command == "CONNECT": + addr, port = payload.decode('utf-8').split(":", 1) # CONNECT payload is text + port = int(port) + print(f"Received CONNECT {handler_id} for {addr}:{port}") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(10) + sock.connect((addr, port)) + with self.lock: + self.connections[handler_id] = sock + threading.Thread(target=self.relay_from_destination, args=(handler_id, sock, link_id), daemon=True).start() + + elif command == "DATA": + print(f"Received {len(payload)} bytes for {handler_id}") + with self.lock: + if handler_id in self.connections: + sock = self.connections[handler_id] + sock.sendall(payload) + else: + print(f"No connection for {handler_id}") + except Exception as e: + print(f"Error handling message: {e}") + + def relay_from_destination(self, handler_id: int, sock: socket.socket, link_id: int): + try: + while True: + data = sock.recv(4096) + if not data: + break + with self.lock: + if link_id in self.channels: + channel = self.channels[link_id] + response = RequestMessage() + response.data = f"DATA:{handler_id}:".encode() + data + channel.send(response) + print(f"Sent {len(data)} bytes back for {handler_id} on link {link_id}") + except Exception as e: + print(f"Error relaying from destination for {handler_id}: {e}") + finally: + with self.lock: + if handler_id in self.connections: + del self.connections[handler_id] + sock.close() + + def link_established(self, link): + link_id = self.next_link_id + self.next_link_id += 1 + print(f"Link {link_id} established from {link.get_remote_identity()}") + channel = link.get_channel() + channel.register_message_type(RequestMessage) + channel.add_message_handler(lambda msg: self.handle_message(msg, link_id)) + with self.lock: + self.channels[link_id] = channel + + def run(self): + print(f"Destination hash: {self.destination.hash.hex()}") + self.destination.set_link_established_callback(self.link_established) + self.destination.accepts_links(True) + self.destination.announce() + print("Counterpart running. Press Ctrl+C to exit.") + sys.stdout.flush() + + try: + threading.Event().wait() + except KeyboardInterrupt: + print("Shutting down...") + self.reticulum.exit_handler() + sys.stdout.flush() + sys.exit(0) \ No newline at end of file diff --git a/rnsh/socksext/protocol.py b/rnsh/socksext/protocol.py new file mode 100644 index 0000000..5dc6aad --- /dev/null +++ b/rnsh/socksext/protocol.py @@ -0,0 +1,23 @@ +import socket +import threading +from typing import Optional + +import RNS + + +class SOCKS5Request: + def __init__(self, client_socket: socket.socket, addr: str, port: int, handler_id: int): + self.client_socket = client_socket + self.addr = addr + self.port = port + self.handler_id = handler_id + self.response: Optional[bytes] = None + self.event = threading.Event() + + +class RequestMessage(RNS.MessageBase): + MSGTYPE = 0x0091 + def pack(self): + return self.data + def unpack(self, raw): + self.data = raw diff --git a/rnsh/socksext/socksproxy.py b/rnsh/socksext/socksproxy.py new file mode 100644 index 0000000..b9771f6 --- /dev/null +++ b/rnsh/socksext/socksproxy.py @@ -0,0 +1,339 @@ +import os +import queue +import socket +import struct +import threading +import time +from collections import deque +from typing import Dict, Tuple, Optional + +import RNS + +from rnsh.socksext.protocol import SOCKS5Request, RequestMessage + +SOCKS_APP_NAME = "socks5proxy" + +class SOCKS5Proxy: + request_handlers: Dict[int, SOCKS5Request] = {} + + def __init__(self, host='127.0.0.1', port=1080, destination_hash: str = None): + self.host = host + self.port = port + self.server_socket = None + self.request_queue = queue.Queue() + self.running = False + self.handler_id = 0 + self.lock = threading.Lock() + self.destination_hash = bytes.fromhex(destination_hash) if destination_hash else None + self.link_pool = LinkPool(self.destination_hash) if destination_hash else None + + def start(self): + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server_socket.bind((self.host, self.port)) + self.server_socket.listen(10) + self.running = True + + if self.link_pool: + self.link_pool.start() + + threading.Thread(target=self.process_responses, daemon=True).start() + print(f"SOCKS5 proxy started on {self.host}:{self.port}") + + try: + while self.running: + client_socket, addr = self.server_socket.accept() + client_thread = threading.Thread( + target=self.handle_client, + args=(client_socket, addr) + ) + client_thread.daemon = True + client_thread.start() + except KeyboardInterrupt: + print("\nShutting down proxy...") + finally: + self.running = False + if self.link_pool: + self.link_pool.stop() + self.server_socket.close() + + def handle_client(self, client_socket: socket.socket, client_addr: tuple): + try: + version, nmethods = struct.unpack("!BB", client_socket.recv(2)) + if version != 5: + return + client_socket.recv(nmethods) + client_socket.sendall(struct.pack("!BB", 5, 0)) + + version, cmd, rsv, atype = struct.unpack("!BBBB", client_socket.recv(4)) + if version != 5 or cmd != 1: + return + + if atype == 1: + addr = socket.inet_ntoa(client_socket.recv(4)) + elif atype == 3: + addr_len = struct.unpack("!B", client_socket.recv(1))[0] + addr = client_socket.recv(addr_len).decode() + else: + return + + port = struct.unpack("!H", client_socket.recv(2))[0] + print(f"Request from {client_addr} to {addr}:{port}") + + with self.lock: + handler_id = self.handler_id + self.handler_id += 1 + request = SOCKS5Request(client_socket, addr, port, handler_id) + self.request_handlers[handler_id] = request + self.request_queue.put((handler_id, request)) + + if request.event.wait(timeout=30): + client_socket.sendall(struct.pack("!BBBB4sH", 5, 0, 0, 1, + socket.inet_aton("0.0.0.0"), 0)) + self.relay_data(request) + else: + print(f"Timeout waiting for connection to {addr}:{port}") + client_socket.sendall(b"\x05\x07\x00\x01\x00\x00\x00\x00\x00\x00") + except Exception as e: + print(f"Error handling client {client_addr}: {e}") + finally: + with self.lock: + for hid, req in list(self.request_handlers.items()): + if req.client_socket == client_socket: + del self.request_handlers[hid] + client_socket.close() + + def relay_data(self, request: SOCKS5Request): + try: + client_socket = request.client_socket + handler_id = request.handler_id + while self.running: + data = client_socket.recv(4096) + if not data: + break + self.link_pool.send_data(handler_id, data) + except Exception as e: + print(f"Error relaying data for handler {request.handler_id}: {e}") + + def process_responses(self): + while self.running: + try: + handler_id, request = self.request_queue.get(timeout=1.0) + + if self.link_pool: + if not self.link_pool.connect(handler_id, request.addr, request.port): + request.response = b"Failed to connect" + request.event.set() + else: + request.event.set() + else: + time.sleep(0.1) + response = f"Hello from {request.addr}:{request.port}".encode('utf-8') + request.response = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: " + str(len(response)).encode('utf-8') + b"\r\n" + b"\r\n" + response + ) + request.event.set() + + self.request_queue.task_done() + + except queue.Empty: + continue + except Exception as e: + print(f"Error processing response: {e}") + +class LinkPool: + def __init__(self, destination_hash: bytes, pool_size: int = 1, configdir: str = None): + self.destination_hash = destination_hash + self.pool_size = pool_size + self.configdir = configdir + self.links: Dict[int, RNS.Link] = {} + self.channels: Dict[int, RNS.Channel.Channel] = {} + self.active_link_ids = deque(maxlen=pool_size) + self.lock = threading.Lock() + self.reticulum = RNS.Reticulum(configdir=self.configdir, loglevel=RNS.LOG_INFO) + self.running = False + self.next_link_id = 0 + self.responses = {} + print("Initializing Reticulum network...") + self.identity = self.load_or_create_identity() + self.target_identity = RNS.Identity.recall(destination_hash) + if not self.target_identity: + print(f"Waiting for identity of {destination_hash.hex()}...") + timeout = time.time() + 10 + while not self.target_identity and time.time() < timeout: + self.target_identity = RNS.Identity.recall(destination_hash) + time.sleep(1) + if not self.target_identity: + raise RuntimeError(f"Could not recall identity for {destination_hash.hex()}") + self.target_destination = RNS.Destination( + self.target_identity, + RNS.Destination.OUT, + RNS.Destination.SINGLE, + SOCKS_APP_NAME, + ) + + def load_or_create_identity(self): + identity_file = "proxy_identity" + if os.path.exists(identity_file): + identity = RNS.Identity.from_file(identity_file) + print("Loaded proxy identity") + else: + identity = RNS.Identity() + identity.to_file(identity_file) + print("Created and saved proxy identity") + return identity + + def start(self): + self.running = True + threading.Thread(target=self.maintain_pool, daemon=True).start() + + def maintain_pool(self): + while self.running: + with self.lock: + print(f"Pool state: active={len(self.active_link_ids)}/{self.pool_size}, " + f"links={len(self.links)}, channels={len(self.channels)}") + + for link_id in list(self.links.keys()): + if self.links[link_id].status == RNS.Link.CLOSED: + self._cleanup_link(link_id) + + if len(self.active_link_ids) < self.pool_size: + if not RNS.Transport.has_path(self.destination_hash): + RNS.Transport.request_path(self.destination_hash) + print(f"Requested path to {self.destination_hash.hex()}") + time.sleep(1) + continue + + link = RNS.Link(self.target_destination) + link_id = self.next_link_id + self.next_link_id += 1 + self.links[link_id] = link + link.set_link_established_callback(lambda l: self.link_established(link_id)) + print(f"Initiated link {link_id}, status: {link.status}") + + timeout = time.time() + 10 + while link.status != RNS.Link.ACTIVE and time.time() < timeout: + time.sleep(0.1) + if link.status != RNS.Link.ACTIVE: + print(f"Link {link_id} failed to activate, status: {link.status}") + del self.links[link_id] + continue + print(f"Link {link_id} confirmed ACTIVE") + time.sleep(5) + + def _cleanup_link(self, link_id: int): + if link_id in self.links: + self.links[link_id].teardown() + del self.links[link_id] + if link_id in self.channels: + del self.channels[link_id] + if link_id in self.active_link_ids: + self.active_link_ids.remove(link_id) + if link_id in self.responses: + del self.responses[link_id] + print(f"Cleaned up link {link_id}") + + def link_established(self, link_id): + with self.lock: + if link_id in self.links: + link = self.links[link_id] + link.identify(self.identity) + channel = link.get_channel() + channel.register_message_type(RequestMessage) + channel.add_message_handler(lambda msg: self.handle_channel_message(msg, link_id)) + self.channels[link_id] = channel + if link_id not in self.active_link_ids: + self.active_link_ids.append(link_id) + print(f"Link {link_id} established with channel, active pool size: {len(self.active_link_ids)}") + + def get_available_link(self) -> Tuple[Optional[RNS.Link], Optional[int]]: + with self.lock: + if not self.active_link_ids: + print("No active links available") + return None, None + + for _ in range(len(self.active_link_ids)): + link_id = self.active_link_ids.popleft() + if (link_id in self.links and + link_id in self.channels and + self.links[link_id].status == RNS.Link.ACTIVE): + self.active_link_ids.append(link_id) + print(f"Using active link {link_id}") + return self.links[link_id], link_id + else: + print(f"Link {link_id} invalid, cleaning up") + self._cleanup_link(link_id) + print("No active links with channels available") + return None, None + + def connect(self, handler_id: int, addr: str, port: int): + link, link_id = self.get_available_link() + if link and link_id is not None: + channel = self.channels[link_id] + if channel and channel.is_ready_to_send(): + request_data = f"CONNECT:{handler_id}:{addr}:{port}".encode() + message = RequestMessage() + message.data = request_data + channel.send(message) + print(f"Sent CONNECT request {handler_id} over channel on link {link_id}") + with self.lock: + self.responses[handler_id] = queue.Queue() + return True + print(f"Failed to connect request {handler_id}: No link available") + return False + + def send_data(self, handler_id: int, data: bytes): + link, link_id = self.get_available_link() + if link and link_id is not None: + channel = self.channels[link_id] + if channel and channel.is_ready_to_send(): + message = RequestMessage() + message.data = f"DATA:{handler_id}:".encode() + data + channel.send(message) + print(f"Sent {len(data)} bytes for {handler_id} over channel on link {link_id}") + else: + print(f"Channel not ready for link {link_id}") + else: + print(f"No link available to send data for {handler_id}") + + def handle_channel_message(self, message, link_id): + try: + # Parse bytes directly, no full decode + data = message.data + parts = data.split(b":", 2) + if len(parts) < 2: + print(f"Invalid message format on link {link_id}") + return + command = parts[0].decode('utf-8') # Command is ASCII + handler_id = int(parts[1].decode('utf-8')) # Handler ID is ASCII + payload = parts[2] if len(parts) > 2 else b"" + + with self.lock: + if handler_id not in SOCKS5Proxy.request_handlers: + print(f"Handler {handler_id} not found") + return + request = SOCKS5Proxy.request_handlers[handler_id] + + if command == "DATA": + with self.lock: + if handler_id in self.responses: + self.responses[handler_id].put(payload) + print(f"Link {link_id} queued {len(payload)} bytes for {handler_id}") + request.client_socket.sendall(payload) + else: + print(f"Unknown command {command} for {handler_id} on link {link_id}") + except Exception as e: + print(f"Error handling channel message on link {link_id}: {e}") + + def stop(self): + self.running = False + with self.lock: + for link in self.links.values(): + link.teardown() + self.links.clear() + self.channels.clear() + self.active_link_ids.clear() + self.responses.clear() \ No newline at end of file diff --git a/rnsh/testlogging.py b/rnsh/testlogging.py deleted file mode 100644 index 38d59c2..0000000 --- a/rnsh/testlogging.py +++ /dev/null @@ -1,36 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Aaron Heise -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import logging as __logging -import os - -log_format = '%(levelname)-6s %(name)-40s %(message)s [%(threadName)s]' \ - if os.environ.get('UNDER_SYSTEMD') == "1" \ - else '\r%(asctime)s.%(msecs)03d %(levelname)-6s %(name)-40s %(message)s [%(threadName)s]' - -__logging.basicConfig( - level=__logging.INFO, - # format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s', - format=log_format, - datefmt='%Y-%m-%d %H:%M:%S', - handlers=[__logging.StreamHandler()]) - diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index 3d32045..0000000 --- a/tests/helpers.py +++ /dev/null @@ -1,164 +0,0 @@ -import logging -import time -import types -import typing -import tempfile - -import pytest - -import rnsh.rnsh -import asyncio -import rnsh.process -import contextlib -import threading -import os -import pathlib -import tests -import shutil -import random - -module_logger = logging.getLogger(__name__) - -module_abs_filename = os.path.abspath(tests.__file__) -module_dir = os.path.dirname(module_abs_filename) - - -class SubprocessReader(contextlib.AbstractContextManager): - def __init__(self, argv: [str], env: dict = None, name: str = None, stdin_is_pipe: bool = False, - stdout_is_pipe: bool = False, stderr_is_pipe: bool = False): - self._log = module_logger.getChild(self.__class__.__name__ + ("" if name is None else f"({name})")) - self.name = name or "subproc" - self.process: rnsh.process.CallbackSubprocess - self.loop = asyncio.get_running_loop() - self.env = env or os.environ.copy() - self.argv = argv - self._lock = threading.RLock() - self._stdout = bytearray() - self._stderr = bytearray() - self.return_code: int = None - self.process = rnsh.process.CallbackSubprocess(argv=self.argv, - env=self.env, - loop=self.loop, - stdout_callback=self._stdout_cb, - terminated_callback=self._terminated_cb, - stderr_callback=self._stderr_cb, - stdin_is_pipe=stdin_is_pipe, - stdout_is_pipe=stdout_is_pipe, - stderr_is_pipe=stderr_is_pipe) - - def _stdout_cb(self, data): - self._log.debug(f"_stdout_cb({data})") - with self._lock: - self._stdout.extend(data) - - def read(self): - self._log.debug(f"read()") - with self._lock: - data = self._stdout.copy() - self._stdout.clear() - self._log.debug(f"read() returns {data}") - return data - - def _stderr_cb(self, data): - self._log.debug(f"_stderr_cb({data})") - with self._lock: - self._stderr.extend(data) - - def read_err(self): - self._log.debug(f"read_err()") - with self._lock: - data = self._stderr.copy() - self._stderr.clear() - self._log.debug(f"read_err() returns {data}") - return data - - def _terminated_cb(self, rc): - self._log.debug(f"_terminated_cb({rc})") - self.return_code = rc - - def start(self): - self._log.debug(f"start()") - self.process.start() - - def cleanup(self): - self._log.debug(f"cleanup()") - if self.process and self.process.running: - self.process.terminate(kill_delay=0.1) - time.sleep(0.5) - - def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, - __traceback: types.TracebackType) -> bool: - self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") - self.cleanup() - return False - - -def replace_text_in_file(filename: str, text: str, replacement: str): - # Read in the file - with open(filename, 'r') as file: - filedata = file.read() - - # Replace the target string - filedata = filedata.replace(text, replacement) - - # Write the file out again - with open(filename, 'w') as file: - file.write(filedata) - - -class tempdir(object): - """Sets the cwd within the context - - Args: - path (Path): The path to the cwd - """ - def __init__(self, cd: bool = False): - self.cd = cd - self.tempdir = tempfile.TemporaryDirectory() - self.path = self.tempdir.name - self.origin = pathlib.Path().absolute() - self.configfile = os.path.join(self.path, "config") - - def setup_files(self): - shutil.copy(os.path.join(module_dir, "reticulum_test_config"), self.configfile) - port1 = random.randint(30000, 65000) - port2 = port1 + 1 - replace_text_in_file(self.configfile, "22222", str(port1)) - replace_text_in_file(self.configfile, "22223", str(port2)) - - - def __enter__(self): - self.setup_files() - if self.cd: - os.chdir(self.path) - - return self.path - - def __exit__(self, exc, value, tb): - if self.cd: - os.chdir(self.origin) - self.tempdir.__exit__(exc, value, tb) - - -def test_config_and_cleanup(): - td = None - with tests.helpers.tempdir() as td: - assert os.path.isfile(os.path.join(td, "config")) - with open(os.path.join(td, "config"), 'r') as file: - filedata = file.read() - assert filedata.index("acehoss test config") > 0 - with pytest.raises(ValueError): - filedata.index("22222") - assert not os.path.exists(os.path.join(td, "config")) - - -def wait_for_condition(condition: callable, timeout: float): - tm = time.time() + timeout - while tm > time.time() and not condition(): - time.sleep(0.01) - - -async def wait_for_condition_async(condition: callable, timeout: float): - tm = time.time() + timeout - while tm > time.time() and not condition(): - await asyncio.sleep(0.01) \ No newline at end of file diff --git a/tests/reticulum_test_config b/tests/reticulum_test_config deleted file mode 100644 index c02ab68..0000000 --- a/tests/reticulum_test_config +++ /dev/null @@ -1,15 +0,0 @@ -# acehoss test config -[reticulum] - enable_transport = False - share_instance = Yes - shared_instance_port = 22222 - instance_control_port = 22223 - panic_on_interface_error = No - -[logging] - loglevel = 7 - -[interfaces] - [[Default Interface]] - type = AutoInterface - enabled = Yes diff --git a/tests/test_args.py b/tests/test_args.py deleted file mode 100644 index fec13a5..0000000 --- a/tests/test_args.py +++ /dev/null @@ -1,106 +0,0 @@ -import rnsh.args -import shlex -from rnsh import docopt - -def test_program_args(): - docopt_threw = False - try: - args = rnsh.args.Args(shlex.split("rnsh -l -n one two three")) - assert args.listen - assert args.program == "one" - assert args.program_args == ["two", "three"] - assert args.command_line == ["one", "two", "three"] - except docopt.DocoptExit: - docopt_threw = True - assert not docopt_threw - - -def test_program_args_dash(): - docopt_threw = False - try: - args = rnsh.args.Args(shlex.split("rnsh -l -n -- one -l -C")) - assert args.listen - assert args.program == "one" - assert args.program_args == ["-l", "-C"] - assert args.command_line == ["one", "-l", "-C"] - except docopt.DocoptExit: - docopt_threw = True - assert not docopt_threw - -def test_program_initiate_no_args(): - docopt_threw = False - try: - args = rnsh.args.Args(shlex.split("rnsh one")) - assert not args.listen - assert args.destination == "one" - assert not args.no_id - assert args.command_line == [] - except docopt.DocoptExit: - docopt_threw = True - assert not docopt_threw - - -def test_program_initiate_no_auth(): - docopt_threw = False - try: - args = rnsh.args.Args(shlex.split("rnsh -N one")) - assert not args.listen - assert args.destination == "one" - assert args.no_id - assert args.command_line == [] - except docopt.DocoptExit: - docopt_threw = True - assert not docopt_threw - - -def test_program_initiate_dash_args(): - docopt_threw = False - try: - args = rnsh.args.Args(shlex.split("rnsh --config ~/Projects/rnsh/testconfig -vvvvvvv a5f72aefc2cb3cdba648f73f77c4e887 -- -l")) - assert not args.listen - assert args.config == "~/Projects/rnsh/testconfig" - assert args.verbose == 7 - assert args.destination == "a5f72aefc2cb3cdba648f73f77c4e887" - assert args.command_line == ["-l"] - except docopt.DocoptExit: - docopt_threw = True - assert not docopt_threw - - -def test_program_listen_dash_args(): - docopt_threw = False - try: - args = rnsh.args.Args(shlex.split("rnsh -l --config ~/Projects/rnsh/testconfig -n -C -- /bin/pwd")) - assert args.listen - assert args.config == "~/Projects/rnsh/testconfig" - assert args.destination is None - assert args.no_auth - assert args.no_remote_cmd - assert args.command_line == ["/bin/pwd"] - except docopt.DocoptExit: - docopt_threw = True - assert not docopt_threw - - -def test_program_listen_config_print(): - docopt_threw = False - try: - args = rnsh.args.Args(shlex.split("rnsh -l --config testconfig -p")) - assert args.listen - assert args.config == "testconfig" - assert args.print_identity - assert args.command_line == [] - except docopt.DocoptExit: - docopt_threw = True - assert not docopt_threw - - -def test_split_at(): - a, b = rnsh.args._split_array_at(["one", "two", "three"], "two") - assert a == ["one"] - assert b == ["three"] - -def test_split_at_not_found(): - a, b = rnsh.args._split_array_at(["one", "two", "three"], "four") - assert a == ["one", "two", "three"] - assert b == [] \ No newline at end of file diff --git a/tests/test_exception.py b/tests/test_exception.py deleted file mode 100644 index 32a6b09..0000000 --- a/tests/test_exception.py +++ /dev/null @@ -1,9 +0,0 @@ -import pytest -import rnsh.exception as exception - -def test_permit(): - with pytest.raises(SystemExit): - with exception.permit(SystemExit): - raise Exception("Should not bubble") - with exception.permit(SystemExit): - raise SystemExit() \ No newline at end of file diff --git a/tests/test_process.py b/tests/test_process.py deleted file mode 100644 index 7f8ebe8..0000000 --- a/tests/test_process.py +++ /dev/null @@ -1,241 +0,0 @@ -import tests.helpers -import time -import pytest -import rnsh.process -import asyncio -import logging -import multiprocessing.pool -logging.getLogger().setLevel(logging.DEBUG) - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_echo(): - """ - Echoing some text through cat. - """ - with tests.helpers.SubprocessReader(argv=["/bin/cat"]) as state: - state.start() - assert state.process is not None - assert state.process.running - message = "test\n" - state.process.write(message.encode("utf-8")) - await asyncio.sleep(0.1) - data = state.read() - state.process.write(rnsh.process.CTRL_D) - await asyncio.sleep(0.1) - assert len(data) > 0 - decoded = data.decode("utf-8") - assert decoded == message.replace("\n", "\r\n") * 2 - assert not state.process.running - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_echo_live(): - """ - Check for immediate echo - """ - with tests.helpers.SubprocessReader(argv=["/bin/cat"]) as state: - state.start() - assert state.process is not None - assert state.process.running - message = "t" - state.process.write(message.encode("utf-8")) - await asyncio.sleep(0.1) - data = state.read() - state.process.write(rnsh.process.CTRL_C) - await asyncio.sleep(0.1) - assert len(data) > 0 - decoded = data.decode("utf-8") - assert decoded == message - assert not state.process.running - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_echo_live_pipe_in(): - """ - Check for immediate echo - """ - with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdin_is_pipe=True) as state: - state.start() - assert state.process is not None - assert state.process.running - message = "t" - state.process.write(message.encode("utf-8")) - await asyncio.sleep(0.1) - data = state.read() - state.process.close_stdin() - await asyncio.sleep(0.1) - assert len(data) > 0 - decoded = data.decode("utf-8") - assert decoded == message - assert not state.process.running - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_echo_live_pipe_out(): - """ - Check for immediate echo - """ - with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdout_is_pipe=True) as state: - state.start() - assert state.process is not None - assert state.process.running - message = "t" - state.process.write(message.encode("utf-8")) - state.process.write(rnsh.process.CTRL_D) - await asyncio.sleep(0.1) - data = state.read() - assert len(data) > 0 - decoded = data.decode("utf-8") - assert decoded == message - data = state.read_err() - assert len(data) > 0 - state.process.close_stdin() - await asyncio.sleep(0.1) - assert not state.process.running - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_echo_live_pipe_err(): - """ - Check for immediate echo - """ - with tests.helpers.SubprocessReader(argv=["/bin/cat"], stderr_is_pipe=True) as state: - state.start() - assert state.process is not None - assert state.process.running - message = "t" - state.process.write(message.encode("utf-8")) - await asyncio.sleep(0.1) - data = state.read() - state.process.write(rnsh.process.CTRL_C) - await asyncio.sleep(0.1) - assert len(data) > 0 - decoded = data.decode("utf-8") - assert decoded == message - assert not state.process.running - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_echo_live_pipe_out_err(): - """ - Check for immediate echo - """ - with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdout_is_pipe=True, stderr_is_pipe=True) as state: - state.start() - assert state.process is not None - assert state.process.running - message = "t" - state.process.write(message.encode("utf-8")) - state.process.write(rnsh.process.CTRL_D) - await asyncio.sleep(0.1) - data = state.read() - assert len(data) > 0 - decoded = data.decode("utf-8") - assert decoded == message - data = state.read_err() - assert len(data) == 0 - state.process.close_stdin() - await asyncio.sleep(0.1) - assert not state.process.running - - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_echo_live_pipe_all(): - """ - Check for immediate echo - """ - with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdout_is_pipe=True, stderr_is_pipe=True, - stdin_is_pipe=True) as state: - state.start() - assert state.process is not None - assert state.process.running - message = "t" - state.process.write(message.encode("utf-8")) - await asyncio.sleep(0.1) - data = state.read() - state.process.close_stdin() - await asyncio.sleep(0.1) - assert len(data) > 0 - decoded = data.decode("utf-8") - assert decoded == message - assert not state.process.running - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_double_echo_live(): - """ - Check for immediate echo - """ - with tests.helpers.SubprocessReader(name="state", argv=["/bin/cat"]) as state: - with tests.helpers.SubprocessReader(name="state2", argv=["/bin/cat"]) as state2: - state.start() - state2.start() - assert state.process is not None - assert state.process.running - assert state2.process is not None - assert state2.process.running - message = "t" - state.process.write(message.encode("utf-8")) - state2.process.write(message.encode("utf-8")) - await asyncio.sleep(0.1) - data = state.read() - data2 = state2.read() - state.process.write(rnsh.process.CTRL_C) - state2.process.write(rnsh.process.CTRL_C) - await asyncio.sleep(0.1) - assert len(data) > 0 - assert len(data2) > 0 - decoded = data.decode("utf-8") - decoded2 = data.decode("utf-8") - assert decoded == message - assert decoded2 == message - assert not state.process.running - assert not state2.process.running - - -@pytest.mark.asyncio -async def test_event_wait_any(): - delay = 0.5 - with multiprocessing.pool.ThreadPool() as pool: - loop = asyncio.get_running_loop() - evt1 = asyncio.Event() - evt2 = asyncio.Event() - - def assert_between(min, max, val): - assert min <= val <= max - - # test 1: both timeout - ts = time.time() - finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay*2) - assert_between(delay*2, delay*2.1, time.time() - ts) - assert finished is None - assert not evt1.is_set() - assert not evt2.is_set() - - #test 2: evt1 set, evt2 not set - hits = 0 - - def test2_bg(): - nonlocal hits - hits += 1 - time.sleep(delay) - evt1.set() - - ts = time.time() - pool.apply_async(test2_bg) - finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay * 2) - assert_between(delay * 0.5, delay * 1.5, time.time() - ts) - assert hits == 1 - assert evt1.is_set() - assert not evt2.is_set() - assert finished == evt1 diff --git a/tests/test_protocol.py b/tests/test_protocol.py deleted file mode 100644 index 66a4e23..0000000 --- a/tests/test_protocol.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import annotations - -import logging - -from RNS.Channel import TPacket, MessageState, ChannelOutletBase, Channel -from typing import Callable - -logging.getLogger().setLevel(logging.DEBUG) - -import rnsh.protocol -import contextlib -import typing -import types -import time -import uuid -from RNS.Channel import MessageBase - - -module_logger = logging.getLogger(__name__) - - -def test_send_receive_streamdata(): - message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN, - data=b'Test', eof=True) - rx_message = message.__class__() - rx_message.unpack(message.pack()) - - assert isinstance(rx_message, message.__class__) - assert rx_message.stream_id == message.stream_id - assert rx_message.data == message.data - assert rx_message.eof == message.eof - - -def test_send_receive_noop(): - message = rnsh.protocol.NoopMessage() - - rx_message = message.__class__() - rx_message.unpack(message.pack()) - - assert isinstance(rx_message, message.__class__) - - -def test_send_receive_execute(): - message = rnsh.protocol.ExecuteCommandMesssage(cmdline=["test", "one", "two"], - pipe_stdin=False, - pipe_stdout=True, - pipe_stderr=False, - tcflags=[12, 34, 56, [78, 90]], - term="xtermmmm") - rx_message = message.__class__() - rx_message.unpack(message.pack()) - - assert isinstance(rx_message, message.__class__) - assert rx_message.cmdline == message.cmdline - assert rx_message.pipe_stdin == message.pipe_stdin - assert rx_message.pipe_stdout == message.pipe_stdout - assert rx_message.pipe_stderr == message.pipe_stderr - assert rx_message.tcflags == message.tcflags - assert rx_message.term == message.term - - -def test_send_receive_windowsize(): - message = rnsh.protocol.WindowSizeMessage(1, 2, 3, 4) - rx_message = message.__class__() - rx_message.unpack(message.pack()) - - assert isinstance(rx_message, message.__class__) - assert rx_message.rows == message.rows - assert rx_message.cols == message.cols - assert rx_message.hpix == message.hpix - assert rx_message.vpix == message.vpix - - -def test_send_receive_versioninfo(): - message = rnsh.protocol.VersionInfoMessage(sw_version="1.2.3") - message.protocol_version = 30 - rx_message = message.__class__() - rx_message.unpack(message.pack()) - - assert isinstance(rx_message, message.__class__) - assert rx_message.sw_version == message.sw_version - assert rx_message.protocol_version == message.protocol_version - - -def test_send_receive_error(): - message = rnsh.protocol.ErrorMessage(msg="TESTerr", - fatal=True, - data={"one": 2}) - rx_message = message.__class__() - rx_message.unpack(message.pack()) - - assert isinstance(rx_message, message.__class__) - assert rx_message.msg == message.msg - assert rx_message.fatal == message.fatal - assert rx_message.data == message.data - - -def test_send_receive_cmdexit(): - message = rnsh.protocol.CommandExitedMessage(5) - rx_message = message.__class__() - rx_message.unpack(message.pack()) - - assert isinstance(rx_message, message.__class__) - assert rx_message.return_code == message.return_code - - - - - diff --git a/tests/test_retry.py b/tests/test_retry.py deleted file mode 100644 index 24f5485..0000000 --- a/tests/test_retry.py +++ /dev/null @@ -1,162 +0,0 @@ -import uuid -import time -from types import TracebackType -from typing import Type - -import rnsh.retry -from contextlib import AbstractContextManager -import logging -logging.getLogger().setLevel(logging.DEBUG) - - -class State(AbstractContextManager): - def __init__(self, delay: float): - self.delay = delay - self.retry_thread = rnsh.retry.RetryThread(self.delay / 10.0) - self.tries = 0 - self.callbacks = 0 - self.timed_out = False - self.tag = str(uuid.uuid4()) - self.results = [self.tag, self.tag, self.tag] - self.got_tag = None - assert self.retry_thread.is_alive() - - def cleanup(self): - self.retry_thread.wait() - assert self.tries != 0 - self.retry_thread.close() - assert not self.retry_thread.is_alive() - - def retry(self, tag, tries): - self.tries = tries - self.got_tag = tag - self.callbacks += 1 - return self.results[tries - 1] - - def timeout(self, tag, tries): - self.tries = tries - self.got_tag = tag - self.timed_out = True - self.callbacks += 1 - - def __exit__(self, __exc_type: Type[BaseException], __exc_value: BaseException, - __traceback: TracebackType) -> bool: - self.cleanup() - return False - - -def test_retry_timeout(): - - with State(0.1) as state: - return_tag = state.retry_thread.begin(try_limit=3, - wait_delay=state.delay, - try_callback=state.retry, - timeout_callback=state.timeout) - assert return_tag == state.tag - assert state.tries == 1 - assert state.callbacks == 1 - assert state.got_tag is None - assert not state.timed_out - time.sleep(state.delay / 2.0) - time.sleep(state.delay) - assert state.tries == 2 - assert state.callbacks == 2 - assert state.got_tag == state.tag - assert not state.timed_out - time.sleep(state.delay) - assert state.tries == 3 - assert state.callbacks == 3 - assert state.got_tag == state.tag - assert not state.timed_out - - # check timeout - time.sleep(state.delay) - assert state.tries == 3 - assert state.callbacks == 4 - assert state.got_tag == state.tag - assert state.timed_out - - # check no more callbacks - time.sleep(state.delay * 3.0) - assert state.callbacks == 4 - assert state.tries == 3 - - -def test_retry_immediate_complete(): - with State(0.01) as state: - state.results[0] = False - return_tag = state.retry_thread.begin(try_limit=3, - wait_delay=state.delay, - try_callback=state.retry, - timeout_callback=state.timeout) - assert not return_tag - assert state.callbacks == 1 - assert not state.got_tag - assert not state.timed_out - time.sleep(state.delay * 3) - assert state.tries == 1 - assert state.callbacks == 1 - assert not state.got_tag - assert not state.timed_out - - -def test_retry_return_complete(): - with State(0.01) as state: - state.results[1] = False - return_tag = state.retry_thread.begin(try_limit=3, - wait_delay=state.delay, - try_callback=state.retry, - timeout_callback=state.timeout) - assert return_tag == state.tag - assert state.callbacks == 1 - assert state.got_tag is None - assert not state.timed_out - time.sleep(state.delay / 2.0) - time.sleep(state.delay) - assert state.tries == 2 - assert state.callbacks == 2 - assert state.got_tag == state.tag - assert not state.timed_out - - time.sleep(state.delay) - assert state.tries == 2 - assert state.callbacks == 2 - assert state.got_tag == state.tag - assert not state.timed_out - - # check no more callbacks - time.sleep(state.delay * 3.0) - assert state.callbacks == 2 - assert state.tries == 2 - - -def test_retry_set_complete(): - with State(0.01) as state: - return_tag = state.retry_thread.begin(try_limit=3, - wait_delay=state.delay, - try_callback=state.retry, - timeout_callback=state.timeout) - assert return_tag == state.tag - assert state.callbacks == 1 - assert state.got_tag is None - assert not state.timed_out - time.sleep(state.delay / 2.0) - time.sleep(state.delay) - assert state.tries == 2 - assert state.callbacks == 2 - assert state.got_tag == state.tag - assert not state.timed_out - - state.retry_thread.complete(state.tag) - - time.sleep(state.delay) - assert state.tries == 2 - assert state.callbacks == 2 - assert state.got_tag == state.tag - assert not state.timed_out - - # check no more callbacks - time.sleep(state.delay * 3.0) - assert state.callbacks == 2 - assert state.tries == 2 - diff --git a/tests/test_rnsh.py b/tests/test_rnsh.py deleted file mode 100644 index 6cde405..0000000 --- a/tests/test_rnsh.py +++ /dev/null @@ -1,223 +0,0 @@ -import logging -logging.getLogger().setLevel(logging.DEBUG) - -import tests.helpers -import rnsh.rnsh -import rnsh.process -import shlex -import pytest -import time -import asyncio -import re -import os - - -def test_version(): - assert rnsh.__version__ != "0.0.0" - assert rnsh.__version__ != "0.0.1" - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_wrapper(): - with tests.helpers.tempdir() as td: - with tests.helpers.SubprocessReader(argv=shlex.split(f"date")) as wrapper: - wrapper.start() - assert wrapper.process is not None - assert wrapper.process.running - await asyncio.sleep(1) - text = wrapper.read().decode("utf-8") - assert len(text) > 5 - assert not wrapper.process.running - - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_rnsh_listen_start_stop(): - with tests.helpers.tempdir() as td: - with tests.helpers.SubprocessReader(argv=shlex.split(f"poetry run rnsh -l --config \"{td}\" -n -C -vvvvvv -- /bin/ls")) as wrapper: - wrapper.start() - await asyncio.sleep(0.1) - assert wrapper.process.running - # wait for process to start up - await asyncio.sleep(3) - # read the output - text = wrapper.read().decode("utf-8") - # listener should have printed "listening - assert text.index("listening") is not None - # stop process with SIGINT - wrapper.process.write(rnsh.process.CTRL_C) - # wait for process to wind down - start_time = time.time() - while wrapper.process.running and time.time() - start_time < 5: - await asyncio.sleep(0.1) - assert not wrapper.process.running - - -async def get_listener_id_and_dest(td: str) -> tuple[str, str]: - with tests.helpers.SubprocessReader(name="getid", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" -p")) as wrapper: - wrapper.start() - await asyncio.sleep(0.1) - assert wrapper.process.running - # wait for process to start up - await tests.helpers.wait_for_condition_async(lambda: not wrapper.process.running, 5) - assert not wrapper.process.running - await asyncio.sleep(2) - # read the output - text = wrapper.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert text.index("Using service name \"default\"") is not None - assert text.index("Identity") is not None - match = re.search(r"<([a-f0-9]{32})>[^<]+<([a-f0-9]{32})>", text) - assert match is not None - ih = match.group(1) - assert len(ih) == 32 - dh = match.group(2) - assert len(dh) == 32 - await asyncio.sleep(0.1) - return ih, dh - - -async def get_initiator_id(td: str) -> str: - with tests.helpers.SubprocessReader(name="getid", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" -p")) as wrapper: - wrapper.start() - await asyncio.sleep(0.1) - assert wrapper.process.running - # wait for process to start up - await tests.helpers.wait_for_condition_async(lambda: not wrapper.process.running, 5) - assert not wrapper.process.running - # read the output - text = wrapper.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert text.index("Identity") is not None - match = re.search(r"<([a-f0-9]{32})>", text) - assert match is not None - ih = match.group(1) - assert len(ih) == 32 - await asyncio.sleep(0.1) - return ih - - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_rnsh_get_listener_id_and_dest() -> [int]: - with tests.helpers.tempdir() as td: - ih, dh = await get_listener_id_and_dest(td) - assert len(ih) == 32 - assert len(dh) == 32 - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_rnsh_get_initiator_id() -> [int]: - with tests.helpers.tempdir() as td: - ih = await get_initiator_id(td) - assert len(ih) == 32 - - -async def do_connected_test(listener_args: str, initiator_args: str, test: callable): - with tests.helpers.tempdir() as td: - ih, dh = await get_listener_id_and_dest(td) - iih = await get_initiator_id(td) - assert len(ih) == 32 - assert len(dh) == 32 - assert len(iih) == 32 - assert "dh" in initiator_args - initiator_args = initiator_args.replace("dh", dh) - listener_args = listener_args.replace("iih", iih) - with tests.helpers.SubprocessReader(name="listener", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" {listener_args}")) as listener, \ - tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -q -c \"{td}\" {initiator_args}")) as initiator: - # listener startup - listener.start() - await asyncio.sleep(0.1) - assert listener.process.running - # wait for process to start up - await asyncio.sleep(5) - # read the output - text = listener.read().decode("utf-8") - assert text.index(dh) is not None - - # initiator run - initiator.start() - assert initiator.process.running - - await test(td, ih, dh, iih, listener, initiator) - - # expect test to shut down initiator - assert not initiator.process.running - - # stop process with SIGINT - listener.process.write(rnsh.process.CTRL_C) - # wait for process to wind down - start_time = time.time() - while listener.process.running and time.time() - start_time < 5: - await asyncio.sleep(0.1) - assert not listener.process.running - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_rnsh_get_echo_through(): - cwd = os.getcwd() - - async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, - initiator: tests.helpers.SubprocessReader): - start_time = time.time() - while initiator.return_code is None and time.time() - start_time < 3: - await asyncio.sleep(0.1) - text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert text == cwd - - await do_connected_test("-n -C -- /bin/pwd", "dh", test) - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_rnsh_no_ident(): - cwd = os.getcwd() - - async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, - initiator: tests.helpers.SubprocessReader): - start_time = time.time() - while initiator.return_code is None and time.time() - start_time < 3: - await asyncio.sleep(0.1) - text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert text == cwd - - await do_connected_test("-n -C -- /bin/pwd", "-N dh", test) - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_rnsh_invalid_ident(): - cwd = os.getcwd() - - async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, - initiator: tests.helpers.SubprocessReader): - start_time = time.time() - while initiator.return_code is None and time.time() - start_time < 3: - await asyncio.sleep(0.1) - text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert "not allowed" in text - - await do_connected_test("-a 12345678901234567890123456789012 -C -- /bin/pwd", "dh", test) - - -@pytest.mark.skip_ci -@pytest.mark.asyncio -async def test_rnsh_valid_ident(): - cwd = os.getcwd() - - async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, - initiator: tests.helpers.SubprocessReader): - start_time = time.time() - while initiator.return_code is None and time.time() - start_time < 3: - await asyncio.sleep(0.1) - text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert (text == cwd) - - await do_connected_test("-a iih -C -- /bin/pwd", "dh", test) - - - - From 374e73a582e4eae90721717450287c3e41fb0abd Mon Sep 17 00:00:00 2001 From: main Date: Mon, 17 Mar 2025 22:12:09 -0700 Subject: [PATCH 2/2] announcement + retry on send --- rnsh/socksext/counterpart.py | 62 ++++++++++++++++++++++++++++-------- rnsh/socksext/socksproxy.py | 32 ++++++++++++------- 2 files changed, 69 insertions(+), 25 deletions(-) diff --git a/rnsh/socksext/counterpart.py b/rnsh/socksext/counterpart.py index 6069064..63f519e 100644 --- a/rnsh/socksext/counterpart.py +++ b/rnsh/socksext/counterpart.py @@ -2,6 +2,7 @@ import sys import threading import socket +import time import RNS from rnsh.socksext.socksproxy import SOCKS_APP_NAME @@ -11,7 +12,7 @@ class SOCKS5CounterPart: - def __init__(self): + def __init__(self, announce_interval: int = 60): self.reticulum = RNS.Reticulum(configdir=None, loglevel=RNS.LOG_INFO) self.identity = self.load_or_create_identity() self.destination = RNS.Destination( @@ -21,6 +22,8 @@ def __init__(self): self.connections = {} self.lock = threading.Lock() self.next_link_id = 0 + self.running = False + self.announce_interval = announce_interval def load_or_create_identity(self): if os.path.exists(COUNTERPART_IDENTITY_FILE): @@ -34,18 +37,17 @@ def load_or_create_identity(self): def handle_message(self, message: RequestMessage, link_id: int): try: - # Parse bytes directly data = message.data parts = data.split(b":", 2) if len(parts) < 2: print(f"Invalid message format on link {link_id}") return - command = parts[0].decode('utf-8') # Command is ASCII - handler_id = int(parts[1].decode('utf-8')) # Handler ID is ASCII + command = parts[0].decode('utf-8') + handler_id = int(parts[1].decode('utf-8')) payload = parts[2] if len(parts) > 2 else b"" if command == "CONNECT": - addr, port = payload.decode('utf-8').split(":", 1) # CONNECT payload is text + addr, port = payload.decode('utf-8').split(":", 1) port = int(port) print(f"Received CONNECT {handler_id} for {addr}:{port}") sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -67,20 +69,43 @@ def handle_message(self, message: RequestMessage, link_id: int): print(f"Error handling message: {e}") def relay_from_destination(self, handler_id: int, sock: socket.socket, link_id: int): + max_retries = 5 + retry_delay = 1 # seconds try: while True: - data = sock.recv(4096) - if not data: - break + try: + data = sock.recv(4096) + if not data: + print(f"Destination closed for handler {handler_id}") + break + except socket.error as e: + print(f"Socket recv error for handler {handler_id}: {e}") + break # Exit on recv error (destination likely closed) + with self.lock: - if link_id in self.channels: - channel = self.channels[link_id] - response = RequestMessage() - response.data = f"DATA:{handler_id}:".encode() + data + if link_id not in self.channels: + print(f"Channel for link {link_id} gone for handler {handler_id}") + break + channel = self.channels[link_id] + response = RequestMessage() + response.data = f"DATA:{handler_id}:".encode() + data + + retries = 0 + while retries < max_retries: + try: channel.send(response) print(f"Sent {len(data)} bytes back for {handler_id} on link {link_id}") + break + except Exception as e: + retries += 1 + print(f"Channel send failed for handler {handler_id} (retry {retries}/{max_retries}): {e}") + if retries == max_retries: + print(f"Max retries reached for handler {handler_id}, giving up") + return # Exit thread if retries exhausted + time.sleep(retry_delay) + except Exception as e: - print(f"Error relaying from destination for {handler_id}: {e}") + print(f"Unexpected error in relay for handler {handler_id}: {e}") finally: with self.lock: if handler_id in self.connections: @@ -97,10 +122,20 @@ def link_established(self, link): with self.lock: self.channels[link_id] = channel + def announce_loop(self): + while self.running: + self.destination.announce() + print(f"Announced destination {self.destination.hash.hex()}") + time.sleep(self.announce_interval) + def run(self): print(f"Destination hash: {self.destination.hash.hex()}") self.destination.set_link_established_callback(self.link_established) self.destination.accepts_links(True) + self.running = True + + threading.Thread(target=self.announce_loop, daemon=True).start() + self.destination.announce() print("Counterpart running. Press Ctrl+C to exit.") sys.stdout.flush() @@ -109,6 +144,7 @@ def run(self): threading.Event().wait() except KeyboardInterrupt: print("Shutting down...") + self.running = False self.reticulum.exit_handler() sys.stdout.flush() sys.exit(0) \ No newline at end of file diff --git a/rnsh/socksext/socksproxy.py b/rnsh/socksext/socksproxy.py index b9771f6..f673431 100644 --- a/rnsh/socksext/socksproxy.py +++ b/rnsh/socksext/socksproxy.py @@ -159,15 +159,7 @@ def __init__(self, destination_hash: bytes, pool_size: int = 1, configdir: str = self.responses = {} print("Initializing Reticulum network...") self.identity = self.load_or_create_identity() - self.target_identity = RNS.Identity.recall(destination_hash) - if not self.target_identity: - print(f"Waiting for identity of {destination_hash.hex()}...") - timeout = time.time() + 10 - while not self.target_identity and time.time() < timeout: - self.target_identity = RNS.Identity.recall(destination_hash) - time.sleep(1) - if not self.target_identity: - raise RuntimeError(f"Could not recall identity for {destination_hash.hex()}") + self.target_identity = self.wait_for_identity(self.destination_hash) self.target_destination = RNS.Destination( self.target_identity, RNS.Destination.OUT, @@ -186,6 +178,23 @@ def load_or_create_identity(self): print("Created and saved proxy identity") return identity + def wait_for_identity(self, destination_hash: bytes, timeout: int = 60*30): + """Wait for the identity to be recalled, with path request""" + target_identity = RNS.Identity.recall(destination_hash) + if not target_identity: + print(f"Waiting for identity of {destination_hash.hex()}...") + RNS.Transport.request_path(destination_hash) # Force path discovery + start_time = time.time() + while not target_identity and time.time() - start_time < timeout: + target_identity = RNS.Identity.recall(destination_hash) + if not target_identity: + time.sleep(1) + print(f"Still waiting... Elapsed: {int(time.time() - start_time)}s") + if not target_identity: + raise RuntimeError(f"Could not recall identity for {destination_hash.hex()} after {timeout}s") + print(f"Identity recalled for {destination_hash.hex()}") + return target_identity + def start(self): self.running = True threading.Thread(target=self.maintain_pool, daemon=True).start() @@ -301,14 +310,13 @@ def send_data(self, handler_id: int, data: bytes): def handle_channel_message(self, message, link_id): try: - # Parse bytes directly, no full decode data = message.data parts = data.split(b":", 2) if len(parts) < 2: print(f"Invalid message format on link {link_id}") return - command = parts[0].decode('utf-8') # Command is ASCII - handler_id = int(parts[1].decode('utf-8')) # Handler ID is ASCII + command = parts[0].decode('utf-8') + handler_id = int(parts[1].decode('utf-8')) payload = parts[2] if len(parts) > 2 else b"" with self.lock: