Skip to content

Commit

Permalink
This update initiates the multipart upload as soon as a record begins…
Browse files Browse the repository at this point in the history
…, and closes the file on flush.

Signed-off-by: Aindriu Lavelle <[email protected]>
  • Loading branch information
aindriu-aiven committed Oct 23, 2024
1 parent 939796f commit 4776d6d
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.TopicPartition;
Expand All @@ -53,17 +55,19 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@SuppressWarnings("PMD.ExcessiveImports")
@SuppressWarnings({ "PMD.ExcessiveImports", "PMD.TooManyMethods" })
public final class S3SinkTask extends SinkTask {

private static final Logger LOGGER = LoggerFactory.getLogger(AivenKafkaConnectS3SinkConnector.class);
private static final Logger LOGGER = LoggerFactory.getLogger(S3SinkTask.class);

private RecordGrouper recordGrouper;

private S3SinkConfig config;

private AmazonS3 s3Client;

private Map<String, OutputWriter> writers;

AwsCredentialProviderFactory credentialFactory = new AwsCredentialProviderFactory();

@SuppressWarnings("PMD.UnnecessaryConstructor") // required by Connect
Expand All @@ -76,6 +80,7 @@ public void start(final Map<String, String> props) {
Objects.requireNonNull(props, "props hasn't been set");
config = new S3SinkConfig(props);
s3Client = createAmazonS3Client(config);
writers = new HashMap<>();
try {
recordGrouper = RecordGrouperFactory.newRecordGrouper(config);
} catch (final Exception e) { // NOPMD AvoidCatchingGenericException
Expand Down Expand Up @@ -110,39 +115,104 @@ public void put(final Collection<SinkRecord> records) {
Objects.requireNonNull(records, "records cannot be null");
LOGGER.info("Processing {} records", records.size());
records.forEach(recordGrouper::put);

recordGrouper.records().forEach((filename, groupedRecords) -> writeToS3(filename, groupedRecords, records));

}

/**
* Flush is used to roll over file and complete the S3 Mutli part upload.
*
* @param offsets
*/
@Override
public void flush(final Map<TopicPartition, OffsetAndMetadata> offsets) {
try {
recordGrouper.records().forEach(this::flushFile);
} finally {
recordGrouper.clear();
}
// On Flush Get Active writers
final Collection<OutputWriter> activeWriters = writers.values();
// Clear recordGrouper so it restarts OFFSET HEADS etc and on next put new writers will be created.
recordGrouper.clear();
// Close
activeWriters.forEach(writer -> {
try {
// Close active writers && remove from writers Map
// Calling close will write anything in the buffer before closing and complete the S3 multi part upload
writer.close();
// Remove once closed
writers.remove(writer);
} catch (IOException e) {
throw new ConnectException(e);
}
});

}

private void flushFile(final String filename, final List<SinkRecord> records) {
Objects.requireNonNull(records, "records cannot be null");
if (records.isEmpty()) {
return;
/**
* getOutputWriter is used to check if an existing compatible OutputWriter exists and if not create one and return
* it to the caller.
*
* @param filename
* used to write to S3
* @param sinkRecord
* a sinkRecord used to create a new S3OutputStream
* @return correct OutputWriter for writing a particular record to S3
*/
private OutputWriter getOutputWriter(final String filename, final SinkRecord sinkRecord) {
final String fileNameTemplate = getFileNameTemplate(filename, sinkRecord);

if (writers.get(fileNameTemplate) == null) {
final var out = newStreamFor(filename, sinkRecord);
try {
writers.put(fileNameTemplate,
OutputWriter.builder()
.withCompressionType(config.getCompressionType())
.withExternalProperties(config.originalsStrings())
.withOutputFields(config.getOutputFields())
.withEnvelopeEnabled(config.envelopeEnabled())
.build(out, config.getFormatType()));
} catch (IOException e) {
throw new ConnectException(e);
}
}
return writers.get(fileNameTemplate);
}

/**
*
* @param filename
* the name of the file in S3 to be written to
* @param records
* all records in this record grouping, including those already written to S3
* @param recordToBeWritten
* new records from put() which are to be written to S3
*/
private void writeToS3(final String filename, final List<SinkRecord> records,
final Collection<SinkRecord> recordToBeWritten) {

final SinkRecord sinkRecord = records.get(0);
try (var out = newStreamFor(filename, sinkRecord);
var outputWriter = OutputWriter.builder()
.withCompressionType(config.getCompressionType())
.withExternalProperties(config.originalsStrings())
.withOutputFields(config.getOutputFields())
.withEnvelopeEnabled(config.envelopeEnabled())
.build(out, config.getFormatType())) {
outputWriter.writeRecords(records);
} catch (final IOException e) {
// This writer is being left open until a flush occurs.
final OutputWriter writer; // NOPMD CloseResource
try {
writer = getOutputWriter(filename, sinkRecord);
// Record Grouper returns all records for that filename, all we want is the new batch of records to be added
// to the multi part upload.
writer.writeRecords(records.stream().filter(recordToBeWritten::contains).collect(Collectors.toList()));
} catch (IOException e) {
throw new ConnectException(e);
}

}

@Override
public void stop() {
writers.forEach((k, v) -> {
try {
v.close();
} catch (IOException e) {
throw new ConnectException(e);
}
});
s3Client.shutdown();

LOGGER.info("Stop S3 Sink Task");
}

Expand All @@ -152,11 +222,15 @@ public String version() {
}

private OutputStream newStreamFor(final String filename, final SinkRecord record) {
final var fullKey = config.usesFileNameTemplate() ? filename : oldFullKey(record);
final var fullKey = getFileNameTemplate(filename, record);
return new S3OutputStream(config.getAwsS3BucketName(), fullKey, config.getAwsS3PartSize(), s3Client,
config.getServerSideEncryptionAlgorithmName());
}

private String getFileNameTemplate(final String filename, final SinkRecord record) {
return config.usesFileNameTemplate() ? filename : oldFullKey(record);
}

private EndpointConfiguration newEndpointConfiguration(final S3SinkConfig config) {
if (Objects.isNull(config.getAwsS3EndPoint())) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.BufferedReader;
import java.io.IOException;
Expand Down Expand Up @@ -431,9 +433,7 @@ void failedForStringValuesByDefault() {

);

task.put(records);

assertThatThrownBy(() -> task.flush(null)).isInstanceOf(ConnectException.class)
assertThatThrownBy(() -> task.put(records)).isInstanceOf(ConnectException.class)
.hasMessage("Record value schema type must be BYTES, STRING given");
}

Expand Down Expand Up @@ -501,9 +501,7 @@ void failedForStructValuesByDefault() {
createRecordWithStructValueSchema("topic0", 1, "key1", "name1", 20, 1001),
createRecordWithStructValueSchema("topic1", 0, "key2", "name2", 30, 1002));

task.put(records);

assertThatThrownBy(() -> task.flush(null)).isInstanceOf(ConnectException.class)
assertThatThrownBy(() -> task.put(records)).isInstanceOf(ConnectException.class)
.hasMessage("Record value schema type must be BYTES, STRUCT given");
}

Expand Down Expand Up @@ -689,17 +687,92 @@ void supportUnwrappedJsonEnvelopeForStructAndClassicJson() throws IOException {
void requestCredentialProviderFromFactoryOnStart() {
final S3SinkTask task = new S3SinkTask();

final AwsCredentialProviderFactory mockedFactory = Mockito.mock(AwsCredentialProviderFactory.class);
final AWSCredentialsProvider provider = Mockito.mock(AWSCredentialsProvider.class);
final AwsCredentialProviderFactory mockedFactory = mock(AwsCredentialProviderFactory.class);
final AWSCredentialsProvider provider = mock(AWSCredentialsProvider.class);

task.credentialFactory = mockedFactory;
Mockito.when(mockedFactory.getProvider(any(S3SinkConfig.class))).thenReturn(provider);
when(mockedFactory.getProvider(any(S3SinkConfig.class))).thenReturn(provider);

task.start(properties);

verify(mockedFactory, Mockito.times(1)).getProvider(any(S3SinkConfig.class));
}

@Test
void mutliPartUploadWriteOnlyExpectedRecordsAndFilesToS3() throws IOException {
final String compression = "none";
properties.put(S3SinkConfig.FILE_COMPRESSION_TYPE_CONFIG, compression);
properties.put(S3SinkConfig.FORMAT_OUTPUT_FIELDS_CONFIG, "value");
properties.put(S3SinkConfig.FORMAT_OUTPUT_ENVELOPE_CONFIG, "false");
properties.put(S3SinkConfig.FORMAT_OUTPUT_TYPE_CONFIG, "json");
properties.put(S3SinkConfig.AWS_S3_PREFIX_CONFIG, "prefix-");

final S3SinkTask task = new S3SinkTask();
task.start(properties);
int timestamp = 1000;
int offset1 = 10;
int offset2 = 20;
int offset3 = 30;
final List<List<SinkRecord>> allRecords = new ArrayList<>();
for (int i = 0; i < 3; i++) {
allRecords.add(
List.of(createRecordWithStructValueSchema("topic0", 0, "key0", "name0", offset1++, timestamp++),
createRecordWithStructValueSchema("topic0", 1, "key1", "name1", offset2++, timestamp++),
createRecordWithStructValueSchema("topic1", 0, "key2", "name2", offset3++, timestamp++)));
}
final TopicPartition tp00 = new TopicPartition("topic0", 0);
final TopicPartition tp01 = new TopicPartition("topic0", 1);
final TopicPartition tp10 = new TopicPartition("topic1", 0);
final Collection<TopicPartition> tps = List.of(tp00, tp01, tp10);
task.open(tps);

allRecords.forEach(task::put);

final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
offsets.put(tp00, new OffsetAndMetadata(offset1));
offsets.put(tp01, new OffsetAndMetadata(offset2));
offsets.put(tp10, new OffsetAndMetadata(offset3));
task.flush(offsets);

final CompressionType compressionType = CompressionType.forName(compression);

List<String> expectedBlobs = Lists.newArrayList(
"prefix-topic0-0-00000000000000000010" + compressionType.extension(),
"prefix-topic0-1-00000000000000000020" + compressionType.extension(),
"prefix-topic1-0-00000000000000000030" + compressionType.extension());
assertThat(expectedBlobs).allMatch(blobName -> testBucketAccessor.doesObjectExist(blobName));

assertThat(testBucketAccessor.readLines("prefix-topic0-0-00000000000000000010", compression))
.containsExactly("[", "{\"name\":\"name0\"},", "{\"name\":\"name0\"},", "{\"name\":\"name0\"}", "]");
assertThat(testBucketAccessor.readLines("prefix-topic0-1-00000000000000000020", compression))
.containsExactly("[", "{\"name\":\"name1\"},", "{\"name\":\"name1\"},", "{\"name\":\"name1\"}", "]");
assertThat(testBucketAccessor.readLines("prefix-topic1-0-00000000000000000030", compression))
.containsExactly("[", "{\"name\":\"name2\"},", "{\"name\":\"name2\"},", "{\"name\":\"name2\"}", "]");
// Reset and send another batch of records to S3
allRecords.clear();
for (int i = 0; i < 3; i++) {
allRecords.add(
List.of(createRecordWithStructValueSchema("topic0", 0, "key0", "name0", offset1++, timestamp++),
createRecordWithStructValueSchema("topic0", 1, "key1", "name1", offset2++, timestamp++),
createRecordWithStructValueSchema("topic1", 0, "key2", "name2", offset3++, timestamp++)));
}
allRecords.forEach(task::put);
offsets.clear();
offsets.put(tp00, new OffsetAndMetadata(offset1));
offsets.put(tp01, new OffsetAndMetadata(offset2));
offsets.put(tp10, new OffsetAndMetadata(offset3));
task.flush(offsets);
expectedBlobs.clear();
expectedBlobs = Lists.newArrayList("prefix-topic0-0-00000000000000000010" + compressionType.extension(),
"prefix-topic0-1-00000000000000000020" + compressionType.extension(),
"prefix-topic1-0-00000000000000000030" + compressionType.extension(),
"prefix-topic0-0-00000000000000000013" + compressionType.extension(),
"prefix-topic0-1-00000000000000000023" + compressionType.extension(),
"prefix-topic1-0-00000000000000000033" + compressionType.extension());
assertThat(expectedBlobs).allMatch(blobName -> testBucketAccessor.doesObjectExist(blobName));

}

private SinkRecord createRecordWithStringValueSchema(final String topic, final int partition, final String key,
final String value, final int offset, final long timestamp) {
return new SinkRecord(topic, partition, Schema.STRING_SCHEMA, key, Schema.STRING_SCHEMA, value, offset,
Expand Down

0 comments on commit 4776d6d

Please sign in to comment.