diff --git a/Sources/Adapters/GCDWebServer/GCDHTTPServer.swift b/Sources/Adapters/GCDWebServer/GCDHTTPServer.swift index d64bfcaa87..e1142c7aca 100644 --- a/Sources/Adapters/GCDWebServer/GCDHTTPServer.swift +++ b/Sources/Adapters/GCDWebServer/GCDHTTPServer.swift @@ -127,7 +127,7 @@ public final class GCDHTTPServer: HTTPServer, Loggable { for request: ReadiumGCDWebServerRequest, completion: @escaping (HTTPServerRequest, HTTPServerResponse, HTTPRequestHandler.OnFailure?) -> Void ) { - let completion = { request, resource, failureHandler in + let dispatchCompletion = { (request: HTTPServerRequest, resource: HTTPServerResponse, failureHandler: HTTPRequestHandler.OnFailure?) in // Escape the queue to avoid deadlocks if something is using the // server in the handler. DispatchQueue.global().async { @@ -166,20 +166,22 @@ public final class GCDHTTPServer: HTTPServer, Loggable { var response = handler.onRequest(request) response.resource = transform(resource: response.resource, request: request, at: endpoint) - completion(request, response, handler.onFailure) + dispatchCompletion(request, response, handler.onFailure) return } log(.warning, "Resource not found for request \(request)") - completion( + dispatchCompletion( HTTPServerRequest(url: url, href: nil), - HTTPServerResponse(error: .errorResponse(HTTPResponse( - request: HTTPRequest(url: url), - url: url, - status: .notFound, - headers: [:], - mediaType: nil, - body: nil + HTTPServerResponse(error: .errorResponse(HTTPFetchResponse( + response: HTTPResponse( + request: HTTPRequest(url: url), + url: url, + status: .notFound, + headers: [:], + mediaType: nil + ), + body: Data() ))), nil ) diff --git a/Sources/LCP/License/License.swift b/Sources/LCP/License/License.swift index e8fab8f935..51fedcc3af 100644 --- a/Sources/LCP/License/License.swift +++ b/Sources/LCP/License/License.swift @@ -223,7 +223,7 @@ extension License: LCPLicense { // done, in case it changed the License. return try await httpClient .fetch(HTTPRequest(url: statusURL, headers: ["Accept": MediaType.lcpStatusDocument.string])) - .map { $0.body ?? Data() } + .map(\.body) .get() } @@ -251,11 +251,11 @@ extension License: LCPLicense { let url = try await makeRenewURL(from: preferredEndDate()) return try await httpClient.fetch(HTTPRequest(url: url, method: .put)) - .map { $0.body ?? Data() } + .map(\.body) .mapError { error -> RenewError in switch error { case let .errorResponse(response): - switch response.status { + switch response.response.status { case .badRequest: return .renewFailed case .forbidden: @@ -299,7 +299,7 @@ extension License: LCPLicense { .mapError { error -> ReturnError in switch error { case let .errorResponse(response): - switch response.status { + switch response.response.status { case .badRequest: return .returnFailed case .forbidden: @@ -311,7 +311,7 @@ extension License: LCPLicense { return .unexpectedServerError(error) } } - .map { $0.body ?? Data() } + .map(\.body) .get() try await validateStatusDocument(data: data) diff --git a/Sources/LCP/License/LicenseValidation.swift b/Sources/LCP/License/LicenseValidation.swift index cc23150d00..2411610ee9 100644 --- a/Sources/LCP/License/LicenseValidation.swift +++ b/Sources/LCP/License/LicenseValidation.swift @@ -299,7 +299,7 @@ extension LicenseValidation { // Short timeout to avoid blocking the License, since the LSD is optional. timeoutInterval: 5 )) - .map { $0.body ?? Data() } + .map(\.body) .get() try await raise(.retrievedStatusData(data)) @@ -316,7 +316,7 @@ extension LicenseValidation { let data = try await httpClient // Short timeout to avoid blocking the License, since it can be updated next time. .fetch(HTTPRequest(url: url, timeoutInterval: 5)) - .map { $0.body ?? Data() } + .map(\.body) .get() try await raise(.retrievedLicenseData(data)) diff --git a/Sources/LCP/Services/CRLService.swift b/Sources/LCP/Services/CRLService.swift index 043a9413ed..aa1766cf40 100644 --- a/Sources/LCP/Services/CRLService.swift +++ b/Sources/LCP/Services/CRLService.swift @@ -53,9 +53,11 @@ final class CRLService { .mapError { _ in LCPError.crlFetching } .get() - guard let body = response.body?.base64EncodedString() else { + guard !response.body.isEmpty else { throw LCPError.crlFetching } + + let body = response.body.base64EncodedString() return "-----BEGIN X509 CRL-----\(body)-----END X509 CRL-----" } diff --git a/Sources/Shared/Toolkit/HTTP/DefaultHTTPClient.swift b/Sources/Shared/Toolkit/HTTP/DefaultHTTPClient.swift index f4ad5cca93..b4be9166df 100644 --- a/Sources/Shared/Toolkit/HTTP/DefaultHTTPClient.swift +++ b/Sources/Shared/Toolkit/HTTP/DefaultHTTPClient.swift @@ -23,17 +23,18 @@ public enum URLAuthenticationChallengeResponse: Sendable { public protocol DefaultHTTPClientDelegate: AnyObject { /// Tells the delegate that the HTTP client will start a new `request`. /// - /// Warning: You MUST call the `completion` handler with the request to start, otherwise the client will hang. - /// /// You can modify the `request`, for example by adding additional HTTP headers or redirecting to a different URL, - /// before calling the `completion` handler with the new request. + /// before returning the new request. + /// + /// - Note: If this method returns a failure, the request is aborted immediately and `httpClient(_:request:didFailWithError:)` + /// is NOT called. func httpClient(_ httpClient: DefaultHTTPClient, willStartRequest request: HTTPRequest) async -> HTTPResult /// Asks the delegate to recover from an `error` received for the given `request`. /// /// This can be used to implement custom authentication flows, for example. /// - /// You can call the `completion` handler with either: + /// You can return either: /// * a new request to start /// * the `error` argument, if you cannot recover from it /// * a new `HTTPError` to provide additional information @@ -51,8 +52,8 @@ public protocol DefaultHTTPClientDelegate: AnyObject { /// You do not need to do anything with this `response`, which the HTTP client will handle. This is merely for /// informational purposes. /// - /// This will be called only if `httpClient(_:recoverRequest:fromError:completion:)` is not implemented, or returns - /// an error. + /// This will be called only if `httpClient(_:recoverRequest:fromError:)` is not implemented, or returns + /// an error. It is also NOT called if `httpClient(_:willStartRequest:)` fails and aborts the request. func httpClient(_ httpClient: DefaultHTTPClient, request: HTTPRequest, didFailWithError error: HTTPError) /// Requests credentials from the delegate in response to an authentication request from the remote server. @@ -113,6 +114,11 @@ public final class DefaultHTTPClient: HTTPClient, Loggable { return "\(appName)/\(appVersion) \(deviceName) \(device.systemName)/\(device.systemVersion) CFNetwork/\(cfNetworkVersion) Darwin/\(darwinVersion)" }() + public weak var delegate: DefaultHTTPClientDelegate? + + private let session: URLSession + private let userAgent: String + /// Creates a `DefaultHTTPClient` with common configuration settings. /// /// - Parameters: @@ -152,12 +158,6 @@ public final class DefaultHTTPClient: HTTPClient, Loggable { self.init(configuration: config, userAgent: userAgent, delegate: delegate) } - public weak var delegate: DefaultHTTPClientDelegate? - - private let tasks: HTTPTaskManager - private let session: URLSession - private let userAgent: String - /// Creates a `DefaultHTTPClient` with a custom configuration. /// /// - Parameters: @@ -169,14 +169,9 @@ public final class DefaultHTTPClient: HTTPClient, Loggable { userAgent: String? = nil, delegate: DefaultHTTPClientDelegate? = nil ) { - let tasks = HTTPTaskManager() - self.userAgent = userAgent ?? DefaultHTTPClient.defaultUserAgent self.delegate = delegate - self.tasks = tasks - // Note that URLSession keeps a strong reference to its delegate, so we - // don't use the DefaultHTTPClient itself as its delegate. - session = URLSession(configuration: configuration, delegate: tasks, delegateQueue: nil) + session = URLSession(configuration: configuration, delegate: nil, delegateQueue: nil) } deinit { @@ -185,51 +180,122 @@ public final class DefaultHTTPClient: HTTPClient, Loggable { public func stream( request: any HTTPRequestConvertible, - consume: @escaping (Data, Double?) -> HTTPResult + onReceiveResponse: ((HTTPResponse) async -> HTTPResult)? = nil, + consume: @Sendable (Data, Double?) -> HTTPResult ) async -> HTTPResult { await request.httpRequest() .asyncFlatMap(willStartRequest) .asyncFlatMap { request in - await startTask(for: request, consume: consume) + let result = await startTask(for: request, onReceiveResponse: onReceiveResponse, consume: consume) .asyncRecover { error in await recover(request, from: error) .asyncFlatMap { newRequest in - await stream(request: newRequest, consume: consume) + await streamOnce(request: newRequest, onReceiveResponse: onReceiveResponse, consume: consume) } } + + if case let .failure(error) = result { + delegate?.httpClient(self, request: request, didFailWithError: error) + } + + return result } } - /// Creates and starts a new task for the `request`, whose cancellable will be exposed through `mediator`. - private func startTask(for request: HTTPRequest, consume: @escaping HTTPTask.Consume) async -> HTTPResult { + private func streamOnce( + request: any HTTPRequestConvertible, + onReceiveResponse: ((HTTPResponse) async -> HTTPResult)?, + consume: @Sendable (Data, Double?) -> HTTPResult + ) async -> HTTPResult { + await request.httpRequest() + .asyncFlatMap { request in + await startTask(for: request, onReceiveResponse: onReceiveResponse, consume: consume) + } + } + + /// Creates and starts an async byte stream for the `request`. + private func startTask( + for request: HTTPRequest, + onReceiveResponse: ((HTTPResponse) async -> HTTPResult)?, + consume: @Sendable (Data, Double?) -> HTTPResult + ) async -> HTTPResult { var request = request if request.userAgent == nil { request.userAgent = userAgent } - let result = await tasks.start( + let taskDelegate = TaskDelegate( request: request, - task: session.dataTask(with: request.urlRequest), - receiveResponse: { [weak self] response in - if let self = self { - self.delegate?.httpClient(self, request: request, didReceiveResponse: response) + clientDelegate: delegate, + client: self + ) + + do { + let task = session.dataTask(with: request.urlRequest) + task.delegate = taskDelegate + + let (stream, response) = try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(AsyncThrowingStream, URLResponse), Error>) in + taskDelegate.responseContinuation = continuation + task.resume() } - }, - receiveChallenge: { [weak self] challenge in - if let self = self, let delegate = self.delegate { - return await delegate.httpClient(self, request: request, didReceive: challenge) - } else { - return .performDefaultHandling + } onCancel: { + task.cancel() + } + + guard let httpURLResponse = response as? HTTPURLResponse, let url = httpURLResponse.url?.httpURL else { + return .failure(.malformedResponse(nil)) + } + + let httpResponse = HTTPResponse(request: request, response: httpURLResponse, url: url) + delegate?.httpClient(self, request: request, didReceiveResponse: httpResponse) + + if !httpResponse.status.isSuccess { + let capacity = min(1024 * 1024, Int(httpResponse.fullContentLength ?? 1024)) + var errorData = Data() + + for try await chunk in stream { + if errorData.count < capacity { + errorData.append(chunk) + } else { + break + } } - }, - consume: consume - ) + errorData = errorData.prefix(capacity) + return .failure(.errorResponse(HTTPFetchResponse(response: httpResponse, body: errorData))) + } - if let delegate = delegate, case let .failure(error) = result { - delegate.httpClient(self, request: request, didFailWithError: error) - } + if request.hasHeader("Range"), !httpResponse.acceptsByteRanges { + log(.error, "Streaming ranges requires the remote HTTP server to support byte range requests: \(url)") + return .failure(.rangeNotSupported) + } + + if let onReceive = onReceiveResponse { + let result = await onReceive(httpResponse) + if case let .failure(error) = result { + return .failure(error) + } + } + + let expectedBytes = httpResponse.fullContentLength + var readBytes: Int64 = httpResponse.contentRangeOffset + + for try await chunk in stream { + readBytes += Int64(chunk.count) + let progress = expectedBytes.map { $0 > 0 ? Double(min(readBytes, $0)) / Double($0) : 1.0 } + if case let .failure(error) = consume(chunk, progress) { + return .failure(error) + } + } - return result + return .success(httpResponse) + + } catch { + if (error is CancellationError) || ((error as? URLError)?.code == .cancelled) { + return .failure(.cancelled) + } + return .failure(.wrap(error) ?? .other(error)) + } } /// Lets the `delegate` customize the `request` if needed, before actually starting it. @@ -241,7 +307,7 @@ public final class DefaultHTTPClient: HTTPClient, Loggable { .flatMap { $0.httpRequest() } } - /// Attempts to recover from a `error` by asking the `delegate` for a new request. + /// Attempts to recover from an `error` by asking the `delegate` for a new request. private func recover(_ request: HTTPRequest, from error: HTTPError) async -> HTTPResult { if let delegate = delegate { return await delegate.httpClient(self, recoverRequest: request, fromError: error) @@ -250,265 +316,104 @@ public final class DefaultHTTPClient: HTTPClient, Loggable { } } - private class HTTPTaskManager: NSObject, URLSessionDataDelegate { - /// On-going tasks. - @Atomic private var tasks: [HTTPTask] = [] - - func start( - request: HTTPRequest, - task sessionTask: URLSessionDataTask, - receiveResponse: @escaping HTTPTask.ReceiveResponse, - receiveChallenge: @escaping HTTPTask.ReceiveChallenge, - consume: @escaping HTTPTask.Consume - ) async -> HTTPResult { - let task = HTTPTask( - request: request, - task: sessionTask, - receiveResponse: receiveResponse, - receiveChallenge: receiveChallenge, - consume: consume - ) - $tasks.write { $0.append(task) } - - let result = await withTaskCancellationHandler { - await withCheckedContinuation { continuation in - task.start(with: continuation) - } - } onCancel: { - task.cancel() - } - - $tasks.write { $0.removeAll { $0.task == sessionTask } } - - return result - } - - private func findTask(for urlTask: URLSessionTask) -> HTTPTask? { - let task = tasks.first { $0.task == urlTask } - if task == nil { - log(.error, "Cannot find on-going HTTP task for \(urlTask)") - } - return task - } - - // MARK: - URLSessionDataDelegate + /// Isolated proxy to pass challenges back to the `DefaultHTTPClientDelegate`. + /// URLSession guarantees its delegate callbacks are serialized, so the mutable `authTask` is safe. + /// Both `urlSession(_:task:didCompleteWithError:)` and `urlSession(_:task:didReceiveChallenge:completionHandler:)` + /// run on the same serial delegate queue. + private final class TaskDelegate: NSObject, URLSessionDataDelegate, @unchecked Sendable { + let request: HTTPRequest + weak var clientDelegate: DefaultHTTPClientDelegate? + weak var client: DefaultHTTPClient? + var authTask: Task? - func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) { - guard let task = findTask(for: dataTask) else { - completionHandler(.cancel) - return - } - task.urlSession(session, didReceive: response, completionHandler: completionHandler) - } - - func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { - findTask(for: dataTask)?.urlSession(session, didReceive: data) - } - - func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - findTask(for: task)?.urlSession(session, didCompleteWithError: error) - } - - func urlSession(_ session: URLSession, task: URLSessionTask, didReceive challenge: URLAuthenticationChallenge, completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void) { - guard let task = findTask(for: task) else { - completionHandler(.performDefaultHandling, nil) - return - } - - task.urlSession(session, didReceive: challenge, completion: completionHandler) - } - } - - /// Represents an on-going HTTP task. - private class HTTPTask: Cancellable, Loggable { - typealias Continuation = CheckedContinuation, Never> - typealias ReceiveResponse = (HTTPResponse) -> Void - typealias ReceiveChallenge = (URLAuthenticationChallenge) async -> URLAuthenticationChallengeResponse - typealias Consume = (Data, Double?) -> HTTPResult - - private let request: HTTPRequest - fileprivate let task: URLSessionTask - private let receiveResponse: ReceiveResponse - private let receiveChallenge: ReceiveChallenge - private let consume: Consume - - /// States the HTTP task can be in. - private var state: State = .initializing - - private enum State { - /// Waiting to start the task. - case initializing - - /// Waiting for the HTTP response. - case start(continuation: Continuation) - - /// We received a success response, the data will be sent to - /// `consume` progressively. - case stream(continuation: Continuation, response: HTTPResponse, readBytes: Int64) - - /// We received an error response, the data will be accumulated in - /// `response.body` if the error is an `HTTPError.errorResponse`, as - /// it could be needed for example when the response is an OPDS - /// Authentication Document. - case failure(continuation: Continuation, error: HTTPError) - - /// The request is terminated. - case finished - - var continuation: Continuation? { - switch self { - case .initializing, .finished: - return nil - case let .start(continuation): - return continuation - case let .stream(continuation, _, _): - return continuation - case let .failure(continuation, _): - return continuation - } - } - } + var streamContinuation: AsyncThrowingStream.Continuation? + var responseContinuation: CheckedContinuation<(AsyncThrowingStream, URLResponse), Error>? init( request: HTTPRequest, - task: URLSessionDataTask, - receiveResponse: @escaping ReceiveResponse, - receiveChallenge: @escaping ReceiveChallenge, - consume: @escaping Consume + clientDelegate: DefaultHTTPClientDelegate?, + client: DefaultHTTPClient ) { self.request = request - self.task = task - self.receiveResponse = receiveResponse - self.receiveChallenge = receiveChallenge - self.consume = consume - } - - deinit { - finish() - } - - func start(with continuation: Continuation) { - log(.info, request) - state = .start(continuation: continuation) - task.resume() + self.clientDelegate = clientDelegate + self.client = client } - func cancel() { - task.cancel() - } - - private func finish() { - switch state { - case let .start(continuation): - continuation.resume(returning: .failure(.cancelled)) - - case let .stream(continuation, response, _): - continuation.resume(returning: .success(response)) - - case let .failure(continuation, error): - var errorDescription = "" - dump(error, to: &errorDescription) - log(.error, "\(request.method) \(request.url) failed with:\n\(errorDescription)") - continuation.resume(returning: .failure(error)) - - case .initializing, .finished: - break + func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { + authTask?.cancel() + if let responseContinuation = responseContinuation { + self.responseContinuation = nil + if let error = error { + responseContinuation.resume(throwing: error) + } else { + responseContinuation.resume(throwing: URLError(.badServerResponse)) + } + } else { + if let error = error { + streamContinuation?.finish(throwing: error) + } else { + streamContinuation?.finish() + } } - - state = .finished } - func urlSession(_ session: URLSession, didReceive urlResponse: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) { - if case .finished = state { - completionHandler(.cancel) - return - } - guard - let continuation = state.continuation, - let urlResponse = urlResponse as? HTTPURLResponse, - let url = urlResponse.url?.httpURL - else { - completionHandler(.cancel) - return - } - - let response = HTTPResponse(request: request, response: urlResponse, url: url) + func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) { + if let responseContinuation = responseContinuation { + self.responseContinuation = nil - guard response.status.isSuccess else { - state = .failure(continuation: continuation, error: .errorResponse(response)) - completionHandler(.allow) - return - } + var streamContinuation: AsyncThrowingStream.Continuation! + let stream = AsyncThrowingStream { cont in + streamContinuation = cont + } + self.streamContinuation = streamContinuation - guard !request.hasHeader("Range") || response.acceptsByteRanges else { - log(.error, "Streaming ranges requires the remote HTTP server to support byte range requests: \(url)") - state = .failure(continuation: continuation, error: .rangeNotSupported) - completionHandler(.cancel) - return + responseContinuation.resume(returning: (stream, response)) } - - state = .stream(continuation: continuation, response: response, readBytes: 0) - receiveResponse(response) - completionHandler(.allow) } - func urlSession(_ session: URLSession, didReceive data: Data) { - switch state { - case .initializing, .start, .finished: - break + func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { + streamContinuation?.yield(data) + } - case .stream(let continuation, let response, var readBytes): - readBytes += Int64(data.count) - var progress: Double? = nil - if let expectedBytes = response.contentLength { - progress = Double(min(readBytes, expectedBytes)) / Double(expectedBytes) - } + func urlSession( + _ session: URLSession, + task: URLSessionTask, + didReceive challenge: URLAuthenticationChallenge, + completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void + ) { + guard let client = client else { + completionHandler(.performDefaultHandling, nil) + return + } - switch consume(data, progress) { - case .success: - state = .stream(continuation: continuation, response: response, readBytes: readBytes) - case let .failure(error): - state = .failure(continuation: continuation, error: error) + authTask?.cancel() + authTask = Task { + if Task.isCancelled { + completionHandler(.cancelAuthenticationChallenge, nil) + return } - case .failure(let continuation, var error): - if case var .errorResponse(response) = error { - var body = response.body ?? Data() - body.append(data) - response.body = body - error = .errorResponse(response) - } + if let delegate = clientDelegate { + let response = await delegate.httpClient(client, request: request, didReceive: challenge) - state = .failure(continuation: continuation, error: error) - } - } + if Task.isCancelled { + completionHandler(.cancelAuthenticationChallenge, nil) + return + } - func urlSession(_ session: URLSession, didCompleteWithError error: Error?) { - if let error = error { - if case .failure = state { - // No-op, we don't want to overwrite the failure state in this case. - } else if let continuation = state.continuation { - state = .failure(continuation: continuation, error: .wrap(error) ?? .other(error)) + switch response { + case let .useCredential(credential): + completionHandler(.useCredential, credential) + case .performDefaultHandling: + completionHandler(.performDefaultHandling, nil) + case .cancelAuthenticationChallenge: + completionHandler(.cancelAuthenticationChallenge, nil) + case .rejectProtectionSpace: + completionHandler(.rejectProtectionSpace, nil) + } } else { - state = .finished - } - } - finish() - } - - func urlSession(_ session: URLSession, didReceive challenge: URLAuthenticationChallenge, completion: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void) { - Task { - let response = await receiveChallenge(challenge) - switch response { - case let .useCredential(credential): - completion(.useCredential, credential) - case .performDefaultHandling: - completion(.performDefaultHandling, nil) - case .cancelAuthenticationChallenge: - completion(.cancelAuthenticationChallenge, nil) - case .rejectProtectionSpace: - completion(.rejectProtectionSpace, nil) + completionHandler(.performDefaultHandling, nil) } } } @@ -539,7 +444,7 @@ private extension HTTPRequest { } private extension HTTPResponse { - init(request: HTTPRequest, response: HTTPURLResponse, url: HTTPURL, body: Data? = nil) { + init(request: HTTPRequest, response: HTTPURLResponse, url: HTTPURL) { var headers: [String: String] = [:] for (k, v) in response.allHeaderFields { if let ks = k as? String, let vs = v as? String { @@ -551,8 +456,7 @@ private extension HTTPResponse { url: url, status: HTTPStatus(rawValue: response.statusCode), headers: headers, - mediaType: response.mimeType.flatMap { MediaType($0) }, - body: body + mediaType: response.mimeType.flatMap { MediaType($0) } ) } } diff --git a/Sources/Shared/Toolkit/HTTP/HTTPClient.swift b/Sources/Shared/Toolkit/HTTP/HTTPClient.swift index f8e6109b17..50aec17f58 100644 --- a/Sources/Shared/Toolkit/HTTP/HTTPClient.swift +++ b/Sources/Shared/Toolkit/HTTP/HTTPClient.swift @@ -5,7 +5,9 @@ // import Foundation -import UIKit +#if canImport(UIKit) + import UIKit +#endif /// An HTTP client performs HTTP requests. /// @@ -15,34 +17,39 @@ public protocol HTTPClient: Loggable { /// /// - Parameters: /// - request: Request to the streamed resource. - /// also access it in the completion block after consuming the data. + /// - onReceiveResponse: Optional callback allowing you to intercept the response headers and cancel early. /// - consume: Callback called for each chunk of data received. Callers /// are responsible to accumulate the data if needed. Return an error /// to abort the request. + /// The `progress` parameter represents the overall resource progress (including + /// any `contentRangeOffset` for range requests), not just the progress of the current chunk. func stream( request: HTTPRequestConvertible, - consume: @escaping (_ chunk: Data, _ progress: Double?) -> HTTPResult + onReceiveResponse: ((HTTPResponse) async -> HTTPResult)?, + consume: @Sendable (_ chunk: Data, _ progress: Double?) -> HTTPResult ) async -> HTTPResult } +/// Safe because the `consume` closure is called serially by the `stream` implementation. +private final class _HTTPFetchBox: @unchecked Sendable { + var data = Data() + init() {} +} + public extension HTTPClient { - /// Fetches the resource from the given `request`. - func fetch(_ request: HTTPRequestConvertible) async -> HTTPResult { - var data = Data() - let response = await stream( + /// Fetches the resource from the given `request` and returns the response alongside the accumulated data. + func fetch(_ request: HTTPRequestConvertible) async -> HTTPResult { + let box = _HTTPFetchBox() + let responseResult = await stream( request: request, + onReceiveResponse: nil, consume: { chunk, _ in - data.append(chunk) + box.data.append(chunk) return .success(()) } ) - return response - .map { - var response = $0 - response.body = data - return response - } + return responseResult.map { HTTPFetchResponse(response: $0, body: box.data) } } /// Fetches the resource and attempts to decode it with the given `decoder`. @@ -55,10 +62,7 @@ public extension HTTPClient { await fetch(request) .flatMap { response in do { - guard - let body = response.body, - let result = try decoder(response, body) - else { + guard let result = try decoder(response.response, response.body) else { return .failure(.malformedResponse(nil)) } return .success(result) @@ -84,16 +88,18 @@ public extension HTTPClient { } } - /// Fetches the resource as an `UIImage`. - func fetchImage(_ request: HTTPRequestConvertible) async -> HTTPResult { - await fetch(request) { - UIImage(data: $1) + #if canImport(UIKit) + /// Fetches the resource as an `UIImage`. + func fetchImage(_ request: HTTPRequestConvertible) async -> HTTPResult { + await fetch(request) { + UIImage(data: $1) + } } - } + #endif /// Downloads the resource at a temporary location. /// - /// You are responsible for moving or deleting the downloaded file in the `completion` block. + /// You are responsible for moving or deleting the downloaded file. func download( _ request: HTTPRequestConvertible, onProgress: @escaping (Double) -> Void @@ -107,14 +113,16 @@ public extension HTTPClient { let fileHandle: FileHandle do { - try "".write(to: location.url, atomically: true, encoding: .utf8) + try Data().write(to: location.url) fileHandle = try FileHandle(forWritingTo: location.url) } catch { return .failure(.fileSystem(.io(error))) } + defer { try? fileHandle.close() } let result = await stream( request: request, + onReceiveResponse: nil, consume: { data, progression in do { try fileHandle.seekToEnd() @@ -164,7 +172,7 @@ public struct HTTPStatus: Equatable, RawRepresentable, ExpressibleByIntegerLiter /// Returns whether this represents a successful HTTP status. public var isSuccess: Bool { - (200 ..< 400).contains(rawValue) + (200 ..< 300).contains(rawValue) } /// (200) OK. @@ -213,23 +221,18 @@ public struct HTTPResponse: Equatable { /// Media type provided in the `Content-Type` header. public let mediaType: MediaType? - /// Response body content, when available. - public var body: Data? - public init( request: HTTPRequest, url: HTTPURL, status: HTTPStatus, headers: [String: String], - mediaType: MediaType?, - body: Data? + mediaType: MediaType? ) { self.request = request self.url = url self.status = status self.headers = headers self.mediaType = mediaType - self.body = body } /// Finds the value of the first header matching the given name. @@ -261,24 +264,66 @@ public struct HTTPResponse: Equatable { .takeIf { $0 >= 0 } } + /// The full expected content length for this resource, when known. + /// + /// This will be the total length of the resource, even for byte range requests. + public var fullContentLength: Int64? { + if let contentRange = valueForHeader("Content-Range"), + let totalLengthString = contentRange.split(separator: "/").last?.trimmingCharacters(in: .whitespaces), + let totalLength = Int64(totalLengthString) + { + return totalLength + } + return contentLength + } + + /// Offset of the current response in the full resource. + public var contentRangeOffset: Int64 { + guard let contentRange = valueForHeader("Content-Range"), + let rangeString = contentRange.split(separator: " ", maxSplits: 1).last, + let rangeStartString = rangeString.split(separator: "-").first?.trimmingCharacters(in: .whitespaces), + let rangeStart = Int64(rangeStartString) + else { + return 0 + } + return rangeStart + } + /// The resource filename as provided by the server in the `Content-Disposition` header. public var filename: String? { - if let disposition = headers["Content-Disposition"] { - let array = disposition.split(separator: ";") - var filenameString: String? - switch array.count { - case 1: - filenameString = String(array[0]).trimmingCharacters(in: .whitespaces) - case 2: - filenameString = String(array[1]).trimmingCharacters(in: .whitespaces) - default: - break + guard let disposition = headers["Content-Disposition"] else { + return nil + } + + let parts = disposition.split(separator: ";") + .map { $0.trimmingCharacters(in: .whitespaces) } + + // Look for filename* first as it takes precedence + for part in parts { + if part.hasPrefix("filename*=") { + let value = part.replacingOccurrences(of: "filename*=", with: "") + let encodingParts = value.split(separator: "'", omittingEmptySubsequences: false) + if encodingParts.count == 3 { + let encoding = String(encodingParts[0]).lowercased() + let encodedFilename = String(encodingParts[2]) + if encoding == "utf-8", let decoded = encodedFilename.removingPercentEncoding { + return decoded + } + } } + } - if let filenameString = filenameString, filenameString.starts(with: "filename=") { - return filenameString.replacingOccurrences(of: "filename=", with: "") + // Fallback to filename + for part in parts { + if part.hasPrefix("filename=") { + var value = part.replacingOccurrences(of: "filename=", with: "") + if value.hasPrefix("\""), value.hasSuffix("\"") { + value = String(value.dropFirst().dropLast()) + } + return value } } + return nil } } @@ -301,3 +346,17 @@ public struct HTTPDownload: Sendable { self.mediaType = mediaType } } + +/// HTTP response with the whole body as a Data buffer. +public struct HTTPFetchResponse { + /// The HTTP response from the server. + public let response: HTTPResponse + + /// The raw data received in the response body. + public let body: Data + + public init(response: HTTPResponse, body: Data) { + self.response = response + self.body = body + } +} diff --git a/Sources/Shared/Toolkit/HTTP/HTTPError.swift b/Sources/Shared/Toolkit/HTTP/HTTPError.swift index dbbe8cd843..a935ebb836 100644 --- a/Sources/Shared/Toolkit/HTTP/HTTPError.swift +++ b/Sources/Shared/Toolkit/HTTP/HTTPError.swift @@ -17,7 +17,7 @@ public enum HTTPError: Error, Loggable { case malformedResponse(Error?) /// The server returned a response with an HTTP status error. - case errorResponse(HTTPResponse) + case errorResponse(HTTPFetchResponse) /// The client, server or gateways timed out. case timeout(Error?) @@ -52,13 +52,13 @@ public enum HTTPError: Error, Loggable { public func problemDetails() throws -> HTTPProblemDetails? { guard case let .errorResponse(response) = self, - response.mediaType?.matches(.problemDetails) == true, - let body = response.body + response.response.mediaType?.matches(.problemDetails) == true, + !response.body.isEmpty else { return nil } - return try HTTPProblemDetails(data: body) + return try HTTPProblemDetails(data: response.body) } /// Wraps a native error into an `HTTPError`, if possible. diff --git a/Sources/Shared/Toolkit/HTTP/HTTPResource.swift b/Sources/Shared/Toolkit/HTTP/HTTPResource.swift index 1422579289..e3501dd2bd 100644 --- a/Sources/Shared/Toolkit/HTTP/HTTPResource.swift +++ b/Sources/Shared/Toolkit/HTTP/HTTPResource.swift @@ -50,10 +50,10 @@ public actor HTTPResource: Resource { private func headResponse() async -> ReadResult { if _headResponse == nil { _headResponse = await client.fetch(HTTPRequest(url: url, method: .head)) - .map { $0 as HTTPResponse? } + .map { $0.response as HTTPResponse? } .flatMapError { error in switch error { - case let .errorResponse(response) where response.status == .methodNotAllowed: + case let .errorResponse(response) where response.response.status == .methodNotAllowed: return .success(nil) default: return .failure(.access(.http(error))) @@ -63,7 +63,7 @@ public actor HTTPResource: Resource { return _headResponse! } - public func stream(range: Range?, consume: @escaping (Data) -> Void) async -> ReadResult { + public func stream(range: Range?, consume: (Data) -> Void) async -> ReadResult { let request = { var request = HTTPRequest(url: url) if let range = range { @@ -74,6 +74,7 @@ public actor HTTPResource: Resource { return await client.stream( request: request, + onReceiveResponse: nil, consume: { data, _ in consume(data) return .success(()) diff --git a/Sources/Shared/Toolkit/HTTP/HTTPServer.swift b/Sources/Shared/Toolkit/HTTP/HTTPServer.swift index 0f9132f7cd..6903098e2e 100644 --- a/Sources/Shared/Toolkit/HTTP/HTTPServer.swift +++ b/Sources/Shared/Toolkit/HTTP/HTTPServer.swift @@ -79,14 +79,13 @@ public extension HTTPServer { onFailure: HTTPRequestHandler.OnFailure? = nil ) throws -> HTTPURL { func onRequest(request: HTTPServerRequest) -> HTTPServerResponse { - lazy var notFound = HTTPError.errorResponse(HTTPResponse( + lazy var notFound = HTTPError.errorResponse(HTTPFetchResponse(response: HTTPResponse( request: HTTPRequest(url: request.url), url: request.url, status: .notFound, headers: [:], - mediaType: nil, - body: nil - )) + mediaType: nil + ), body: Data())) guard let href = request.href, diff --git a/TestApp/Sources/App/Readium.swift b/TestApp/Sources/App/Readium.swift index c2232079de..ca4d543736 100644 --- a/TestApp/Sources/App/Readium.swift +++ b/TestApp/Sources/App/Readium.swift @@ -102,7 +102,7 @@ extension ReadiumShared.HTTPError: UserErrorConvertible { UserError(cause: self) { switch self { case let .errorResponse(response): - switch response.status { + switch response.response.status { case .notFound: return "error_not_found".localized case .unauthorized, .forbidden: diff --git a/Tests/SharedTests/Toolkit/HTTP/DefaultHTTPClientTests.swift b/Tests/SharedTests/Toolkit/HTTP/DefaultHTTPClientTests.swift new file mode 100644 index 0000000000..639f52e9e8 --- /dev/null +++ b/Tests/SharedTests/Toolkit/HTTP/DefaultHTTPClientTests.swift @@ -0,0 +1,1274 @@ +// +// Copyright 2026 Readium Foundation. All rights reserved. +// Use of this source code is governed by the BSD-style license +// available in the top-level LICENSE file of the project. +// + +import Foundation +@testable import ReadiumShared +import Testing + +private final class Box: @unchecked Sendable { + var value: T + init(_ value: T) { + self.value = value + } +} + +@Suite(.serialized) +struct DefaultHTTPClientTests { + /// Creates a `DefaultHTTPClient` configured with `MockHTTPURLProtocol` + /// for intercepting all requests. + private func makeClient( + userAgent: String? = nil, + additionalHeaders: [String: String]? = nil, + requestTimeout: TimeInterval? = nil, + resourceTimeout: TimeInterval? = nil, + ephemeral: Bool = true, + delegate: DefaultHTTPClientDelegate? = nil + ) -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + + if let additionalHeaders = additionalHeaders { + config.httpAdditionalHeaders = additionalHeaders + } + if let requestTimeout = requestTimeout { + config.timeoutIntervalForRequest = requestTimeout + } + if let resourceTimeout = resourceTimeout { + config.timeoutIntervalForResource = resourceTimeout + } + + return DefaultHTTPClient( + configuration: config, + userAgent: userAgent, + delegate: delegate + ) + } + + private func makeURL(_ path: String = "/test") -> HTTPURL { + HTTPURL(string: "https://example.com\(path)")! + } + + // MARK: - User Agent + + @Suite(.serialized) + struct UserAgent { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient( + userAgent: String? = nil, + delegate: DefaultHTTPClientDelegate? = nil + ) -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient( + configuration: config, + userAgent: userAgent, + delegate: delegate + ) + } + + private func makeURL(_ path: String = "/test") -> HTTPURL { + HTTPURL(string: "https://example.com\(path)")! + } + + @Test("Default user agent is set when none provided on request") + func defaultUserAgentIsSet() async { + var receivedUserAgent: String? + + MockHTTPURLProtocol.requestHandler = { request in + receivedUserAgent = request.value(forHTTPHeaderField: "User-Agent") + return .success(body: Data("ok".utf8)) + } + + let client = makeClient() + _ = await client.fetch(makeURL()) + + #expect(receivedUserAgent != nil) + #expect(receivedUserAgent?.isEmpty == false) + } + + @Test("Custom user agent overrides default") + func customUserAgent() async { + var receivedUserAgent: String? + let customUA = "MyApp/1.0" + + MockHTTPURLProtocol.requestHandler = { request in + receivedUserAgent = request.value(forHTTPHeaderField: "User-Agent") + return .success(body: Data("ok".utf8)) + } + + let client = makeClient(userAgent: customUA) + _ = await client.fetch(makeURL()) + + #expect(receivedUserAgent == customUA) + } + + @Test("Per-request user agent takes precedence over client default") + func perRequestUserAgent() async { + var receivedUserAgent: String? + let requestUA = "RequestSpecific/2.0" + + MockHTTPURLProtocol.requestHandler = { request in + receivedUserAgent = request.value(forHTTPHeaderField: "User-Agent") + return .success(body: Data("ok".utf8)) + } + + let client = makeClient(userAgent: "ClientDefault/1.0") + var request = HTTPRequest(url: makeURL()) + request.userAgent = requestUA + _ = await client.fetch(request) + + #expect(receivedUserAgent == requestUA) + } + + @Test("Default user agent string is non-empty") + func defaultUserAgentStringFormat() { + let ua = DefaultHTTPClient.defaultUserAgent + #expect(!ua.isEmpty) + } + } + + // MARK: - Headers + + @Suite(.serialized) + struct Headers { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient( + additionalHeaders: [String: String]? = nil + ) -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + if let additionalHeaders = additionalHeaders { + config.httpAdditionalHeaders = additionalHeaders + } + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/test")! + } + + @Test("Additional headers from configuration are sent") + func additionalHeaders() async { + var receivedHeader: String? + + MockHTTPURLProtocol.requestHandler = { request in + receivedHeader = request.value(forHTTPHeaderField: "X-Custom") + return .success(body: Data("ok".utf8)) + } + + let client = makeClient(additionalHeaders: ["X-Custom": "hello"]) + _ = await client.fetch(makeURL()) + + #expect(receivedHeader == "hello") + } + + @Test("Per-request headers are sent") + func perRequestHeaders() async { + var receivedHeader: String? + + MockHTTPURLProtocol.requestHandler = { request in + receivedHeader = request.value(forHTTPHeaderField: "X-Request") + return .success(body: Data("ok".utf8)) + } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + let client = DefaultHTTPClient(configuration: config) + let request = HTTPRequest(url: makeURL(), headers: ["X-Request": "value"]) + _ = await client.fetch(request) + + #expect(receivedHeader == "value") + } + } + + // MARK: - Streaming + + @Suite(.serialized) + struct Streaming { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient() -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/stream")! + } + + @Test("Stream delivers data in chunks") + func streamDeliversChunks() async throws { + let chunk1 = Data("hello ".utf8) + let chunk2 = Data("world".utf8) + + MockHTTPURLProtocol.requestHandler = { _ in + .success(chunks: [chunk1, chunk2]) + } + + nonisolated(unsafe) var receivedChunks: [Data] = [] + + let result = await makeClient().stream( + request: makeURL() + ) { data, _ in + receivedChunks.append(data) + return .success(()) + } + + let response = try result.get() + #expect(response.status == .ok) + // URLSession may coalesce chunks, so verify total data + let totalData = receivedChunks.reduce(Data(), +) + #expect(totalData == chunk1 + chunk2) + } + + @Test("Stream reports progress when Content-Length is known") + func streamReportsProgress() async throws { + let body = Data("hello world".utf8) + + MockHTTPURLProtocol.requestHandler = { _ in + .success( + headers: ["Content-Length": "\(body.count)"], + body: body + ) + } + + let lastProgress = Box(nil) + + let result = await makeClient().stream( + request: makeURL() + ) { _, progress in + if let progress = progress { + lastProgress.value = progress + } + return .success(()) + } + + _ = try result.get() + #expect(lastProgress.value != nil) + // Final progress should be 1.0 (all data received) + if let progress = lastProgress.value { + #expect(progress > 0) + #expect(progress <= 1.0) + } + } + + @Test("Stream reports nil progress when Content-Length is unknown") + func streamReportsNilProgressWhenContentLengthUnknown() async throws { + MockHTTPURLProtocol.requestHandler = { _ in + .success(body: Data("data".utf8)) + } + + let allProgressValues = Box<[Double?]>([]) + + let result = await makeClient().stream( + request: makeURL() + ) { _, progress in + allProgressValues.value.append(progress) + return .success(()) + } + + _ = try result.get() + // All progress values should be nil since no Content-Length + #expect(allProgressValues.value.allSatisfy { $0 == nil }) + } + + @Test("Returning failure from consume aborts the stream") + func consumeFailureAbortsStream() async { + let largeBody = Data(repeating: 0x42, count: 1024) + + MockHTTPURLProtocol.requestHandler = { _ in + .success( + headers: ["Content-Length": "\(largeBody.count)"], + chunks: [ + Data(largeBody[0 ..< 512]), + Data(largeBody[512...]), + ] + ) + } + + let result = await makeClient().stream( + request: makeURL() + ) { _, _ in + .failure(.cancelled) + } + + guard case .failure(.cancelled) = result else { + Issue.record("Expected .cancelled failure but got \(result)") + return + } + } + } + + // MARK: - Fetch + + @Suite(.serialized) + struct Fetch { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient() -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/fetch")! + } + + @Test("Fetch accumulates streamed data into response body") + func fetchAccumulatesData() async throws { + let chunk1 = Data("hello ".utf8) + let chunk2 = Data("world".utf8) + + MockHTTPURLProtocol.requestHandler = { _ in + .success(chunks: [chunk1, chunk2]) + } + + let response = try await makeClient().fetch(makeURL()).get() + + #expect(response.body == chunk1 + chunk2) + } + + @Test("Fetch returns correct response metadata") + func fetchReturnsMetadata() async throws { + MockHTTPURLProtocol.requestHandler = { _ in + .success( + statusCode: 200, + headers: [ + "Content-Type": "text/plain", + "Content-Length": "5", + ], + body: Data("hello".utf8) + ) + } + + let response = try await makeClient().fetch(makeURL()).get() + + #expect(response.response.status == .ok) + #expect(response.response.mediaType == MediaType.text) + } + + @Test("fetchString returns decoded string") + func fetchString() async throws { + let text = "Hello, Readium!" + + MockHTTPURLProtocol.requestHandler = { _ in + .success( + headers: ["Content-Type": "text/plain; charset=utf-8"], + body: Data(text.utf8) + ) + } + + let result = try await makeClient().fetchString(makeURL()).get() + #expect(result == text) + } + } + + // MARK: - Download + + @Suite(.serialized) + struct Download { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient() -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/download")! + } + + @Test("Download writes data to a temporary file") + func downloadWritesToFile() async throws { + let content = Data("file content".utf8) + + MockHTTPURLProtocol.requestHandler = { _ in + .success( + headers: [ + "Content-Length": "\(content.count)", + "Content-Type": "application/octet-stream", + ], + body: content + ) + } + + let download = try await makeClient() + .download(makeURL()) { _ in } + .get() + + let downloadedData = try Data(contentsOf: download.location.url) + #expect(downloadedData == content) + + // Cleanup + try FileManager.default.removeItem(at: download.location.url) + } + + @Test("Download reports progress") + func downloadReportsProgress() async throws { + let content = Data(repeating: 0x42, count: 1024) + + MockHTTPURLProtocol.requestHandler = { _ in + .success( + headers: ["Content-Length": "\(content.count)"], + body: content + ) + } + + var progressValues: [Double] = [] + + let download = try await makeClient() + .download(makeURL()) { progress in + progressValues.append(progress) + } + .get() + + #expect(!progressValues.isEmpty) + if let last = progressValues.last { + #expect(last > 0) + #expect(last <= 1.0) + } + + // Cleanup + try FileManager.default.removeItem(at: download.location.url) + } + + @Test("Download cleans up temporary file on failure") + func downloadCleansUpOnFailure() async { + MockHTTPURLProtocol.requestHandler = { _ in + .success(statusCode: 500, body: Data("error".utf8)) + } + + let tempDir = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) + let countBefore = (try? FileManager.default.contentsOfDirectory(atPath: tempDir.path))?.count ?? 0 + + let result = await makeClient() + .download(makeURL()) { _ in } + + let countAfter = (try? FileManager.default.contentsOfDirectory(atPath: tempDir.path))?.count ?? 0 + + guard case .failure = result else { + Issue.record("Expected failure") + return + } + #expect(countAfter == countBefore, "Temporary file should be deleted on failure") + } + + @Test("Download preserves suggested filename from Content-Disposition") + func downloadSuggestedFilename() async throws { + MockHTTPURLProtocol.requestHandler = { _ in + .success( + headers: [ + "Content-Disposition": "attachment; filename=book.epub", + "Content-Type": "application/epub+zip", + ], + body: Data("epub".utf8) + ) + } + + let download = try await makeClient() + .download(makeURL()) { _ in } + .get() + + #expect(download.suggestedFilename == "book.epub") + + // Cleanup + try FileManager.default.removeItem(at: download.location.url) + } + } + + // MARK: - HTTP Errors + + @Suite(.serialized) + struct HTTPErrors { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient() -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/error")! + } + + @Test( + "HTTP error status codes return .errorResponse", + arguments: [400, 401, 403, 404, 405, 500] + ) + func httpErrorStatusCodes(statusCode: Int) async { + MockHTTPURLProtocol.requestHandler = { _ in + .success(statusCode: statusCode, body: Data()) + } + + let result = await makeClient().fetch(makeURL()) + + guard case let .failure(.errorResponse(response)) = result else { + Issue.record("Expected .errorResponse for status \(statusCode)") + return + } + #expect(response.response.status.rawValue == statusCode) + } + + @Test("Error response body is accumulated") + func errorResponseIncludesBody() async { + let errorBody = Data(""" + {"type": "https://example.com/auth", "title": "Authentication Required"} + """.utf8) + + MockHTTPURLProtocol.requestHandler = { _ in + .success( + statusCode: 401, + headers: ["Content-Type": "application/problem+json"], + body: errorBody + ) + } + + let result = await makeClient().fetch(makeURL()) + + guard case let .failure(.errorResponse(response)) = result else { + Issue.record("Expected .errorResponse") + return + } + #expect(response.body == errorBody) + } + + @Test( + "2xx status codes are treated as success", + arguments: [200, 201, 204] + ) + func successStatusCodes(statusCode: Int) async { + MockHTTPURLProtocol.requestHandler = { _ in + .success(statusCode: statusCode, body: Data("ok".utf8)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .success = result else { + Issue.record("Status \(statusCode) should be success but got \(result)") + return + } + } + } + + // MARK: - Network Errors + + @Suite(.serialized) + struct NetworkErrors { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient() -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/network")! + } + + @Test("Timeout returns .timeout error") + func timeoutError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.timedOut)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.timeout) = result else { + Issue.record("Expected .timeout error, got \(result)") + return + } + } + + @Test("Cannot connect to host returns .unreachable error") + func unreachableError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.cannotConnectToHost)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.unreachable) = result else { + Issue.record("Expected .unreachable error, got \(result)") + return + } + } + + @Test("Cannot find host returns .unreachable error") + func cannotFindHostError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.cannotFindHost)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.unreachable) = result else { + Issue.record("Expected .unreachable error, got \(result)") + return + } + } + + @Test("Not connected to internet returns .offline error") + func offlineError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.notConnectedToInternet)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.offline) = result else { + Issue.record("Expected .offline error, got \(result)") + return + } + } + + @Test("Network connection lost returns .offline error") + func networkConnectionLostError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.networkConnectionLost)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.offline) = result else { + Issue.record("Expected .offline error, got \(result)") + return + } + } + + @Test("Cancelled request returns .cancelled error") + func cancelledError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.cancelled)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.cancelled) = result else { + Issue.record("Expected .cancelled error, got \(result)") + return + } + } + + @Test("Secure connection failed returns .security error") + func securityError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.secureConnectionFailed)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.security) = result else { + Issue.record("Expected .security error, got \(result)") + return + } + } + + @Test("Too many redirects returns .redirection error") + func redirectionError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.httpTooManyRedirects)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.redirection) = result else { + Issue.record("Expected .redirection error, got \(result)") + return + } + } + + @Test("Bad server response returns .malformedResponse error") + func malformedResponseError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.badServerResponse)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.malformedResponse) = result else { + Issue.record("Expected .malformedResponse error, got \(result)") + return + } + } + + @Test("Unknown URLError returns .other error") + func otherError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.backgroundSessionWasDisconnected)) + } + + let result = await makeClient().fetch(makeURL()) + + guard case .failure(.other) = result else { + Issue.record("Expected .other error, got \(result)") + return + } + } + } + + // MARK: - Cancellation + + @Suite(.serialized) + struct Cancellation { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient() -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/cancel")! + } + + @Test("Cancelling the Swift task cancels the HTTP request") + func cancelledTaskReturnsCancelledError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .delayed(seconds: 2, then: .success(body: Data("late".utf8))) + } + + let task = Task { + await makeClient().fetch(makeURL()) + } + + // Give the request time to start, then cancel. + try? await Task.sleep(nanoseconds: 100_000_000) // 100ms + task.cancel() + + let result = await task.value + + guard case .failure(.cancelled) = result else { + // URLSession may also report .timeout or other errors on cancel + // depending on timing, so we accept any failure. + if case .failure = result { + return + } + Issue.record("Expected failure after cancellation, got \(result)") + return + } + } + } + + // MARK: - Range Requests + + @Suite(.serialized) + struct RangeRequests { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient() -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient(configuration: config) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/range")! + } + + @Test("Range request succeeds when server signals Accept-Ranges") + func rangeRequestSuccessViaAcceptRanges() async throws { + let partialContent = Data("partial".utf8) + + MockHTTPURLProtocol.requestHandler = { request in + let rangeHeader = request.value(forHTTPHeaderField: "Range") + #expect(rangeHeader != nil) + + return .success( + statusCode: 206, + headers: [ + "Accept-Ranges": "bytes", + "Content-Length": "\(partialContent.count)", + ], + body: partialContent + ) + } + + var httpRequest = HTTPRequest(url: makeURL()) + httpRequest.setRange(0 ..< 7) + + let response = try await makeClient().fetch(httpRequest).get() + #expect(response.response.status == .partialContent) + #expect(response.body == partialContent) + } + + @Test("Range request succeeds when server signals Content-Range without Accept-Ranges") + func rangeRequestSuccessViaContentRange() async throws { + let partialContent = Data("partial".utf8) + + MockHTTPURLProtocol.requestHandler = { _ in + .success( + statusCode: 206, + headers: [ + "Content-Range": "bytes 0-6/100", + "Content-Length": "\(partialContent.count)", + ], + body: partialContent + ) + } + + var httpRequest = HTTPRequest(url: makeURL()) + httpRequest.setRange(0 ..< 7) + + let response = try await makeClient().fetch(httpRequest).get() + #expect(response.response.status == .partialContent) + #expect(response.body == partialContent) + } + + @Test("Range request fails when server does not support byte ranges") + func rangeRequestFailsWithoutServerSupport() async { + MockHTTPURLProtocol.requestHandler = { _ in + .success( + statusCode: 200, + headers: [:], + body: Data("full content".utf8) + ) + } + + var httpRequest = HTTPRequest(url: makeURL()) + httpRequest.setRange(0 ..< 7) + + let result = await makeClient().fetch(httpRequest) + + guard case .failure(.rangeNotSupported) = result else { + Issue.record("Expected .rangeNotSupported, got \(result)") + return + } + } + } + + // MARK: - Delegate Callbacks + + @Suite(.serialized) + struct DelegateCallbacks { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient(delegate: DefaultHTTPClientDelegate) -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient( + configuration: config, + delegate: delegate + ) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/delegate")! + } + + @Test("willStartRequest is called before the request") + func willStartRequestIsCalled() async { + MockHTTPURLProtocol.requestHandler = { _ in + .success(body: Data("ok".utf8)) + } + + let delegate = SpyDelegate() + let client = makeClient(delegate: delegate) + _ = await client.fetch(makeURL()) + + #expect(delegate.willStartRequestCalled) + } + + @Test("willStartRequest can modify the request") + func willStartRequestModifiesRequest() async { + var receivedHeader: String? + + MockHTTPURLProtocol.requestHandler = { request in + receivedHeader = request.value(forHTTPHeaderField: "X-Injected") + return .success(body: Data("ok".utf8)) + } + + let delegate = SpyDelegate() + delegate.onWillStartRequest = { request in + var modified = request + modified.headers["X-Injected"] = "by-delegate" + return .success(modified) + } + + let client = makeClient(delegate: delegate) + _ = await client.fetch(makeURL()) + + #expect(receivedHeader == "by-delegate") + } + + @Test("willStartRequest returning failure aborts the request without sending it") + func willStartRequestFailureAbortsRequest() async { + MockHTTPURLProtocol.requestHandler = { _ in + Issue.record("Request should not have been sent") + return .success(body: Data()) + } + + let delegate = SpyDelegate() + delegate.onWillStartRequest = { _ in + .failure(.cancelled) + } + + let client = makeClient(delegate: delegate) + let result = await client.fetch(makeURL()) + + guard case .failure(.cancelled) = result else { + Issue.record("Expected .cancelled from willStartRequest failure, got \(result)") + return + } + #expect(!delegate.didFailWithErrorCalled) + } + + @Test("didReceiveResponse is called on success") + func didReceiveResponseIsCalled() async { + MockHTTPURLProtocol.requestHandler = { _ in + .success(body: Data("ok".utf8)) + } + + let delegate = SpyDelegate() + let client = makeClient(delegate: delegate) + _ = await client.fetch(makeURL()) + + #expect(delegate.didReceiveResponseCalled) + } + + @Test("didFailWithError is called on failure") + func didFailWithErrorIsCalled() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.timedOut)) + } + + let delegate = SpyDelegate() + let client = makeClient(delegate: delegate) + _ = await client.fetch(makeURL()) + + #expect(delegate.didFailWithErrorCalled) + } + + @Test("recoverRequest can retry with a new request") + func recoverRequestRetries() async throws { + var requestCount = 0 + + MockHTTPURLProtocol.requestHandler = { _ in + requestCount += 1 + if requestCount == 1 { + return .error(URLError(.timedOut)) + } + return .success(body: Data("recovered".utf8)) + } + + let delegate = SpyDelegate() + delegate.onRecoverRequest = { request, _ in + // Retry the same request + .success(request) + } + + let client = makeClient(delegate: delegate) + let result = await client.fetch(makeURL()) + + let response = try result.get() + #expect(response.body == Data("recovered".utf8)) + #expect(requestCount == 2) + } + + @Test("recoverRequest propagates error when unrecoverable") + func recoverRequestPropagatesError() async { + MockHTTPURLProtocol.requestHandler = { _ in + .error(URLError(.timedOut)) + } + + let delegate = SpyDelegate() + delegate.onRecoverRequest = { _, error in + .failure(error) + } + + let client = makeClient(delegate: delegate) + let result = await client.fetch(makeURL()) + + guard case .failure(.timeout) = result else { + Issue.record("Expected .timeout, got \(result)") + return + } + #expect(delegate.didFailWithErrorCalled) + } + + @Test("willStartRequest can redirect to a different URL") + func willStartRequestRedirects() async throws { + MockHTTPURLProtocol.requestHandler = { request in + let path = request.url?.path ?? "" + if path == "/redirected" { + return .success(body: Data("redirected response".utf8)) + } + return .success(statusCode: 404, body: Data()) + } + + let delegate = SpyDelegate() + delegate.onWillStartRequest = { _ in + let redirectURL = HTTPURL(string: "https://example.com/redirected")! + return .success(HTTPRequest(url: redirectURL)) + } + + let client = makeClient(delegate: delegate) + let response = try await client.fetch(makeURL()).get() + #expect(response.body == Data("redirected response".utf8)) + } + } + + // MARK: - Authentication Challenges + + @Suite(.serialized) + struct AuthenticationChallenges { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeClient(delegate: DefaultHTTPClientDelegate) -> DefaultHTTPClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + return DefaultHTTPClient( + configuration: config, + delegate: delegate + ) + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/auth")! + } + + @Test("Delegate receives authentication challenge") + func delegateReceivesChallenge() async { + MockHTTPURLProtocol.requestHandler = { _ in + .authenticationChallenge( + host: "example.com", + method: NSURLAuthenticationMethodHTTPBasic, + then: .success(body: Data("authenticated".utf8)) + ) + } + + let delegate = SpyDelegate() + delegate.onDidReceiveChallenge = { _ in + .performDefaultHandling + } + + let client = makeClient(delegate: delegate) + _ = await client.fetch(makeURL()) + + #expect(delegate.didReceiveChallengeCalled) + } + + @Test("Regular request succeeds without a delegate") + func regularRequestSucceedsWithoutDelegate() async { + MockHTTPURLProtocol.requestHandler = { _ in + .success(body: Data("ok".utf8)) + } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + let client = DefaultHTTPClient(configuration: config, delegate: nil) + + let result = await client.fetch(makeURL()) + + guard case .success = result else { + Issue.record("Expected success for a regular request without a delegate, got \(result)") + return + } + } + } + + // MARK: - Configuration + + @Suite(.serialized) + struct Configuration { + init() { + MockHTTPURLProtocol.requestHandler = nil + } + + private func makeURL() -> HTTPURL { + HTTPURL(string: "https://example.com/config")! + } + + @Test("Request timeout is passed to URLSessionConfiguration") + func requestTimeoutIsApplied() async { + var receivedTimeout: TimeInterval? + + MockHTTPURLProtocol.requestHandler = { request in + receivedTimeout = request.timeoutInterval + return .success(body: Data("ok".utf8)) + } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + config.timeoutIntervalForRequest = 42.0 + let client = DefaultHTTPClient(configuration: config) + _ = await client.fetch(makeURL()) + + // URLSession may apply its own timeout logic, but the + // configuration value should influence the request. + #expect(receivedTimeout != nil) + } + + @Test("Per-request timeout overrides session timeout") + func perRequestTimeoutOverridesSession() async { + var receivedTimeout: TimeInterval? + + MockHTTPURLProtocol.requestHandler = { request in + receivedTimeout = request.timeoutInterval + return .success(body: Data("ok".utf8)) + } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + config.timeoutIntervalForRequest = 60.0 + let client = DefaultHTTPClient(configuration: config) + + var request = HTTPRequest(url: makeURL()) + request.timeoutInterval = 5.0 + _ = await client.fetch(request) + + #expect(receivedTimeout == 5.0) + } + + @Test("HTTP method is correctly transmitted") + func httpMethodIsTransmitted() async { + var receivedMethod: String? + + MockHTTPURLProtocol.requestHandler = { request in + receivedMethod = request.httpMethod + return .success(body: Data("ok".utf8)) + } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + let client = DefaultHTTPClient(configuration: config) + + let request = HTTPRequest(url: makeURL(), method: .post) + _ = await client.fetch(request) + + #expect(receivedMethod == "POST") + } + + @Test("Request body is transmitted for POST requests") + func requestBodyIsTransmitted() async { + var receivedBody: Data? + + MockHTTPURLProtocol.requestHandler = { request in + if let stream = request.httpBodyStream { + stream.open() + var data = Data() + let buffer = UnsafeMutablePointer.allocate(capacity: 1024) + defer { buffer.deallocate() } + while stream.hasBytesAvailable { + let bytesRead = stream.read(buffer, maxLength: 1024) + if bytesRead > 0 { + data.append(buffer, count: bytesRead) + } + } + stream.close() + receivedBody = data + } else { + receivedBody = request.httpBody + } + return .success(body: Data("ok".utf8)) + } + + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockHTTPURLProtocol.self] + let client = DefaultHTTPClient(configuration: config) + + let bodyData = Data("request body".utf8) + let request = HTTPRequest(url: makeURL(), method: .post, body: .data(bodyData)) + _ = await client.fetch(request) + + #expect(receivedBody == bodyData) + } + } +} + +// MARK: - Spy Delegate + +/// A test spy implementing `DefaultHTTPClientDelegate` that records calls +/// and allows customizing behavior via closures. +private class SpyDelegate: DefaultHTTPClientDelegate, @unchecked Sendable { + var willStartRequestCalled = false + var didReceiveResponseCalled = false + var didFailWithErrorCalled = false + var didReceiveChallengeCalled = false + + var lastRequest: HTTPRequest? + var lastResponse: HTTPResponse? + var lastError: HTTPError? + var lastChallenge: URLAuthenticationChallenge? + + var onWillStartRequest: ((HTTPRequest) -> HTTPResult)? + var onRecoverRequest: ((HTTPRequest, HTTPError) -> HTTPResult)? + var onDidReceiveChallenge: ((URLAuthenticationChallenge) -> URLAuthenticationChallengeResponse)? + + func httpClient( + _ httpClient: DefaultHTTPClient, + willStartRequest request: HTTPRequest + ) async -> HTTPResult { + willStartRequestCalled = true + lastRequest = request + return onWillStartRequest?(request) ?? .success(request) + } + + func httpClient( + _ httpClient: DefaultHTTPClient, + recoverRequest request: HTTPRequest, + fromError error: HTTPError + ) async -> HTTPResult { + onRecoverRequest?(request, error) ?? .failure(error) + } + + func httpClient( + _ httpClient: DefaultHTTPClient, + request: HTTPRequest, + didReceiveResponse response: HTTPResponse + ) { + didReceiveResponseCalled = true + lastResponse = response + } + + func httpClient( + _ httpClient: DefaultHTTPClient, + request: HTTPRequest, + didFailWithError error: HTTPError + ) { + didFailWithErrorCalled = true + lastError = error + } + + func httpClient( + _ httpClient: DefaultHTTPClient, + request: HTTPRequest, + didReceive challenge: URLAuthenticationChallenge + ) async -> URLAuthenticationChallengeResponse { + didReceiveChallengeCalled = true + lastChallenge = challenge + return onDidReceiveChallenge?(challenge) ?? .performDefaultHandling + } +} diff --git a/Tests/SharedTests/Toolkit/HTTP/HTTPProblemDetailsTests.swift b/Tests/SharedTests/Toolkit/HTTP/HTTPProblemDetailsTests.swift index a2a639a3b9..7576000281 100644 --- a/Tests/SharedTests/Toolkit/HTTP/HTTPProblemDetailsTests.swift +++ b/Tests/SharedTests/Toolkit/HTTP/HTTPProblemDetailsTests.swift @@ -4,46 +4,77 @@ // available in the top-level LICENSE file of the project. // +import Foundation @testable import ReadiumShared -import XCTest +import Testing -class HTTPProblemDetailsTests: XCTestCase { +@Suite("HTTPProblemDetails") +struct HTTPProblemDetailsTests { /// Parses a minimal Problem Details JSON. - func testParseMinimalJSON() throws { + @Test func parseMinimalJSON() throws { let json = """ {"title": "You do not have enough credit."} """.data(using: .utf8)! - XCTAssertEqual(try (HTTPProblemDetails(data: json)).title, "You do not have enough credit.") + let details = try HTTPProblemDetails(data: json) + #expect(details.title == "You do not have enough credit.") } /// Parses a full Problem Details JSON. - func testParseFullJSON() throws { + @Test func parseFullJSON() throws { let json = """ { "type": "https://example.net/validation-error", "title": "Your request parameters didn't validate.", "status": 400, + "detail": "Age must be a positive integer.", + "instance": "https://example.net/validation-error/123", "invalid-params": [ { "name": "age", "reason": "must be a positive integer" - }, - { - "name": "color", - "reason": "must be 'green', 'red' or 'blue'" } ] } """.data(using: .utf8)! - XCTAssertEqual( - try HTTPProblemDetails(data: json), - HTTPProblemDetails( - title: "Your request parameters didn't validate.", - type: "https://example.net/validation-error", - status: 400 - ) + let details = try HTTPProblemDetails(data: json) + #expect(details.title == "Your request parameters didn't validate.") + #expect(details.type == "https://example.net/validation-error") + #expect(details.status == 400) + #expect(details.detail == "Age must be a positive integer.") + #expect(details.instance == "https://example.net/validation-error/123") + } + + @Test func parseInvalidJSON() { + let json = """ + {"not-a-title": "Missing title"} + """.data(using: .utf8)! + + #expect(throws: HTTPProblemDetails.Error.self) { + try HTTPProblemDetails(data: json) + } + } + + @Test func extractFromHTTPError() throws { + let json = """ + {"title": "Forbidden action"} + """.data(using: .utf8)! + + let fetchResponse = try HTTPFetchResponse( + response: HTTPResponse( + request: HTTPRequest(url: #require(HTTPURL(string: "http://example.com"))), + url: #require(HTTPURL(string: "http://example.com")), + status: .forbidden, + headers: ["Content-Type": "application/problem+json"], + mediaType: .problemDetails + ), + body: json ) + + let error = HTTPError.errorResponse(fetchResponse) + let details = try error.problemDetails() + + #expect(details?.title == "Forbidden action") } } diff --git a/Tests/SharedTests/Toolkit/HTTP/HTTPRequestTests.swift b/Tests/SharedTests/Toolkit/HTTP/HTTPRequestTests.swift new file mode 100644 index 0000000000..6ced069534 --- /dev/null +++ b/Tests/SharedTests/Toolkit/HTTP/HTTPRequestTests.swift @@ -0,0 +1,55 @@ +// +// Copyright 2026 Readium Foundation. All rights reserved. +// Use of this source code is governed by the BSD-style license +// available in the top-level LICENSE file of the project. +// + +import Foundation +@testable import ReadiumShared +import Testing + +@Suite("HTTPRequest") +struct HTTPRequestTests { + private let url = HTTPURL(string: "http://example.com")! + + @Test func setRange() { + var request = HTTPRequest(url: url) + + request.setRange(0 ..< 100) + #expect(request.headers["Range"] == "bytes=0-99") + + request.setRange(100 ..< 200) + #expect(request.headers["Range"] == "bytes=100-199") + } + + @Test func setRangeUntilEnd() { + var request = HTTPRequest(url: url) + + request.setRange(100 ..< 100) + #expect(request.headers["Range"] == "bytes=100-") + } + + @Test func setPOSTForm() { + var request = HTTPRequest(url: url) + request.setPOSTForm([ + "field1": "value1", + "field2": "value with spaces", + "field3": "special&*characters", + "field4": nil, + ]) + + #expect(request.method == .post) + #expect(request.headers["Content-Type"] == "application/x-www-form-urlencoded") + + if case let .data(data) = request.body, let bodyString = String(data: data, encoding: .utf8) { + let parts = bodyString.split(separator: "&") + #expect(parts.contains("field1=value1")) + #expect(parts.contains("field2=value+with+spaces")) + #expect(parts.contains("field3=special%26*characters")) + #expect(parts.contains("field4=")) + #expect(parts.count == 4) + } else { + Issue.record("Expected data body") + } + } +} diff --git a/Tests/SharedTests/Toolkit/HTTP/HTTPResourceTests.swift b/Tests/SharedTests/Toolkit/HTTP/HTTPResourceTests.swift new file mode 100644 index 0000000000..0991b4c157 --- /dev/null +++ b/Tests/SharedTests/Toolkit/HTTP/HTTPResourceTests.swift @@ -0,0 +1,110 @@ +// +// Copyright 2026 Readium Foundation. All rights reserved. +// Use of this source code is governed by the BSD-style license +// available in the top-level LICENSE file of the project. +// + +import Foundation +@testable import ReadiumShared +import Testing + +@Suite("HTTPResource") +struct HTTPResourceTests { + private let url = HTTPURL(string: "http://example.com/book.epub")! + + class MockHTTPClient: HTTPClient { + var fetchResults: [String: HTTPResult] = [:] + var fetchCount = 0 + + func stream( + request: HTTPRequestConvertible, + onReceiveResponse: ((HTTPResponse) async -> HTTPResult)?, + consume: (Data, Double?) -> HTTPResult + ) async -> HTTPResult { + let req = try! request.httpRequest().get() + let key = "\(req.method.rawValue) \(req.url.string)" + fetchCount += 1 + + if let result = fetchResults[key] { + switch result { + case let .success(fetchResponse): + if let onReceiveResponse = onReceiveResponse { + let _ = await onReceiveResponse(fetchResponse.response) + } + _ = consume(fetchResponse.body, 1.0) + return .success(fetchResponse.response) + case let .failure(error): + return .failure(error) + } + } + return .failure(.cancelled) + } + } + + @Test func headResponseIsCached() async throws { + let client = MockHTTPClient() + let resource = HTTPResource(url: url, client: client) + + client.fetchResults["HEAD \(url.string)"] = .success(HTTPFetchResponse( + response: HTTPResponse( + request: HTTPRequest(url: url), + url: url, + status: .ok, + headers: ["Content-Length": "1024"], + mediaType: .epub + ), + body: Data() + )) + + let length1 = await resource.estimatedLength() + try #expect(length1.get() == 1024) + #expect(client.fetchCount == 1) + + let length2 = await resource.estimatedLength() + try #expect(length2.get() == 1024) + #expect(client.fetchCount == 1) // Should be cached + } + + @Test func headResponseFallbackOnMethodNotAllowed() async throws { + let client = MockHTTPClient() + let resource = HTTPResource(url: url, client: client) + + let response = HTTPFetchResponse( + response: HTTPResponse( + request: HTTPRequest(url: url, method: .head), + url: url, + status: .methodNotAllowed, + headers: [:], + mediaType: nil + ), + body: Data() + ) + client.fetchResults["HEAD \(url.string)"] = .failure(.errorResponse(response)) + + let length = await resource.estimatedLength() + try #expect(length.get() == nil) + #expect(client.fetchCount == 1) + } + + @Test func streamWithRange() async throws { + let client = MockHTTPClient() + let resource = HTTPResource(url: url, client: client) + + client.fetchResults["GET \(url.string)"] = try .success(HTTPFetchResponse( + response: HTTPResponse( + request: HTTPRequest(url: url), + url: url, + status: .partialContent, + headers: ["Content-Range": "bytes 0-9/100"], + mediaType: .epub + ), + body: #require("0123456789".data(using: .utf8)) + )) + + var streamedData = Data() + let result = await resource.stream(range: 0 ..< 10, consume: { streamedData.append($0) }) + + try result.get() + #expect(streamedData == "0123456789".data(using: .utf8)) + } +} diff --git a/Tests/SharedTests/Toolkit/HTTP/HTTPResponseTests.swift b/Tests/SharedTests/Toolkit/HTTP/HTTPResponseTests.swift new file mode 100644 index 0000000000..aff97b0153 --- /dev/null +++ b/Tests/SharedTests/Toolkit/HTTP/HTTPResponseTests.swift @@ -0,0 +1,67 @@ +// +// Copyright 2026 Readium Foundation. All rights reserved. +// Use of this source code is governed by the BSD-style license +// available in the top-level LICENSE file of the project. +// + +import Foundation +@testable import ReadiumShared +import Testing + +@Suite("HTTPResponse") +struct HTTPResponseTests { + private let request = HTTPRequest(url: HTTPURL(string: "http://example.com")!) + private let url = HTTPURL(string: "http://example.com")! + + @Test func valueForHeader() { + let response = HTTPResponse( + request: request, + url: url, + status: .ok, + headers: ["Content-Type": "application/pdf", "X-Custom": "Value"], + mediaType: .pdf + ) + + #expect(response.valueForHeader("Content-Type") == "application/pdf") + #expect(response.valueForHeader("content-type") == "application/pdf") + #expect(response.valueForHeader("X-Custom") == "Value") + #expect(response.valueForHeader("Unknown") == nil) + } + + @Test func acceptsByteRanges() { + var response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Accept-Ranges": "bytes"], mediaType: nil) + #expect(response.acceptsByteRanges) + + response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Range": "bytes 0-100/1000"], mediaType: nil) + #expect(response.acceptsByteRanges) + + response = HTTPResponse(request: request, url: url, status: .ok, headers: [:], mediaType: nil) + #expect(!response.acceptsByteRanges) + } + + @Test func contentLength() { + let response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Length": "1024"], mediaType: nil) + #expect(response.contentLength == 1024) + + let responseInvalid = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Length": "invalid"], mediaType: nil) + #expect(responseInvalid.contentLength == nil) + } + + @Test func filename() { + var response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Disposition": "attachment; filename=book.epub"], mediaType: nil) + #expect(response.filename == "book.epub") + + response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Disposition": "filename=image.png"], mediaType: nil) + #expect(response.filename == "image.png") + + response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Disposition": "inline"], mediaType: nil) + #expect(response.filename == nil) + + response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Disposition": "attachment; filename*=UTF-8''%e2%82%ac%20rates; filename=fallback.txt"], mediaType: nil) + #expect(response.filename == "€ rates") + + // Malformed UTF-8 in filename* should fall back to filename + response = HTTPResponse(request: request, url: url, status: .ok, headers: ["Content-Disposition": "attachment; filename*=UTF-8''%FF%FF; filename=fallback.txt"], mediaType: nil) + #expect(response.filename == "fallback.txt") + } +} diff --git a/Tests/SharedTests/Toolkit/HTTP/MockHTTPURLProtocol.swift b/Tests/SharedTests/Toolkit/HTTP/MockHTTPURLProtocol.swift new file mode 100644 index 0000000000..403312829b --- /dev/null +++ b/Tests/SharedTests/Toolkit/HTTP/MockHTTPURLProtocol.swift @@ -0,0 +1,160 @@ +// +// Copyright 2026 Readium Foundation. All rights reserved. +// Use of this source code is governed by the BSD-style license +// available in the top-level LICENSE file of the project. +// + +import Foundation + +/// A `URLProtocol` subclass that intercepts HTTP requests for testing +/// `DefaultHTTPClient` without hitting the network. +/// +/// Configure the static `requestHandler` before each test to control +/// the response returned for intercepted requests. +final class MockHTTPURLProtocol: URLProtocol { + /// Handler called for each intercepted request. Returns the response + /// configuration to simulate. + /// + /// Must be set before starting a request. + nonisolated(unsafe) static var requestHandler: ((URLRequest) -> MockResponse)? + + /// Describes a mock response to return for an intercepted request. + indirect enum MockResponse { + /// A successful response delivered as a sequence of data chunks. + case success( + statusCode: Int = 200, + headers: [String: String] = [:], + chunks: [Data] = [] + ) + + /// A simulated network error. + case error(URLError) + + /// A response that is delivered after a delay, useful for timeout + /// testing. Delivery is cancelled early if `stopLoading()` is called. + case delayed( + seconds: TimeInterval, + then: MockResponse + ) + + /// An authentication challenge, followed by a response if the + /// challenge is resolved. + case authenticationChallenge( + host: String = "example.com", + method: String = NSURLAuthenticationMethodHTTPBasic, + then: MockResponse + ) + + /// Convenience for a simple success response with a single body. + static func success( + statusCode: Int = 200, + headers: [String: String] = [:], + body: Data = Data() + ) -> MockResponse { + .success(statusCode: statusCode, headers: headers, chunks: [body]) + } + } + + // MARK: - URLProtocol + + override class func canInit(with request: URLRequest) -> Bool { + true + } + + override class func canonicalRequest(for request: URLRequest) -> URLRequest { + request + } + + /// Set to `true` by `stopLoading()` so that an in-progress `.delayed` + /// delivery can abort early rather than blocking until the full interval + /// elapses. + private nonisolated(unsafe) var isStopped = false + + override func startLoading() { + guard let handler = Self.requestHandler else { + fatalError("MockHTTPURLProtocol.requestHandler is not set.") + } + + let mockResponse = handler(request) + deliver(mockResponse) + } + + override func stopLoading() { + isStopped = true + } + + // MARK: - Private + + private func deliver(_ response: MockResponse) { + switch response { + case let .success(statusCode, headers, chunks): + deliverSuccess(statusCode: statusCode, headers: headers, chunks: chunks) + + case let .error(urlError): + client?.urlProtocol(self, didFailWithError: urlError) + + case let .delayed(seconds, then): + // Poll in short intervals so that `stopLoading()` can interrupt + // the wait without blocking the thread for the full duration. + let pollInterval: TimeInterval = 0.05 + var elapsed: TimeInterval = 0 + while elapsed < seconds, !isStopped { + Thread.sleep(forTimeInterval: pollInterval) + elapsed += pollInterval + } + if !isStopped { + deliver(then) + } + + case let .authenticationChallenge(host, method, then): + let protectionSpace = URLProtectionSpace( + host: host, + port: 443, + protocol: "https", + realm: "Test", + authenticationMethod: method + ) + let challenge = URLAuthenticationChallenge( + protectionSpace: protectionSpace, + proposedCredential: nil, + previousFailureCount: 0, + failureResponse: nil, + error: nil, + sender: MockAuthChallengeSender() + ) + client?.urlProtocol(self, didReceive: challenge) + + deliver(then) + } + } + + private func deliverSuccess(statusCode: Int, headers: [String: String], chunks: [Data]) { + guard let url = request.url, + let response = HTTPURLResponse( + url: url, + statusCode: statusCode, + httpVersion: "HTTP/1.1", + headerFields: headers + ) + else { + client?.urlProtocol(self, didFailWithError: URLError(.badServerResponse)) + return + } + + client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + + for chunk in chunks { + client?.urlProtocol(self, didLoad: chunk) + } + + client?.urlProtocolDidFinishLoading(self) + } +} + +/// A minimal `URLAuthenticationChallengeSender` implementation required +/// to construct `URLAuthenticationChallenge` objects in tests. +private class MockAuthChallengeSender: NSObject, URLAuthenticationChallengeSender { + func use(_ credential: URLCredential, for challenge: URLAuthenticationChallenge) {} + func continueWithoutCredential(for challenge: URLAuthenticationChallenge) {} + func cancel(_ challenge: URLAuthenticationChallenge) {} +}