@@ -8,14 +8,16 @@ use azure_identity::{
88 ManagedIdentityCredentialOptions , UserAssignedId ,
99} ;
1010use std:: sync:: Arc ;
11+ use tokio:: sync:: RwLock ;
1112
1213use crate :: experimental:: azure_monitor_exporter:: config:: { AuthConfig , AuthMethod } ;
1314
1415#[ derive( Clone ) ]
1516pub struct Auth {
1617 credential : Arc < dyn TokenCredential > ,
1718 scope : String ,
18- cached_token : AccessToken ,
19+ // Thread-safe shared token cache
20+ cached_token : Arc < RwLock < Option < AccessToken > > > ,
1921}
2022
2123impl Auth {
@@ -25,31 +27,39 @@ impl Auth {
2527 Ok ( Self {
2628 credential,
2729 scope : auth_config. scope . clone ( ) ,
28- cached_token : AccessToken {
29- token : "" . into ( ) ,
30- expires_on : OffsetDateTime :: now_utc ( ) ,
31- } ,
30+ cached_token : Arc :: new ( RwLock :: new ( None ) ) ,
3231 } )
3332 }
3433
3534 pub fn from_credential ( credential : Arc < dyn TokenCredential > , scope : String ) -> Self {
3635 Self {
3736 credential,
3837 scope,
39- cached_token : AccessToken {
40- token : "" . into ( ) ,
41- expires_on : OffsetDateTime :: now_utc ( ) ,
42- } ,
38+ cached_token : Arc :: new ( RwLock :: new ( None ) ) ,
4339 }
4440 }
4541
46- pub async fn get_token ( & mut self ) -> Result < AccessToken , String > {
42+ pub async fn get_token ( & self ) -> Result < AccessToken , String > {
4743 println ! ( "[AzureMonitorExporter][Auth] Acquiring token" ) ;
4844
49- if self . cached_token . expires_on
50- > OffsetDateTime :: now_utc ( ) + azure_core:: time:: Duration :: minutes ( 5 )
45+ // Try to use cached token
5146 {
52- return Ok ( self . cached_token . clone ( ) ) ;
47+ let cached = self . cached_token . read ( ) . await ;
48+ if let Some ( token) = & * cached {
49+ if token. expires_on > OffsetDateTime :: now_utc ( ) {
50+ return Ok ( token. clone ( ) ) ;
51+ }
52+ }
53+ }
54+
55+ // Need to refresh - acquire write lock
56+ let mut cached = self . cached_token . write ( ) . await ;
57+
58+ // Double-check in case another thread refreshed while we waited
59+ if let Some ( token) = & * cached {
60+ if token. expires_on > OffsetDateTime :: now_utc ( ) {
61+ return Ok ( token. clone ( ) ) ;
62+ }
5363 }
5464
5565 let token_response = self
@@ -62,16 +72,14 @@ impl Auth {
6272 . map_err ( |e| format ! ( "Failed to get token: {e}" ) ) ?;
6373
6474 // Update the cached token
65- self . cached_token = token_response. clone ( ) ;
75+ * cached = Some ( token_response. clone ( ) ) ;
6676
6777 Ok ( token_response)
6878 }
6979
70- pub fn invalidate_token ( & mut self ) {
71- self . cached_token = AccessToken {
72- token : "" . into ( ) ,
73- expires_on : OffsetDateTime :: now_utc ( ) ,
74- } ;
80+ pub async fn invalidate_token ( & self ) {
81+ let mut cached = self . cached_token . write ( ) . await ;
82+ * cached = None ;
7583 }
7684
7785 #[ allow( clippy:: print_stdout) ]
@@ -148,7 +156,7 @@ mod tests {
148156 call_count : call_count. clone ( ) ,
149157 } ) ;
150158
151- let mut auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
159+ let auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
152160
153161 // First call should hit the credential
154162 let token1 = auth. get_token ( ) . await . unwrap ( ) ;
@@ -171,7 +179,7 @@ mod tests {
171179 call_count : call_count. clone ( ) ,
172180 } ) ;
173181
174- let mut auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
182+ let auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
175183
176184 // First call
177185 let _ = auth. get_token ( ) . await . unwrap ( ) ;
@@ -191,22 +199,22 @@ mod tests {
191199 call_count : call_count. clone ( ) ,
192200 } ) ;
193201
194- let mut auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
202+ let auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
195203
196204 // First call
197205 let _ = auth. get_token ( ) . await . unwrap ( ) ;
198206 assert_eq ! ( * call_count. lock( ) . unwrap( ) , 1 ) ;
199207
200208 // Invalidate
201- auth. invalidate_token ( ) ;
209+ auth. invalidate_token ( ) . await ;
202210
203211 // Should refresh
204212 let _ = auth. get_token ( ) . await . unwrap ( ) ;
205213 assert_eq ! ( * call_count. lock( ) . unwrap( ) , 2 ) ;
206214 }
207215
208- #[ test]
209- fn test_new_with_managed_identity ( ) {
216+ #[ tokio :: test]
217+ async fn test_new_with_managed_identity ( ) {
210218 let auth_config = AuthConfig {
211219 method : AuthMethod :: ManagedIdentity ,
212220 client_id : Some ( "test-client-id" . to_string ( ) ) ,
@@ -217,11 +225,12 @@ mod tests {
217225 assert ! ( auth. is_ok( ) ) ;
218226 let auth = auth. unwrap ( ) ;
219227 assert_eq ! ( auth. scope, "https://test.scope" ) ;
220- assert_eq ! ( auth. cached_token. token. secret( ) , "" ) ;
228+ // Check that cached_token is None initially
229+ assert ! ( auth. cached_token. read( ) . await . is_none( ) ) ;
221230 }
222231
223- #[ test]
224- fn test_new_with_system_assigned_managed_identity ( ) {
232+ #[ tokio :: test]
233+ async fn test_new_with_system_assigned_managed_identity ( ) {
225234 let auth_config = AuthConfig {
226235 method : AuthMethod :: ManagedIdentity ,
227236 client_id : None ,
@@ -232,11 +241,12 @@ mod tests {
232241 assert ! ( auth. is_ok( ) ) ;
233242 let auth = auth. unwrap ( ) ;
234243 assert_eq ! ( auth. scope, "https://test.scope" ) ;
235- assert_eq ! ( auth. cached_token. token. secret( ) , "" ) ;
244+ // Check that cached_token is None initially
245+ assert ! ( auth. cached_token. read( ) . await . is_none( ) ) ;
236246 }
237247
238- #[ test]
239- fn test_new_with_development_auth ( ) {
248+ #[ tokio :: test]
249+ async fn test_new_with_development_auth ( ) {
240250 let auth_config = AuthConfig {
241251 method : AuthMethod :: Development ,
242252 client_id : None ,
@@ -248,7 +258,8 @@ mod tests {
248258 match auth {
249259 Ok ( auth) => {
250260 assert_eq ! ( auth. scope, "https://test.scope" ) ;
251- assert_eq ! ( auth. cached_token. token. secret( ) , "" ) ;
261+ // Check that cached_token is None initially
262+ assert ! ( auth. cached_token. read( ) . await . is_none( ) ) ;
252263 }
253264 Err ( err) => {
254265 // Expected if Azure CLI/Azure Developer CLI is not installed
@@ -257,6 +268,21 @@ mod tests {
257268 }
258269 }
259270
271+ #[ tokio:: test]
272+ async fn test_from_credential_initializes_with_none ( ) {
273+ let credential = Arc :: new ( MockCredential {
274+ token : "test_token" . to_string ( ) ,
275+ expires_in : azure_core:: time:: Duration :: minutes ( 60 ) ,
276+ call_count : Arc :: new ( Mutex :: new ( 0 ) ) ,
277+ } ) ;
278+
279+ let auth = Auth :: from_credential ( credential, "test_scope" . to_string ( ) ) ;
280+
281+ // Check that initial cached token is None
282+ assert ! ( auth. cached_token. read( ) . await . is_none( ) ) ;
283+ assert_eq ! ( auth. scope, "test_scope" ) ;
284+ }
285+
260286 #[ tokio:: test]
261287 async fn test_get_token_with_exactly_5_minute_buffer ( ) {
262288 let call_count = Arc :: new ( Mutex :: new ( 0 ) ) ;
@@ -267,7 +293,7 @@ mod tests {
267293 call_count : call_count. clone ( ) ,
268294 } ) ;
269295
270- let mut auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
296+ let auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
271297
272298 // First call
273299 let _ = auth. get_token ( ) . await . unwrap ( ) ;
@@ -298,7 +324,7 @@ mod tests {
298324 }
299325
300326 let credential = Arc :: new ( FailingCredential ) ;
301- let mut auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
327+ let auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
302328
303329 let result = auth. get_token ( ) . await ;
304330 assert ! ( result. is_err( ) ) ;
@@ -314,7 +340,7 @@ mod tests {
314340 call_count : call_count. clone ( ) ,
315341 } ) ;
316342
317- let mut auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
343+ let auth = Auth :: from_credential ( credential, "scope" . to_string ( ) ) ;
318344
319345 // Get token twice and verify they are different instances (cloned)
320346 let token1 = auth. get_token ( ) . await . unwrap ( ) ;
@@ -326,8 +352,8 @@ mod tests {
326352 assert_eq ! ( * call_count. lock( ) . unwrap( ) , 1 ) ;
327353 }
328354
329- #[ test]
330- fn test_from_credential_initializes_with_expired_token ( ) {
355+ #[ tokio :: test]
356+ async fn test_from_credential_initializes_with_expired_token ( ) {
331357 let credential = Arc :: new ( MockCredential {
332358 token : "test_token" . to_string ( ) ,
333359 expires_in : azure_core:: time:: Duration :: minutes ( 60 ) ,
@@ -336,9 +362,9 @@ mod tests {
336362
337363 let auth = Auth :: from_credential ( credential, "test_scope" . to_string ( ) ) ;
338364
339- // Check that initial cached token is expired
340- assert_eq ! ( auth. cached_token. token . secret ( ) , "" ) ;
341- assert ! ( auth . cached_token . expires_on <= OffsetDateTime :: now_utc ( ) ) ;
365+ // Check that initial cached token is None
366+ let cached = auth. cached_token . read ( ) . await ;
367+ assert ! ( cached . is_none ( ) , "Initial cached token should be None" ) ;
342368 assert_eq ! ( auth. scope, "test_scope" ) ;
343369 }
344370}
0 commit comments