Skip to content

Commit 563ed6a

Browse files
committed
feat: implement close_notify for graceful connection shutdown
Add close_notify support for both DTLS 1.2 and DTLS 1.3, implementing graceful connection shutdown per RFC 5246 §7.2.1 and RFC 9147 §5.10. DTLS 1.2: close() sends close_notify and transitions to Closed state. Receiving close_notify triggers a reciprocal alert and discards pending writes. No half-close support (full close only). DTLS 1.3: close() sends close_notify and enters HalfClosedLocal state where the read half remains open. Receiving close_notify while half-closed transitions to Closed. Incoming KeyUpdate messages are still processed (recv keys updated) but no outgoing records are sent. Engine tracks close_notify at the record layer (filtering app data after the alert sequence), while client/server handle connection state and Output::CloseNotify emission. Error::ConnectionClosed replaces SecurityError("connection closed") for send_application_data on closed connections.
1 parent 2439bf5 commit 563ed6a

File tree

15 files changed

+1510
-630
lines changed

15 files changed

+1510
-630
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ references into your provided buffer:
7171
- `PeerCert(&[u8])`: peer leaf certificate (DER) — validate in your app
7272
- `KeyingMaterial(KeyingMaterial, SrtpProfile)`: DTLS‑SRTP export
7373
- `ApplicationData(&[u8])`: plaintext received from peer
74+
- `CloseNotify`: peer sent a `close_notify` alert (graceful shutdown)
7475

7576
## Example (Sans‑IO loop)
7677

@@ -106,6 +107,10 @@ fn example_event_loop(mut dtls: Dtls) -> Result<(), dimpl::Error> {
106107
Output::ApplicationData(_data) => {
107108
// Deliver plaintext to application
108109
}
110+
Output::CloseNotify => {
111+
// Peer initiated graceful shutdown
112+
break;
113+
}
109114
_ => {}
110115
}
111116
}

src/dtls12/client.rs

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ pub struct Client {
7777

7878
/// Data that is sent before we are connected.
7979
queued_data: Vec<Buf>,
80+
81+
/// Whether we have already emitted a CloseNotify output event.
82+
close_notify_reported: bool,
8083
}
8184

8285
#[derive(Debug, PartialEq, Eq)]
@@ -106,6 +109,7 @@ impl Client {
106109
last_now: now,
107110
local_events: VecDeque::new(),
108111
queued_data: Vec::new(),
112+
close_notify_reported: false,
109113
}
110114
}
111115

@@ -153,6 +157,7 @@ impl Client {
153157
last_now: now,
154158
local_events: VecDeque::new(),
155159
queued_data: Vec::new(),
160+
close_notify_reported: false,
156161
};
157162
client.handle_timeout(now)?;
158163
Ok(client)
@@ -173,13 +178,18 @@ impl Client {
173178
}
174179

175180
pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> {
176-
let last_now = self.last_now;
177-
178181
if let Some(event) = self.local_events.pop_front() {
179182
return event.into_output(buf, &self.server_certificates);
180183
}
181-
182-
self.engine.poll_output(buf, last_now)
184+
let output = self.engine.poll_output(buf, self.last_now);
185+
if matches!(output, Output::Timeout(_))
186+
&& self.engine.close_notify_received()
187+
&& !self.close_notify_reported
188+
{
189+
self.close_notify_reported = true;
190+
return Output::CloseNotify;
191+
}
192+
output
183193
}
184194

185195
/// Explicitly start the handshake process by sending a ClientHello
@@ -198,6 +208,10 @@ impl Client {
198208
/// This should only be called when the client is in the Running state,
199209
/// after the handshake is complete.
200210
pub fn send_application_data(&mut self, data: &[u8]) -> Result<(), Error> {
211+
if self.state == State::Closed {
212+
return Err(Error::ConnectionClosed);
213+
}
214+
201215
if self.state != State::AwaitApplicationData {
202216
self.queued_data.push(data.to_buf());
203217
return Ok(());
@@ -213,6 +227,25 @@ impl Client {
213227
Ok(())
214228
}
215229

230+
/// Initiate graceful shutdown by sending a `close_notify` alert.
231+
pub fn close(&mut self) -> Result<(), Error> {
232+
if self.state == State::Closed {
233+
return Ok(());
234+
}
235+
if self.state != State::AwaitApplicationData {
236+
self.engine.abort();
237+
self.state = State::Closed;
238+
return Ok(());
239+
}
240+
self.engine
241+
.create_record(ContentType::Alert, 1, false, |body| {
242+
body.push(1); // level: warning
243+
body.push(0); // description: close_notify
244+
})?;
245+
self.state = State::Closed;
246+
Ok(())
247+
}
248+
216249
fn make_progress(&mut self) -> Result<(), Error> {
217250
loop {
218251
let prev_state = self.state;
@@ -247,6 +280,7 @@ enum State {
247280
AwaitNewSessionTicket,
248281
AwaitFinished,
249282
AwaitApplicationData,
283+
Closed,
250284
}
251285

252286
impl State {
@@ -268,6 +302,7 @@ impl State {
268302
State::AwaitNewSessionTicket => "AwaitNewSessionTicket",
269303
State::AwaitFinished => "AwaitFinished",
270304
State::AwaitApplicationData => "AwaitApplicationData",
305+
State::Closed => "Closed",
271306
}
272307
}
273308

@@ -289,6 +324,7 @@ impl State {
289324
State::AwaitNewSessionTicket => self.await_new_session_ticket(client),
290325
State::AwaitFinished => self.await_finished(client),
291326
State::AwaitApplicationData => self.await_application_data(client),
327+
State::Closed => Ok(self),
292328
}
293329
}
294330

@@ -1051,6 +1087,19 @@ impl State {
10511087
}
10521088

10531089
fn await_application_data(self, client: &mut Client) -> Result<Self, Error> {
1090+
if client.engine.close_notify_received() {
1091+
// RFC 5246 §7.2.1: respond with a reciprocal close_notify and
1092+
// close down immediately, discarding any pending writes.
1093+
client.engine.discard_pending_writes();
1094+
client
1095+
.engine
1096+
.create_record(ContentType::Alert, 1, false, |body| {
1097+
body.push(1); // level: warning
1098+
body.push(0); // description: close_notify
1099+
})?;
1100+
return Ok(State::Closed);
1101+
}
1102+
10541103
if !client.queued_data.is_empty() {
10551104
debug!(
10561105
"Sending queued application data: {}",

src/dtls12/engine.rs

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::sync::Arc;
33
use std::sync::atomic::AtomicBool;
44
use std::time::{Duration, Instant};
55

6+
use arrayvec::ArrayVec;
7+
68
use super::queue::{QueueRx, QueueTx};
79
use crate::buffer::{Buf, BufferPool, TmpBuf};
810
use crate::crypto::{Aad, Iv, Nonce};
@@ -88,6 +90,9 @@ pub struct Engine {
8890

8991
/// Whether we are ready to release application data from poll_output.
9092
release_app_data: bool,
93+
94+
/// Whether a close_notify alert has been received from the peer.
95+
close_notify_received: bool,
9196
}
9297

9398
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -140,6 +145,7 @@ impl Engine {
140145
flight_timeout: Timeout::Unarmed,
141146
connect_timeout: Timeout::Unarmed,
142147
release_app_data: false,
148+
close_notify_received: false,
143149
}
144150
}
145151

@@ -206,11 +212,12 @@ impl Engine {
206212

207213
/// Insert the Incoming using the logic:
208214
///
209-
/// 1. If it is a handshake, sort by the message_seq
210-
/// 2. If it is not a handshake, sort by sequence_number
215+
/// 1. Extract alert records for immediate processing.
216+
/// 2. If it is a handshake, sort by the message_seq
217+
/// 3. If it is not a handshake, sort by sequence_number
211218
///
212219
fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
213-
// Capacity guard
220+
// Capacity guard before iterating records.
214221
if self.queue_rx.len() >= self.config.max_queue_rx() {
215222
warn!(
216223
"Receive queue full (max {}): {:?}",
@@ -220,6 +227,10 @@ impl Engine {
220227
return Err(Error::ReceiveQueueFull);
221228
}
222229

230+
let Some(incoming) = self.extract_alerts(incoming)? else {
231+
return Ok(());
232+
};
233+
223234
// Dispatch to specialized handlers
224235
if incoming.first().first_handshake().is_some() {
225236
self.insert_incoming_handshake(incoming)
@@ -228,6 +239,83 @@ impl Engine {
228239
}
229240
}
230241

242+
/// Process alert records from the incoming datagram, returning the
243+
/// remaining non-alert records for queuing.
244+
///
245+
/// A single UDP datagram can contain mixed record types, so each
246+
/// record is inspected individually. Alert records are handled
247+
/// per-epoch for authentication, and after receiving close_notify,
248+
/// any further ApplicationData is discarded.
249+
fn extract_alerts(&mut self, incoming: Incoming) -> Result<Option<Incoming>, Error> {
250+
let mut remaining = ArrayVec::new();
251+
for record in incoming.into_records() {
252+
if record.record().content_type == ContentType::Alert {
253+
let epoch = record.record().sequence.epoch;
254+
if epoch == 0 {
255+
if self.peer_encryption_enabled {
256+
// Post-handshake: epoch 0 alerts are unauthenticated, discard
257+
self.buffers_free.push(record.into_buffer());
258+
continue;
259+
}
260+
// During handshake: accept fatal alerts (level==2)
261+
let fragment = record.record().fragment(record.buffer());
262+
if fragment.len() >= 2 && fragment[0] == 2 {
263+
let description = fragment[1];
264+
self.buffers_free.push(record.into_buffer());
265+
return Err(Error::SecurityError(format!(
266+
"Received fatal alert: level=2, description={}",
267+
description
268+
)));
269+
}
270+
// Non-fatal epoch 0 alert during handshake: discard
271+
self.buffers_free.push(record.into_buffer());
272+
continue;
273+
}
274+
if !self.peer_encryption_enabled {
275+
// Epoch ≥ 1 but peer encryption not yet enabled: keep for
276+
// re-parsing after enable_peer_encryption (ciphertext record).
277+
remaining.try_push(record).ok();
278+
continue;
279+
}
280+
// Authenticated alert (epoch ≥ 1, peer encryption enabled)
281+
let fragment = record.record().fragment(record.buffer());
282+
if fragment.len() >= 2 {
283+
let level = fragment[0];
284+
let description = fragment[1];
285+
if description == 0 {
286+
// close_notify: signal graceful shutdown
287+
self.close_notify_received = true;
288+
self.buffers_free.push(record.into_buffer());
289+
continue;
290+
} else if level == 2 {
291+
// Fatal alert (non close_notify)
292+
self.buffers_free.push(record.into_buffer());
293+
return Err(Error::SecurityError(format!(
294+
"Received fatal alert: level={}, description={}",
295+
level, description
296+
)));
297+
}
298+
}
299+
// Warning alerts with non-zero description: discard
300+
self.buffers_free.push(record.into_buffer());
301+
continue;
302+
}
303+
304+
// After close_notify (from this or a prior datagram), discard
305+
// any further ApplicationData — the read half is closed.
306+
if self.close_notify_received
307+
&& record.record().content_type == ContentType::ApplicationData
308+
{
309+
self.buffers_free.push(record.into_buffer());
310+
continue;
311+
}
312+
313+
remaining.try_push(record).ok();
314+
}
315+
316+
Ok(Incoming::from_records(remaining))
317+
}
318+
231319
fn insert_incoming_handshake(&mut self, incoming: Incoming) -> Result<(), Error> {
232320
let first_record = incoming.first();
233321
let handshake = first_record
@@ -899,6 +987,27 @@ impl Engine {
899987
self.release_app_data = true;
900988
}
901989

990+
/// Whether a close_notify alert has been received from the peer.
991+
pub fn close_notify_received(&self) -> bool {
992+
self.close_notify_received
993+
}
994+
995+
/// Discard all pending outgoing data.
996+
///
997+
/// RFC 5246 §7.2.1: on receiving close_notify, discard any pending writes.
998+
pub fn discard_pending_writes(&mut self) {
999+
self.queue_tx.clear();
1000+
}
1001+
1002+
/// Abort the connection: flush all queued output, retransmission state, and
1003+
/// disable timers so that no further packets are emitted.
1004+
pub fn abort(&mut self) {
1005+
self.queue_tx.clear();
1006+
self.flight_saved_records.clear();
1007+
self.flight_timeout = Timeout::Disabled;
1008+
self.connect_timeout = Timeout::Disabled;
1009+
}
1010+
9021011
/// Pop a buffer from the buffer pool for temporary use
9031012
pub(crate) fn pop_buffer(&mut self) -> Buf {
9041013
self.buffers_free.pop()

src/dtls12/incoming.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ impl Incoming {
3030
pub fn into_records(self) -> impl Iterator<Item = Record> {
3131
self.records.records.into_iter()
3232
}
33+
34+
/// Create an Incoming from pre-filtered records.
35+
/// Returns None if records is empty (same invariant as parse_packet).
36+
pub fn from_records(records: ArrayVec<Record, 8>) -> Option<Self> {
37+
if records.is_empty() {
38+
return None;
39+
}
40+
Some(Incoming {
41+
records: Box::new(Records { records }),
42+
})
43+
}
3344
}
3445

3546
impl Incoming {

0 commit comments

Comments
 (0)