Skip to content

Commit d2565b0

Browse files
committed
backup/wip
1 parent 413048d commit d2565b0

File tree

4 files changed

+791
-364
lines changed

4 files changed

+791
-364
lines changed

rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/auth.rs

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ use azure_identity::{
88
ManagedIdentityCredentialOptions, UserAssignedId,
99
};
1010
use std::sync::Arc;
11+
use tokio::sync::RwLock;
1112

1213
use crate::experimental::azure_monitor_exporter::config::{AuthConfig, AuthMethod};
1314

1415
#[derive(Clone)]
1516
pub struct Auth {
1617
credential: Arc<dyn TokenCredential>,
1718
scope: String,
18-
cached_token: AccessToken,
19+
// Thread-safe shared token cache
20+
cached_token: Arc<RwLock<Option<AccessToken>>>,
1921
}
2022

2123
impl Auth {
@@ -25,31 +27,39 @@ impl Auth {
2527
Ok(Self {
2628
credential,
2729
scope: auth_config.scope.clone(),
28-
cached_token: AccessToken {
29-
token: "".into(),
30-
expires_on: OffsetDateTime::now_utc(),
31-
},
30+
cached_token: Arc::new(RwLock::new(None)),
3231
})
3332
}
3433

3534
pub fn from_credential(credential: Arc<dyn TokenCredential>, scope: String) -> Self {
3635
Self {
3736
credential,
3837
scope,
39-
cached_token: AccessToken {
40-
token: "".into(),
41-
expires_on: OffsetDateTime::now_utc(),
42-
},
38+
cached_token: Arc::new(RwLock::new(None)),
4339
}
4440
}
4541

46-
pub async fn get_token(&mut self) -> Result<AccessToken, String> {
42+
pub async fn get_token(&self) -> Result<AccessToken, String> {
4743
println!("[AzureMonitorExporter][Auth] Acquiring token");
4844

49-
if self.cached_token.expires_on
50-
> OffsetDateTime::now_utc() + azure_core::time::Duration::minutes(5)
45+
// Try to use cached token
5146
{
52-
return Ok(self.cached_token.clone());
47+
let cached = self.cached_token.read().await;
48+
if let Some(token) = &*cached {
49+
if token.expires_on > OffsetDateTime::now_utc() {
50+
return Ok(token.clone());
51+
}
52+
}
53+
}
54+
55+
// Need to refresh - acquire write lock
56+
let mut cached = self.cached_token.write().await;
57+
58+
// Double-check in case another thread refreshed while we waited
59+
if let Some(token) = &*cached {
60+
if token.expires_on > OffsetDateTime::now_utc() {
61+
return Ok(token.clone());
62+
}
5363
}
5464

5565
let token_response = self
@@ -62,16 +72,14 @@ impl Auth {
6272
.map_err(|e| format!("Failed to get token: {e}"))?;
6373

6474
// Update the cached token
65-
self.cached_token = token_response.clone();
75+
*cached = Some(token_response.clone());
6676

6777
Ok(token_response)
6878
}
6979

70-
pub fn invalidate_token(&mut self) {
71-
self.cached_token = AccessToken {
72-
token: "".into(),
73-
expires_on: OffsetDateTime::now_utc(),
74-
};
80+
pub async fn invalidate_token(&self) {
81+
let mut cached = self.cached_token.write().await;
82+
*cached = None;
7583
}
7684

7785
#[allow(clippy::print_stdout)]
@@ -148,7 +156,7 @@ mod tests {
148156
call_count: call_count.clone(),
149157
});
150158

151-
let mut auth = Auth::from_credential(credential, "scope".to_string());
159+
let auth = Auth::from_credential(credential, "scope".to_string());
152160

153161
// First call should hit the credential
154162
let token1 = auth.get_token().await.unwrap();
@@ -171,7 +179,7 @@ mod tests {
171179
call_count: call_count.clone(),
172180
});
173181

174-
let mut auth = Auth::from_credential(credential, "scope".to_string());
182+
let auth = Auth::from_credential(credential, "scope".to_string());
175183

176184
// First call
177185
let _ = auth.get_token().await.unwrap();
@@ -191,22 +199,22 @@ mod tests {
191199
call_count: call_count.clone(),
192200
});
193201

194-
let mut auth = Auth::from_credential(credential, "scope".to_string());
202+
let auth = Auth::from_credential(credential, "scope".to_string());
195203

196204
// First call
197205
let _ = auth.get_token().await.unwrap();
198206
assert_eq!(*call_count.lock().unwrap(), 1);
199207

200208
// Invalidate
201-
auth.invalidate_token();
209+
auth.invalidate_token().await;
202210

203211
// Should refresh
204212
let _ = auth.get_token().await.unwrap();
205213
assert_eq!(*call_count.lock().unwrap(), 2);
206214
}
207215

208-
#[test]
209-
fn test_new_with_managed_identity() {
216+
#[tokio::test]
217+
async fn test_new_with_managed_identity() {
210218
let auth_config = AuthConfig {
211219
method: AuthMethod::ManagedIdentity,
212220
client_id: Some("test-client-id".to_string()),
@@ -217,11 +225,12 @@ mod tests {
217225
assert!(auth.is_ok());
218226
let auth = auth.unwrap();
219227
assert_eq!(auth.scope, "https://test.scope");
220-
assert_eq!(auth.cached_token.token.secret(), "");
228+
// Check that cached_token is None initially
229+
assert!(auth.cached_token.read().await.is_none());
221230
}
222231

223-
#[test]
224-
fn test_new_with_system_assigned_managed_identity() {
232+
#[tokio::test]
233+
async fn test_new_with_system_assigned_managed_identity() {
225234
let auth_config = AuthConfig {
226235
method: AuthMethod::ManagedIdentity,
227236
client_id: None,
@@ -232,11 +241,12 @@ mod tests {
232241
assert!(auth.is_ok());
233242
let auth = auth.unwrap();
234243
assert_eq!(auth.scope, "https://test.scope");
235-
assert_eq!(auth.cached_token.token.secret(), "");
244+
// Check that cached_token is None initially
245+
assert!(auth.cached_token.read().await.is_none());
236246
}
237247

238-
#[test]
239-
fn test_new_with_development_auth() {
248+
#[tokio::test]
249+
async fn test_new_with_development_auth() {
240250
let auth_config = AuthConfig {
241251
method: AuthMethod::Development,
242252
client_id: None,
@@ -248,7 +258,8 @@ mod tests {
248258
match auth {
249259
Ok(auth) => {
250260
assert_eq!(auth.scope, "https://test.scope");
251-
assert_eq!(auth.cached_token.token.secret(), "");
261+
// Check that cached_token is None initially
262+
assert!(auth.cached_token.read().await.is_none());
252263
}
253264
Err(err) => {
254265
// Expected if Azure CLI/Azure Developer CLI is not installed
@@ -257,6 +268,21 @@ mod tests {
257268
}
258269
}
259270

271+
#[tokio::test]
272+
async fn test_from_credential_initializes_with_none() {
273+
let credential = Arc::new(MockCredential {
274+
token: "test_token".to_string(),
275+
expires_in: azure_core::time::Duration::minutes(60),
276+
call_count: Arc::new(Mutex::new(0)),
277+
});
278+
279+
let auth = Auth::from_credential(credential, "test_scope".to_string());
280+
281+
// Check that initial cached token is None
282+
assert!(auth.cached_token.read().await.is_none());
283+
assert_eq!(auth.scope, "test_scope");
284+
}
285+
260286
#[tokio::test]
261287
async fn test_get_token_with_exactly_5_minute_buffer() {
262288
let call_count = Arc::new(Mutex::new(0));
@@ -267,7 +293,7 @@ mod tests {
267293
call_count: call_count.clone(),
268294
});
269295

270-
let mut auth = Auth::from_credential(credential, "scope".to_string());
296+
let auth = Auth::from_credential(credential, "scope".to_string());
271297

272298
// First call
273299
let _ = auth.get_token().await.unwrap();
@@ -298,7 +324,7 @@ mod tests {
298324
}
299325

300326
let credential = Arc::new(FailingCredential);
301-
let mut auth = Auth::from_credential(credential, "scope".to_string());
327+
let auth = Auth::from_credential(credential, "scope".to_string());
302328

303329
let result = auth.get_token().await;
304330
assert!(result.is_err());
@@ -314,7 +340,7 @@ mod tests {
314340
call_count: call_count.clone(),
315341
});
316342

317-
let mut auth = Auth::from_credential(credential, "scope".to_string());
343+
let auth = Auth::from_credential(credential, "scope".to_string());
318344

319345
// Get token twice and verify they are different instances (cloned)
320346
let token1 = auth.get_token().await.unwrap();
@@ -326,8 +352,8 @@ mod tests {
326352
assert_eq!(*call_count.lock().unwrap(), 1);
327353
}
328354

329-
#[test]
330-
fn test_from_credential_initializes_with_expired_token() {
355+
#[tokio::test]
356+
async fn test_from_credential_initializes_with_expired_token() {
331357
let credential = Arc::new(MockCredential {
332358
token: "test_token".to_string(),
333359
expires_in: azure_core::time::Duration::minutes(60),
@@ -336,9 +362,9 @@ mod tests {
336362

337363
let auth = Auth::from_credential(credential, "test_scope".to_string());
338364

339-
// Check that initial cached token is expired
340-
assert_eq!(auth.cached_token.token.secret(), "");
341-
assert!(auth.cached_token.expires_on <= OffsetDateTime::now_utc());
365+
// Check that initial cached token is None
366+
let cached = auth.cached_token.read().await;
367+
assert!(cached.is_none(), "Initial cached token should be None");
342368
assert_eq!(auth.scope, "test_scope");
343369
}
344370
}

rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/client.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
use azure_core::credentials::TokenCredential;
55
use azure_core::time::OffsetDateTime;
6+
7+
use bytes::Bytes;
8+
69
use reqwest::{
710
Client,
811
header::{AUTHORIZATION, CONTENT_ENCODING, CONTENT_TYPE, HeaderValue},
@@ -28,6 +31,8 @@ pub struct LogsIngestionClient {
2831

2932
/// Token expiry time using monotonic clock for faster comparisons
3033
pub token_valid_until: Instant,
34+
35+
token_refresh_after: Instant,
3136
}
3237

3338
impl LogsIngestionClient {
@@ -53,6 +58,7 @@ impl LogsIngestionClient {
5358
auth: Auth::from_credential(credential, scope),
5459
auth_header: HeaderValue::from_static("Bearer "),
5560
token_valid_until: Instant::now(),
61+
token_refresh_after: Instant::now(),
5662
}
5763
}
5864

@@ -79,14 +85,16 @@ impl LogsIngestionClient {
7985
);
8086

8187
let auth =
82-
Auth::new(&config.auth).map_err(|e| format!("Failed to create auth handler: {e}"))?;
88+
Auth::new(&config.auth)
89+
.map_err(|e| format!("Failed to create auth handler: {e}"))?;
8390

8491
Ok(Self {
8592
http_client,
8693
endpoint,
8794
auth,
8895
auth_header: HeaderValue::from_static("Bearer "),
8996
token_valid_until: Instant::now(),
97+
token_refresh_after: Instant::now(),
9098
})
9199
}
92100

@@ -96,7 +104,7 @@ impl LogsIngestionClient {
96104
let now = Instant::now();
97105

98106
// Fast path: token is still valid
99-
if now < self.token_valid_until {
107+
if now < self.token_refresh_after {
100108
return Ok(());
101109
}
102110

@@ -114,10 +122,10 @@ impl LogsIngestionClient {
114122
// Calculate validity using Instant for faster comparisons
115123
// Refresh 5 minutes before expiry
116124
let valid_seconds = (token.expires_on - OffsetDateTime::now_utc())
117-
.whole_seconds()
118-
.saturating_sub(300); // 5 minutes = 300 seconds
125+
.whole_seconds();
119126

120127
self.token_valid_until = now + Duration::from_secs(valid_seconds.max(0) as u64);
128+
self.token_refresh_after = self.token_valid_until - Duration::from_secs(300);
121129

122130
println!("[AzureMonitorExporter] Acquired new token, valid for {} seconds, valid until {:?}, current time {:?}", valid_seconds, self.token_valid_until, now);
123131

@@ -134,10 +142,7 @@ impl LogsIngestionClient {
134142
/// # Returns
135143
/// * `Ok(())` - If the request was successful
136144
/// * `Err(String)` - Error message if the request failed
137-
pub async fn send(&mut self, body: Vec<u8>) -> Result<(), String> {
138-
// Ensure we have a valid token (fast path when cached)
139-
self.ensure_valid_token().await?;
140-
145+
pub async fn send(&mut self, body: Bytes) -> Result<(), String> {
141146
let start = Instant::now();
142147

143148
// Send compressed body - avoid cloning headers by setting them individually
@@ -168,7 +173,9 @@ impl LogsIngestionClient {
168173
401 => {
169174
// Invalidate token and force refresh on next call
170175
self.token_valid_until = Instant::now();
171-
self.auth.invalidate_token();
176+
self.auth.invalidate_token().await;
177+
self.ensure_valid_token().await?;
178+
172179
Err(format!("Authentication failed: {error}"))
173180
}
174181
403 => Err(format!("Authorization failed: {error}")),
@@ -183,6 +190,7 @@ impl LogsIngestionClient {
183190
mod tests {
184191
use super::*;
185192
use super::super::config::{ApiConfig, AuthConfig, AuthMethod};
193+
use azure_core::Bytes;
186194
use azure_core::credentials::TokenRequestOptions;
187195
use azure_core::credentials::{AccessToken, TokenCredential};
188196
use std::sync::Mutex;
@@ -277,7 +285,7 @@ mod tests {
277285
"scope".to_string(),
278286
);
279287

280-
let result = client.send(vec![1, 2, 3]).await;
288+
let result = client.send(Bytes::from(vec![1, 2, 3])).await;
281289
assert!(result.is_ok());
282290
assert_eq!(*call_count.lock().unwrap(), 1); // Token fetched once
283291
}
@@ -312,7 +320,7 @@ mod tests {
312320
);
313321

314322
// This should fail with 401, but invalidate the token
315-
let result = client.send(vec![1, 2, 3]).await;
323+
let result = client.send(Bytes::from(vec![1, 2, 3])).await;
316324
assert!(result.is_err());
317325
assert!(result.unwrap_err().contains("Authentication failed"));
318326

@@ -326,7 +334,7 @@ mod tests {
326334
.mount(&mock_server)
327335
.await;
328336

329-
let result = client.send(vec![1, 2, 3]).await;
337+
let result = client.send(Bytes::from(vec![1, 2, 3])).await;
330338
assert!(result.is_ok());
331339
assert_eq!(*call_count.lock().unwrap(), 2); // Token fetched again
332340
}
@@ -357,7 +365,7 @@ mod tests {
357365
"scope".to_string(),
358366
);
359367

360-
let result = client.send(vec![1, 2, 3]).await;
368+
let result = client.send(Bytes::from(vec![1, 2, 3])).await;
361369
assert!(result.is_err());
362370
assert!(result.unwrap_err().contains("Rate limited"));
363371
}

0 commit comments

Comments
 (0)