Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,24 @@
import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
import net.schmizz.sshj.userauth.keyprovider.PKCS8KeyFile;
import net.schmizz.sshj.userauth.method.AuthPublickey;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.ASN1Sequence;
import org.bouncycastle.asn1.DERNull;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.pkcs.RSAPrivateKey;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.jce.provider.BouncyCastleProvider;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.StringReader;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.charset.StandardCharsets;
import java.security.Security;
import java.util.Base64;

import static com.appsmith.external.constants.ConnectionMethod.CONNECTION_METHOD_SSH;
import static com.appsmith.external.constants.PluginConstants.HostName.LOCALHOST;
Expand Down Expand Up @@ -76,7 +82,6 @@ public static SSHTunnelContext createSSHTunnel(
client.addHostKeyVerifier(new PromiscuousVerifier());

client.connect(sshHost, sshPort);
Reader targetReader = new InputStreamReader(new ByteArrayInputStream(key.getDecodedContent()));
String keyContent;
KeyProvider keyFile = null;
try (Reader reader = new StringReader(new String(key.getDecodedContent(), StandardCharsets.UTF_8));
Expand All @@ -102,11 +107,17 @@ public static SSHTunnelContext createSSHTunnel(
OpenSSHKeyFile openSSHKeyFile = new OpenSSHKeyFile();
openSSHKeyFile.init(new StringReader(keyContent));
keyFile = openSSHKeyFile;
} else if (keyContent.contains(PKCS_8_PEM_HEADER) || keyContent.contains(PKCS_1_PEM_HEADER)) {
// Handle PEM (PKCS#8) and RSA PEM formats
} else if (keyContent.contains(PKCS_8_PEM_HEADER)) {
// Handle PEM (PKCS#8) format
PKCS8KeyFile pkcs8KeyFile = new PKCS8KeyFile();
pkcs8KeyFile.init(new StringReader(keyContent));
keyFile = pkcs8KeyFile;
} else if (keyContent.contains(PKCS_1_PEM_HEADER)) {
// Handle traditional RSA (PKCS#1) format by converting it to PKCS#8
String pkcs8FormattedKey = convertRsaPkcs1ToPkcs8(keyContent);
PKCS8KeyFile pkcs8KeyFile = new PKCS8KeyFile();
pkcs8KeyFile.init(new StringReader(pkcs8FormattedKey));
keyFile = pkcs8KeyFile;
} else {
throw new AppsmithPluginException(
AppsmithPluginError.PLUGIN_DATASOURCE_ARGUMENT_ERROR, INVALID_SSH_KEY_FORMAT_ERROR_MSG);
Expand Down Expand Up @@ -232,4 +243,39 @@ public static Boolean isSSHTunnelConnected(SSHTunnelContext sshTunnelContext) {
SSHClient sshClient = sshTunnelContext.getSshClient();
return sshClient != null && sshClient.isConnected() && sshClient.isAuthenticated();
}

static String convertRsaPkcs1ToPkcs8(String keyContent) throws IOException {
String sanitizedContent = keyContent
.replace("-----BEGIN RSA PRIVATE KEY-----", "")
.replace("-----END RSA PRIVATE KEY-----", "")
.replaceAll("\\s", "");

if (sanitizedContent.isEmpty()) {
throw new IOException("Empty RSA key content");
}

byte[] pkcs1Bytes;
try {
pkcs1Bytes = Base64.getDecoder().decode(sanitizedContent);
} catch (IllegalArgumentException e) {
throw new IOException("Invalid Base64 encoding for RSA key", e);
}

ASN1Primitive asn1Primitive = ASN1Primitive.fromByteArray(pkcs1Bytes);
if (!(asn1Primitive instanceof ASN1Sequence)) {
throw new IOException("Invalid RSA key structure");
}
ASN1Sequence asn1Sequence = (ASN1Sequence) asn1Primitive;
RSAPrivateKey rsaPrivateKey = RSAPrivateKey.getInstance(asn1Sequence);

PrivateKeyInfo privateKeyInfo = new PrivateKeyInfo(
new AlgorithmIdentifier(PKCSObjectIdentifiers.rsaEncryption, DERNull.INSTANCE), rsaPrivateKey);

return toPemPrivateKey(privateKeyInfo.getEncoded());
}

private static String toPemPrivateKey(byte[] pkcs8Bytes) {
String base64Encoded = Base64.getMimeEncoder(64, new byte[] {'\n'}).encodeToString(pkcs8Bytes);
return "-----BEGIN PRIVATE KEY-----\n" + base64Encoded + "\n-----END PRIVATE KEY-----\n";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
import com.appsmith.external.models.Property;
import com.appsmith.external.models.SSHConnection;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
import net.schmizz.sshj.userauth.keyprovider.PKCS8KeyFile;
import org.bouncycastle.asn1.pkcs.RSAPrivateKey;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.Security;
import java.security.interfaces.RSAPrivateCrtKey;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;

import static com.appsmith.external.helpers.SSHUtils.getConnectionContext;
Expand All @@ -31,28 +36,15 @@ public class SSHUtilsTest {

@BeforeAll
static void setup() {
Security.addProvider(new BouncyCastleProvider()); // Ensure BouncyCastle is available for OpenSSH keys
}

/* Test OpenSSH Key Parsing */
@Test
public void testOpenSSHKeyParsing() throws Exception {
String opensshKey = "-----BEGIN OPENSSH PRIVATE KEY-----\n"
+ "b3BlbnNzaC1rZXktdmVyc2lvbjE=\n"
+ "-----END OPENSSH PRIVATE KEY-----";

Reader reader = new StringReader(opensshKey);
OpenSSHKeyFile openSSHKeyFile = new OpenSSHKeyFile();
openSSHKeyFile.init(reader);

assertNotNull(openSSHKeyFile);
Security.addProvider(
new BouncyCastleProvider()); // Ensure BouncyCastle algorithms are registered for key parsing
}

/* Test PKCS#8 PEM Key Parsing */
@Test
public void testPKCS8PEMKeyParsing() throws Exception {
String pkcs8Key =
"-----BEGIN PRIVATE KEY-----\n" + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASC...\n" + "-----END PRIVATE KEY-----";
KeyPair keyPair = generateRsaKeyPair();
String pkcs8Key = toPkcs8Pem(keyPair);

Reader reader = new StringReader(pkcs8Key);
PKCS8KeyFile pkcs8KeyFile = new PKCS8KeyFile();
Expand All @@ -64,14 +56,17 @@ public void testPKCS8PEMKeyParsing() throws Exception {
/* Test RSA PEM Key Parsing */
@Test
public void testRSAPEMKeyParsing() throws Exception {
String rsaKey =
"-----BEGIN RSA PRIVATE KEY-----\n" + "MIIEowIBAAKCAQEA7...\n" + "-----END RSA PRIVATE KEY-----";
KeyPair keyPair = generateRsaKeyPair();
String rsaPkcs1 = toPkcs1Pem((RSAPrivateCrtKey) keyPair.getPrivate());

String convertedKey = SSHUtils.convertRsaPkcs1ToPkcs8(rsaPkcs1);

Reader reader = new StringReader(rsaKey);
Reader reader = new StringReader(convertedKey);
PKCS8KeyFile pkcs8KeyFile = new PKCS8KeyFile();
pkcs8KeyFile.init(reader);

assertNotNull(pkcs8KeyFile);
assertTrue(convertedKey.contains("BEGIN PRIVATE KEY"));
}

/* Test is ssh enabled method */
Expand Down Expand Up @@ -168,4 +163,34 @@ public void testDefaultDBPortValue() {

assertEquals(getDBPortFromConfigOrDefault(datasourceConfiguration, 1234L), 1234L);
}

private KeyPair generateRsaKeyPair() throws Exception {
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA");
generator.initialize(1024);
return generator.generateKeyPair();
}

private String toPkcs8Pem(KeyPair keyPair) {
byte[] pkcs8Bytes = keyPair.getPrivate().getEncoded();
return formatPem("PRIVATE KEY", pkcs8Bytes);
}

private String toPkcs1Pem(RSAPrivateCrtKey privateKey) throws IOException {
RSAPrivateKey bcPrivateKey = new RSAPrivateKey(
privateKey.getModulus(),
privateKey.getPublicExponent(),
privateKey.getPrivateExponent(),
privateKey.getPrimeP(),
privateKey.getPrimeQ(),
privateKey.getPrimeExponentP(),
privateKey.getPrimeExponentQ(),
privateKey.getCrtCoefficient());

return formatPem("RSA PRIVATE KEY", bcPrivateKey.getEncoded());
}

private String formatPem(String header, byte[] encodedBytes) {
String base64 = Base64.getMimeEncoder(64, new byte[] {'\n'}).encodeToString(encodedBytes);
return "-----BEGIN " + header + "-----\n" + base64 + "\n-----END " + header + "-----\n";
}
}