diff --git a/README.md b/README.md
index 55ec8e7..a60f5fb 100644
--- a/README.md
+++ b/README.md
@@ -219,12 +219,21 @@ A possible solution to remove duplicates when reading the written data could be
- `admin.url` (Deprecated) |
+ `admin.url` |
A service HTTP URL of your Pulsar cluster |
No |
None |
Streaming and Batch |
- The Pulsar `serviceHttpUrl` configuration. |
+ The Pulsar `serviceHttpUrl` configuration. Only needed when `maxBytesPerTrigger` is specified |
+
+
+
+ `maxBytesPerTrigger` |
+ A long value in unit of number of bytes |
+ No |
+ None |
+ Streaming and Batch |
+ A soft limit of the maximum number of bytes we want to process per microbatch. If this is specified, `admin.url` also needs to be specified. |
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala
index 275c4d6..fe8bdc5 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala
@@ -22,16 +22,22 @@ import scala.collection.mutable
import scala.language.postfixOps
import scala.util.control.NonFatal
+import org.apache.pulsar.client.admin.PulsarAdmin
import org.apache.pulsar.client.api.{MessageId, PulsarClient}
import org.apache.pulsar.client.impl.{MessageIdImpl, PulsarClientImpl}
import org.apache.pulsar.client.impl.schema.BytesSchema
+import org.apache.pulsar.client.internal.DefaultImplementation
import org.apache.pulsar.common.api.proto.CommandGetTopicsOfNamespace
import org.apache.pulsar.common.naming.TopicName
import org.apache.pulsar.common.schema.SchemaInfo
import org.apache.pulsar.shade.com.google.common.util.concurrent.Uninterruptibles
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit}
import org.apache.spark.sql.pulsar.PulsarOptions._
+import org.apache.spark.sql.pulsar.PulsarSourceUtils.{getEntryId, getLedgerId}
+import org.apache.spark.sql.pulsar.SpecificPulsarOffset.getTopicOffsets
import org.apache.spark.sql.types.StructType
/**
@@ -40,6 +46,7 @@ import org.apache.spark.sql.types.StructType
*/
private[pulsar] case class PulsarHelper(
serviceUrl: String,
+ adminUrl: Option[String],
clientConf: ju.Map[String, Object],
driverGroupIdPrefix: String,
caseInsensitiveParameters: Map[String, String],
@@ -55,6 +62,12 @@ private[pulsar] case class PulsarHelper(
private var topics: Seq[String] = _
private var topicPartitions: Seq[String] = _
+ // We can call adminUrl.get because admissionControlHelper
+ // will only be called if latestOffset is called and there should
+ // be an exception thrown in PulsarProvider if maxBytes is set,
+ // and adminUrl is not set
+ private lazy val admissionControlHelper = new PulsarAdmissionControlHelper(adminUrl.get)
+
override def close(): Unit = {
// do nothing
}
@@ -122,7 +135,9 @@ private[pulsar] case class PulsarHelper(
offset.foreach { case (tp, mid) =>
try {
val (subscription, _) = extractSubscription(predefinedSubscription, tp)
- CachedConsumer.getOrCreate(tp, subscription, client).seek(mid)
+ val consumer = CachedConsumer.getOrCreate(tp, subscription, client)
+ if (!consumer.isConnected) consumer.getLastMessageId
+ consumer.seek(mid)
} catch {
case e: Throwable =>
throw new RuntimeException(
@@ -207,6 +222,35 @@ private[pulsar] case class PulsarHelper(
}.toMap)
}
+ def latestOffsets(startingOffset: streaming.Offset,
+ totalReadLimit: Long): SpecificPulsarOffset = {
+ // implement helper inside PulsarHelper in order to use getTopicPartitions
+ val topicPartitions = getTopicPartitions
+ // add new partitions from PulsarAdmin, set to earliest entry and ledger id based on limit
+ // start a reader, get to the earliest offset for new topic partitions
+ val existingStartOffsets = if (startingOffset != null) {
+ getTopicOffsets(startingOffset.asInstanceOf[org.apache.spark.sql.execution.streaming.Offset])
+ } else {
+ Map[String, MessageId]()
+ }
+ val newTopics = topicPartitions.toSet.diff(existingStartOffsets.keySet)
+ val startPartitionOffsets = existingStartOffsets ++ newTopics.map(topicPartition
+ => {
+ topicPartition -> MessageId.earliest
+ })
+ val offsets = mutable.Map[String, MessageId]()
+ val numPartitions = startPartitionOffsets.size
+ // move all topic partition logic to helper function
+ val readLimit = totalReadLimit / numPartitions
+ startPartitionOffsets.keys.foreach { topicPartition =>
+ val startMessageId = startPartitionOffsets.apply(topicPartition)
+ offsets += (topicPartition ->
+ admissionControlHelper.latestOffsetForTopicPartition(
+ topicPartition, startMessageId, readLimit))
+ }
+ SpecificPulsarOffset(offsets.toMap)
+ }
+
def fetchLatestOffsetForTopic(topic: String): MessageId = {
val messageId =
try {
@@ -472,3 +516,68 @@ private[pulsar] case class PulsarHelper(
CachedConsumer.getOrCreate(topic, subscriptionName, client).getLastMessageId
}
}
+
+class PulsarAdmissionControlHelper(adminUrl: String)
+ extends Logging {
+
+ private lazy val pulsarAdmin = PulsarAdmin.builder().serviceHttpUrl(adminUrl).build()
+
+ import scala.collection.JavaConverters._
+
+ def latestOffsetForTopicPartition(topicPartition: String,
+ startMessageId: MessageId,
+ readLimit: Long): MessageId = {
+ val startLedgerId = getLedgerId(startMessageId)
+ val startEntryId = getEntryId(startMessageId)
+ val stats = pulsarAdmin.topics.getInternalStats(topicPartition)
+ val ledgers = pulsarAdmin.topics.getInternalStats(topicPartition).ledgers.
+ asScala.filter(_.ledgerId >= startLedgerId).sortBy(_.ledgerId)
+ // The last ledger of the ledgers list doesn't have .size or .entries
+ // properly populated, and the corresponding info is in currentLedgerSize
+ // and currentLedgerEntries
+ if (ledgers.nonEmpty) {
+ ledgers.last.size = stats.currentLedgerSize
+ ledgers.last.entries = stats.currentLedgerEntries
+ }
+ val partitionIndex = if (topicPartition.contains(PartitionSuffix)) {
+ topicPartition.split(PartitionSuffix)(1).toInt
+ } else {
+ -1
+ }
+ var messageId = startMessageId
+ var readLimitLeft = readLimit
+ ledgers.filter(_.entries != 0).sortBy(_.ledgerId).foreach { ledger =>
+ assert(readLimitLeft >= 0)
+ if (readLimitLeft == 0) {
+ return messageId
+ }
+ val avgBytesPerEntries = ledger.size / ledger.entries
+ // approximation of bytes left in ledger to deal with case
+ // where we are at the middle of the ledger
+ val bytesLeftInLedger = if (ledger.ledgerId == startLedgerId) {
+ avgBytesPerEntries * (ledger.entries - startEntryId - 1)
+ } else {
+ ledger.size
+ }
+ if (readLimitLeft > bytesLeftInLedger) {
+ readLimitLeft -= bytesLeftInLedger
+ messageId = DefaultImplementation
+ .getDefaultImplementation
+ .newMessageId(ledger.ledgerId, ledger.entries - 1, partitionIndex)
+ } else {
+ val numEntriesToRead = Math.max(1, readLimitLeft / avgBytesPerEntries)
+ val lastEntryId = if (ledger.ledgerId != startLedgerId) {
+ numEntriesToRead - 1
+ } else {
+ startEntryId + numEntriesToRead
+ }
+ val lastEntryRead = Math.min(ledger.entries - 1, lastEntryId)
+ messageId = DefaultImplementation
+ .getDefaultImplementation
+ .newMessageId(ledger.ledgerId, lastEntryRead, partitionIndex)
+ readLimitLeft = 0
+ }
+ }
+ messageId
+ }
+}
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
index d0f5122..d9ec02f 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
@@ -36,6 +36,7 @@ private[pulsar] object PulsarOptions {
val TopicOptionKeys: Set[String] = Set(TopicSingle, TopicMulti, TopicPattern)
val ServiceUrlOptionKey: String = "service.url"
+ val AdminUrlOptionKey: String = "admin.url"
val StartingOffsetsOptionKey: String = "startingOffsets".toLowerCase(Locale.ROOT)
val StartingTime: String = "startingTime".toLowerCase(Locale.ROOT)
val EndingTime: String = "endingTime".toLowerCase(Locale.ROOT)
@@ -45,6 +46,7 @@ private[pulsar] object PulsarOptions {
val SubscriptionPrefix: String = "subscriptionPrefix".toLowerCase(Locale.ROOT)
val PredefinedSubscription: String = "predefinedSubscription".toLowerCase(Locale.ROOT)
+ val MaxBytesPerTrigger: String = "maxBytesPerTrigger".toLowerCase(Locale.ROOT)
val PollTimeoutMS: String = "pollTimeoutMs".toLowerCase(Locale.ROOT)
val FailOnDataLossOptionKey: String = "failOnDataLoss".toLowerCase(Locale.ROOT)
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
index a0fa502..9a69938 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
@@ -56,12 +56,13 @@ private[pulsar] class PulsarProvider
parameters: Map[String, String]): (String, StructType) = {
val caseInsensitiveParams = validateStreamOptions(parameters)
- val (clientConfig, _, serviceUrlConfig) = prepareConfForReader(parameters)
+ val (clientConfig, _, serviceUrlConfig, adminUrl) = prepareConfForReader(parameters)
val subscriptionNamePrefix = s"spark-pulsar-${UUID.randomUUID}"
val inferredSchema = Utils.tryWithResource(
PulsarHelper(
serviceUrlConfig,
+ adminUrl,
clientConfig,
subscriptionNamePrefix,
caseInsensitiveParams,
@@ -84,13 +85,14 @@ private[pulsar] class PulsarProvider
logDebug(s"Creating Pulsar source: $parameters")
val caseInsensitiveParams = validateStreamOptions(parameters)
- val (clientConfig, readerConfig, serviceUrl) = prepareConfForReader(parameters)
+ val (clientConfig, readerConfig, serviceUrl, adminUrl) = prepareConfForReader(parameters)
logDebug(
s"Client config: $clientConfig; Reader config: $readerConfig; Service URL: $serviceUrl")
val subscriptionNamePrefix = getSubscriptionPrefix(parameters)
val pulsarHelper = PulsarHelper(
serviceUrl,
+ adminUrl,
clientConfig,
subscriptionNamePrefix,
caseInsensitiveParams,
@@ -105,6 +107,12 @@ private[pulsar] class PulsarProvider
pulsarHelper.offsetForEachTopic(caseInsensitiveParams, LatestOffset, StartOptionKey)
pulsarHelper.setupCursor(offset)
+ val maxBytes = maxBytesPerTrigger(caseInsensitiveParams)
+ if (adminUrl.isEmpty && maxBytes != 0L) {
+ throw new IllegalArgumentException("admin.url " +
+ "must be specified if maxBytesPerTrigger is specified")
+ }
+
new PulsarSource(
sqlContext,
pulsarHelper,
@@ -113,6 +121,7 @@ private[pulsar] class PulsarProvider
metadataPath,
offset,
pollTimeoutMs(caseInsensitiveParams),
+ maxBytesPerTrigger(caseInsensitiveParams),
failOnDataLoss(caseInsensitiveParams),
subscriptionNamePrefix,
jsonOptions)
@@ -125,10 +134,11 @@ private[pulsar] class PulsarProvider
val subscriptionNamePrefix = getSubscriptionPrefix(parameters, isBatch = true)
- val (clientConfig, readerConfig, serviceUrl) = prepareConfForReader(parameters)
+ val (clientConfig, readerConfig, serviceUrl, adminUrl) = prepareConfForReader(parameters)
val (start, end, schema, pSchema) = Utils.tryWithResource(
PulsarHelper(
serviceUrl,
+ adminUrl,
clientConfig,
subscriptionNamePrefix,
caseInsensitiveParams,
@@ -366,6 +376,10 @@ private[pulsar] object PulsarProvider extends Logging {
parameters(ServiceUrlOptionKey)
}
+ private def getAdminUrl(parameters: Map[String, String]): Option[String] = {
+ parameters.get(AdminUrlOptionKey)
+ }
+
private def getAllowDifferentTopicSchemas(parameters: Map[String, String]): Boolean = {
parameters.getOrElse(AllowDifferentTopicSchemas, "false").toBoolean
}
@@ -380,6 +394,13 @@ private[pulsar] object PulsarProvider extends Logging {
(SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000).toString)
.toInt
+ private def maxBytesPerTrigger(caseInsensitiveParams: Map[String, String]): Long =
+ caseInsensitiveParams
+ .getOrElse(
+ PulsarOptions.MaxBytesPerTrigger,
+ 0L.toString
+ ).toLong
+
private def validateGeneralOptions(
caseInsensitiveParams: Map[String, String]): Map[String, String] = {
if (!caseInsensitiveParams.contains(ServiceUrlOptionKey)) {
@@ -486,9 +507,10 @@ private[pulsar] object PulsarProvider extends Logging {
}
private def prepareConfForReader(parameters: Map[String, String])
- : (ju.Map[String, Object], ju.Map[String, Object], String) = {
+ : (ju.Map[String, Object], ju.Map[String, Object], String, Option[String]) = {
val serviceUrl = getServiceUrl(parameters)
+ val adminUrl = getAdminUrl(parameters)
var clientParams = getClientParams(parameters)
clientParams += (ServiceUrlOptionKey -> serviceUrl)
val readerParams = getReaderParams(parameters)
@@ -496,7 +518,7 @@ private[pulsar] object PulsarProvider extends Logging {
(
paramsToPulsarConf("pulsar.client", clientParams),
paramsToPulsarConf("pulsar.reader", readerParams),
- serviceUrl)
+ serviceUrl, adminUrl)
}
private def prepareConfForProducer(parameters: Map[String, String])
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
index 851ddca..8405e65 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
@@ -13,19 +13,30 @@
*/
package org.apache.spark.sql.pulsar
+
import java.{util => ju}
+import scala.collection.JavaConverters.asScalaBufferConverter
+import scala.collection.mutable
+
+import org.apache.pulsar.client.admin.PulsarAdmin
import org.apache.pulsar.client.api.MessageId
import org.apache.pulsar.client.impl.MessageIdImpl
+import org.apache.pulsar.client.internal.DefaultImplementation
import org.apache.pulsar.common.schema.SchemaInfo
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.JSONOptionsInRead
-import org.apache.spark.sql.execution.streaming.{Offset, Source}
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit, ReadMaxFiles, SupportsAdmissionControl}
+import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset, Source}
+import org.apache.spark.sql.pulsar.PulsarOptions.ServiceUrlOptionKey
+import org.apache.spark.sql.pulsar.SpecificPulsarOffset.getTopicOffsets
import org.apache.spark.sql.types.StructType
+
private[pulsar] class PulsarSource(
sqlContext: SQLContext,
pulsarHelper: PulsarHelper,
@@ -34,11 +45,13 @@ private[pulsar] class PulsarSource(
metadataPath: String,
startingOffsets: PerTopicOffset,
pollTimeoutMs: Int,
+ maxBytesPerTrigger: Long,
failOnDataLoss: Boolean,
subscriptionNamePrefix: String,
jsonOptions: JSONOptionsInRead)
extends Source
- with Logging {
+ with Logging
+ with SupportsAdmissionControl {
import PulsarSourceUtils._
@@ -54,17 +67,37 @@ private[pulsar] class PulsarSource(
private var currentTopicOffsets: Option[Map[String, MessageId]] = None
+
private lazy val pulsarSchema: SchemaInfo = pulsarHelper.getPulsarSchema
override def schema(): StructType = SchemaUtils.pulsarSourceSchema(pulsarSchema)
override def getOffset: Option[Offset] = {
- // Make sure initialTopicOffsets is initialized
+ throw new UnsupportedOperationException(
+ "latestOffset(Offset, ReadLimit) should be called instead of this method")
+ }
+
+ override def latestOffset(startingOffset: streaming.Offset,
+ readLimit: ReadLimit): streaming.Offset = {
initialTopicOffsets
- val latest = pulsarHelper.fetchLatestOffsets()
- currentTopicOffsets = Some(latest.topicOffsets)
- logDebug(s"GetOffset: ${latest.topicOffsets.toSeq.map(_.toString).sorted}")
- Some(latest.asInstanceOf[Offset])
+ readLimit match {
+ case ReadMaxBytes(maxBytes) =>
+ startingOffset match {
+ // deals with the case where we add a topic-partition after
+ // the stream has started, since adding a new topic-partition
+ // sets startingOffset to null
+ case null => pulsarHelper.latestOffsets(initialTopicOffsets, maxBytes)
+ case startingOffset => pulsarHelper.latestOffsets(startingOffset, maxBytes)
+ }
+ case _: ReadAllAvailable => pulsarHelper.fetchLatestOffsets()
+ }
+ }
+ override def getDefaultReadLimit: ReadLimit = {
+ if (maxBytesPerTrigger == 0L) {
+ ReadLimit.allAvailable()
+ } else {
+ ReadMaxBytes(maxBytesPerTrigger)
+ }
}
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
@@ -169,3 +202,7 @@ private[pulsar] class PulsarSource(
}
}
+
+/** A read limit that admits a soft-max of `maxBytes` per micro-batch. */
+case class ReadMaxBytes(maxBytes: Long) extends ReadLimit
+
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
index ec86a48..990578d 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
@@ -120,6 +120,36 @@ private[pulsar] object PulsarSourceUtils extends Logging {
}
}
+ def getLedgerId(mid: MessageId): Long = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getLedgerId
+ case midi: MessageIdImpl => midi.getLedgerId
+ case t: TopicMessageIdImpl => getLedgerId(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getLedgerId
+ }
+ }
+
+ def getEntryId(mid: MessageId): Long = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getEntryId
+ case midi: MessageIdImpl => midi.getEntryId
+ case t: TopicMessageIdImpl => getEntryId(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getEntryId
+ }
+ }
+
+ def getPartitionIndex(mid: MessageId): Int = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getPartitionIndex
+ case midi: MessageIdImpl => midi.getPartitionIndex
+ case t: TopicMessageIdImpl => getPartitionIndex(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getPartitionIndex
+ }
+ }
+
def seekableLatestMid(mid: MessageId): MessageId = {
if (messageExists(mid)) mid else MessageId.earliest
}
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/PulsarAdmissionControlSuite.scala b/src/test/scala/org/apache/spark/sql/pulsar/PulsarAdmissionControlSuite.scala
new file mode 100644
index 0000000..6dc416f
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/PulsarAdmissionControlSuite.scala
@@ -0,0 +1,257 @@
+package org.apache.spark.sql.pulsar
+
+import org.apache.pulsar.client.admin.PulsarAdmin
+import org.apache.pulsar.client.api.MessageId
+import org.apache.pulsar.client.internal.DefaultImplementation
+import org.apache.spark.sql.pulsar.PulsarSourceUtils.{getEntryId, getLedgerId}
+import org.apache.spark.sql.streaming.Trigger.{Once, ProcessingTime}
+import org.apache.spark.util.Utils
+
+class PulsarAdmissionControlSuite extends PulsarSourceTest {
+
+ import PulsarOptions._
+ import testImplicits._
+
+ private val maxEntriesPerLedger = "managedLedgerMaxEntriesPerLedger"
+ private val ledgerRolloverTime = "managedLedgerMinLedgerRolloverTimeMinutes"
+ private val approxSizeOfInt = 50
+
+ override def beforeAll(): Unit = {
+ brokerConfigs.put(maxEntriesPerLedger, "3")
+ brokerConfigs.put(ledgerRolloverTime, "0")
+ super.beforeAll()
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ }
+
+ test("Only admit first entry of ledger") {
+ val topic = newTopic()
+ val messageIds = sendMessages(topic, Array("1", "2", "3"))
+ val firstMid = messageIds.head._2
+ val firstLedger = getLedgerId(firstMid)
+ val firstEntry = getEntryId(firstMid)
+ require(getLatestOffsets(Set(topic)).size === 1)
+ val admissionControlHelper = new PulsarAdmissionControlHelper(adminUrl)
+ val offset = admissionControlHelper.latestOffsetForTopicPartition(topic, MessageId.earliest, approxSizeOfInt)
+ assert(getLedgerId(offset) == firstLedger && getEntryId(offset) == firstEntry)
+ }
+
+ test("Admit entry in the middle of the ledger") {
+ val topic = newTopic()
+ val messageIds = sendMessages(topic, Array("1", "2", "3"))
+ val firstMid = messageIds.head._2
+ val secondMid = messageIds.apply(1)._2
+ require(getLatestOffsets(Set(topic)).size === 1)
+ val admissionControlHelper = new PulsarAdmissionControlHelper(adminUrl)
+ val offset = admissionControlHelper.latestOffsetForTopicPartition(topic, firstMid, approxSizeOfInt)
+ assert(getLedgerId(offset) == getLedgerId(secondMid) && getEntryId(offset) == getEntryId(secondMid))
+
+ }
+
+ test("Check last batch where message size is greater than maxBytesPerTrigger") {
+ val topic = newTopic()
+ sendMessages(topic, Array("-1"))
+ require(getLatestOffsets(Set(topic)).size === 1)
+
+ val pulsar = spark.readStream
+ .format("pulsar")
+ .option(TopicSingle, topic)
+ .option(ServiceUrlOptionKey, serviceUrl)
+ .option(AdminUrlOptionKey, adminUrl)
+ .option(FailOnDataLossOptionKey, "true")
+ .option(MaxBytesPerTrigger, approxSizeOfInt * 3)
+ .load()
+ .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+
+ val mapped = pulsar.map(kv => kv._2.toInt + 1)
+
+ testStream(mapped)(
+ StartStream(trigger = ProcessingTime(1000)),
+ makeSureGetOffsetCalled,
+ AddPulsarData(Set(topic), 1, 2, 3),
+ AddPulsarData(Set(topic), 4, 5, 6, 7, 8, 9),
+ AssertOnQuery { query =>
+ query.recentProgress.map(microBatch =>
+ microBatch.numInputRows <= 4
+ ).forall(_ == true)
+ }
+ )
+ }
+
+ test("Admission Control for multiple topics") {
+ val topic1 = newTopic()
+ val topic2 = newTopic()
+
+ val pulsar = spark.readStream
+ .format("pulsar")
+ .option(TopicMulti, s"$topic1,$topic2")
+ .option(ServiceUrlOptionKey, serviceUrl)
+ .option(AdminUrlOptionKey, adminUrl)
+ .option(FailOnDataLossOptionKey, "true")
+ .option(MaxBytesPerTrigger, approxSizeOfInt * 6)
+ .load()
+ .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+
+ val mapped = pulsar.map(kv => kv._2.toInt + 1)
+
+ testStream(mapped)(
+ StartStream(trigger = ProcessingTime(1000)),
+ makeSureGetOffsetCalled,
+ AddPulsarData(Set(topic1), 1, 2, 3),
+ AddPulsarData(Set(topic2), 4, 5, 6, 7, 8, 9),
+ AssertOnQuery { query =>
+ query.recentProgress.map(microBatch =>
+ microBatch.numInputRows <= 4
+ ).forall(_ == true)
+ }
+ )
+ }
+
+ test("Admission Control for concurrent topic writes") {
+ val topic1 = newTopic()
+ val topic2 = newTopic()
+
+ val pulsar = spark.readStream
+ .format("pulsar")
+ .option(TopicMulti, s"$topic1,$topic2")
+ .option(ServiceUrlOptionKey, serviceUrl)
+ .option(AdminUrlOptionKey, adminUrl)
+ .option(FailOnDataLossOptionKey, "true")
+ .option(MaxBytesPerTrigger, approxSizeOfInt * 6)
+ .load()
+ .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+
+ val mapped = pulsar.map(kv => kv._2.toInt + 1)
+
+ testStream(mapped)(
+ StartStream(trigger = ProcessingTime(1000)),
+ makeSureGetOffsetCalled,
+ AddPulsarData(Set(topic1, topic2), 1, 2, 3),
+ AddPulsarData(Set(topic1, topic2), 4, 5, 6, 7, 8, 9),
+ AssertOnQuery { query =>
+ query.recentProgress.map(microBatch =>
+ microBatch.numInputRows <= 4
+ ).forall(_ == true)
+ }
+ )
+ }
+
+ test("Admission Control with one topic-partition") {
+ val topic = newTopic()
+
+ Utils.tryWithResource(PulsarAdmin.builder().serviceHttpUrl(adminUrl).build()) { admin =>
+ admin.topics().createPartitionedTopic(topic, 1)
+ require(getLatestOffsets(Set(topic)).size === 1)
+ }
+
+ val reader = spark.readStream
+ .format("pulsar")
+ .option(ServiceUrlOptionKey, serviceUrl)
+ .option(AdminUrlOptionKey, adminUrl)
+ .option(FailOnDataLossOptionKey, "true")
+ .option(MaxBytesPerTrigger, approxSizeOfInt * 3)
+
+ val pulsar = reader
+ .option(TopicSingle, topic)
+ .load()
+ .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val mapped = pulsar.map(kv => kv._2.toInt)
+
+ testStream(mapped)(
+ StartStream(trigger = ProcessingTime(1000)),
+ makeSureGetOffsetCalled,
+ AddPulsarDataWithPartition(topic, Some(0), 1, 2, 3, 4),
+ AssertOnQuery { query =>
+ query.recentProgress.map(microBatch =>
+ microBatch.numInputRows <= 4
+ ).forall(_ == true)
+ }
+ )
+ }
+
+ test("Admission Control with multiple topic-partitions") {
+ val topic = newTopic()
+
+ Utils.tryWithResource(PulsarAdmin.builder().serviceHttpUrl(adminUrl).build()) { admin =>
+ admin.topics().createPartitionedTopic(topic, 2)
+ require(getLatestOffsets(Set(topic)).size === 2)
+ }
+
+ val reader = spark.readStream
+ .format("pulsar")
+ .option(ServiceUrlOptionKey, serviceUrl)
+ .option(AdminUrlOptionKey, adminUrl)
+ .option(FailOnDataLossOptionKey, "true")
+ .option(MaxBytesPerTrigger, approxSizeOfInt * 4)
+
+ val pulsar = reader
+ .option(TopicSingle, topic)
+ .load()
+ .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val mapped = pulsar.map(kv => kv._2.toInt)
+
+ testStream(mapped)(
+ StartStream(trigger = ProcessingTime(1000)),
+ makeSureGetOffsetCalled,
+ AddPulsarDataWithPartition(topic, Some(0), 1, 2, 3, 4),
+ AddPulsarDataWithPartition(topic, Some(1), 5, 6, 7, 8),
+ AssertOnQuery { query =>
+ query.recentProgress.map(microBatch =>
+ microBatch.numInputRows <= 3
+ ).forall(_ == true)
+ }
+ )
+ }
+
+ test("Add topic-partition after starting stream") {
+ val topic = newTopic()
+
+ Utils.tryWithResource(PulsarAdmin.builder().serviceHttpUrl(adminUrl).build()) { admin =>
+ admin.topics().createPartitionedTopic(topic, 1)
+ require(getLatestOffsets(Set(topic)).size === 1)
+ }
+
+ val reader = spark.readStream
+ .format("pulsar")
+ .option(ServiceUrlOptionKey, serviceUrl)
+ .option(AdminUrlOptionKey, adminUrl)
+ .option(FailOnDataLossOptionKey, "true")
+ .option(MaxBytesPerTrigger, approxSizeOfInt * 4)
+
+ val pulsar = reader
+ .option(TopicSingle, topic)
+ .load()
+ .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val mapped = pulsar.map(kv => kv._2.toInt)
+
+ testStream(mapped)(
+ StartStream(trigger = ProcessingTime(1000)),
+ makeSureGetOffsetCalled,
+ AddPulsarDataWithPartition(topic, Some(0), 1, 2, 3, 4),
+ AssertOnQuery { query =>
+ query.recentProgress.map(microBatch =>
+ microBatch.numInputRows <= 4
+ ).forall(_ == true)
+ }
+ )
+
+ addPartitions(topic, 2)
+
+ testStream(mapped)(
+ AddPulsarDataWithPartition(topic, Some(1), 5, 6, 7, 8),
+ AssertOnQuery { query =>
+ query.recentProgress.map(microBatch =>
+ microBatch.numInputRows <= 3
+ ).forall(_ == true)
+ }
+ )
+ }
+}
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala b/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala
index 7eb1ea9..749871d 100644
--- a/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala
+++ b/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala
@@ -126,6 +126,72 @@ class PulsarSourceTest extends StreamTest with SharedSparkSession with PulsarTes
s"AddPulsarData(topics = $topics, data = $data, message = $message)"
}
+ /**
+ * Add data to Pulsar with partition specified
+ *
+ * `topicAction` can be used to run actions for each topic before inserting data.
+ */
+ case class AddPulsarDataWithPartition(
+ topic: String,
+ partition: Option[Int],
+ data: Int*)(
+ implicit ensureDataInMultiplePartition: Boolean = false,
+ concurrent: Boolean = false,
+ message: String = "",
+ topicAction: (String, Option[MessageId]) => Unit = (_, _) => {})
+ extends AddData {
+
+ val topics = Set(topic)
+ override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = {
+ query match {
+ // Make sure no Spark job is running when deleting a topic
+ case Some(m: MicroBatchExecution) => m.processAllAvailable()
+ case _ =>
+ }
+
+ val existingTopics = getAllTopicsSize().toMap
+ val newTopics = topics.diff(existingTopics.keySet)
+ for (newTopic <- newTopics) {
+ topicAction(newTopic, None)
+ }
+ for (existingTopicPartitions <- existingTopics) {
+ topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2))
+ }
+
+ require(
+ query.nonEmpty,
+ "Cannot add data when there is no query for finding the active pulsar source")
+
+ val sources = query.get.logicalPlan.collect {
+ case StreamingExecutionRelation(source: PulsarSource, _, _) => source
+ case StreamingExecutionRelation(source: PulsarMicroBatchReader, _, _) => source
+ }.distinct
+
+ if (sources.isEmpty) {
+ throw new Exception(
+ "Could not find Pulsar source in the StreamExecution logical plan to add data to")
+ } else if (sources.size > 1) {
+ throw new Exception(
+ "Could not select the Pulsar source in the StreamExecution logical plan as there" +
+ "are multiple Pulsar sources:\n\t" + sources.mkString("\n\t"))
+ }
+ val pulsarSource = sources.head
+ val topic = topics.toSeq(Random.nextInt(topics.size))
+
+ sendMessages(topic, data.map {
+ _.toString
+ }.toArray, partition)
+ val sizes = getLatestOffsets(topics).toSeq
+ val offset = SpecificPulsarOffset(sizes: _*)
+ logInfo(s"Added data, expected offset $offset")
+ (pulsarSource, offset)
+ }
+
+ override def toString: String =
+ s"AddPulsarDataWithPartition(topics = $topics, partition = $partition, " +
+ s"data = $data, message = $message)"
+ }
+
/**
* Add data to Pulsar.
*
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala b/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala
index 5b034b6..827aebd 100644
--- a/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala
+++ b/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala
@@ -20,6 +20,7 @@ import java.time.{Clock, Duration}
import java.util.{Map => JMap}
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.reflect.ClassTag
import org.scalatest.concurrent.Eventually.{eventually, timeout}
@@ -54,11 +55,15 @@ trait PulsarTest extends BeforeAndAfterAll with BeforeAndAfterEach {
var serviceUrl: String = null
var adminUrl: String = null
+ val brokerConfigs = mutable.Map[String, String]()
private val logger: Logger = LoggerFactory.getLogger("pulsar-spark-test-logger")
override def beforeAll(): Unit = {
pulsarContainer = new PulsarContainer(parse("apachepulsar/pulsar:" + CURRENT_VERSION))
pulsarContainer.withStartupTimeout(Duration.ofMinutes(5))
+ brokerConfigs.foreach( kv =>
+ pulsarContainer.withEnv("PULSAR_PREFIX_" + kv._1, kv._2)
+ )
pulsarContainer.start()
@@ -80,6 +85,7 @@ trait PulsarTest extends BeforeAndAfterAll with BeforeAndAfterEach {
if (pulsarContainer != null) {
pulsarContainer.stop()
pulsarContainer.close()
+ brokerConfigs.clear()
}
}