diff --git a/srt/src/main/java/com/pedro/srt/srt/CommandsManager.kt b/srt/src/main/java/com/pedro/srt/srt/CommandsManager.kt index 5292875f8..ca7c1865d 100644 --- a/srt/src/main/java/com/pedro/srt/srt/CommandsManager.kt +++ b/srt/src/main/java/com/pedro/srt/srt/CommandsManager.kt @@ -73,6 +73,8 @@ class CommandsManager { return encryptor?.type ?: EncryptionType.NONE } + fun encryptionEnabled() = encryptor != null + fun loadStartTs() { startTS = TimeUtils.getCurrentTimeMicro() } diff --git a/srt/src/main/java/com/pedro/srt/srt/SrtClient.kt b/srt/src/main/java/com/pedro/srt/srt/SrtClient.kt index 5ea80b22f..f88fc8819 100644 --- a/srt/src/main/java/com/pedro/srt/srt/SrtClient.kt +++ b/srt/src/main/java/com/pedro/srt/srt/SrtClient.kt @@ -140,7 +140,7 @@ class SrtClient(private val connectChecker: ConnectChecker) { */ fun setPassphrase(passphrase: String, type: EncryptionType) { if (!isStreaming) { - if (passphrase.length < 10 || passphrase.length > 79) { + if (passphrase.length !in 10..79) { throw IllegalArgumentException("passphrase must between 10 and 79 length") } commandsManager.setPassphrase(passphrase, type) @@ -213,6 +213,15 @@ class SrtClient(private val connectChecker: ConnectChecker) { val port = urlParser.port ?: 8888 val path = urlParser.getQuery("streamid") ?: urlParser.getFullPath() latency = urlParser.getQuery("latency")?.toIntOrNull() ?: latency + val passphrase = urlParser.getQuery("passphrase") ?: "" + if (passphrase.isNotEmpty() && passphrase.length in 10..79) { + val encryptionType = when (urlParser.getQuery("pbkeylen")?.toIntOrNull()) { + 192 -> EncryptionType.AES192 + 256 -> EncryptionType.AES256 + else -> EncryptionType.AES128 + } + commandsManager.setPassphrase(passphrase, encryptionType) + } if (path.isEmpty()) { isStreaming = false onMainThread { @@ -232,7 +241,7 @@ class SrtClient(private val connectChecker: ConnectChecker) { commandsManager.writeHandshake(socket, response.copy( encryption = commandsManager.getEncryptType(), - extensionField = ExtensionField.calculateValue(response.extensionField), + extensionField = ExtensionField.calculateValue(response.extensionField, commandsManager.encryptionEnabled()), handshakeType = HandshakeType.CONCLUSION, handshakeExtension = HandshakeExtension( flags = ExtensionContentFlag.TSBPDSND.value or ExtensionContentFlag.TSBPDRCV.value or diff --git a/srt/src/main/java/com/pedro/srt/srt/packets/control/handshake/ExtensionField.kt b/srt/src/main/java/com/pedro/srt/srt/packets/control/handshake/ExtensionField.kt index b2a08b809..79aecefde 100644 --- a/srt/src/main/java/com/pedro/srt/srt/packets/control/handshake/ExtensionField.kt +++ b/srt/src/main/java/com/pedro/srt/srt/packets/control/handshake/ExtensionField.kt @@ -27,9 +27,9 @@ enum class ExtensionField(val value: Int) { companion object { infix fun from(value: Int): ExtensionField = entries.firstOrNull { it.value == value } ?: throw IOException("unknown extension field: $value") - fun calculateValue(value: Int): Int { + fun calculateValue(value: Int, encrypted: Boolean): Int { val hsV5enabled = value and HS_V5_FLAG.value != 0 - val extensionField = HS_REQ.value or CONFIG.value + val extensionField = if (encrypted) HS_REQ.value or KM_REQ.value or CONFIG.value else HS_REQ.value or CONFIG.value return if (hsV5enabled) HS_V5_FLAG.value or KM_REQ.value or extensionField else extensionField } } diff --git a/srt/src/main/java/com/pedro/srt/utils/EncryptionUtil.kt b/srt/src/main/java/com/pedro/srt/utils/EncryptionUtil.kt index 0e2467e13..ed0a609ec 100644 --- a/srt/src/main/java/com/pedro/srt/utils/EncryptionUtil.kt +++ b/srt/src/main/java/com/pedro/srt/utils/EncryptionUtil.kt @@ -40,7 +40,7 @@ class EncryptionUtil(val type: EncryptionType, passphrase: String) { private val cipherType = CipherType.CTR private val salt: ByteArray private val keyLength: Int = when (type) { - EncryptionType.NONE -> 0 + EncryptionType.NONE, EncryptionType.AES128 -> 16 EncryptionType.AES192 -> 24 EncryptionType.AES256 -> 32 diff --git a/srt/src/test/java/com/pedro/srt/srt/control/HandshakeTest.kt b/srt/src/test/java/com/pedro/srt/srt/control/HandshakeTest.kt index b7d402736..26cf6b67c 100644 --- a/srt/src/test/java/com/pedro/srt/srt/control/HandshakeTest.kt +++ b/srt/src/test/java/com/pedro/srt/srt/control/HandshakeTest.kt @@ -34,8 +34,9 @@ class HandshakeTest { @Test fun `test extension field calculation`() { - assertEquals(131079, ExtensionField.calculateValue(150039)) - assertEquals(5, ExtensionField.calculateValue(18967)) + assertEquals(131079, ExtensionField.calculateValue(150039, false)) + assertEquals(5, ExtensionField.calculateValue(18967, false)) + assertEquals(7, ExtensionField.calculateValue(18967, true)) } @Test