diff --git a/README.md b/README.md index 55ec8e7..a60f5fb 100644 --- a/README.md +++ b/README.md @@ -219,12 +219,21 @@ A possible solution to remove duplicates when reading the written data could be - `admin.url` (Deprecated) + `admin.url` A service HTTP URL of your Pulsar cluster No None Streaming and Batch - The Pulsar `serviceHttpUrl` configuration. + The Pulsar `serviceHttpUrl` configuration. Only needed when `maxBytesPerTrigger` is specified + + + + `maxBytesPerTrigger` + A long value in unit of number of bytes + No + None + Streaming and Batch + A soft limit of the maximum number of bytes we want to process per microbatch. If this is specified, `admin.url` also needs to be specified. diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala index 275c4d6..fe8bdc5 100644 --- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala +++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala @@ -22,16 +22,22 @@ import scala.collection.mutable import scala.language.postfixOps import scala.util.control.NonFatal +import org.apache.pulsar.client.admin.PulsarAdmin import org.apache.pulsar.client.api.{MessageId, PulsarClient} import org.apache.pulsar.client.impl.{MessageIdImpl, PulsarClientImpl} import org.apache.pulsar.client.impl.schema.BytesSchema +import org.apache.pulsar.client.internal.DefaultImplementation import org.apache.pulsar.common.api.proto.CommandGetTopicsOfNamespace import org.apache.pulsar.common.naming.TopicName import org.apache.pulsar.common.schema.SchemaInfo import org.apache.pulsar.shade.com.google.common.util.concurrent.Uninterruptibles import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.read.streaming +import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit} import org.apache.spark.sql.pulsar.PulsarOptions._ +import org.apache.spark.sql.pulsar.PulsarSourceUtils.{getEntryId, getLedgerId} +import org.apache.spark.sql.pulsar.SpecificPulsarOffset.getTopicOffsets import org.apache.spark.sql.types.StructType /** @@ -40,6 +46,7 @@ import org.apache.spark.sql.types.StructType */ private[pulsar] case class PulsarHelper( serviceUrl: String, + adminUrl: Option[String], clientConf: ju.Map[String, Object], driverGroupIdPrefix: String, caseInsensitiveParameters: Map[String, String], @@ -55,6 +62,12 @@ private[pulsar] case class PulsarHelper( private var topics: Seq[String] = _ private var topicPartitions: Seq[String] = _ + // We can call adminUrl.get because admissionControlHelper + // will only be called if latestOffset is called and there should + // be an exception thrown in PulsarProvider if maxBytes is set, + // and adminUrl is not set + private lazy val admissionControlHelper = new PulsarAdmissionControlHelper(adminUrl.get) + override def close(): Unit = { // do nothing } @@ -122,7 +135,9 @@ private[pulsar] case class PulsarHelper( offset.foreach { case (tp, mid) => try { val (subscription, _) = extractSubscription(predefinedSubscription, tp) - CachedConsumer.getOrCreate(tp, subscription, client).seek(mid) + val consumer = CachedConsumer.getOrCreate(tp, subscription, client) + if (!consumer.isConnected) consumer.getLastMessageId + consumer.seek(mid) } catch { case e: Throwable => throw new RuntimeException( @@ -207,6 +222,35 @@ private[pulsar] case class PulsarHelper( }.toMap) } + def latestOffsets(startingOffset: streaming.Offset, + totalReadLimit: Long): SpecificPulsarOffset = { + // implement helper inside PulsarHelper in order to use getTopicPartitions + val topicPartitions = getTopicPartitions + // add new partitions from PulsarAdmin, set to earliest entry and ledger id based on limit + // start a reader, get to the earliest offset for new topic partitions + val existingStartOffsets = if (startingOffset != null) { + getTopicOffsets(startingOffset.asInstanceOf[org.apache.spark.sql.execution.streaming.Offset]) + } else { + Map[String, MessageId]() + } + val newTopics = topicPartitions.toSet.diff(existingStartOffsets.keySet) + val startPartitionOffsets = existingStartOffsets ++ newTopics.map(topicPartition + => { + topicPartition -> MessageId.earliest + }) + val offsets = mutable.Map[String, MessageId]() + val numPartitions = startPartitionOffsets.size + // move all topic partition logic to helper function + val readLimit = totalReadLimit / numPartitions + startPartitionOffsets.keys.foreach { topicPartition => + val startMessageId = startPartitionOffsets.apply(topicPartition) + offsets += (topicPartition -> + admissionControlHelper.latestOffsetForTopicPartition( + topicPartition, startMessageId, readLimit)) + } + SpecificPulsarOffset(offsets.toMap) + } + def fetchLatestOffsetForTopic(topic: String): MessageId = { val messageId = try { @@ -472,3 +516,68 @@ private[pulsar] case class PulsarHelper( CachedConsumer.getOrCreate(topic, subscriptionName, client).getLastMessageId } } + +class PulsarAdmissionControlHelper(adminUrl: String) + extends Logging { + + private lazy val pulsarAdmin = PulsarAdmin.builder().serviceHttpUrl(adminUrl).build() + + import scala.collection.JavaConverters._ + + def latestOffsetForTopicPartition(topicPartition: String, + startMessageId: MessageId, + readLimit: Long): MessageId = { + val startLedgerId = getLedgerId(startMessageId) + val startEntryId = getEntryId(startMessageId) + val stats = pulsarAdmin.topics.getInternalStats(topicPartition) + val ledgers = pulsarAdmin.topics.getInternalStats(topicPartition).ledgers. + asScala.filter(_.ledgerId >= startLedgerId).sortBy(_.ledgerId) + // The last ledger of the ledgers list doesn't have .size or .entries + // properly populated, and the corresponding info is in currentLedgerSize + // and currentLedgerEntries + if (ledgers.nonEmpty) { + ledgers.last.size = stats.currentLedgerSize + ledgers.last.entries = stats.currentLedgerEntries + } + val partitionIndex = if (topicPartition.contains(PartitionSuffix)) { + topicPartition.split(PartitionSuffix)(1).toInt + } else { + -1 + } + var messageId = startMessageId + var readLimitLeft = readLimit + ledgers.filter(_.entries != 0).sortBy(_.ledgerId).foreach { ledger => + assert(readLimitLeft >= 0) + if (readLimitLeft == 0) { + return messageId + } + val avgBytesPerEntries = ledger.size / ledger.entries + // approximation of bytes left in ledger to deal with case + // where we are at the middle of the ledger + val bytesLeftInLedger = if (ledger.ledgerId == startLedgerId) { + avgBytesPerEntries * (ledger.entries - startEntryId - 1) + } else { + ledger.size + } + if (readLimitLeft > bytesLeftInLedger) { + readLimitLeft -= bytesLeftInLedger + messageId = DefaultImplementation + .getDefaultImplementation + .newMessageId(ledger.ledgerId, ledger.entries - 1, partitionIndex) + } else { + val numEntriesToRead = Math.max(1, readLimitLeft / avgBytesPerEntries) + val lastEntryId = if (ledger.ledgerId != startLedgerId) { + numEntriesToRead - 1 + } else { + startEntryId + numEntriesToRead + } + val lastEntryRead = Math.min(ledger.entries - 1, lastEntryId) + messageId = DefaultImplementation + .getDefaultImplementation + .newMessageId(ledger.ledgerId, lastEntryRead, partitionIndex) + readLimitLeft = 0 + } + } + messageId + } +} diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala index d0f5122..d9ec02f 100644 --- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala +++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala @@ -36,6 +36,7 @@ private[pulsar] object PulsarOptions { val TopicOptionKeys: Set[String] = Set(TopicSingle, TopicMulti, TopicPattern) val ServiceUrlOptionKey: String = "service.url" + val AdminUrlOptionKey: String = "admin.url" val StartingOffsetsOptionKey: String = "startingOffsets".toLowerCase(Locale.ROOT) val StartingTime: String = "startingTime".toLowerCase(Locale.ROOT) val EndingTime: String = "endingTime".toLowerCase(Locale.ROOT) @@ -45,6 +46,7 @@ private[pulsar] object PulsarOptions { val SubscriptionPrefix: String = "subscriptionPrefix".toLowerCase(Locale.ROOT) val PredefinedSubscription: String = "predefinedSubscription".toLowerCase(Locale.ROOT) + val MaxBytesPerTrigger: String = "maxBytesPerTrigger".toLowerCase(Locale.ROOT) val PollTimeoutMS: String = "pollTimeoutMs".toLowerCase(Locale.ROOT) val FailOnDataLossOptionKey: String = "failOnDataLoss".toLowerCase(Locale.ROOT) diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala index a0fa502..9a69938 100644 --- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala +++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala @@ -56,12 +56,13 @@ private[pulsar] class PulsarProvider parameters: Map[String, String]): (String, StructType) = { val caseInsensitiveParams = validateStreamOptions(parameters) - val (clientConfig, _, serviceUrlConfig) = prepareConfForReader(parameters) + val (clientConfig, _, serviceUrlConfig, adminUrl) = prepareConfForReader(parameters) val subscriptionNamePrefix = s"spark-pulsar-${UUID.randomUUID}" val inferredSchema = Utils.tryWithResource( PulsarHelper( serviceUrlConfig, + adminUrl, clientConfig, subscriptionNamePrefix, caseInsensitiveParams, @@ -84,13 +85,14 @@ private[pulsar] class PulsarProvider logDebug(s"Creating Pulsar source: $parameters") val caseInsensitiveParams = validateStreamOptions(parameters) - val (clientConfig, readerConfig, serviceUrl) = prepareConfForReader(parameters) + val (clientConfig, readerConfig, serviceUrl, adminUrl) = prepareConfForReader(parameters) logDebug( s"Client config: $clientConfig; Reader config: $readerConfig; Service URL: $serviceUrl") val subscriptionNamePrefix = getSubscriptionPrefix(parameters) val pulsarHelper = PulsarHelper( serviceUrl, + adminUrl, clientConfig, subscriptionNamePrefix, caseInsensitiveParams, @@ -105,6 +107,12 @@ private[pulsar] class PulsarProvider pulsarHelper.offsetForEachTopic(caseInsensitiveParams, LatestOffset, StartOptionKey) pulsarHelper.setupCursor(offset) + val maxBytes = maxBytesPerTrigger(caseInsensitiveParams) + if (adminUrl.isEmpty && maxBytes != 0L) { + throw new IllegalArgumentException("admin.url " + + "must be specified if maxBytesPerTrigger is specified") + } + new PulsarSource( sqlContext, pulsarHelper, @@ -113,6 +121,7 @@ private[pulsar] class PulsarProvider metadataPath, offset, pollTimeoutMs(caseInsensitiveParams), + maxBytesPerTrigger(caseInsensitiveParams), failOnDataLoss(caseInsensitiveParams), subscriptionNamePrefix, jsonOptions) @@ -125,10 +134,11 @@ private[pulsar] class PulsarProvider val subscriptionNamePrefix = getSubscriptionPrefix(parameters, isBatch = true) - val (clientConfig, readerConfig, serviceUrl) = prepareConfForReader(parameters) + val (clientConfig, readerConfig, serviceUrl, adminUrl) = prepareConfForReader(parameters) val (start, end, schema, pSchema) = Utils.tryWithResource( PulsarHelper( serviceUrl, + adminUrl, clientConfig, subscriptionNamePrefix, caseInsensitiveParams, @@ -366,6 +376,10 @@ private[pulsar] object PulsarProvider extends Logging { parameters(ServiceUrlOptionKey) } + private def getAdminUrl(parameters: Map[String, String]): Option[String] = { + parameters.get(AdminUrlOptionKey) + } + private def getAllowDifferentTopicSchemas(parameters: Map[String, String]): Boolean = { parameters.getOrElse(AllowDifferentTopicSchemas, "false").toBoolean } @@ -380,6 +394,13 @@ private[pulsar] object PulsarProvider extends Logging { (SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000).toString) .toInt + private def maxBytesPerTrigger(caseInsensitiveParams: Map[String, String]): Long = + caseInsensitiveParams + .getOrElse( + PulsarOptions.MaxBytesPerTrigger, + 0L.toString + ).toLong + private def validateGeneralOptions( caseInsensitiveParams: Map[String, String]): Map[String, String] = { if (!caseInsensitiveParams.contains(ServiceUrlOptionKey)) { @@ -486,9 +507,10 @@ private[pulsar] object PulsarProvider extends Logging { } private def prepareConfForReader(parameters: Map[String, String]) - : (ju.Map[String, Object], ju.Map[String, Object], String) = { + : (ju.Map[String, Object], ju.Map[String, Object], String, Option[String]) = { val serviceUrl = getServiceUrl(parameters) + val adminUrl = getAdminUrl(parameters) var clientParams = getClientParams(parameters) clientParams += (ServiceUrlOptionKey -> serviceUrl) val readerParams = getReaderParams(parameters) @@ -496,7 +518,7 @@ private[pulsar] object PulsarProvider extends Logging { ( paramsToPulsarConf("pulsar.client", clientParams), paramsToPulsarConf("pulsar.reader", readerParams), - serviceUrl) + serviceUrl, adminUrl) } private def prepareConfForProducer(parameters: Map[String, String]) diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala index 851ddca..8405e65 100644 --- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala +++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala @@ -13,19 +13,30 @@ */ package org.apache.spark.sql.pulsar + import java.{util => ju} +import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.mutable + +import org.apache.pulsar.client.admin.PulsarAdmin import org.apache.pulsar.client.api.MessageId import org.apache.pulsar.client.impl.MessageIdImpl +import org.apache.pulsar.client.internal.DefaultImplementation import org.apache.pulsar.common.schema.SchemaInfo import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.JSONOptionsInRead -import org.apache.spark.sql.execution.streaming.{Offset, Source} +import org.apache.spark.sql.connector.read.streaming +import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit, ReadMaxFiles, SupportsAdmissionControl} +import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset, Source} +import org.apache.spark.sql.pulsar.PulsarOptions.ServiceUrlOptionKey +import org.apache.spark.sql.pulsar.SpecificPulsarOffset.getTopicOffsets import org.apache.spark.sql.types.StructType + private[pulsar] class PulsarSource( sqlContext: SQLContext, pulsarHelper: PulsarHelper, @@ -34,11 +45,13 @@ private[pulsar] class PulsarSource( metadataPath: String, startingOffsets: PerTopicOffset, pollTimeoutMs: Int, + maxBytesPerTrigger: Long, failOnDataLoss: Boolean, subscriptionNamePrefix: String, jsonOptions: JSONOptionsInRead) extends Source - with Logging { + with Logging + with SupportsAdmissionControl { import PulsarSourceUtils._ @@ -54,17 +67,37 @@ private[pulsar] class PulsarSource( private var currentTopicOffsets: Option[Map[String, MessageId]] = None + private lazy val pulsarSchema: SchemaInfo = pulsarHelper.getPulsarSchema override def schema(): StructType = SchemaUtils.pulsarSourceSchema(pulsarSchema) override def getOffset: Option[Offset] = { - // Make sure initialTopicOffsets is initialized + throw new UnsupportedOperationException( + "latestOffset(Offset, ReadLimit) should be called instead of this method") + } + + override def latestOffset(startingOffset: streaming.Offset, + readLimit: ReadLimit): streaming.Offset = { initialTopicOffsets - val latest = pulsarHelper.fetchLatestOffsets() - currentTopicOffsets = Some(latest.topicOffsets) - logDebug(s"GetOffset: ${latest.topicOffsets.toSeq.map(_.toString).sorted}") - Some(latest.asInstanceOf[Offset]) + readLimit match { + case ReadMaxBytes(maxBytes) => + startingOffset match { + // deals with the case where we add a topic-partition after + // the stream has started, since adding a new topic-partition + // sets startingOffset to null + case null => pulsarHelper.latestOffsets(initialTopicOffsets, maxBytes) + case startingOffset => pulsarHelper.latestOffsets(startingOffset, maxBytes) + } + case _: ReadAllAvailable => pulsarHelper.fetchLatestOffsets() + } + } + override def getDefaultReadLimit: ReadLimit = { + if (maxBytesPerTrigger == 0L) { + ReadLimit.allAvailable() + } else { + ReadMaxBytes(maxBytesPerTrigger) + } } override def getBatch(start: Option[Offset], end: Offset): DataFrame = { @@ -169,3 +202,7 @@ private[pulsar] class PulsarSource( } } + +/** A read limit that admits a soft-max of `maxBytes` per micro-batch. */ +case class ReadMaxBytes(maxBytes: Long) extends ReadLimit + diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala index ec86a48..990578d 100644 --- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala +++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala @@ -120,6 +120,36 @@ private[pulsar] object PulsarSourceUtils extends Logging { } } + def getLedgerId(mid: MessageId): Long = { + mid match { + case bmid: BatchMessageIdImpl => + bmid.getLedgerId + case midi: MessageIdImpl => midi.getLedgerId + case t: TopicMessageIdImpl => getLedgerId(t.getInnerMessageId) + case up: UserProvidedMessageId => up.getLedgerId + } + } + + def getEntryId(mid: MessageId): Long = { + mid match { + case bmid: BatchMessageIdImpl => + bmid.getEntryId + case midi: MessageIdImpl => midi.getEntryId + case t: TopicMessageIdImpl => getEntryId(t.getInnerMessageId) + case up: UserProvidedMessageId => up.getEntryId + } + } + + def getPartitionIndex(mid: MessageId): Int = { + mid match { + case bmid: BatchMessageIdImpl => + bmid.getPartitionIndex + case midi: MessageIdImpl => midi.getPartitionIndex + case t: TopicMessageIdImpl => getPartitionIndex(t.getInnerMessageId) + case up: UserProvidedMessageId => up.getPartitionIndex + } + } + def seekableLatestMid(mid: MessageId): MessageId = { if (messageExists(mid)) mid else MessageId.earliest } diff --git a/src/test/scala/org/apache/spark/sql/pulsar/PulsarAdmissionControlSuite.scala b/src/test/scala/org/apache/spark/sql/pulsar/PulsarAdmissionControlSuite.scala new file mode 100644 index 0000000..6dc416f --- /dev/null +++ b/src/test/scala/org/apache/spark/sql/pulsar/PulsarAdmissionControlSuite.scala @@ -0,0 +1,257 @@ +package org.apache.spark.sql.pulsar + +import org.apache.pulsar.client.admin.PulsarAdmin +import org.apache.pulsar.client.api.MessageId +import org.apache.pulsar.client.internal.DefaultImplementation +import org.apache.spark.sql.pulsar.PulsarSourceUtils.{getEntryId, getLedgerId} +import org.apache.spark.sql.streaming.Trigger.{Once, ProcessingTime} +import org.apache.spark.util.Utils + +class PulsarAdmissionControlSuite extends PulsarSourceTest { + + import PulsarOptions._ + import testImplicits._ + + private val maxEntriesPerLedger = "managedLedgerMaxEntriesPerLedger" + private val ledgerRolloverTime = "managedLedgerMinLedgerRolloverTimeMinutes" + private val approxSizeOfInt = 50 + + override def beforeAll(): Unit = { + brokerConfigs.put(maxEntriesPerLedger, "3") + brokerConfigs.put(ledgerRolloverTime, "0") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + } + + test("Only admit first entry of ledger") { + val topic = newTopic() + val messageIds = sendMessages(topic, Array("1", "2", "3")) + val firstMid = messageIds.head._2 + val firstLedger = getLedgerId(firstMid) + val firstEntry = getEntryId(firstMid) + require(getLatestOffsets(Set(topic)).size === 1) + val admissionControlHelper = new PulsarAdmissionControlHelper(adminUrl) + val offset = admissionControlHelper.latestOffsetForTopicPartition(topic, MessageId.earliest, approxSizeOfInt) + assert(getLedgerId(offset) == firstLedger && getEntryId(offset) == firstEntry) + } + + test("Admit entry in the middle of the ledger") { + val topic = newTopic() + val messageIds = sendMessages(topic, Array("1", "2", "3")) + val firstMid = messageIds.head._2 + val secondMid = messageIds.apply(1)._2 + require(getLatestOffsets(Set(topic)).size === 1) + val admissionControlHelper = new PulsarAdmissionControlHelper(adminUrl) + val offset = admissionControlHelper.latestOffsetForTopicPartition(topic, firstMid, approxSizeOfInt) + assert(getLedgerId(offset) == getLedgerId(secondMid) && getEntryId(offset) == getEntryId(secondMid)) + + } + + test("Check last batch where message size is greater than maxBytesPerTrigger") { + val topic = newTopic() + sendMessages(topic, Array("-1")) + require(getLatestOffsets(Set(topic)).size === 1) + + val pulsar = spark.readStream + .format("pulsar") + .option(TopicSingle, topic) + .option(ServiceUrlOptionKey, serviceUrl) + .option(AdminUrlOptionKey, adminUrl) + .option(FailOnDataLossOptionKey, "true") + .option(MaxBytesPerTrigger, approxSizeOfInt * 3) + .load() + .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = pulsar.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + StartStream(trigger = ProcessingTime(1000)), + makeSureGetOffsetCalled, + AddPulsarData(Set(topic), 1, 2, 3), + AddPulsarData(Set(topic), 4, 5, 6, 7, 8, 9), + AssertOnQuery { query => + query.recentProgress.map(microBatch => + microBatch.numInputRows <= 4 + ).forall(_ == true) + } + ) + } + + test("Admission Control for multiple topics") { + val topic1 = newTopic() + val topic2 = newTopic() + + val pulsar = spark.readStream + .format("pulsar") + .option(TopicMulti, s"$topic1,$topic2") + .option(ServiceUrlOptionKey, serviceUrl) + .option(AdminUrlOptionKey, adminUrl) + .option(FailOnDataLossOptionKey, "true") + .option(MaxBytesPerTrigger, approxSizeOfInt * 6) + .load() + .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = pulsar.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + StartStream(trigger = ProcessingTime(1000)), + makeSureGetOffsetCalled, + AddPulsarData(Set(topic1), 1, 2, 3), + AddPulsarData(Set(topic2), 4, 5, 6, 7, 8, 9), + AssertOnQuery { query => + query.recentProgress.map(microBatch => + microBatch.numInputRows <= 4 + ).forall(_ == true) + } + ) + } + + test("Admission Control for concurrent topic writes") { + val topic1 = newTopic() + val topic2 = newTopic() + + val pulsar = spark.readStream + .format("pulsar") + .option(TopicMulti, s"$topic1,$topic2") + .option(ServiceUrlOptionKey, serviceUrl) + .option(AdminUrlOptionKey, adminUrl) + .option(FailOnDataLossOptionKey, "true") + .option(MaxBytesPerTrigger, approxSizeOfInt * 6) + .load() + .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = pulsar.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + StartStream(trigger = ProcessingTime(1000)), + makeSureGetOffsetCalled, + AddPulsarData(Set(topic1, topic2), 1, 2, 3), + AddPulsarData(Set(topic1, topic2), 4, 5, 6, 7, 8, 9), + AssertOnQuery { query => + query.recentProgress.map(microBatch => + microBatch.numInputRows <= 4 + ).forall(_ == true) + } + ) + } + + test("Admission Control with one topic-partition") { + val topic = newTopic() + + Utils.tryWithResource(PulsarAdmin.builder().serviceHttpUrl(adminUrl).build()) { admin => + admin.topics().createPartitionedTopic(topic, 1) + require(getLatestOffsets(Set(topic)).size === 1) + } + + val reader = spark.readStream + .format("pulsar") + .option(ServiceUrlOptionKey, serviceUrl) + .option(AdminUrlOptionKey, adminUrl) + .option(FailOnDataLossOptionKey, "true") + .option(MaxBytesPerTrigger, approxSizeOfInt * 3) + + val pulsar = reader + .option(TopicSingle, topic) + .load() + .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = pulsar.map(kv => kv._2.toInt) + + testStream(mapped)( + StartStream(trigger = ProcessingTime(1000)), + makeSureGetOffsetCalled, + AddPulsarDataWithPartition(topic, Some(0), 1, 2, 3, 4), + AssertOnQuery { query => + query.recentProgress.map(microBatch => + microBatch.numInputRows <= 4 + ).forall(_ == true) + } + ) + } + + test("Admission Control with multiple topic-partitions") { + val topic = newTopic() + + Utils.tryWithResource(PulsarAdmin.builder().serviceHttpUrl(adminUrl).build()) { admin => + admin.topics().createPartitionedTopic(topic, 2) + require(getLatestOffsets(Set(topic)).size === 2) + } + + val reader = spark.readStream + .format("pulsar") + .option(ServiceUrlOptionKey, serviceUrl) + .option(AdminUrlOptionKey, adminUrl) + .option(FailOnDataLossOptionKey, "true") + .option(MaxBytesPerTrigger, approxSizeOfInt * 4) + + val pulsar = reader + .option(TopicSingle, topic) + .load() + .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = pulsar.map(kv => kv._2.toInt) + + testStream(mapped)( + StartStream(trigger = ProcessingTime(1000)), + makeSureGetOffsetCalled, + AddPulsarDataWithPartition(topic, Some(0), 1, 2, 3, 4), + AddPulsarDataWithPartition(topic, Some(1), 5, 6, 7, 8), + AssertOnQuery { query => + query.recentProgress.map(microBatch => + microBatch.numInputRows <= 3 + ).forall(_ == true) + } + ) + } + + test("Add topic-partition after starting stream") { + val topic = newTopic() + + Utils.tryWithResource(PulsarAdmin.builder().serviceHttpUrl(adminUrl).build()) { admin => + admin.topics().createPartitionedTopic(topic, 1) + require(getLatestOffsets(Set(topic)).size === 1) + } + + val reader = spark.readStream + .format("pulsar") + .option(ServiceUrlOptionKey, serviceUrl) + .option(AdminUrlOptionKey, adminUrl) + .option(FailOnDataLossOptionKey, "true") + .option(MaxBytesPerTrigger, approxSizeOfInt * 4) + + val pulsar = reader + .option(TopicSingle, topic) + .load() + .selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = pulsar.map(kv => kv._2.toInt) + + testStream(mapped)( + StartStream(trigger = ProcessingTime(1000)), + makeSureGetOffsetCalled, + AddPulsarDataWithPartition(topic, Some(0), 1, 2, 3, 4), + AssertOnQuery { query => + query.recentProgress.map(microBatch => + microBatch.numInputRows <= 4 + ).forall(_ == true) + } + ) + + addPartitions(topic, 2) + + testStream(mapped)( + AddPulsarDataWithPartition(topic, Some(1), 5, 6, 7, 8), + AssertOnQuery { query => + query.recentProgress.map(microBatch => + microBatch.numInputRows <= 3 + ).forall(_ == true) + } + ) + } +} diff --git a/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala b/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala index 7eb1ea9..749871d 100644 --- a/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala +++ b/src/test/scala/org/apache/spark/sql/pulsar/PulsarSourceTest.scala @@ -126,6 +126,72 @@ class PulsarSourceTest extends StreamTest with SharedSparkSession with PulsarTes s"AddPulsarData(topics = $topics, data = $data, message = $message)" } + /** + * Add data to Pulsar with partition specified + * + * `topicAction` can be used to run actions for each topic before inserting data. + */ + case class AddPulsarDataWithPartition( + topic: String, + partition: Option[Int], + data: Int*)( + implicit ensureDataInMultiplePartition: Boolean = false, + concurrent: Boolean = false, + message: String = "", + topicAction: (String, Option[MessageId]) => Unit = (_, _) => {}) + extends AddData { + + val topics = Set(topic) + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { + query match { + // Make sure no Spark job is running when deleting a topic + case Some(m: MicroBatchExecution) => m.processAllAvailable() + case _ => + } + + val existingTopics = getAllTopicsSize().toMap + val newTopics = topics.diff(existingTopics.keySet) + for (newTopic <- newTopics) { + topicAction(newTopic, None) + } + for (existingTopicPartitions <- existingTopics) { + topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) + } + + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active pulsar source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: PulsarSource, _, _) => source + case StreamingExecutionRelation(source: PulsarMicroBatchReader, _, _) => source + }.distinct + + if (sources.isEmpty) { + throw new Exception( + "Could not find Pulsar source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the Pulsar source in the StreamExecution logical plan as there" + + "are multiple Pulsar sources:\n\t" + sources.mkString("\n\t")) + } + val pulsarSource = sources.head + val topic = topics.toSeq(Random.nextInt(topics.size)) + + sendMessages(topic, data.map { + _.toString + }.toArray, partition) + val sizes = getLatestOffsets(topics).toSeq + val offset = SpecificPulsarOffset(sizes: _*) + logInfo(s"Added data, expected offset $offset") + (pulsarSource, offset) + } + + override def toString: String = + s"AddPulsarDataWithPartition(topics = $topics, partition = $partition, " + + s"data = $data, message = $message)" + } + /** * Add data to Pulsar. * diff --git a/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala b/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala index 5b034b6..827aebd 100644 --- a/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala +++ b/src/test/scala/org/apache/spark/sql/pulsar/PulsarTest.scala @@ -20,6 +20,7 @@ import java.time.{Clock, Duration} import java.util.{Map => JMap} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import org.scalatest.concurrent.Eventually.{eventually, timeout} @@ -54,11 +55,15 @@ trait PulsarTest extends BeforeAndAfterAll with BeforeAndAfterEach { var serviceUrl: String = null var adminUrl: String = null + val brokerConfigs = mutable.Map[String, String]() private val logger: Logger = LoggerFactory.getLogger("pulsar-spark-test-logger") override def beforeAll(): Unit = { pulsarContainer = new PulsarContainer(parse("apachepulsar/pulsar:" + CURRENT_VERSION)) pulsarContainer.withStartupTimeout(Duration.ofMinutes(5)) + brokerConfigs.foreach( kv => + pulsarContainer.withEnv("PULSAR_PREFIX_" + kv._1, kv._2) + ) pulsarContainer.start() @@ -80,6 +85,7 @@ trait PulsarTest extends BeforeAndAfterAll with BeforeAndAfterEach { if (pulsarContainer != null) { pulsarContainer.stop() pulsarContainer.close() + brokerConfigs.clear() } }