diff --git a/README.md b/README.md index 2c1232e..91f14ae 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,7 @@ JavaReceiverInputDStream receiverStream = RabbitMQUtils.createJavaStream[R](java |---------------------------|------------------------------|--------------------------------------| | hosts | RabbitMQ hosts | Yes (default: localhost) | | virtualHosts | RabbitMQ virtual Host | Yes | +| sslProtocol | SSL Protocol | Yes (default: No SSL connection) | | queueName | Queue name | Yes | | exchangeName | Exchange name | Yes | | exchangeType | Exchange type | Yes | diff --git a/src/main/scala/org/apache/spark/streaming/rabbitmq/ConfigParameters.scala b/src/main/scala/org/apache/spark/streaming/rabbitmq/ConfigParameters.scala index 764c91b..37b3cfa 100644 --- a/src/main/scala/org/apache/spark/streaming/rabbitmq/ConfigParameters.scala +++ b/src/main/scala/org/apache/spark/streaming/rabbitmq/ConfigParameters.scala @@ -33,7 +33,8 @@ object ConfigParameters { val VirtualHostKey = "virtualHost" val UserNameKey = "userName" val PasswordKey = "password" - val ConnectionKeys = List(HostsKey, VirtualHostKey, UserNameKey, PasswordKey) + val SslProtocolKey = "sslProtocol" + val ConnectionKeys = List(HostsKey, VirtualHostKey, UserNameKey, PasswordKey, SslProtocolKey) /** * Queue Connection properties @@ -59,6 +60,7 @@ object ConfigParameters { val AutoAckType = "auto" val DefaultHost = "localhost" val DefaultPrefetchCount = 1 + val DefaultSslProtocol = null /** * Message Consumed properties diff --git a/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/Consumer.scala b/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/Consumer.scala index c464386..c30c056 100644 --- a/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/Consumer.scala +++ b/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/Consumer.scala @@ -247,6 +247,9 @@ object Consumer extends Logging with ConsumerParamsUtils { private def getChannel(params: Map[String, String]): Try[Channel] = { val addresses = getAddresses(params) + if (useSslConnection(params)) { + factory.useSslProtocol(getSslProtocol(params)) + } val addressesKey = addresses.mkString(",") val connection = connections.getOrElse(addressesKey, addConnection(addressesKey, addresses)) diff --git a/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/ConsumerParamsUtils.scala b/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/ConsumerParamsUtils.scala index 124b8d3..f2d76e3 100644 --- a/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/ConsumerParamsUtils.scala +++ b/src/main/scala/org/apache/spark/streaming/rabbitmq/consumer/ConsumerParamsUtils.scala @@ -73,6 +73,14 @@ trait ConsumerParamsUtils { ExchangeAndRouting(exchangeName, exchangeType, routingKeys) } + def useSslConnection(params: Map[String, String]): Boolean = { + params.get(SslProtocolKey).orNull != null + } + + def getSslProtocol(params: Map[String, String]): String = { + params.getOrElse(SslProtocolKey, DefaultSslProtocol) + } + /** * Queue Properties */ diff --git a/src/test/scala/org/apache/spark/streaming/rabbitmq/RabbitMQConsumerIT.scala b/src/test/scala/org/apache/spark/streaming/rabbitmq/RabbitMQConsumerIT.scala index 58ae168..dfa7e5d 100644 --- a/src/test/scala/org/apache/spark/streaming/rabbitmq/RabbitMQConsumerIT.scala +++ b/src/test/scala/org/apache/spark/streaming/rabbitmq/RabbitMQConsumerIT.scala @@ -27,16 +27,30 @@ class RabbitMQConsumerIT extends TemporalDataSuite { override val exchangeName = s"$configExchangeName-${this.getClass().getName()}-${UUID.randomUUID().toString}" - test("RabbitMQ Receiver should read all the records") { + test("RabbitMQ Receiver should read all the records without SSL") { + testReadRecords(5672, null) + } + + test("RabbitMQ Receiver should read all the records with SSL") { + testReadRecords(5671, "tlsv1.2") + } + + private def hostsWithPort(port: Int): String = hosts + .split(",") + .map(h => if(h.contains(":")) h.split(":")(0) else h) + .map(_ + ":" + port) + .mkString(",") + private def testReadRecords(port: Int, ssl: String): Unit = { val receiverStream = RabbitMQUtils.createStream(ssc, Map( - "hosts" -> hosts, + "hosts" -> hostsWithPort(port), "queueName" -> queueName, "exchangeName" -> exchangeName, "exchangeType" -> exchangeType, "vHost" -> vHost, "userName" -> userName, - "password" -> password + "password" -> password, + "sslProtocol" -> ssl )) val totalEvents = ssc.sparkContext.longAccumulator("My Accumulator")