From 252b26de135a7a4ac3f9c6d1ab352ee323c5cd74 Mon Sep 17 00:00:00 2001 From: Tim Perry Date: Thu, 26 Feb 2026 10:55:16 +0100 Subject: [PATCH] Extract client-side handshake into separate components --- doc/ws.md | 39 +++ index.js | 2 + lib/handshake-request.js | 205 +++++++++++ lib/handshake-validator.js | 113 ++++++ lib/websocket.js | 257 +++----------- test/handshake-request.test.js | 572 +++++++++++++++++++++++++++++++ test/handshake-validator.test.js | 297 ++++++++++++++++ 7 files changed, 1275 insertions(+), 210 deletions(-) create mode 100644 lib/handshake-request.js create mode 100644 lib/handshake-validator.js create mode 100644 test/handshake-request.test.js create mode 100644 test/handshake-validator.test.js diff --git a/doc/ws.md b/doc/ws.md index a9c3a2fd5..012a5285b 100644 --- a/doc/ws.md +++ b/doc/ws.md @@ -48,6 +48,8 @@ - [websocket.send(data[, options][, callback])](#websocketsenddata-options-callback) - [websocket.terminate()](#websocketterminate) - [websocket.url](#websocketurl) +- [Class: HandshakeRequest](#class-handshakerequest) +- [Class: HandshakeValidator](#class-handshakevalidator) - [createWebSocketStream(websocket[, options])](#createwebsocketstreamwebsocket-options) - [Environment variables](#environment-variables) - [WS_NO_BUFFER_UTIL](#ws_no_buffer_util) @@ -318,6 +320,12 @@ This class represents a WebSocket. It extends the `EventEmitter`. takes a `Buffer` that must be filled synchronously and is called before a message is sent, for each message. By default, the buffer is filled with cryptographically strong random bytes. + - `handshakeRequest` {HandshakeRequest} A + [`HandshakeRequest`](#class-handshakerequest) instance used to build the + HTTP upgrade request. + - `handshakeValidator` {HandshakeValidator} A + [`HandshakeValidator`](#class-handshakevalidator) instance used to validate + the server's upgrade response. - `handshakeTimeout` {Number} Timeout in milliseconds for the handshake request. This is reset after every redirection. - `maxPayload` {Number} The maximum allowed message size in bytes. Defaults to @@ -628,6 +636,37 @@ Forcibly close the connection. Internally, this calls [`socket.destroy()`][]. The URL of the WebSocket server. Server clients don't have this attribute. +## Class: HandshakeRequest + +This class builds the HTTP upgrade request for a WebSocket client handshake. + +### handshakeRequest.build(address, protocols, opts[, extensionOfferHeader]) + +- `address` {String|url.URL} The URL to connect to. +- `protocols` {Array} The subprotocols. +- `opts` {Object} Connection options (as passed to the `WebSocket` constructor). +- `extensionOfferHeader` {String} The `Sec-WebSocket-Extensions` header value. + +Build the handshake request. Returns an object containing `parsedUrl` {url.URL}, +`key` {String}, `protocolSet` {Set}, and request options suitable for +`http.request()` / `https.request()` (`host`, `port`, `path`, `headers`, etc.). + +## Class: HandshakeValidator + +This class validates a WebSocket server's upgrade response. + +### handshakeValidator.validate(res, key, protocolSet[, perMessageDeflate]) + +- `res` {http.IncomingMessage} The HTTP upgrade response. +- `key` {String} The `Sec-WebSocket-Key` that was sent. +- `protocolSet` {Set} The subprotocols that were offered. +- `perMessageDeflate` {Object} The `PerMessageDeflate` instance, or `null`. +- Returns: {Object} An object with `protocol` {String} and `extensions` {Object} + properties. + +Validate the server's upgrade response and return the negotiated protocol and +extensions. Throws an `Error` if validation fails. + ## createWebSocketStream(websocket[, options]) - `websocket` {WebSocket} A `WebSocket` object. diff --git a/index.js b/index.js index 41edb3b81..e2bc80832 100644 --- a/index.js +++ b/index.js @@ -6,6 +6,8 @@ WebSocket.createWebSocketStream = require('./lib/stream'); WebSocket.Server = require('./lib/websocket-server'); WebSocket.Receiver = require('./lib/receiver'); WebSocket.Sender = require('./lib/sender'); +WebSocket.HandshakeRequest = require('./lib/handshake-request'); +WebSocket.HandshakeValidator = require('./lib/handshake-validator'); WebSocket.WebSocket = WebSocket; WebSocket.WebSocketServer = WebSocket.Server; diff --git a/lib/handshake-request.js b/lib/handshake-request.js new file mode 100644 index 000000000..a4c5f64c9 --- /dev/null +++ b/lib/handshake-request.js @@ -0,0 +1,205 @@ +'use strict'; + +const { randomBytes } = require('crypto'); +const { URL } = require('url'); + +const subprotocolRegex = /^[!#$%&'*+\-.0-9A-Z^_`|a-z~]+$/; + +/** + * Builds the HTTP request for a WebSocket handshake. + * Individual methods can be subclassed to customize behavior. + */ +class HandshakeRequest { + /** + * @param {(String|URL)} address The URL to connect to + * @param {Array} protocols The subprotocols + * @param {Object} opts Options object + * @param {String} [extensionOfferHeader] The Sec-WebSocket-Extensions value + * @return {Object} An object with `parsedUrl`, `key`, `protocolSet`, and + * fields suitable for `http.request()` / `https.request()` + */ + build(address, protocols, opts, extensionOfferHeader) { + const parsedUrl = this.parseUrl(address); + this.validateUrl(parsedUrl); + + const key = this.generateKey(); + const protocolSet = this.buildProtocolSet(protocols); + + const headers = {}; + + // + // User headers first, then WS protocol headers (which take precedence). + // + if (opts.headers) { + Object.assign(headers, opts.headers); + } + + headers['Sec-WebSocket-Version'] = String(opts.protocolVersion); + headers['Sec-WebSocket-Key'] = key; + headers['Connection'] = 'Upgrade'; + headers['Upgrade'] = 'websocket'; + + if (extensionOfferHeader) { + headers['Sec-WebSocket-Extensions'] = extensionOfferHeader; + } + if (protocolSet.size) { + headers['Sec-WebSocket-Protocol'] = protocols.join(','); + } + if (opts.origin) { + if (opts.protocolVersion < 13) { + headers['Sec-WebSocket-Origin'] = opts.origin; + } else { + headers['Origin'] = opts.origin; + } + } + + const isSecure = parsedUrl.protocol === 'wss:'; + const isIpcUrl = parsedUrl.protocol === 'ws+unix:'; + const defaultPort = isSecure ? 443 : 80; + + const host = parsedUrl.hostname.startsWith('[') + ? parsedUrl.hostname.slice(1, -1) + : parsedUrl.hostname; + + const path = parsedUrl.pathname + parsedUrl.search; + + const result = { + parsedUrl, + key, + protocolSet, + host, + port: parsedUrl.port || defaultPort, + path, + defaultPort, + timeout: opts.handshakeTimeout, + headers + }; + + // + // Handle auth. URL credentials take precedence over opts.auth. + // Node.js http.request() generates the Authorization header from + // opts.auth. + // + if (parsedUrl.username || parsedUrl.password) { + result.auth = `${parsedUrl.username}:${parsedUrl.password}`; + } else if (opts.auth) { + result.auth = opts.auth; + } + + if (isIpcUrl) { + const parts = path.split(':'); + + result.socketPath = parts[0]; + result.path = parts[1]; + } + + return result; + } + + parseUrl(address) { + let parsedUrl; + + if (address instanceof URL) { + parsedUrl = address; + } else { + try { + parsedUrl = new URL(address); + } catch (e) { + throw new SyntaxError(`Invalid URL: ${address}`); + } + } + + if (parsedUrl.protocol === 'http:') { + parsedUrl.protocol = 'ws:'; + } else if (parsedUrl.protocol === 'https:') { + parsedUrl.protocol = 'wss:'; + } + + return parsedUrl; + } + + validateUrl(parsedUrl) { + const isSecure = parsedUrl.protocol === 'wss:'; + const isIpcUrl = parsedUrl.protocol === 'ws+unix:'; + let message; + + if (parsedUrl.protocol !== 'ws:' && !isSecure && !isIpcUrl) { + message = + 'The URL\'s protocol must be one of "ws:", "wss:", ' + + '"http:", "https:", or "ws+unix:"'; + } else if (isIpcUrl && !parsedUrl.pathname) { + message = "The URL's pathname is empty"; + } else if (parsedUrl.hash) { + message = 'The URL contains a fragment identifier'; + } + + if (message) throw new SyntaxError(message); + } + + generateKey() { + return randomBytes(16).toString('base64'); + } + + initRedirectOptions(options) { + // + // Shallow copy the user provided options so that headers can be changed + // without mutating the original object. + // + const headers = options.headers; + options = { ...options, headers: {} }; + + if (headers) { + for (const [key, value] of Object.entries(headers)) { + options.headers[key.toLowerCase()] = value; + } + } + + return options; + } + + stripRedirectAuth(opts, isSameHost) { + // + // Match curl 7.77.0 behavior and drop the following headers. These + // headers are also dropped when following a redirect to a subdomain. + // + delete opts.headers.authorization; + delete opts.headers.cookie; + + if (!isSameHost) delete opts.headers.host; + + opts.auth = undefined; + } + + injectAuthHeader(headers, auth) { + // + // Match curl 7.77.0 behavior and make the first `Authorization` header win. + // If the `Authorization` header is set, then there is nothing to do as it + // will take precedence. + // + if (auth && !headers.authorization) { + headers.authorization = 'Basic ' + Buffer.from(auth).toString('base64'); + } + } + + buildProtocolSet(protocols) { + const protocolSet = new Set(); + + for (const protocol of protocols) { + if ( + typeof protocol !== 'string' || + !subprotocolRegex.test(protocol) || + protocolSet.has(protocol) + ) { + throw new SyntaxError( + 'An invalid or duplicated subprotocol was specified' + ); + } + + protocolSet.add(protocol); + } + + return protocolSet; + } +} + +module.exports = HandshakeRequest; diff --git a/lib/handshake-validator.js b/lib/handshake-validator.js new file mode 100644 index 000000000..86912b245 --- /dev/null +++ b/lib/handshake-validator.js @@ -0,0 +1,113 @@ +'use strict'; + +const { createHash } = require('crypto'); + +const { GUID } = require('./constants'); +const { parse } = require('./extension'); +const PerMessageDeflate = require('./permessage-deflate'); + +/** + * Validates a WebSocket upgrade response. Subclass and override individual + * methods to customize validation behavior. + */ +class HandshakeValidator { + /** + * @param {Object} res The HTTP upgrade response + * @param {String} key The `Sec-WebSocket-Key` that was sent + * @param {Set} protocolSet The subprotocols that were offered + * @param {Object} [perMessageDeflate] The PerMessageDeflate instance, if any + * @return {{ protocol: String, extensions: Object }} + */ + validate(res, key, protocolSet, perMessageDeflate) { + this.validateUpgrade(res); + this.validateAcceptKey(res.headers['sec-websocket-accept'], key); + + const protocol = this.validateSubprotocol( + res.headers['sec-websocket-protocol'], + protocolSet + ); + + const extensions = this.validateExtensions( + res.headers['sec-websocket-extensions'], + perMessageDeflate + ); + + return { protocol, extensions }; + } + + validateUpgrade(res) { + const upgrade = res.headers.upgrade; + + if (upgrade === undefined || upgrade.toLowerCase() !== 'websocket') { + throw new Error('Invalid Upgrade header'); + } + } + + validateAcceptKey(actual, key) { + const expected = createHash('sha1') + .update(key + GUID) + .digest('base64'); + + if (actual !== expected) { + throw new Error('Invalid Sec-WebSocket-Accept header'); + } + } + + validateSubprotocol(serverProt, protocolSet) { + if (serverProt !== undefined) { + if (!protocolSet.size) { + throw new Error('Server sent a subprotocol but none was requested'); + } + + if (!protocolSet.has(serverProt)) { + throw new Error('Server sent an invalid subprotocol'); + } + + return serverProt; + } + + if (protocolSet.size) { + throw new Error('Server sent no subprotocol'); + } + + return ''; + } + + validateExtensions(headerValue, perMessageDeflate) { + if (headerValue === undefined) return {}; + + if (!perMessageDeflate) { + throw new Error( + 'Server sent a Sec-WebSocket-Extensions header but no extension ' + + 'was requested' + ); + } + + let extensions; + + try { + extensions = parse(headerValue); + } catch (err) { + throw new Error('Invalid Sec-WebSocket-Extensions header'); + } + + const extensionNames = Object.keys(extensions); + + if ( + extensionNames.length !== 1 || + extensionNames[0] !== PerMessageDeflate.extensionName + ) { + throw new Error('Server indicated an extension that was not requested'); + } + + try { + perMessageDeflate.accept(extensions[PerMessageDeflate.extensionName]); + } catch (err) { + throw new Error('Invalid Sec-WebSocket-Extensions header'); + } + + return { [PerMessageDeflate.extensionName]: perMessageDeflate }; + } +} + +module.exports = HandshakeValidator; diff --git a/lib/websocket.js b/lib/websocket.js index ca2e1ad80..458a4bf63 100644 --- a/lib/websocket.js +++ b/lib/websocket.js @@ -7,11 +7,12 @@ const https = require('https'); const http = require('http'); const net = require('net'); const tls = require('tls'); -const { randomBytes, createHash } = require('crypto'); const { Duplex, Readable } = require('stream'); const { URL } = require('url'); +const HandshakeValidator = require('./handshake-validator'); const PerMessageDeflate = require('./permessage-deflate'); +const HandshakeRequest = require('./handshake-request'); const Receiver = require('./receiver'); const Sender = require('./sender'); const { isBlob } = require('./validation'); @@ -20,7 +21,6 @@ const { BINARY_TYPES, CLOSE_TIMEOUT, EMPTY_BUFFER, - GUID, kForOnEventAttribute, kListener, kStatusCode, @@ -30,13 +30,12 @@ const { const { EventTarget: { addEventListener, removeEventListener } } = require('./event-target'); -const { format, parse } = require('./extension'); +const { format } = require('./extension'); const { toBuffer } = require('./buffer-util'); const kAborted = Symbol('kAborted'); const protocolVersions = [8, 13]; const readyStates = ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED']; -const subprotocolRegex = /^[!#$%&'*+\-.0-9A-Z^_`|a-z~]+$/; /** * Class representing a WebSocket. @@ -686,73 +685,12 @@ function initAsClient(websocket, address, protocols, options) { ); } - let parsedUrl; + const handshakeValidator = + opts.handshakeValidator || new HandshakeValidator(); + const handshakeRequest = opts.handshakeRequest || new HandshakeRequest(); - if (address instanceof URL) { - parsedUrl = address; - } else { - try { - parsedUrl = new URL(address); - } catch { - throw new SyntaxError(`Invalid URL: ${address}`); - } - } - - if (parsedUrl.protocol === 'http:') { - parsedUrl.protocol = 'ws:'; - } else if (parsedUrl.protocol === 'https:') { - parsedUrl.protocol = 'wss:'; - } - - websocket._url = parsedUrl.href; - - const isSecure = parsedUrl.protocol === 'wss:'; - const isIpcUrl = parsedUrl.protocol === 'ws+unix:'; - let invalidUrlMessage; - - if (parsedUrl.protocol !== 'ws:' && !isSecure && !isIpcUrl) { - invalidUrlMessage = - 'The URL\'s protocol must be one of "ws:", "wss:", ' + - '"http:", "https:", or "ws+unix:"'; - } else if (isIpcUrl && !parsedUrl.pathname) { - invalidUrlMessage = "The URL's pathname is empty"; - } else if (parsedUrl.hash) { - invalidUrlMessage = 'The URL contains a fragment identifier'; - } - - if (invalidUrlMessage) { - const err = new SyntaxError(invalidUrlMessage); - - if (websocket._redirects === 0) { - throw err; - } else { - emitErrorAndClose(websocket, err); - return; - } - } - - const defaultPort = isSecure ? 443 : 80; - const key = randomBytes(16).toString('base64'); - const request = isSecure ? https.request : http.request; - const protocolSet = new Set(); let perMessageDeflate; - - opts.createConnection = - opts.createConnection || (isSecure ? tlsConnect : netConnect); - opts.defaultPort = opts.defaultPort || defaultPort; - opts.port = parsedUrl.port || defaultPort; - opts.host = parsedUrl.hostname.startsWith('[') - ? parsedUrl.hostname.slice(1, -1) - : parsedUrl.hostname; - opts.headers = { - ...opts.headers, - 'Sec-WebSocket-Version': opts.protocolVersion, - 'Sec-WebSocket-Key': key, - Connection: 'Upgrade', - Upgrade: 'websocket' - }; - opts.path = parsedUrl.pathname + parsedUrl.search; - opts.timeout = opts.handshakeTimeout; + let extensionOfferHeader; if (opts.perMessageDeflate) { perMessageDeflate = new PerMessageDeflate( @@ -760,44 +698,38 @@ function initAsClient(websocket, address, protocols, options) { false, opts.maxPayload ); - opts.headers['Sec-WebSocket-Extensions'] = format({ + extensionOfferHeader = format({ [PerMessageDeflate.extensionName]: perMessageDeflate.offer() }); } - if (protocols.length) { - for (const protocol of protocols) { - if ( - typeof protocol !== 'string' || - !subprotocolRegex.test(protocol) || - protocolSet.has(protocol) - ) { - throw new SyntaxError( - 'An invalid or duplicated subprotocol was specified' - ); - } - protocolSet.add(protocol); - } + let handshake; - opts.headers['Sec-WebSocket-Protocol'] = protocols.join(','); - } - if (opts.origin) { - if (opts.protocolVersion < 13) { - opts.headers['Sec-WebSocket-Origin'] = opts.origin; - } else { - opts.headers.Origin = opts.origin; - } - } - if (parsedUrl.username || parsedUrl.password) { - opts.auth = `${parsedUrl.username}:${parsedUrl.password}`; + try { + handshake = handshakeRequest.build( + address, + protocols, + opts, + extensionOfferHeader + ); + } catch (err) { + if (websocket._redirects === 0) throw err; + emitErrorAndClose(websocket, err); + return; } - if (isIpcUrl) { - const parts = opts.path.split(':'); + const { parsedUrl, key, protocolSet, ...requestOptions } = handshake; - opts.socketPath = parts[0]; - opts.path = parts[1]; - } + websocket._url = parsedUrl.href; + + const isSecure = parsedUrl.protocol === 'wss:'; + const isIpcUrl = parsedUrl.protocol === 'ws+unix:'; + + opts.createConnection = + opts.createConnection || (isSecure ? tlsConnect : netConnect); + const httpRequest = isSecure ? https.request : http.request; + + Object.assign(opts, requestOptions); let req; @@ -809,19 +741,7 @@ function initAsClient(websocket, address, protocols, options) { ? opts.socketPath : parsedUrl.host; - const headers = options && options.headers; - - // - // Shallow copy the user provided options so that headers can be changed - // without mutating the original object. - // - options = { ...options, headers: {} }; - - if (headers) { - for (const [key, value] of Object.entries(headers)) { - options.headers[key.toLowerCase()] = value; - } - } + options = handshakeRequest.initRedirectOptions(options || {}); } else if (websocket.listenerCount('redirect') === 0) { const isSameHost = isIpcUrl ? websocket._originalIpc @@ -832,30 +752,13 @@ function initAsClient(websocket, address, protocols, options) { : parsedUrl.host === websocket._originalHostOrSocketPath; if (!isSameHost || (websocket._originalSecure && !isSecure)) { - // - // Match curl 7.77.0 behavior and drop the following headers. These - // headers are also dropped when following a redirect to a subdomain. - // - delete opts.headers.authorization; - delete opts.headers.cookie; - - if (!isSameHost) delete opts.headers.host; - - opts.auth = undefined; + handshakeRequest.stripRedirectAuth(opts, isSameHost); } } - // - // Match curl 7.77.0 behavior and make the first `Authorization` header win. - // If the `Authorization` header is set, then there is nothing to do as it - // will take precedence. - // - if (opts.auth && !options.headers.authorization) { - options.headers.authorization = - 'Basic ' + Buffer.from(opts.auth).toString('base64'); - } + handshakeRequest.injectAuthHeader(options.headers, opts.auth); - req = websocket._req = request(opts); + req = websocket._req = httpRequest(opts); if (websocket._redirects) { // @@ -870,7 +773,7 @@ function initAsClient(websocket, address, protocols, options) { websocket.emit('redirect', websocket.url, req); } } else { - req = websocket._req = request(opts); + req = websocket._req = httpRequest(opts); } if (opts.timeout) { @@ -934,86 +837,20 @@ function initAsClient(websocket, address, protocols, options) { req = websocket._req = null; - const upgrade = res.headers.upgrade; - - if (upgrade === undefined || upgrade.toLowerCase() !== 'websocket') { - abortHandshake(websocket, socket, 'Invalid Upgrade header'); - return; - } - - const digest = createHash('sha1') - .update(key + GUID) - .digest('base64'); - - if (res.headers['sec-websocket-accept'] !== digest) { - abortHandshake(websocket, socket, 'Invalid Sec-WebSocket-Accept header'); - return; - } - - const serverProt = res.headers['sec-websocket-protocol']; - let protError; - - if (serverProt !== undefined) { - if (!protocolSet.size) { - protError = 'Server sent a subprotocol but none was requested'; - } else if (!protocolSet.has(serverProt)) { - protError = 'Server sent an invalid subprotocol'; - } - } else if (protocolSet.size) { - protError = 'Server sent no subprotocol'; - } - - if (protError) { - abortHandshake(websocket, socket, protError); + try { + const { protocol, extensions } = handshakeValidator.validate( + res, + key, + protocolSet, + perMessageDeflate + ); + websocket._protocol = protocol; + websocket._extensions = extensions; + } catch (err) { + abortHandshake(websocket, socket, err.message); return; } - if (serverProt) websocket._protocol = serverProt; - - const secWebSocketExtensions = res.headers['sec-websocket-extensions']; - - if (secWebSocketExtensions !== undefined) { - if (!perMessageDeflate) { - const message = - 'Server sent a Sec-WebSocket-Extensions header but no extension ' + - 'was requested'; - abortHandshake(websocket, socket, message); - return; - } - - let extensions; - - try { - extensions = parse(secWebSocketExtensions); - } catch (err) { - const message = 'Invalid Sec-WebSocket-Extensions header'; - abortHandshake(websocket, socket, message); - return; - } - - const extensionNames = Object.keys(extensions); - - if ( - extensionNames.length !== 1 || - extensionNames[0] !== PerMessageDeflate.extensionName - ) { - const message = 'Server indicated an extension that was not requested'; - abortHandshake(websocket, socket, message); - return; - } - - try { - perMessageDeflate.accept(extensions[PerMessageDeflate.extensionName]); - } catch (err) { - const message = 'Invalid Sec-WebSocket-Extensions header'; - abortHandshake(websocket, socket, message); - return; - } - - websocket._extensions[PerMessageDeflate.extensionName] = - perMessageDeflate; - } - websocket.setSocket(socket, head, { allowSynchronousEvents: opts.allowSynchronousEvents, generateMask: opts.generateMask, diff --git a/test/handshake-request.test.js b/test/handshake-request.test.js new file mode 100644 index 000000000..8bac8925c --- /dev/null +++ b/test/handshake-request.test.js @@ -0,0 +1,572 @@ +'use strict'; + +const assert = require('assert'); +const { URL } = require('url'); + +const HandshakeRequest = require('../lib/handshake-request'); +const WebSocket = require('..'); + +describe('HandshakeRequest', () => { + function makeOpts(overrides = {}) { + return { + protocolVersion: 13, + maxPayload: 100 * 1024 * 1024, + headers: {}, + ...overrides + }; + } + + describe('#parseUrl', () => { + it('parses a string URL', () => { + const r = new HandshakeRequest(); + const url = r.parseUrl('ws://example.com/path'); + + assert.strictEqual(url.protocol, 'ws:'); + assert.strictEqual(url.hostname, 'example.com'); + assert.strictEqual(url.pathname, '/path'); + }); + + it('accepts a URL object', () => { + const r = new HandshakeRequest(); + const input = new URL('ws://example.com'); + const url = r.parseUrl(input); + + assert.strictEqual(url, input); + }); + + it('normalizes http: to ws:', () => { + const r = new HandshakeRequest(); + const url = r.parseUrl('http://example.com'); + + assert.strictEqual(url.protocol, 'ws:'); + }); + + it('normalizes https: to wss:', () => { + const r = new HandshakeRequest(); + const url = r.parseUrl('https://example.com'); + + assert.strictEqual(url.protocol, 'wss:'); + }); + + it('throws on invalid URL string', () => { + const r = new HandshakeRequest(); + + assert.throws(() => r.parseUrl('not a url'), { + name: 'SyntaxError', + message: 'Invalid URL: not a url' + }); + }); + }); + + describe('#validateUrl', () => { + it('accepts ws: URLs', () => { + const r = new HandshakeRequest(); + + assert.doesNotThrow(() => r.validateUrl(new URL('ws://example.com'))); + }); + + it('accepts wss: URLs', () => { + const r = new HandshakeRequest(); + + assert.doesNotThrow(() => r.validateUrl(new URL('wss://example.com'))); + }); + + it('rejects unsupported protocols', () => { + const r = new HandshakeRequest(); + + assert.throws(() => r.validateUrl(new URL('ftp://example.com')), { + name: 'SyntaxError', + message: /protocol must be one of/ + }); + }); + + it('rejects URLs with fragment identifiers', () => { + const r = new HandshakeRequest(); + + assert.throws(() => r.validateUrl(new URL('ws://example.com#frag')), { + message: 'The URL contains a fragment identifier' + }); + }); + }); + + describe('#generateKey', () => { + it('returns a base64-encoded string', () => { + const r = new HandshakeRequest(); + const key = r.generateKey(); + + assert.strictEqual(typeof key, 'string'); + assert.strictEqual(Buffer.from(key, 'base64').length, 16); + }); + + it('returns different values each time', () => { + const r = new HandshakeRequest(); + + assert.notStrictEqual(r.generateKey(), r.generateKey()); + }); + }); + + describe('#buildProtocolSet', () => { + it('returns an empty set for no protocols', () => { + const r = new HandshakeRequest(); + const set = r.buildProtocolSet([]); + + assert.strictEqual(set.size, 0); + }); + + it('returns a set of valid protocols', () => { + const r = new HandshakeRequest(); + const set = r.buildProtocolSet(['foo', 'bar']); + + assert.strictEqual(set.size, 2); + assert.ok(set.has('foo')); + assert.ok(set.has('bar')); + }); + + it('throws on duplicate protocols', () => { + const r = new HandshakeRequest(); + + assert.throws(() => r.buildProtocolSet(['foo', 'foo']), { + name: 'SyntaxError', + message: 'An invalid or duplicated subprotocol was specified' + }); + }); + + it('throws on invalid protocol characters', () => { + const r = new HandshakeRequest(); + + assert.throws(() => r.buildProtocolSet(['foo bar']), { + name: 'SyntaxError' + }); + }); + + it('throws on non-string protocols', () => { + const r = new HandshakeRequest(); + + assert.throws(() => r.buildProtocolSet([123]), { + name: 'SyntaxError' + }); + }); + }); + + describe('#build', () => { + it('returns parsedUrl, key, and protocolSet', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const result = r.build('ws://example.com/path?q=1', [], opts); + + assert.strictEqual(result.parsedUrl.href, 'ws://example.com/path?q=1'); + assert.strictEqual(typeof result.key, 'string'); + assert.strictEqual(result.protocolSet.size, 0); + }); + + it('returns correct request options', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const reqOpts = r.build('ws://example.com/path?q=1', [], opts); + + assert.strictEqual(reqOpts.host, 'example.com'); + assert.strictEqual(reqOpts.port, 80); + assert.strictEqual(reqOpts.path, '/path?q=1'); + assert.strictEqual(reqOpts.defaultPort, 80); + }); + + it('includes WS protocol headers in the headers object', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const { key, headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(headers['Connection'], 'Upgrade'); + assert.strictEqual(headers['Upgrade'], 'websocket'); + assert.strictEqual(headers['Sec-WebSocket-Version'], '13'); + assert.strictEqual(headers['Sec-WebSocket-Key'], key); + }); + + it('uses port 443 for wss:', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const reqOpts = r.build('wss://example.com', [], opts); + + assert.strictEqual(reqOpts.parsedUrl.protocol, 'wss:'); + assert.strictEqual(reqOpts.port, 443); + }); + + it('uses an explicit port when provided', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const reqOpts = r.build('ws://example.com:9000', [], opts); + + assert.strictEqual(reqOpts.port, '9000'); + }); + + it('strips brackets from IPv6 hostnames', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const reqOpts = r.build('ws://[::1]:8080/path', [], opts); + + assert.strictEqual(reqOpts.host, '::1'); + }); + + it('adds the extension offer header', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const { headers } = r.build( + 'ws://example.com', + [], + opts, + 'permessage-deflate' + ); + + assert.strictEqual( + headers['Sec-WebSocket-Extensions'], + 'permessage-deflate' + ); + }); + + it('omits extension header when offer is empty', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(headers['Sec-WebSocket-Extensions'], undefined); + }); + + it('sets the subprotocol header', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const { headers, protocolSet } = r.build( + 'ws://example.com', + ['foo', 'bar'], + opts + ); + + assert.strictEqual(headers['Sec-WebSocket-Protocol'], 'foo,bar'); + assert.strictEqual(protocolSet.size, 2); + }); + + it('sets the Origin header for version >= 13', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ origin: 'http://example.com' }); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(headers['Origin'], 'http://example.com'); + assert.strictEqual(headers['Sec-WebSocket-Origin'], undefined); + }); + + it('sets Sec-WebSocket-Origin for version < 13', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ + origin: 'http://example.com', + protocolVersion: 8 + }); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(headers['Sec-WebSocket-Origin'], 'http://example.com'); + }); + + it('extracts auth from the URL', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const reqOpts = r.build('ws://user:pass@example.com', [], opts); + + assert.strictEqual(reqOpts.auth, 'user:pass'); + }); + + it('extracts auth from opts.auth', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ auth: 'foo:bar' }); + + const reqOpts = r.build('ws://example.com', [], opts); + + assert.strictEqual(reqOpts.auth, 'foo:bar'); + }); + + it('URL auth takes precedence over opts.auth', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ auth: 'foo:bar' }); + + const reqOpts = r.build('ws://baz:qux@example.com', [], opts); + + assert.strictEqual(reqOpts.auth, 'baz:qux'); + }); + + it('maps handshakeTimeout to timeout', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ handshakeTimeout: 5000 }); + + const reqOpts = r.build('ws://example.com', [], opts); + + assert.strictEqual(reqOpts.timeout, 5000); + }); + + it('does not include auth when URL has no credentials', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const reqOpts = r.build('ws://example.com', [], opts); + + assert.strictEqual(reqOpts.auth, undefined); + }); + + it('headers is a plain object', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(typeof headers, 'object'); + assert.ok(!Array.isArray(headers)); + }); + + it('handles IPC URLs', () => { + const r = new HandshakeRequest(); + const opts = makeOpts(); + + const reqOpts = r.build('ws+unix:///tmp/sock:/path', [], opts); + + assert.strictEqual(reqOpts.socketPath, '/tmp/sock'); + assert.strictEqual(reqOpts.path, '/path'); + }); + }); + + describe('object headers (options.headers)', () => { + it('preserves user-supplied headers', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ headers: { 'X-Custom': 'value' } }); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(headers['X-Custom'], 'value'); + }); + + it('WS protocol headers overwrite user headers with same name', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ headers: { Connection: 'keep-alive' } }); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(headers['Connection'], 'Upgrade'); + }); + + it('handles array header values', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ + headers: { 'X-Multi': ['one', 'two'] } + }); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.deepStrictEqual(headers['X-Multi'], ['one', 'two']); + }); + + it('coerces non-string values to strings', () => { + const r = new HandshakeRequest(); + const opts = makeOpts({ headers: { 'X-Num': 42 } }); + + const { headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(headers['X-Num'], 42); + }); + }); + + describe('subclassing', () => { + it('allows overriding generateKey', () => { + class FixedKeyRequest extends HandshakeRequest { + generateKey() { + return 'fixed-key-for-testing=='; + } + } + + const r = new FixedKeyRequest(); + const opts = makeOpts(); + + const { key, headers } = r.build('ws://example.com', [], opts); + + assert.strictEqual(key, 'fixed-key-for-testing=='); + assert.strictEqual( + headers['Sec-WebSocket-Key'], + 'fixed-key-for-testing==' + ); + }); + + it('allows overriding parseUrl', () => { + class RewritingRequest extends HandshakeRequest { + parseUrl(address) { + return super.parseUrl(address.replace('internal:', 'ws:')); + } + } + + const r = new RewritingRequest(); + const opts = makeOpts(); + + const { parsedUrl } = r.build('internal://example.com', [], opts); + + assert.strictEqual(parsedUrl.href, 'ws://example.com/'); + }); + + it('allows overriding validateUrl to accept custom protocols', () => { + class LenientRequest extends HandshakeRequest { + validateUrl() {} + } + + const r = new LenientRequest(); + const opts = makeOpts(); + + assert.doesNotThrow(() => r.build('ftp://example.com', [], opts)); + }); + }); + + describe('#initRedirectOptions', () => { + it('lowercases header keys', () => { + const r = new HandshakeRequest(); + const result = r.initRedirectOptions({ + headers: { + 'Content-Type': 'text/plain', + Authorization: 'Bearer token' + } + }); + + assert.deepStrictEqual(result, { + headers: { + 'content-type': 'text/plain', + authorization: 'Bearer token' + } + }); + }); + + it('returns an empty object when headers is undefined', () => { + const r = new HandshakeRequest(); + const result = r.initRedirectOptions({}); + assert.deepStrictEqual(result, { headers: {} }); + }); + + it('returns a new object', () => { + const r = new HandshakeRequest(); + const original = { headers: { 'X-Foo': 'bar' } }; + const result = r.initRedirectOptions(original); + + assert.notStrictEqual(result, original); + }); + }); + + describe('#stripRedirectAuth', () => { + it('deletes authorization and cookie headers', () => { + const r = new HandshakeRequest(); + const headers = { + authorization: 'Basic abc', + cookie: 'session=xyz', + host: 'example.com', + 'x-custom': 'value' + }; + + r.stripRedirectAuth({ headers }, true); + + assert.strictEqual(headers.authorization, undefined); + assert.strictEqual(headers.cookie, undefined); + assert.strictEqual(headers.host, 'example.com'); + assert.strictEqual(headers['x-custom'], 'value'); + }); + + it('also deletes host when not same host', () => { + const r = new HandshakeRequest(); + const headers = { + authorization: 'Basic abc', + cookie: 'session=xyz', + host: 'example.com', + 'x-custom': 'value' + }; + + r.stripRedirectAuth({ headers }, false); + + assert.strictEqual(headers.authorization, undefined); + assert.strictEqual(headers.cookie, undefined); + assert.strictEqual(headers.host, undefined); + assert.strictEqual(headers['x-custom'], 'value'); + }); + + it('does not throw when headers are missing', () => { + const r = new HandshakeRequest(); + const headers = { 'x-custom': 'value' }; + + assert.doesNotThrow(() => r.stripRedirectAuth({ headers }, false)); + }); + }); + + describe('#injectAuthHeader', () => { + it('sets the authorization header from auth string', () => { + const r = new HandshakeRequest(); + const headers = {}; + + r.injectAuthHeader(headers, 'user:pass'); + + assert.strictEqual( + headers.authorization, + 'Basic ' + Buffer.from('user:pass').toString('base64') + ); + }); + + it('does not overwrite an existing authorization header', () => { + const r = new HandshakeRequest(); + const headers = { authorization: 'Bearer existing' }; + + r.injectAuthHeader(headers, 'user:pass'); + + assert.strictEqual(headers.authorization, 'Bearer existing'); + }); + + it('does nothing when auth is undefined', () => { + const r = new HandshakeRequest(); + const headers = {}; + + r.injectAuthHeader(headers, undefined); + + assert.strictEqual(headers.authorization, undefined); + }); + + it('does nothing when auth is empty string', () => { + const r = new HandshakeRequest(); + const headers = {}; + + r.injectAuthHeader(headers, ''); + + assert.strictEqual(headers.authorization, undefined); + }); + }); + + describe('Integration with WebSocket', () => { + it('uses a custom HandshakeRequest', (done) => { + let called = false; + + class CustomRequest extends HandshakeRequest { + generateKey() { + called = true; + return super.generateKey(); + } + } + + const wss = new WebSocket.Server({ port: 0 }, () => { + const ws = new WebSocket(`ws://localhost:${wss.address().port}`, { + handshakeRequest: new CustomRequest() + }); + + ws.on('open', () => { + assert.ok(called); + ws.close(); + }); + + ws.on('close', () => wss.close(done)); + }); + }); + }); +}); diff --git a/test/handshake-validator.test.js b/test/handshake-validator.test.js new file mode 100644 index 000000000..3f196a954 --- /dev/null +++ b/test/handshake-validator.test.js @@ -0,0 +1,297 @@ +'use strict'; + +const assert = require('assert'); +const { createHash } = require('crypto'); + +const HandshakeValidator = require('../lib/handshake-validator'); +const { GUID } = require('../lib/constants'); +const WebSocket = require('..'); + +function computeAccept(key) { + return createHash('sha1') + .update(key + GUID) + .digest('base64'); +} + +describe('HandshakeValidator', () => { + const key = 'dGhlIHNhbXBsZSBub25jZQ=='; + const accept = computeAccept(key); + + function makeRes(overrides = {}) { + return { + headers: { + upgrade: 'websocket', + 'sec-websocket-accept': accept, + ...overrides + } + }; + } + + describe('#validate', () => { + it('accepts a valid handshake with no subprotocol or extensions', () => { + const v = new HandshakeValidator(); + const { protocol, extensions } = v.validate( + makeRes(), + key, + new Set(), + null + ); + + assert.strictEqual(protocol, ''); + assert.deepStrictEqual(extensions, {}); + }); + + it('throws on missing Upgrade header', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => v.validate(makeRes({ upgrade: undefined }), key, new Set(), null), + { message: 'Invalid Upgrade header' } + ); + }); + + it('throws on wrong Upgrade header value', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => v.validate(makeRes({ upgrade: 'http' }), key, new Set(), null), + { message: 'Invalid Upgrade header' } + ); + }); + + it('throws on invalid Sec-WebSocket-Accept', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => + v.validate( + makeRes({ 'sec-websocket-accept': 'wrong' }), + key, + new Set(), + null + ), + { message: 'Invalid Sec-WebSocket-Accept header' } + ); + }); + + it('throws if server sends a subprotocol but none was requested', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => + v.validate( + makeRes({ 'sec-websocket-protocol': 'foo' }), + key, + new Set(), + null + ), + { message: 'Server sent a subprotocol but none was requested' } + ); + }); + + it('throws if server sends an invalid subprotocol', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => + v.validate( + makeRes({ 'sec-websocket-protocol': 'bar' }), + key, + new Set(['foo']), + null + ), + { message: 'Server sent an invalid subprotocol' } + ); + }); + + it('throws if server omits subprotocol when one was requested', () => { + const v = new HandshakeValidator(); + + assert.throws(() => v.validate(makeRes(), key, new Set(['foo']), null), { + message: 'Server sent no subprotocol' + }); + }); + + it('returns the matched subprotocol', () => { + const v = new HandshakeValidator(); + const { protocol } = v.validate( + makeRes({ 'sec-websocket-protocol': 'foo' }), + key, + new Set(['foo', 'bar']), + null + ); + + assert.strictEqual(protocol, 'foo'); + }); + + it('throws if server sends extensions but none were requested', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => + v.validate( + makeRes({ 'sec-websocket-extensions': 'permessage-deflate' }), + key, + new Set(), + null + ), + { message: /no extension was requested/ } + ); + }); + + it('throws if server indicates an unrequested extension', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => + v.validate( + makeRes({ 'sec-websocket-extensions': 'foo' }), + key, + new Set(), + {} + ), + { message: 'Server indicated an extension that was not requested' } + ); + }); + + it('wraps extension accept errors', () => { + const v = new HandshakeValidator(); + const fakeDeflate = { + accept() { + throw new Error('accept failure'); + } + }; + + assert.throws( + () => + v.validate( + makeRes({ + 'sec-websocket-extensions': 'permessage-deflate' + }), + key, + new Set(), + fakeDeflate + ), + { message: 'Invalid Sec-WebSocket-Extensions header' } + ); + }); + + it('wraps unparseable extensions header', () => { + const v = new HandshakeValidator(); + + assert.throws( + () => + v.validate( + makeRes({ + 'sec-websocket-extensions': 'permessage-deflate; =bad' + }), + key, + new Set(), + {} + ), + { message: 'Invalid Sec-WebSocket-Extensions header' } + ); + }); + }); + + describe('subclassing', () => { + it('allows overriding validateAcceptKey', () => { + let called = false; + + class CustomValidator extends HandshakeValidator { + validateAcceptKey(actual, k) { + called = true; + assert.strictEqual(actual, 'custom'); + assert.strictEqual(k, key); + } + } + + const v = new CustomValidator(); + + v.validate( + makeRes({ 'sec-websocket-accept': 'custom' }), + key, + new Set(), + null + ); + assert.ok(called); + }); + + it('allows overriding validateSubprotocol to be lenient', () => { + class LenientValidator extends HandshakeValidator { + validateSubprotocol(serverProt, protocolSet) { + if (serverProt !== undefined) { + return super.validateSubprotocol(serverProt, protocolSet); + } + + return ''; + } + } + + const v = new LenientValidator(); + const { protocol } = v.validate(makeRes(), key, new Set(['foo']), null); + + assert.strictEqual(protocol, ''); + }); + + it('allows overriding validateUpgrade', () => { + class SkipUpgradeValidator extends HandshakeValidator { + validateUpgrade() {} + } + + const v = new SkipUpgradeValidator(); + const { protocol } = v.validate( + makeRes({ upgrade: undefined }), + key, + new Set(), + null + ); + + assert.strictEqual(protocol, ''); + }); + }); + + describe('Integration with WebSocket', () => { + it('uses a custom handshakeValidator subclass', (done) => { + let called = false; + + class CustomValidator extends HandshakeValidator { + validateAcceptKey() { + called = true; + } + } + + const wss = new WebSocket.Server({ port: 0 }, () => { + const ws = new WebSocket(`ws://localhost:${wss.address().port}`, { + handshakeValidator: new CustomValidator() + }); + + ws.on('open', () => { + assert.ok(called); + ws.close(); + }); + + ws.on('close', () => wss.close(done)); + }); + }); + + it('aborts when a custom validator rejects', (done) => { + class RejectingValidator extends HandshakeValidator { + validateAcceptKey() { + throw new Error('rejected'); + } + } + + const wss = new WebSocket.Server({ port: 0 }, () => { + const ws = new WebSocket(`ws://localhost:${wss.address().port}`, { + handshakeValidator: new RejectingValidator() + }); + + ws.on('error', (err) => { + assert.strictEqual(err.message, 'rejected'); + ws.on('close', () => wss.close(done)); + }); + }); + }); + }); +});