diff --git a/rust/src/http_client.rs b/rust/src/http_client.rs index e67dae169f6..ca4bf1590b8 100644 --- a/rust/src/http_client.rs +++ b/rust/src/http_client.rs @@ -137,7 +137,7 @@ fn get_runtime<'a>(reactor: &Bound<'a, PyAny>) -> PyResult = OnceCell::new(); /// Access to the `twisted.internet.defer` module. -fn defer(py: Python<'_>) -> PyResult<&Bound> { +fn defer(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> { Ok(DEFER .get_or_try_init(|| py.import("twisted.internet.defer").map(Into::into))? .bind(py)) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 6522148fa15..46e22a5cd6f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -11,6 +11,7 @@ pub mod http; pub mod http_client; pub mod identifier; pub mod matrix_const; +pub mod msc4388_rendezvous; pub mod push; pub mod rendezvous; pub mod segmenter; @@ -54,6 +55,7 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { events::register_module(py, m)?; http_client::register_module(py, m)?; rendezvous::register_module(py, m)?; + msc4388_rendezvous::register_module(py, m)?; segmenter::register_module(py, m)?; Ok(()) diff --git a/rust/src/msc4388_rendezvous/mod.rs b/rust/src/msc4388_rendezvous/mod.rs new file mode 100644 index 00000000000..02204f51d23 --- /dev/null +++ b/rust/src/msc4388_rendezvous/mod.rs @@ -0,0 +1,381 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2025 Element Creations, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +use std::{ + collections::BTreeMap, + time::{Duration, SystemTime}, +}; + +use bytes::Bytes; +use headers::{ + AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin, HeaderMapExt, +}; +use http::header::HeaderName; +use http::{header, HeaderMap, Method, Response, StatusCode}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyModule, PyModuleMethods}, + Bound, IntoPyObject, Py, PyAny, PyObject, PyResult, Python, +}; +use ulid::Ulid; + +use self::session::Session; +use crate::{ + errors::{NotFoundError, SynapseError}, + http::{http_request_from_twisted, http_response_to_twisted}, + UnwrapInfallible, +}; + +mod session; + +// Annoyingly we need to set the normal CORS headers on every response as the Python layer doesn't do it for us. +// List is taken from https://spec.matrix.org/v1.16/client-server-api/#web-browser-clients +fn prepare_headers(headers: &mut HeaderMap) { + headers.typed_insert(AccessControlAllowOrigin::ANY); + headers.typed_insert(AccessControlAllowMethods::from_iter([ + Method::POST, + Method::GET, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ])); + headers.typed_insert(AccessControlAllowHeaders::from_iter([ + HeaderName::from_static("x-requested-with"), + header::CONTENT_TYPE, + header::AUTHORIZATION, + ])); +} + +#[pyclass] +struct MSC4388RendezvousHandler { + clock: PyObject, + sessions: BTreeMap, + capacity: usize, + max_content_length: u64, + ttl: Duration, +} + +impl MSC4388RendezvousHandler { + /// Check the length of the data parameter and throw error if invalid. + fn check_data_length(&self, data: &str) -> PyResult<()> { + let data_length = data.len() as u64; + if data_length > self.max_content_length { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + return Err(SynapseError::new( + StatusCode::PAYLOAD_TOO_LARGE, + "Payload too large".to_owned(), + "M_TOO_LARGE", + None, + Some(headers), + )); + } + Ok(()) + } + + /// Evict expired sessions and remove the oldest sessions until we're under the capacity. + fn evict(&mut self, now: SystemTime) { + // First remove all the entries which expired + self.sessions.retain(|_, session| !session.expired(now)); + + // Then we remove the oldest entries until we're under the limit + while self.sessions.len() > self.capacity { + self.sessions.pop_first(); + } + } +} + +#[pymethods] +impl MSC4388RendezvousHandler { + #[new] + #[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=2*60*1000))] + fn new( + py: Python<'_>, + homeserver: &Bound<'_, PyAny>, + capacity: usize, + max_content_length: u64, + eviction_interval: u64, + ttl: u64, + ) -> PyResult> { + let clock = homeserver + .call_method0("get_clock")? + .into_pyobject(py) + .unwrap_infallible() + .unbind(); + + // Construct a Python object so that we can get a reference to the + // evict method and schedule it to run. + let self_ = Py::new( + py, + Self { + clock, + sessions: BTreeMap::new(), + capacity, + max_content_length, + ttl: Duration::from_millis(ttl), + }, + )?; + + let evict = self_.getattr(py, "_evict")?; + homeserver.call_method0("get_clock")?.call_method( + "looping_call", + (evict, eviction_interval), + None, + )?; + + Ok(self_) + } + + fn _evict(&mut self, py: Python<'_>) -> PyResult<()> { + let clock = self.clock.bind(py); + let now: u64 = clock.call_method0("time_msec")?.extract()?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + self.evict(now); + + Ok(()) + } + + fn handle_post(&mut self, py: Python<'_>, twisted_request: &Bound<'_, PyAny>) -> PyResult<()> { + let request = http_request_from_twisted(twisted_request)?; + + let clock = self.clock.bind(py); + let now: u64 = clock.call_method0("time_msec")?.extract()?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + + // We trigger an immediate eviction if we're at 2x the capacity + if self.sessions.len() >= self.capacity * 2 { + self.evict(now); + } + + // Generate a new ULID for the session from the current time. + let id = Ulid::from_datetime(now); + + // parse JSON body out of request + let json: serde_json::Value = + serde_json::from_slice(&request.into_body()).map_err(|_| { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + SynapseError::new( + StatusCode::BAD_REQUEST, + "Invalid JSON in request body".to_owned(), + "M_INVALID_PARAM", + None, + Some(headers), + ) + })?; + + let data: String = json["data"].as_str().map(|s| s.to_owned()).ok_or_else(|| { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + SynapseError::new( + StatusCode::BAD_REQUEST, + "Missing 'data' field in JSON body".to_owned(), + "M_INVALID_PARAM", + None, + Some(headers), + ) + })?; + + self.check_data_length(&data)?; + + let session = Session::new(id, data, now, self.ttl); + + let response_body = serde_json::to_string(&session.post_response()).map_err(|_| { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + SynapseError::new( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to serialize response".to_owned(), + "M_UNKNOWN", + None, + Some(headers), + ) + })?; + let mut response = Response::new(response_body.as_bytes()); + *response.status_mut() = StatusCode::OK; + let headers = response.headers_mut(); + prepare_headers(headers); + http_response_to_twisted(twisted_request, response)?; + + self.sessions.insert(id, session); + + Ok(()) + } + + fn handle_get( + &mut self, + py: Python<'_>, + twisted_request: &Bound<'_, PyAny>, + id: &str, + ) -> PyResult<()> { + let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let session = self + .sessions + .get(&id) + .filter(|s| !s.expired(now)) + .ok_or_else(NotFoundError::new)?; + + let response_body = serde_json::to_string(&session.get_response()).map_err(|_| { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + SynapseError::new( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to serialize response".to_owned(), + "M_UNKNOWN", + None, + Some(headers), + ) + })?; + let mut response = Response::new(response_body.as_bytes()); + *response.status_mut() = StatusCode::OK; + prepare_headers(response.headers_mut()); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) + } + + fn handle_put( + &mut self, + py: Python<'_>, + twisted_request: &Bound<'_, PyAny>, + id: &str, + ) -> PyResult<()> { + let request = http_request_from_twisted(twisted_request)?; + + // parse JSON body out of request + let json: serde_json::Value = + serde_json::from_slice(&request.into_body()).map_err(|_| { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + SynapseError::new( + StatusCode::BAD_REQUEST, + "Invalid JSON in request body".to_owned(), + "M_INVALID_PARAM", + None, + Some(headers), + ) + })?; + + let sequence_token: String = json["sequence_token"] + .as_str() + .map(|s| s.to_owned()) + .ok_or_else(|| { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + SynapseError::new( + StatusCode::BAD_REQUEST, + "Missing 'sequence_token' field in JSON body".to_owned(), + "M_INVALID_PARAM", + None, + Some(headers), + ) + })?; + + let data: String = json["data"].as_str().map(|s| s.to_owned()).ok_or_else(|| { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + SynapseError::new( + StatusCode::BAD_REQUEST, + "Missing 'data' field in JSON body".to_owned(), + "M_INVALID_PARAM", + None, + Some(headers), + ) + })?; + + self.check_data_length(&data)?; + + let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let session = self + .sessions + .get_mut(&id) + .filter(|s| !s.expired(now)) + .ok_or_else(NotFoundError::new)?; + + if !session.sequence_token().eq(&sequence_token) { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers); + + return Err(SynapseError::new( + StatusCode::CONFLICT, + "sequence_token does not match".to_owned(), + "IO_ELEMENT_MSC4388_CONCURRENT_WRITE", + None, + Some(headers), + )); + } + + session.update(data, now); + + let response_body = serde_json::to_string(&session.put_response()).map_err(|_| { + SynapseError::new( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to serialize response".to_owned(), + "M_UNKNOWN", + None, + None, + ) + })?; + let mut response = Response::new(response_body.as_bytes()); + *response.status_mut() = StatusCode::OK; + prepare_headers(response.headers_mut()); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) + } + + fn handle_delete(&mut self, twisted_request: &Bound<'_, PyAny>, id: &str) -> PyResult<()> { + let _request = http_request_from_twisted(twisted_request)?; + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?; + + let mut response = Response::new(Bytes::new()); + *response.status_mut() = StatusCode::OK; + prepare_headers(response.headers_mut()); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) + } +} + +pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + let child_module = PyModule::new(py, "msc4388_rendezvous")?; + + child_module.add_class::()?; + + m.add_submodule(&child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import rendezvous` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.msc4388_rendezvous", child_module)?; + + Ok(()) +} diff --git a/rust/src/msc4388_rendezvous/session.rs b/rust/src/msc4388_rendezvous/session.rs new file mode 100644 index 00000000000..87fcdad663e --- /dev/null +++ b/rust/src/msc4388_rendezvous/session.rs @@ -0,0 +1,109 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2025 Element Creations, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use serde::Serialize; +use sha2::{Digest, Sha256}; +use ulid::Ulid; + +/// A single session, containing data, metadata, and expiry information. +pub struct Session { + id: Ulid, + hash: [u8; 32], + data: String, + last_modified: SystemTime, + expires: SystemTime, +} + +#[derive(Serialize)] +pub struct PostResponse { + id: String, + sequence_token: String, + expires_ts: u64, +} + +#[derive(Serialize)] +pub struct GetResponse { + data: String, + sequence_token: String, + expires_ts: u64, +} + +#[derive(Serialize)] +pub struct PutResponse { + sequence_token: String, +} + +impl Session { + /// Create a new session with the given data and time-to-live. + pub fn new(id: Ulid, data: String, now: SystemTime, ttl: Duration) -> Self { + let hash = Sha256::digest(&data).into(); + Self { + id, + hash, + data, + expires: now + ttl, + last_modified: now, + } + } + + /// Returns true if the session has expired at the given time. + pub fn expired(&self, now: SystemTime) -> bool { + self.expires <= now + } + + /// Update the session with new data and last modified time. + pub fn update(&mut self, data: String, now: SystemTime) { + self.hash = Sha256::digest(&data).into(); + self.data = data; + self.last_modified = now; + } + + /// The sequence token for the session. + pub fn sequence_token(&self) -> String { + URL_SAFE_NO_PAD.encode(self.hash) + } + + pub fn get_response(&self) -> GetResponse { + GetResponse { + data: self.data.clone(), + sequence_token: self.sequence_token(), + expires_ts: self + .expires + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64, + } + } + + pub fn post_response(&self) -> PostResponse { + PostResponse { + id: self.id.to_string(), + sequence_token: self.sequence_token(), + expires_ts: self + .expires + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64, + } + } + + pub fn put_response(&self) -> PutResponse { + PutResponse { + sequence_token: self.sequence_token(), + } + } +} diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f82e8572f22..b266bb174f0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -523,7 +523,7 @@ def read_config( "msc4069_profile_inhibit_propagation", False ) - # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code + # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code - 2024 version: self.msc4108_enabled = experimental.get("msc4108_enabled", False) self.msc4108_delegation_endpoint: Optional[str] = experimental.get( @@ -548,6 +548,25 @@ def read_config( ("experimental", "msc4108_delegation_endpoint"), ) + # MSC4388: Secure out-of-band channel for sign in with QR: + msc4388_mode = experimental.get("msc4388_mode", "off") + + if ["off", "public", "authenticated"].count(msc4388_mode) != 1: + raise ConfigError( + "msc4388_mode must be one of 'off', 'public' or 'authenticated'", + ("experimental", "msc4388_mode"), + ) + self.msc4388_enabled: bool = msc4388_mode != "off" + self.msc4388_requires_authentication: bool = msc4388_mode == "authenticated" + + if self.msc4388_enabled and not ( + config.get("matrix_authentication_service") or {} + ).get("enabled", False): + raise ConfigError( + "MSC4388 requires matrix_authentication_service to be enabled", + ("experimental", "msc4388_enabled"), + ) + # MSC4133: Custom profile fields self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py index a1808847f0a..388942711df 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py @@ -68,9 +68,55 @@ def on_POST(self, request: SynapseRequest) -> None: self._handler.handle_post(request) +class MSC4388CreateRendezvousServlet(RestServlet): + PATTERNS = client_patterns( + "/io.element.msc4388/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer") -> None: + super().__init__() + self._handler = hs.get_msc4388_rendezvous_handler() + self.auth = hs.get_auth() + self.require_authentication = ( + hs.config.experimental.msc4388_requires_authentication + ) + + async def on_POST(self, request: SynapseRequest) -> None: + if self.require_authentication: + # This will raise if the user is not authenticated + await self.auth.get_user_by_req(request) + self._handler.handle_post(request) + + +class MSC4388UpdateRendezvousServlet(RestServlet): + PATTERNS = client_patterns( + "/io.element.msc4388/rendezvous/(?P[^/]+)$", + releases=[], + v1=False, + unstable=True, + ) + + def __init__(self, hs: "HomeServer") -> None: + super().__init__() + self._handler = hs.get_msc4388_rendezvous_handler() + + def on_GET(self, request: SynapseRequest, rendezvous_id: str) -> None: + self._handler.handle_get(request, rendezvous_id) + + def on_PUT(self, request: SynapseRequest, rendezvous_id: str) -> None: + self._handler.handle_put(request, rendezvous_id) + + def on_DELETE(self, request: SynapseRequest, rendezvous_id: str) -> None: + self._handler.handle_delete(request, rendezvous_id) + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc4108_enabled: MSC4108RendezvousServlet(hs).register(http_server) if hs.config.experimental.msc4108_delegation_endpoint is not None: MSC4108DelegationRendezvousServlet(hs).register(http_server) + + if hs.config.experimental.msc4388_enabled: + MSC4388CreateRendezvousServlet(hs).register(http_server) + MSC4388UpdateRendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index dee2cdb637b..db00a3ecb33 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -161,7 +161,7 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: "org.matrix.msc4069": self.config.experimental.msc4069_profile_inhibit_propagation, # Allows clients to handle push for encrypted events. "org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events, - # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code + # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code - 2024 version "org.matrix.msc4108": ( self.config.experimental.msc4108_enabled or ( @@ -169,6 +169,8 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: is not None ) ), + # MSC4388: Secure out-of-band channel for sign in with QR + "io.element.msc4388": (self.config.experimental.msc4388_enabled), # MSC4140: Delayed events "org.matrix.msc4140": bool(self.config.server.max_event_delay_ms), # Simplified sliding sync diff --git a/synapse/server.py b/synapse/server.py index 2c252ce86fd..c174c767dd0 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -170,6 +170,7 @@ from synapse.storage import Databases from synapse.storage.controllers import StorageControllers from synapse.streams.events import EventSources +from synapse.synapse_rust.msc4388_rendezvous import MSC4388RendezvousHandler from synapse.synapse_rust.rendezvous import RendezvousHandler from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import SYNAPSE_VERSION @@ -1156,6 +1157,10 @@ def get_room_forgetter_handler(self) -> RoomForgetterHandler: def get_rendezvous_handler(self) -> RendezvousHandler: return RendezvousHandler(self) + @cache_in_self + def get_msc4388_rendezvous_handler(self) -> MSC4388RendezvousHandler: + return MSC4388RendezvousHandler(self) + @cache_in_self def get_outbound_redis_connection(self) -> "ConnectionHandler": """ diff --git a/synapse/synapse_rust/msc4388_rendezvous.pyi b/synapse/synapse_rust/msc4388_rendezvous.pyi new file mode 100644 index 00000000000..a22cf3a017d --- /dev/null +++ b/synapse/synapse_rust/msc4388_rendezvous.pyi @@ -0,0 +1,30 @@ +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from twisted.web.iweb import IRequest + +from synapse.server import HomeServer + +class MSC4388RendezvousHandler: + def __init__( + self, + homeserver: HomeServer, + /, + capacity: int = 100, # This should be configurable + max_content_length: int = 4 * 1024, # MSC4388 specifies maximum of 4KB + eviction_interval: int = 60 * 1000, + ttl: int = 2 * 60 * 1000, # MSC4388 specifies minimum of 120 seconds + ) -> None: ... + def handle_post(self, request: IRequest) -> None: ... + def handle_get(self, request: IRequest, session_id: str) -> None: ... + def handle_put(self, request: IRequest, session_id: str) -> None: ... + def handle_delete(self, request: IRequest, session_id: str) -> None: ... diff --git a/tests/rest/client/test_msc4388_rendezvous.py b/tests/rest/client/test_msc4388_rendezvous.py new file mode 100644 index 00000000000..5f996018f7a --- /dev/null +++ b/tests/rest/client/test_msc4388_rendezvous.py @@ -0,0 +1,632 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + + +import json +import urllib.parse +from typing import Any, Mapping +from unittest.mock import Mock + +from twisted.internet.testing import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, rendezvous +from synapse.server import HomeServer +from synapse.types import UserID +from synapse.util.clock import Clock + +from tests import unittest +from tests.unittest import checked_cast, override_config +from tests.utils import HAS_AUTHLIB + +rz_endpoint = "/_matrix/client/unstable/io.element.msc4388/rendezvous" + + +class RendezvousServletTestCase(unittest.HomeserverTestCase): + """ + Test the experimental MSC4388 rendezvous endpoint. + """ + + servlets = [ + admin.register_servlets, + login.register_servlets, + rendezvous.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + return self.hs + + def setup_mock_oauth(self) -> None: + """ + This isn't a very elegant way to mock the OAuth API, but it works for our purposes. + """ + + # Import this here so that we've checked that authlib is available. + from synapse.api.auth.mas import MasDelegatedAuth + + self.auth = checked_cast(MasDelegatedAuth, self.hs.get_auth()) + + self._rust_client = Mock(spec=["post"]) + self._rust_client.post = self._mock_oauth_response + self.auth._rust_http_client = self._rust_client + + async def _mock_oauth_response( + self, + url: str, + response_limit: int, + headers: Mapping[str, str], + request_body: str, + ) -> bytes: + # get the token from the request body which is form encoded + parsed_body = urllib.parse.parse_qs(request_body) + token = parsed_body.get("token", [""])[0] + + if not token.startswith("mock_token_"): + return bytes(json.dumps({"active": False}).encode("utf-8")) + token = token.replace("mock_token_", "") + + username, device_id = token.split("_", 1) + user_id = UserID(username, self.hs.hostname) + store = self.hs.get_datastores().main + + # Check th user exists in the store + user_info = await store.get_user_by_id(user_id=user_id.to_string()) + if user_info is None: + return bytes(json.dumps({"active": False}).encode("utf-8")) + + # Check the device exists in the store + device = await store.get_device( + user_id=user_id.to_string(), device_id=device_id + ) + if device is None: + return bytes(json.dumps({"active": False}).encode("utf-8")) + + return bytes( + json.dumps( + { + "active": True, + "scope": "urn:matrix:client:device:" + + device_id + + " urn:matrix:client:api:*", + "username": username, + } + ).encode("utf-8") + ) + + def register_oauth_user(self, username: str, device_id: str) -> str: + # Provision the user and the device + store = self.hs.get_datastores().main + user_id = UserID(username, self.hs.hostname) + + self.get_success(store.register_user(user_id=user_id.to_string())) + self.get_success( + store.store_device( + user_id=user_id.to_string(), + device_id=device_id, + initial_device_display_name=None, + ) + ) + # Generate an access token for the device + return "mock_token_" + username + "_" + device_id + + def test_disabled(self) -> None: + channel = self.make_request("POST", rz_endpoint, {}, access_token=None) + self.assertEqual(channel.code, 404) + + @override_config( + { + "experimental_features": { + "msc4388_mode": "off", + }, + } + ) + def test_off(self) -> None: + channel = self.make_request("POST", rz_endpoint, {}, access_token=None) + self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "matrix_authentication_service": { + "enabled": True, + "secret": "secret_value", + "endpoint": "https://issuer", + }, + "experimental_features": { + "msc4388_mode": "public", + }, + } + ) + def test_rendezvous_public(self) -> None: + """ + Test the MSC4108 rendezvous endpoint, including: + - Creating a session + - Getting the data back + - Updating the data + - Deleting the data + - Sequence token handling + """ + # We can post arbitrary data to the endpoint + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + rendezvous_id = channel.json_body["id"] + sequence_token = channel.json_body["sequence_token"] + expires_ts = channel.json_body["expires_ts"] + self.assertGreater(expires_ts, self.hs.get_clock().time_msec()) + + session_endpoint = rz_endpoint + f"/{rendezvous_id}" + + # We can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["data"], "foo=bar") + self.assertEqual(channel.json_body["sequence_token"], sequence_token) + self.assertEqual(channel.json_body["expires_ts"], expires_ts) + + # We can update the data + channel = self.make_request( + "PUT", + session_endpoint, + {"sequence_token": sequence_token, "data": "foo=baz"}, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + old_sequence_token = sequence_token + new_sequence_token = channel.json_body["sequence_token"] + + # If we try to update it again with the old etag, it should fail + channel = self.make_request( + "PUT", + session_endpoint, + {"sequence_token": old_sequence_token, "data": "bar=baz"}, + access_token=None, + ) + + self.assertEqual(channel.code, 409) + self.assertEqual( + channel.json_body["errcode"], "IO_ELEMENT_MSC4388_CONCURRENT_WRITE" + ) + + # We should get the updated data + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["data"], "foo=baz") + self.assertEqual(channel.json_body["sequence_token"], new_sequence_token) + self.assertEqual(channel.json_body["expires_ts"], expires_ts) + + # We can delete the data + channel = self.make_request( + "DELETE", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + + # If we try to get the data again, it should fail + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "matrix_authentication_service": { + "enabled": True, + "secret": "secret_value", + "endpoint": "https://issuer", + }, + "experimental_features": { + "msc4388_mode": "authenticated", + }, + } + ) + def test_rendezvous_requires_authentication(self) -> None: + """ + Test the MSC4108 rendezvous endpoint when configured with the mode authenticated, including: + - Creating a session + - Getting the data back + - Updating the data + - Deleting the data + - Sequence token handling + """ + self.setup_mock_oauth() + alice_token = self.register_oauth_user("alice", "device1") + + # This should fail without authentication: + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=None, + ) + self.assertEqual(channel.code, 401) + + # This should work as we are now authenticated + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=alice_token, + ) + self.assertEqual(channel.code, 200) + rendezvous_id = channel.json_body["id"] + sequence_token = channel.json_body["sequence_token"] + expires_ts = channel.json_body["expires_ts"] + self.assertGreater(expires_ts, self.hs.get_clock().time_msec()) + + session_endpoint = rz_endpoint + f"/{rendezvous_id}" + + # We can get the data back without authentication + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["data"], "foo=bar") + self.assertEqual(channel.json_body["sequence_token"], sequence_token) + self.assertEqual(channel.json_body["expires_ts"], expires_ts) + + # We can update the data without authentication + channel = self.make_request( + "PUT", + session_endpoint, + {"sequence_token": sequence_token, "data": "foo=baz"}, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + new_sequence_token = channel.json_body["sequence_token"] + + # We should get the updated data without authentication + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["data"], "foo=baz") + self.assertEqual(channel.json_body["sequence_token"], new_sequence_token) + self.assertEqual(channel.json_body["expires_ts"], expires_ts) + + # We can delete the data without authentication + channel = self.make_request( + "DELETE", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + + # If we try to get the data again, it should fail + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "matrix_authentication_service": { + "enabled": True, + "secret": "secret_value", + "endpoint": "https://issuer", + }, + "experimental_features": { + "msc4388_mode": "public", + }, + } + ) + def test_expiration(self) -> None: + """ + Test that entries are evicted after a TTL. + """ + # Start a new session + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + session_endpoint = rz_endpoint + "/" + channel.json_body["id"] + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["data"], "foo=bar") + + # Advance the clock, TTL of entries is 2 minutes + self.reactor.advance(120) + + # Get the data back, it should be gone + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "matrix_authentication_service": { + "enabled": True, + "secret": "secret_value", + "endpoint": "https://issuer", + }, + "experimental_features": { + "msc4388_mode": "public", + }, + } + ) + def test_capacity(self) -> None: + """ + Test that a capacity limit is enforced on the rendezvous sessions, as old + entries are evicted at an interval when the limit is reached. + """ + # Start a new session + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + session_endpoint = rz_endpoint + "/" + channel.json_body["id"] + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["data"], "foo=bar") + + # We advance the clock to make sure that this entry is the "lowest" in the session list + self.reactor.advance(1) + + # Start a lot of new sessions + for _ in range(100): + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + + # Get the data back, it should still be there, as the eviction hasn't run yet + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + + # Advance the clock, as it will trigger the eviction + self.reactor.advance(59) + + # Get the data back, it should be gone + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "matrix_authentication_service": { + "enabled": True, + "secret": "secret_value", + "endpoint": "https://issuer", + }, + "experimental_features": { + "msc4388_mode": "public", + }, + } + ) + def test_hard_capacity(self) -> None: + """ + Test that a hard capacity limit is enforced on the rendezvous sessions, as old + entries are evicted immediately when the limit is reached. + """ + # Start a new session + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + session_endpoint = rz_endpoint + "/" + channel.json_body["id"] + # We advance the clock to make sure that this entry is the "lowest" in the session list + self.reactor.advance(1) + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["data"], "foo=bar") + + # Start a lot of new sessions + for _ in range(200): + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "foo=bar"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + + # Get the data back, it should already be gone as we hit the hard limit + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "matrix_authentication_service": { + "enabled": True, + "secret": "secret_value", + "endpoint": "https://issuer", + }, + "experimental_features": { + "msc4388_mode": "public", + }, + } + ) + def test_data_type(self) -> None: + """ + Test that the data field is restricted to string. + """ + invalid_datas: list[Any] = [123214, ["asd"], {"asd": "asdsad"}, None] + + # We cannot post invalid non-string data field values to the endpoint + for invalid_data in invalid_datas: + channel = self.make_request( + "POST", + rz_endpoint, + {"data": invalid_data}, + access_token=None, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + + # Make a valid request + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "test"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + rendezvous_id = channel.json_body["id"] + sequence_token = channel.json_body["sequence_token"] + + session_endpoint = rz_endpoint + f"/{rendezvous_id}" + + # We can't update the data with invalid data + for invalid_data in invalid_datas: + channel = self.make_request( + "PUT", + session_endpoint, + {"sequence_token": sequence_token, "data": invalid_data}, + access_token=None, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "matrix_authentication_service": { + "enabled": True, + "secret": "secret_value", + "endpoint": "https://issuer", + }, + "experimental_features": { + "msc4388_mode": "public", + }, + } + ) + def test_max_length(self) -> None: + """ + Test that the data max length is restricted. + """ + too_long_data = "a" * 5000 # MSC4108 specifies 4KB max length + + channel = self.make_request( + "POST", + rz_endpoint, + {"data": too_long_data}, + access_token=None, + ) + self.assertEqual(channel.code, 413) + self.assertEqual(channel.json_body["errcode"], "M_TOO_LARGE") + + # Make a valid request + channel = self.make_request( + "POST", + rz_endpoint, + {"data": "test"}, + access_token=None, + ) + self.assertEqual(channel.code, 200) + rendezvous_id = channel.json_body["id"] + sequence_token = channel.json_body["sequence_token"] + + session_endpoint = rz_endpoint + f"/{rendezvous_id}" + + # We can't update the data with invalid data + channel = self.make_request( + "PUT", + session_endpoint, + {"sequence_token": sequence_token, "data": too_long_data}, + access_token=None, + ) + self.assertEqual(channel.code, 413) + self.assertEqual(channel.json_body["errcode"], "M_TOO_LARGE")