diff --git a/lib/job-iteration/active_record_cursor.rb b/lib/job-iteration/active_record_cursor.rb index 4789d1b1..9b16fc4e 100644 --- a/lib/job-iteration/active_record_cursor.rb +++ b/lib/job-iteration/active_record_cursor.rb @@ -19,8 +19,11 @@ def initialize end def initialize(relation, columns = nil, position = nil) - columns ||= "#{relation.table_name}.#{relation.primary_key}" - @columns = Array.wrap(columns) + @columns = if columns + Array(columns) + else + Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } + end self.position = Array.wrap(position) raise ArgumentError, "Must specify at least one column" if columns.empty? if relation.joins_values.present? && !@columns.all? { |column| column.to_s.include?(".") } diff --git a/lib/job-iteration/active_record_enumerator.rb b/lib/job-iteration/active_record_enumerator.rb index 5e7ab832..363a4ecf 100644 --- a/lib/job-iteration/active_record_enumerator.rb +++ b/lib/job-iteration/active_record_enumerator.rb @@ -10,7 +10,11 @@ class ActiveRecordEnumerator def initialize(relation, columns: nil, batch_size: 100, cursor: nil) @relation = relation @batch_size = batch_size - @columns = Array(columns || "#{relation.table_name}.#{relation.primary_key}") + @columns = if columns + Array(columns) + else + Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } + end @cursor = cursor end diff --git a/test/test_helper.rb b/test/test_helper.rb index 80cbd963..f684e031 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -44,6 +44,10 @@ def enqueue_at(job, _delay) class Product < ActiveRecord::Base end +class TravelRoute < ActiveRecord::Base + self.primary_key = [:origin, :destination] +end + host = ENV["USING_DEV"] == "1" ? "job-iteration.railgun" : "localhost" connection_config = { @@ -68,11 +72,16 @@ class Product < ActiveRecord::Base config.redis = { host: host } end -ActiveRecord::Base.connection.create_table(Product.table_name, force: true) do |t| +ActiveRecord::Base.connection.create_table(:products, force: true) do |t| t.string(:name) t.timestamps end +ActiveRecord::Base.connection.create_table(:travel_routes, force: true, primary_key: [:origin, :destination]) do |t| + t.string(:destination) + t.string(:origin) +end + module LoggingHelpers def assert_logged(message) old_logger = ActiveJob::Base.logger @@ -124,6 +133,7 @@ def insert_fixtures end def truncate_fixtures + ActiveRecord::Base.connection.truncate(TravelRoute.table_name) ActiveRecord::Base.connection.truncate(Product.table_name) end end diff --git a/test/unit/active_record_enumerator_test.rb b/test/unit/active_record_enumerator_test.rb index 26fc886d..f1a83ab1 100644 --- a/test/unit/active_record_enumerator_test.rb +++ b/test/unit/active_record_enumerator_test.rb @@ -105,6 +105,19 @@ class ActiveRecordEnumeratorTest < IterationUnitTest assert_equal(10, enum.size) end + test "enumerator for a relation with a composite primary key" do + TravelRoute.create!(origin: "A", destination: "B") + TravelRoute.create!(origin: "A", destination: "C") + TravelRoute.create!(origin: "B", destination: "A") + + enum = build_enumerator(relation: TravelRoute.all, batch_size: 2) + + cursors = [] + enum.records.each { |_record, cursor| cursors << cursor } + + assert_equal([["A", "B"], ["A", "C"], ["B", "A"]], cursors) + end if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") + private def build_enumerator(relation: Product.all, batch_size: 2, columns: nil, cursor: nil)