diff --git a/src/esaml_sp.erl b/src/esaml_sp.erl index 7d7844b..6ddfa82 100644 --- a/src/esaml_sp.erl +++ b/src/esaml_sp.erl @@ -316,19 +316,48 @@ decrypt_assertion(Xml, #esaml_sp{key = PrivateKey}) -> [EncryptedData] = xmerl_xpath:string("./xenc:EncryptedData", Xml, [{namespace, XencNs}]), [#xmlText{value = CipherValue64}] = xmerl_xpath:string("xenc:CipherData/xenc:CipherValue/text()", EncryptedData, [{namespace, XencNs}]), CipherValue = base64:decode(CipherValue64), - SymmetricKey = decrypt_key_info(EncryptedData, PrivateKey), + SymmetricKey = decrypt_key_info(EncryptedData, Xml, PrivateKey), [#xmlAttribute{value = Algorithm}] = xmerl_xpath:string("./xenc:EncryptionMethod/@Algorithm", EncryptedData, [{namespace, XencNs}]), AssertionXml = block_decrypt(Algorithm, SymmetricKey, CipherValue), {Assertion, _} = xmerl_scan:string(AssertionXml, [{namespace_conformant, true}]), Assertion. -decrypt_key_info(EncryptedData, Key) -> +decrypt_key_info(EncryptedData, EncryptedAssertion, Key) -> DsNs = [{"ds", 'http://www.w3.org/2000/09/xmldsig#'}], XencNs = [{"xenc", 'http://www.w3.org/2001/04/xmlenc#'}], [KeyInfo] = xmerl_xpath:string("./ds:KeyInfo", EncryptedData, [{namespace, DsNs}]), - [#xmlAttribute{value = Algorithm}] = xmerl_xpath:string("./xenc:EncryptedKey/xenc:EncryptionMethod/@Algorithm", KeyInfo, [{namespace, XencNs}]), - [#xmlText{value = CipherValue64}] = xmerl_xpath:string("./xenc:EncryptedKey/xenc:CipherData/xenc:CipherValue/text()", KeyInfo, [{namespace, XencNs}]), + + %% Try standard nested EncryptedKey first (backward compatibility) + case xmerl_xpath:string("./xenc:EncryptedKey", KeyInfo, [{namespace, XencNs}]) of + [EncryptedKey | _] -> + %% Standard pattern: EncryptedKey nested in KeyInfo + extract_symmetric_key(EncryptedKey, Key, XencNs); + [] -> + %% Try RetrievalMethod pattern (Okta) + case xmerl_xpath:string("./ds:RetrievalMethod/@URI", KeyInfo, [{namespace, DsNs}]) of + [#xmlAttribute{value = URI}] -> + %% Extract ID from URI (remove leading # if present) + KeyId = case URI of + [$# | Rest] -> Rest; + _ -> URI + end, + %% Search for sibling EncryptedKey by ID within EncryptedAssertion + XPath = lists:flatten(io_lib:format("./xenc:EncryptedKey[@Id='~s']", [KeyId])), + case xmerl_xpath:string(XPath, EncryptedAssertion, [{namespace, XencNs}]) of + [EncryptedKey | _] -> + extract_symmetric_key(EncryptedKey, Key, XencNs); + [] -> + error({encrypted_key_not_found, KeyId}) + end; + [] -> + error(no_encrypted_key_found) + end + end. + +extract_symmetric_key(EncryptedKey, Key, XencNs) -> + [#xmlAttribute{value = Algorithm}] = xmerl_xpath:string("./xenc:EncryptionMethod/@Algorithm", EncryptedKey, [{namespace, XencNs}]), + [#xmlText{value = CipherValue64}] = xmerl_xpath:string("./xenc:CipherData/xenc:CipherValue/text()", EncryptedKey, [{namespace, XencNs}]), CipherValue = base64:decode(CipherValue64), decrypt(CipherValue, Algorithm, Key).