Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions src/main/java/com/aws/greengrass/tes/CredentialRequestHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,18 @@ public class CredentialRequestHandler implements HttpHandler {
public static final String AUTH_HEADER = "Authorization";
public static final String IOT_CREDENTIALS_HTTP_VERB = "GET";
public static final String SUPPORTED_REQUEST_VERB = "GET";
public static final int TIME_BEFORE_CACHE_EXPIRE_IN_MIN = 5;
public static final int CLOUD_4XX_ERROR_CACHE_IN_MIN = 2;
public static final int CLOUD_5XX_ERROR_CACHE_IN_MIN = 1;
public static final int UNKNOWN_ERROR_CACHE_IN_MIN = 5;
public static final int TIME_BEFORE_CACHE_EXPIRE_IN_SEC = 300;
public static final int CLOUD_4XX_ERROR_CACHE_IN_SEC = 120;
public static final int CLOUD_5XX_ERROR_CACHE_IN_SEC = 60;
public static final int UNKNOWN_ERROR_CACHE_IN_SEC = 300;

public static final String CLOUD_4XX_ERROR_CACHE_TOPIC = "cloud4xxErrorCacheInSec";
public static final String CLOUD_5XX_ERROR_CACHE_TOPIC = "cloud5xxErrorCacheInSec";
public static final String UNKNOWN_ERROR_CACHE_TOPIC = "unknownErrorCacheInSec";

private int cloud4xxErrorCacheInSec = CLOUD_4XX_ERROR_CACHE_IN_SEC;
private int cloud5xxErrorCacheInSec = CLOUD_5XX_ERROR_CACHE_IN_SEC;
private int unknownErrorCacheInSec = UNKNOWN_ERROR_CACHE_IN_SEC;

private String iotCredentialsPath;

Expand Down Expand Up @@ -142,6 +150,12 @@ void setIotCredentialsPath(String iotRoleAlias) {
this.iotCredentialsPath = "/role-aliases/" + iotRoleAlias + "/credentials";
}

void configureCacheSettings(int cloud4xxErrorCache, int cloud5xxErrorCache, int unknownErrorCache) {
this.cloud4xxErrorCacheInSec = cloud4xxErrorCache;
this.cloud5xxErrorCacheInSec = cloud5xxErrorCache;
this.unknownErrorCacheInSec = unknownErrorCache;
}

@Override
@SuppressWarnings("PMD.AvoidCatchingThrowable")
public void handle(final HttpExchange exchange) throws IOException {
Expand Down Expand Up @@ -281,14 +295,14 @@ private byte[] getCredentialsBypassCache() {
LOGGER.atError().kv(IOT_CRED_PATH_KEY, iotCredentialsPath)
.log("Unable to cache expired credentials which expired at {}", expiry);
} else {
newExpiry = expiry.minus(Duration.ofMinutes(TIME_BEFORE_CACHE_EXPIRE_IN_MIN));
newExpiry = expiry.minus(Duration.ofSeconds(TIME_BEFORE_CACHE_EXPIRE_IN_SEC));
tesCache.get(iotCredentialsPath).responseCode = HttpURLConnection.HTTP_OK;

if (newExpiry.isBefore(Instant.now(clock))) {
LOGGER.atWarn().kv(IOT_CRED_PATH_KEY, iotCredentialsPath)
.log("Can't cache credentials as new credentials {} will "
+ "expire in less than {} minutes", expiry,
TIME_BEFORE_CACHE_EXPIRE_IN_MIN);
+ "expire in less than {} seconds", expiry,
TIME_BEFORE_CACHE_EXPIRE_IN_SEC);
} else {
LOGGER.atInfo().kv(IOT_CRED_PATH_KEY, iotCredentialsPath)
.log("Received IAM credentials that will be cached until {}", newExpiry);
Expand Down Expand Up @@ -318,7 +332,7 @@ private byte[] getCredentialsBypassCache() {
String responseString = "Failed to get connection";
response = responseString.getBytes(StandardCharsets.UTF_8);
// Use unknown error cache policy for SSL/TLS connection errors to prevent excessive retries
newExpiry = Instant.now(clock).plus(Duration.ofMinutes(UNKNOWN_ERROR_CACHE_IN_MIN));
newExpiry = Instant.now(clock).plus(Duration.ofSeconds(unknownErrorCacheInSec));
tesCache.get(iotCredentialsPath).responseCode = HttpURLConnection.HTTP_INTERNAL_ERROR;
tesCache.get(iotCredentialsPath).expiry = newExpiry;
tesCache.get(iotCredentialsPath).credentials = response;
Expand Down Expand Up @@ -421,16 +435,16 @@ private String parseExpiryFromResponse(final String credentials) throws AWSIotEx
}

private Instant getExpiryPolicyForErr(int statusCode) {
int expiryTime = UNKNOWN_ERROR_CACHE_IN_MIN; // In case of unrecognized cloud errors, back off
int expiryTime = unknownErrorCacheInSec; // In case of unrecognized cloud errors, back off
// Add caching Time-To-Live (TTL) for TES cloud errors
if (statusCode >= 400 && statusCode < 500) {
// 4xx retries are only meaningful unless a user action has been adopted, TTL should be longer
expiryTime = CLOUD_4XX_ERROR_CACHE_IN_MIN;
expiryTime = cloud4xxErrorCacheInSec;
} else if (statusCode >= 500 && statusCode < 600) {
// 5xx could be a temporary cloud unavailability, TTL should be shorter
expiryTime = CLOUD_5XX_ERROR_CACHE_IN_MIN;
expiryTime = cloud5xxErrorCacheInSec;
}
return Instant.now(clock).plus(Duration.ofMinutes(expiryTime));
return Instant.now(clock).plus(Duration.ofSeconds(expiryTime));
}

/**
Expand Down
63 changes: 63 additions & 0 deletions src/main/java/com/aws/greengrass/tes/TokenExchangeService.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public class TokenExchangeService extends GreengrassService implements AwsCreden
private String iotRoleAlias;
private HttpServerImpl server;

private int cloud4xxErrorCache;
private int cloud5xxErrorCache;
private int unknownErrorCache;
private static final int MINIMUM_ERROR_CACHE_IN_SEC = 10;

private final AuthorizationHandler authZHandler;
private final CredentialRequestHandler credentialRequestHandler;

Expand Down Expand Up @@ -75,6 +80,56 @@ public TokenExchangeService(Topics topics,

this.authZHandler = authZHandler;
this.credentialRequestHandler = credentialRequestHandler;

cloud4xxErrorCache = validateCacheConfig(Coerce.toInt(config.lookup(
CONFIGURATION_CONFIG_KEY, CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_TOPIC).dflt(
CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_IN_SEC)),
CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_IN_SEC);
cloud5xxErrorCache = validateCacheConfig(Coerce.toInt(config.lookup(
CONFIGURATION_CONFIG_KEY, CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_TOPIC).dflt(
CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_IN_SEC)),
CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_IN_SEC);
unknownErrorCache = validateCacheConfig(Coerce.toInt(config.lookup(
CONFIGURATION_CONFIG_KEY, CredentialRequestHandler.UNKNOWN_ERROR_CACHE_TOPIC).dflt(
CredentialRequestHandler.UNKNOWN_ERROR_CACHE_IN_SEC)),
CredentialRequestHandler.UNKNOWN_ERROR_CACHE_IN_SEC);

credentialRequestHandler.configureCacheSettings(cloud4xxErrorCache, cloud5xxErrorCache, unknownErrorCache);

// Subscribe to cache configuration changes
config.subscribe((why, node) -> {
if (node != null && (node.childOf(CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_TOPIC)
|| node.childOf(CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_TOPIC)
|| node.childOf(CredentialRequestHandler.UNKNOWN_ERROR_CACHE_TOPIC))) {
logger.atDebug("tes-cache-config-change").kv("node", node).kv("why", why).log();

int newCloud4xxErrorCache = validateCacheConfig(Coerce.toInt(config.lookup(
CONFIGURATION_CONFIG_KEY, CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_TOPIC).dflt(
CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_IN_SEC)), cloud4xxErrorCache);
int newCloud5xxErrorCache = validateCacheConfig(Coerce.toInt(config.lookup(
CONFIGURATION_CONFIG_KEY, CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_TOPIC).dflt(
CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_IN_SEC)), cloud5xxErrorCache);
int newUnknownErrorCache = validateCacheConfig(Coerce.toInt(config.lookup(
CONFIGURATION_CONFIG_KEY, CredentialRequestHandler.UNKNOWN_ERROR_CACHE_TOPIC).dflt(
CredentialRequestHandler.UNKNOWN_ERROR_CACHE_IN_SEC)), unknownErrorCache);

if (cloud4xxErrorCache != newCloud4xxErrorCache
|| cloud5xxErrorCache != newCloud5xxErrorCache
|| unknownErrorCache != newUnknownErrorCache) {
cloud4xxErrorCache = newCloud4xxErrorCache;
cloud5xxErrorCache = newCloud5xxErrorCache;
unknownErrorCache = newUnknownErrorCache;
credentialRequestHandler.configureCacheSettings(
newCloud4xxErrorCache, newCloud5xxErrorCache, newUnknownErrorCache);

logger.atInfo("tes-cache-config-change").kv("unknownErrorCache", newUnknownErrorCache)
.kv("cloud4xxErrorCache", newCloud4xxErrorCache)
.kv("cloud5xxErrorCache", newCloud5xxErrorCache)
.log("Restarting TES server due to cache config change");
requestRestart();
}
}
});
}

@Override
Expand Down Expand Up @@ -130,6 +185,14 @@ private void validateConfig() {
}
}

private int validateCacheConfig(int newCacheValue, int oldCacheValue) {
if (newCacheValue < MINIMUM_ERROR_CACHE_IN_SEC) {
logger.atError().log("Error cache value must be at least {}", MINIMUM_ERROR_CACHE_IN_SEC);
return oldCacheValue;
}
return newCacheValue;
}

@Override
public AwsCredentials resolveCredentials() {
return credentialRequestHandler.getAwsCredentials();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_IN_MIN;
import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_IN_MIN;
import static com.aws.greengrass.tes.CredentialRequestHandler.TIME_BEFORE_CACHE_EXPIRE_IN_MIN;
import static com.aws.greengrass.tes.CredentialRequestHandler.UNKNOWN_ERROR_CACHE_IN_MIN;
import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_IN_SEC;
import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_IN_SEC;
import static com.aws.greengrass.tes.CredentialRequestHandler.TIME_BEFORE_CACHE_EXPIRE_IN_SEC;
import static com.aws.greengrass.tes.CredentialRequestHandler.UNKNOWN_ERROR_CACHE_IN_SEC;
import static com.aws.greengrass.testcommons.testutilities.ExceptionLogProtector.ignoreExceptionOfType;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
Expand Down Expand Up @@ -320,15 +320,15 @@ void GIVEN_credential_handler_WHEN_called_handle_THEN_caches_creds() throws Exce
verify(mockStream, times(1)).write(expectedResponse);

// Expiry time in recent future won't give error but there wil be no caching
expirationTime = Instant.now().plus(Duration.ofMinutes(TIME_BEFORE_CACHE_EXPIRE_IN_MIN - 1));
expirationTime = Instant.now().plus(Duration.ofSeconds(TIME_BEFORE_CACHE_EXPIRE_IN_SEC - 60));
responseStr = String.format(RESPONSE_STR, expirationTime.toString());
mockResponse = new IotCloudResponse(responseStr.getBytes(StandardCharsets.UTF_8), 200);
when(mockCloudHelper.sendHttpRequest(any(), any(), any(), any(), any())).thenReturn(mockResponse);
handler.handle(mockExchange);
verify(mockCloudHelper, times(2)).sendHttpRequest(any(), any(), any(), any(), any());

// Expiry time in future will result in credentials being cached
expirationTime = Instant.now().plus(Duration.ofMinutes(TIME_BEFORE_CACHE_EXPIRE_IN_MIN + 1));
expirationTime = Instant.now().plus(Duration.ofSeconds(TIME_BEFORE_CACHE_EXPIRE_IN_SEC + 60));
responseStr = String.format(RESPONSE_STR, expirationTime.toString());
mockResponse = new IotCloudResponse(responseStr.getBytes(StandardCharsets.UTF_8), 200);
when(mockCloudHelper.sendHttpRequest(any(), any(), any(), any(), any())).thenReturn(mockResponse);
Expand Down Expand Up @@ -401,7 +401,7 @@ void GIVEN_4xx_response_code_WHEN_called_handle_THEN_expire_in_2_minutes() throw
String.format("TES responded with status code: %d. Caching response. ", expectedStatus).getBytes();
// expire in 2 minutes
handler.getAwsCredentials();
Instant expirationTime = Instant.now().plus(Duration.ofMinutes(CLOUD_4XX_ERROR_CACHE_IN_MIN));
Instant expirationTime = Instant.now().plus(Duration.ofSeconds(CLOUD_4XX_ERROR_CACHE_IN_SEC));
Clock mockClock = Clock.fixed(expirationTime, ZoneId.of("UTC"));
handler.setClock(mockClock);
handler.getAwsCredentials();
Expand All @@ -425,7 +425,7 @@ void GIVEN_5xx_response_code_WHEN_called_handle_THEN_expire_in_1_minute() throws
String.format("TES responded with status code: %d. Caching response. ", expectedStatus).getBytes();
// expire in 1 minute
handler.getAwsCredentials();
Instant expirationTime = Instant.now().plus(Duration.ofMinutes(CLOUD_5XX_ERROR_CACHE_IN_MIN));
Instant expirationTime = Instant.now().plus(Duration.ofSeconds(CLOUD_5XX_ERROR_CACHE_IN_SEC));
Clock mockClock = Clock.fixed(expirationTime, ZoneId.of("UTC"));
handler.setClock(mockClock);
handler.getAwsCredentials();
Expand All @@ -449,7 +449,7 @@ void GIVEN_unknown_error_response_code_WHEN_called_handle_THEN_expire_in_5_minut
String.format("TES responded with status code: %d. Caching response. ", expectedStatus).getBytes();
// expire in 5 minutes
handler.getAwsCredentials();
Instant expirationTime = Instant.now().plus(Duration.ofMinutes(UNKNOWN_ERROR_CACHE_IN_MIN));
Instant expirationTime = Instant.now().plus(Duration.ofSeconds(UNKNOWN_ERROR_CACHE_IN_SEC));
Clock mockClock = Clock.fixed(expirationTime, ZoneId.of("UTC"));
handler.setClock(mockClock);
handler.getAwsCredentials();
Expand Down
63 changes: 63 additions & 0 deletions src/test/java/com/aws/greengrass/tes/TokenExchangeServiceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
import static com.aws.greengrass.lifecyclemanager.GreengrassService.SETENV_CONFIG_NAMESPACE;
import static com.aws.greengrass.lifecyclemanager.Kernel.SERVICE_TYPE_TOPIC_KEY;
import static com.aws.greengrass.lifecyclemanager.KernelCommandLine.MAIN_SERVICE_NAME;
import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_TOPIC;
import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_TOPIC;
import static com.aws.greengrass.tes.CredentialRequestHandler.UNKNOWN_ERROR_CACHE_TOPIC;
import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_4XX_ERROR_CACHE_IN_SEC;
import static com.aws.greengrass.tes.CredentialRequestHandler.CLOUD_5XX_ERROR_CACHE_IN_SEC;
import static com.aws.greengrass.tes.CredentialRequestHandler.UNKNOWN_ERROR_CACHE_IN_SEC;
import static com.aws.greengrass.tes.TokenExchangeService.ACTIVE_PORT_TOPIC;
import static com.aws.greengrass.tes.TokenExchangeService.PORT_TOPIC;
import static com.aws.greengrass.tes.TokenExchangeService.TES_URI_ENV_VARIABLE_NAME;
Expand Down Expand Up @@ -156,6 +162,25 @@ void GIVEN_token_exchange_service_WHEN_started_THEN_correct_env_set(int port) th
return null;
});

Topic cloud4xxCacheTopic = mock(Topic.class);
when(cloud4xxCacheTopic.dflt(CLOUD_4XX_ERROR_CACHE_IN_SEC))
.thenReturn(cloud4xxCacheTopic);

Topic cloud5xxCacheTopic = mock(Topic.class);
when(cloud5xxCacheTopic.dflt(CLOUD_5XX_ERROR_CACHE_IN_SEC))
.thenReturn(cloud5xxCacheTopic);

Topic unknownCacheTopic = mock(Topic.class);
when(unknownCacheTopic.dflt(UNKNOWN_ERROR_CACHE_IN_SEC))
.thenReturn(unknownCacheTopic);

when(config.lookup(CONFIGURATION_CONFIG_KEY, CLOUD_4XX_ERROR_CACHE_TOPIC))
.thenReturn(cloud4xxCacheTopic);
when(config.lookup(CONFIGURATION_CONFIG_KEY, CLOUD_5XX_ERROR_CACHE_TOPIC))
.thenReturn(cloud5xxCacheTopic);
when(config.lookup(CONFIGURATION_CONFIG_KEY, UNKNOWN_ERROR_CACHE_TOPIC))
.thenReturn(unknownCacheTopic);

TokenExchangeService tes = new TokenExchangeService(config,
mockCredentialHandler,
mockAuthZHandler, deviceConfigurationWithRoleAlias(MOCK_ROLE_ALIAS));
Expand Down Expand Up @@ -205,6 +230,25 @@ void GIVEN_token_exchange_service_WHEN_started_with_empty_role_alias_THEN_server
when(configuration.lookup(SERVICES_NAMESPACE_TOPIC, DEFAULT_NUCLEUS_COMPONENT_NAME, CONFIGURATION_CONFIG_KEY,
IOT_ROLE_ALIAS_TOPIC)).thenReturn(roleTopic);

Topic cloud4xxCacheTopic = mock(Topic.class);
when(cloud4xxCacheTopic.dflt(CLOUD_4XX_ERROR_CACHE_IN_SEC))
.thenReturn(cloud4xxCacheTopic);

Topic cloud5xxCacheTopic = mock(Topic.class);
when(cloud5xxCacheTopic.dflt(CLOUD_5XX_ERROR_CACHE_IN_SEC))
.thenReturn(cloud5xxCacheTopic);

Topic unknownCacheTopic = mock(Topic.class);
when(unknownCacheTopic.dflt(UNKNOWN_ERROR_CACHE_IN_SEC))
.thenReturn(unknownCacheTopic);

when(config.lookup(CONFIGURATION_CONFIG_KEY, CLOUD_4XX_ERROR_CACHE_TOPIC))
.thenReturn(cloud4xxCacheTopic);
when(config.lookup(CONFIGURATION_CONFIG_KEY, CLOUD_5XX_ERROR_CACHE_TOPIC))
.thenReturn(cloud5xxCacheTopic);
when(config.lookup(CONFIGURATION_CONFIG_KEY, UNKNOWN_ERROR_CACHE_TOPIC))
.thenReturn(unknownCacheTopic);

TokenExchangeService tes = spy(new TokenExchangeService(config,
mockCredentialHandler,
mockAuthZHandler, deviceConfigurationWithRoleAlias(roleAlias)));
Expand Down Expand Up @@ -236,6 +280,25 @@ void GIVEN_token_exchange_service_WHEN_auth_errors_THEN_server_errors_out(Extens
when(configuration.lookup(SERVICES_NAMESPACE_TOPIC, DEFAULT_NUCLEUS_COMPONENT_NAME, CONFIGURATION_CONFIG_KEY,
IOT_ROLE_ALIAS_TOPIC)).thenReturn(roleTopic);

Topic cloud4xxCacheTopic = mock(Topic.class);
when(cloud4xxCacheTopic.dflt(CLOUD_4XX_ERROR_CACHE_IN_SEC))
.thenReturn(cloud4xxCacheTopic);

Topic cloud5xxCacheTopic = mock(Topic.class);
when(cloud5xxCacheTopic.dflt(CLOUD_5XX_ERROR_CACHE_IN_SEC))
.thenReturn(cloud5xxCacheTopic);

Topic unknownCacheTopic = mock(Topic.class);
when(unknownCacheTopic.dflt(UNKNOWN_ERROR_CACHE_IN_SEC))
.thenReturn(unknownCacheTopic);

when(config.lookup(CONFIGURATION_CONFIG_KEY, CLOUD_4XX_ERROR_CACHE_TOPIC))
.thenReturn(cloud4xxCacheTopic);
when(config.lookup(CONFIGURATION_CONFIG_KEY, CLOUD_5XX_ERROR_CACHE_TOPIC))
.thenReturn(cloud5xxCacheTopic);
when(config.lookup(CONFIGURATION_CONFIG_KEY, UNKNOWN_ERROR_CACHE_TOPIC))
.thenReturn(unknownCacheTopic);

TokenExchangeService tes = spy(new TokenExchangeService(config,
mockCredentialHandler,
mockAuthZHandler, deviceConfigurationWithRoleAlias("TEST")));
Expand Down
Loading