Skip to content

Commit 77bf722

Browse files
committed
feat: implement close_notify for graceful connection shutdown
DTLS 1.2: sender queues close_notify alert and transitions to Closed; receiver auto-replies with reciprocal close_notify per RFC 5246 §7.2.1, discards pending writes, and closes immediately. DTLS 1.3: sender enters HalfClosedLocal (write-closed, read-open) per RFC 8446 §6.1; receiver filters records by sequence number per RFC 9147 §5.10 and lets the application decide when to close. ACKs are suppressed after close_notify is sent.
1 parent 3079cc9 commit 77bf722

16 files changed

Lines changed: 1527 additions & 671 deletions

File tree

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ references into your provided buffer:
7171
- `PeerCert(&[u8])`: peer leaf certificate (DER) — validate in your app
7272
- `KeyingMaterial(KeyingMaterial, SrtpProfile)`: DTLS‑SRTP export
7373
- `ApplicationData(&[u8])`: plaintext received from peer
74+
- `CloseNotify`: peer sent a `close_notify` alert (graceful shutdown)
7475

7576
## Example (Sans‑IO loop)
7677

@@ -106,6 +107,10 @@ fn example_event_loop(mut dtls: Dtls) -> Result<(), dimpl::Error> {
106107
Output::ApplicationData(_data) => {
107108
// Deliver plaintext to application
108109
}
110+
Output::CloseNotify => {
111+
// Peer initiated graceful shutdown
112+
break;
113+
}
109114
_ => {}
110115
}
111116
}

src/crypto/rust_crypto/kx_group.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl EcdhKeyExchange {
4747
match group {
4848
NamedGroup::X25519 => {
4949
use rand_core::OsRng;
50-
let secret = x25519_dalek::EphemeralSecret::random_from_rng(&mut OsRng);
50+
let secret = x25519_dalek::EphemeralSecret::random_from_rng(OsRng);
5151
let public_key_obj = x25519_dalek::PublicKey::from(&secret);
5252
buf.clear();
5353
buf.extend_from_slice(public_key_obj.as_bytes());

src/dtls12/client.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,10 @@ impl Client {
173173
}
174174

175175
pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> {
176-
let last_now = self.last_now;
177-
178176
if let Some(event) = self.local_events.pop_front() {
179177
return event.into_output(buf, &self.server_certificates);
180178
}
181-
182-
self.engine.poll_output(buf, last_now)
179+
self.engine.poll_output(buf, self.last_now)
183180
}
184181

185182
/// Explicitly start the handshake process by sending a ClientHello
@@ -198,6 +195,10 @@ impl Client {
198195
/// This should only be called when the client is in the Running state,
199196
/// after the handshake is complete.
200197
pub fn send_application_data(&mut self, data: &[u8]) -> Result<(), Error> {
198+
if self.state == State::Closed {
199+
return Err(Error::SecurityError("connection closed".to_string()));
200+
}
201+
201202
if self.state != State::AwaitApplicationData {
202203
self.queued_data.push(data.to_buf());
203204
return Ok(());
@@ -213,6 +214,25 @@ impl Client {
213214
Ok(())
214215
}
215216

217+
/// Initiate graceful shutdown by sending a `close_notify` alert.
218+
pub fn close(&mut self) -> Result<(), Error> {
219+
if self.state == State::Closed {
220+
return Ok(());
221+
}
222+
if self.state != State::AwaitApplicationData {
223+
self.engine.abort();
224+
self.state = State::Closed;
225+
return Ok(());
226+
}
227+
self.engine
228+
.create_record(ContentType::Alert, 1, false, |body| {
229+
body.push(1); // level: warning
230+
body.push(0); // description: close_notify
231+
})?;
232+
self.state = State::Closed;
233+
Ok(())
234+
}
235+
216236
fn make_progress(&mut self) -> Result<(), Error> {
217237
loop {
218238
let prev_state = self.state;
@@ -247,6 +267,7 @@ enum State {
247267
AwaitNewSessionTicket,
248268
AwaitFinished,
249269
AwaitApplicationData,
270+
Closed,
250271
}
251272

252273
impl State {
@@ -268,6 +289,7 @@ impl State {
268289
State::AwaitNewSessionTicket => "AwaitNewSessionTicket",
269290
State::AwaitFinished => "AwaitFinished",
270291
State::AwaitApplicationData => "AwaitApplicationData",
292+
State::Closed => "Closed",
271293
}
272294
}
273295

@@ -289,6 +311,7 @@ impl State {
289311
State::AwaitNewSessionTicket => self.await_new_session_ticket(client),
290312
State::AwaitFinished => self.await_finished(client),
291313
State::AwaitApplicationData => self.await_application_data(client),
314+
State::Closed => Ok(self),
292315
}
293316
}
294317

@@ -1051,6 +1074,19 @@ impl State {
10511074
}
10521075

10531076
fn await_application_data(self, client: &mut Client) -> Result<Self, Error> {
1077+
if client.engine.close_notify_received() {
1078+
// RFC 5246 §7.2.1: respond with a reciprocal close_notify and
1079+
// close down immediately, discarding any pending writes.
1080+
client.engine.discard_pending_writes();
1081+
client
1082+
.engine
1083+
.create_record(ContentType::Alert, 1, false, |body| {
1084+
body.push(1); // level: warning
1085+
body.push(0); // description: close_notify
1086+
})?;
1087+
return Ok(State::Closed);
1088+
}
1089+
10541090
if !client.queued_data.is_empty() {
10551091
debug!(
10561092
"Sending queued application data: {}",

src/dtls12/engine.rs

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::sync::Arc;
33
use std::sync::atomic::AtomicBool;
44
use std::time::{Duration, Instant};
55

6+
use arrayvec::ArrayVec;
7+
68
use super::queue::{QueueRx, QueueTx};
79
use crate::buffer::{Buf, BufferPool, TmpBuf};
810
use crate::crypto::{Aad, Iv, Nonce};
@@ -88,6 +90,12 @@ pub struct Engine {
8890

8991
/// Whether we are ready to release application data from poll_output.
9092
release_app_data: bool,
93+
94+
/// Whether a close_notify alert has been received from the peer.
95+
close_notify_received: bool,
96+
97+
/// Whether we have already emitted a CloseNotify output event.
98+
close_notify_reported: bool,
9199
}
92100

93101
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -140,6 +148,8 @@ impl Engine {
140148
flight_timeout: Timeout::Unarmed,
141149
connect_timeout: Timeout::Unarmed,
142150
release_app_data: false,
151+
close_notify_received: false,
152+
close_notify_reported: false,
143153
}
144154
}
145155

@@ -206,11 +216,12 @@ impl Engine {
206216

207217
/// Insert the Incoming using the logic:
208218
///
209-
/// 1. If it is a handshake, sort by the message_seq
210-
/// 2. If it is not a handshake, sort by sequence_number
219+
/// 1. Handle alert records per-record (not just the first record).
220+
/// 2. If it is a handshake, sort by the message_seq
221+
/// 3. If it is not a handshake, sort by sequence_number
211222
///
212223
fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
213-
// Capacity guard
224+
// Capacity guard before iterating records.
214225
if self.queue_rx.len() >= self.config.max_queue_rx() {
215226
warn!(
216227
"Receive queue full (max {}): {:?}",
@@ -220,6 +231,80 @@ impl Engine {
220231
return Err(Error::ReceiveQueueFull);
221232
}
222233

234+
// Handle Alert records individually; collect the rest for queuing.
235+
// A single UDP datagram can contain mixed record types, so we process
236+
// each record individually without discarding siblings.
237+
let mut remaining = ArrayVec::new();
238+
for record in incoming.into_records() {
239+
if record.record().content_type == ContentType::Alert {
240+
let epoch = record.record().sequence.epoch;
241+
if epoch == 0 {
242+
if self.peer_encryption_enabled {
243+
// Post-handshake: epoch 0 alerts are unauthenticated, discard
244+
self.buffers_free.push(record.into_buffer());
245+
continue;
246+
}
247+
// During handshake: accept fatal alerts (level==2)
248+
let fragment = record.record().fragment(record.buffer());
249+
if fragment.len() >= 2 && fragment[0] == 2 {
250+
let description = fragment[1];
251+
self.buffers_free.push(record.into_buffer());
252+
return Err(Error::SecurityError(format!(
253+
"Received fatal alert: level=2, description={}",
254+
description
255+
)));
256+
}
257+
// Non-fatal epoch 0 alert during handshake: discard
258+
self.buffers_free.push(record.into_buffer());
259+
continue;
260+
}
261+
if !self.peer_encryption_enabled {
262+
// Epoch ≥ 1 but peer encryption not yet enabled: keep for
263+
// re-parsing after enable_peer_encryption (ciphertext record).
264+
remaining.try_push(record).ok();
265+
continue;
266+
}
267+
// Authenticated alert (epoch ≥ 1, peer encryption enabled)
268+
let fragment = record.record().fragment(record.buffer());
269+
if fragment.len() >= 2 {
270+
let level = fragment[0];
271+
let description = fragment[1];
272+
if description == 0 {
273+
// close_notify: signal graceful shutdown
274+
self.close_notify_received = true;
275+
self.buffers_free.push(record.into_buffer());
276+
continue;
277+
} else if level == 2 {
278+
// Fatal alert (non close_notify)
279+
self.buffers_free.push(record.into_buffer());
280+
return Err(Error::SecurityError(format!(
281+
"Received fatal alert: level={}, description={}",
282+
level, description
283+
)));
284+
}
285+
}
286+
// Warning alerts with non-zero description: discard
287+
self.buffers_free.push(record.into_buffer());
288+
continue;
289+
}
290+
291+
// After close_notify (from this or a prior datagram), discard
292+
// any further ApplicationData — the read half is closed.
293+
if self.close_notify_received
294+
&& record.record().content_type == ContentType::ApplicationData
295+
{
296+
self.buffers_free.push(record.into_buffer());
297+
continue;
298+
}
299+
300+
remaining.try_push(record).ok();
301+
}
302+
303+
let incoming = match Incoming::from_records(remaining) {
304+
Some(incoming) => incoming,
305+
None => return Ok(()),
306+
};
307+
223308
// Dispatch to specialized handlers
224309
if incoming.first().first_handshake().is_some() {
225310
self.insert_incoming_handshake(incoming)
@@ -370,6 +455,11 @@ impl Engine {
370455
return Output::Packet(p);
371456
}
372457

458+
if self.release_app_data && self.close_notify_received && !self.close_notify_reported {
459+
self.close_notify_reported = true;
460+
return Output::CloseNotify;
461+
}
462+
373463
let next_timeout = self.poll_timeout(now);
374464

375465
Output::Timeout(next_timeout)
@@ -899,6 +989,27 @@ impl Engine {
899989
self.release_app_data = true;
900990
}
901991

992+
/// Whether a close_notify alert has been received from the peer.
993+
pub fn close_notify_received(&self) -> bool {
994+
self.close_notify_received
995+
}
996+
997+
/// Discard all pending outgoing data.
998+
///
999+
/// RFC 5246 §7.2.1: on receiving close_notify, discard any pending writes.
1000+
pub fn discard_pending_writes(&mut self) {
1001+
self.queue_tx.clear();
1002+
}
1003+
1004+
/// Abort the connection: flush all queued output, retransmission state, and
1005+
/// disable timers so that no further packets are emitted.
1006+
pub fn abort(&mut self) {
1007+
self.queue_tx.clear();
1008+
self.flight_saved_records.clear();
1009+
self.flight_timeout = Timeout::Disabled;
1010+
self.connect_timeout = Timeout::Disabled;
1011+
}
1012+
9021013
/// Pop a buffer from the buffer pool for temporary use
9031014
pub(crate) fn pop_buffer(&mut self) -> Buf {
9041015
self.buffers_free.pop()

src/dtls12/incoming.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ impl Incoming {
3030
pub fn into_records(self) -> impl Iterator<Item = Record> {
3131
self.records.records.into_iter()
3232
}
33+
34+
/// Create an Incoming from pre-filtered records.
35+
/// Returns None if records is empty (same invariant as parse_packet).
36+
pub fn from_records(records: ArrayVec<Record, 8>) -> Option<Self> {
37+
if records.is_empty() {
38+
return None;
39+
}
40+
Some(Incoming {
41+
records: Box::new(Records { records }),
42+
})
43+
}
3344
}
3445

3546
impl Incoming {

0 commit comments

Comments
 (0)