Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(uat): resolve issue when key file in PCKS#8 format can't be read #384

Merged
merged 3 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ private MqttConnectOptions convertParams(MqttLib.ConnectionParams connectionPara
connectionParams.getCert() != null);
connectionOptions.setServerURIs(new String[]{uri});

if (connectionParams.getKey() != null) {
SSLSocketFactory sslSocketFactory = SslUtil.getSocketFactory(connectionParams);
if (connectionParams.getCert() != null) {
SSLSocketFactory sslSocketFactory = SslUtil.getSocketFactory(
connectionParams.getCa(), connectionParams.getCert(), connectionParams.getKey());
connectionOptions.setSocketFactory(sslSocketFactory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class ConnectionParams {
/** Clean session (clean start) flag of CONNECT packet. */
private boolean cleanSession;

/** Content of CA, optional. */
private String ca;
/** List of CA, optional. */
private List<String> ca;

/** Content of MQTT client's certificate, optional. */
private String cert;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ public void createMqttConnection(MqttConnectRequest request,
return;
}

final String ca = String.join("\n", caList);
connectionParamsBuilder.ca(ca).cert(cert).key(key);
connectionParamsBuilder.ca(caList).cert(cert).key(key);
}

logger.atInfo().log("createMqttConnection: clientId {} broker {}:{}", clientId, host, port);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,9 @@ private MqttConnectionOptions convertConnectParams(MqttLib.ConnectionParams conn
connectionOptions.setServerURIs(new String[]{uri});
connectionOptions.setConnectionTimeout(connectionParams.getConnectionTimeout());

if (connectionParams.getKey() != null) {
SSLSocketFactory sslSocketFactory = SslUtil.getSocketFactory(connectionParams);
if (connectionParams.getCert() != null) {
SSLSocketFactory sslSocketFactory = SslUtil.getSocketFactory(
connectionParams.getCa(), connectionParams.getCert(), connectionParams.getKey());
connectionOptions.setSocketFactory(sslSocketFactory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,105 +5,132 @@

package com.aws.greengrass.testing.util;

import com.aws.greengrass.testing.mqtt5.client.MqttLib;
import lombok.experimental.UtilityClass;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.Security;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.cert.Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.List;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManagerFactory;

@UtilityClass
public final class SslUtil {
private SslUtil() {
}
private static final char[] PASSWORD = "".toCharArray();

/**
* generate SSL socket factory.
*
* @param connectionParams MQTT connection parameters
* @throws IOException on errors
* @throws GeneralSecurityException on errors
*/
public static SSLSocketFactory getSocketFactory(MqttLib.ConnectionParams connectionParams)
throws IOException, GeneralSecurityException {
return getSocketFactory(connectionParams.getCa(), connectionParams.getCert(), connectionParams.getKey());
// PKCS#8 format
private static final String PEM_PRIVATE_START = "-----BEGIN PRIVATE KEY-----";
private static final String PEM_PRIVATE_END = "-----END PRIVATE KEY-----";

// PKCS#1 format
private static final String PEM_RSA_PRIVATE_START = "-----BEGIN RSA PRIVATE KEY-----";
private static final String PEM_RSA_PRIVATE_END = "-----END RSA PRIVATE KEY-----";


static {
Security.addProvider(new BouncyCastleProvider());
}

/**
* generate SSL socket factory.
* Generates SSL socket factory.
*
* @param caCrtFile certificate authority
* @param crtFile certification
* @param keyFile private key
* @param caList the list of certificate authority
* @param crt the certificate
* @param key the private key
* @return instance of ssl socket factory with attached client credential and CA list
* @throws IOException on errors
* @throws GeneralSecurityException on errors
*/
public static SSLSocketFactory getSocketFactory(final String caCrtFile, final String crtFile, final String keyFile)
public static SSLSocketFactory getSocketFactory(final List<String> caList, final String crt, final String key)
throws IOException, GeneralSecurityException {
Security.addProvider(new BouncyCastleProvider());

// load CA certificate
X509Certificate caCert = null;

InputStream bis = new ByteArrayInputStream(caCrtFile.getBytes());
CertificateFactory cf = CertificateFactory.getInstance("X.509");
TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance("X509");

while (bis.available() > 0) {
caCert = (X509Certificate) cf.generateCertificate(bis);
}
// CA certificates are used to authenticate server
KeyStore caKeyStore = KeyStore.getInstance(KeyStore.getDefaultType());
caKeyStore.load(null, null);

// load client certificate
bis = new ByteArrayInputStream(crtFile.getBytes());
X509Certificate cert = null;
while (bis.available() > 0) {
cert = (X509Certificate) cf.generateCertificate(bis);
// load CA certificates
int certNo = 0;
for (String ca : caList) {
Certificate caCertificate = getCerificate(ca);
caKeyStore.setCertificateEntry(String.format("ca-certificate-%d", ++certNo), caCertificate);
}

// CA certificate is used to authenticate server
KeyStore caKs = KeyStore.getInstance(KeyStore.getDefaultType());
caKs.load(null, null);
caKs.setCertificateEntry("ca-certificate", caCert);
TrustManagerFactory tmf = TrustManagerFactory.getInstance("X509");
tmf.init(caKs);
trustManagerFactory.init(caKeyStore);

// client key and certificates are sent to server so it can authenticate
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
ks.load(null, null);
ks.setCertificateEntry("certificate", cert);
KeyStore credKeyStore = KeyStore.getInstance(KeyStore.getDefaultType());
credKeyStore.load(null, null);
Certificate certificate = getCerificate(crt);
credKeyStore.setCertificateEntry("certificate", certificate);

// load client private key
Object object;
KeyPair key;
try (PEMParser pemParser = new PEMParser(new InputStreamReader(new ByteArrayInputStream(keyFile.getBytes())))) {
object = pemParser.readObject();
JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider("BC");
key = converter.getKeyPair((PEMKeyPair) object);
PrivateKey privateKey = loadPrivateKey(key);
credKeyStore.setKeyEntry("private-key", privateKey, PASSWORD, new Certificate[]{certificate});

}
ks.setKeyEntry("private-key", key.getPrivate(), "".toCharArray(),
new java.security.cert.Certificate[]{cert});
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory
.getDefaultAlgorithm());
kmf.init(ks, "".toCharArray());
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(credKeyStore, PASSWORD);

// finally, create SSL socket factory
SSLContext context = SSLContext.getInstance("TLSv1.2");
context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
// FIXME: probably that force to use only TLS 1.2
// SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null);

return context.getSocketFactory();
return sslContext.getSocketFactory();
}

private static Certificate getCerificate(final String certificatePem)
throws IOException, GeneralSecurityException {

try (PEMParser pemParser = new PEMParser(
new InputStreamReader(
new ByteArrayInputStream(certificatePem.getBytes())))) {
X509CertificateHolder certHolder = (X509CertificateHolder) pemParser.readObject();
JcaX509CertificateConverter certificateConverter = new JcaX509CertificateConverter().setProvider("BC");
return certificateConverter.getCertificate(certHolder);
}
}

private static PrivateKey loadPrivateKey(String privateKeyPem) throws GeneralSecurityException, IOException {
if (privateKeyPem.contains(PEM_PRIVATE_START)) {
// PKCS#8 format
privateKeyPem = privateKeyPem.replace(PEM_PRIVATE_START, "").replace(PEM_PRIVATE_END, "");
privateKeyPem = privateKeyPem.replaceAll("\\s", "");

byte[] pkcs8EncodedKey = Base64.getDecoder().decode(privateKeyPem);

KeyFactory factory = KeyFactory.getInstance("RSA");
return factory.generatePrivate(new PKCS8EncodedKeySpec(pkcs8EncodedKey));

} else if (privateKeyPem.contains(PEM_RSA_PRIVATE_START)) {
// PKCS#1 format
try (PEMParser pemParser = new PEMParser(
new InputStreamReader(
new ByteArrayInputStream(privateKeyPem.getBytes())))) {
Object object = pemParser.readObject();
JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider("BC");
return converter.getKeyPair((PEMKeyPair) object).getPrivate();
}
}

throw new GeneralSecurityException("Not supported format of a private key");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ConnectionParams {
/** Clean session (clean start) flag of CONNECT packet. */
private boolean cleanSession;

/** Content of CA, optional. */
/** Content of CA list joined by \n, optional. */
private String ca;

/** Content of MQTT client's certificate, optional. */
Expand Down
Loading