diff --git a/lib/fixture_kit/coders/active_record_coder.rb b/lib/fixture_kit/coders/active_record_coder.rb index 32138ba..e42f394 100644 --- a/lib/fixture_kit/coders/active_record_coder.rb +++ b/lib/fixture_kit/coders/active_record_coder.rb @@ -30,11 +30,17 @@ def generate(parent_data: nil, &block) end def mount(data) - statements_by_connection(data).each do |connection, statements| - connection.disable_referential_integrity do - # execute_batch is private in current supported Rails versions. - # This should be revisited when Rails 8.2 makes it public. - connection.__send__(:execute_batch, statements, "FixtureKit Insert") + models_by_pool(data).each do |pool, models| + pool.with_connection do |connection| + statements = models.flat_map do |model| + [build_delete_sql(connection, model.table_name), data[model]].compact + end + + connection.disable_referential_integrity do + # execute_batch is private in current supported Rails versions. + # This should be revisited when Rails 8.2 makes it public. + connection.__send__(:execute_batch, statements, "FixtureKit Insert") + end end end end @@ -70,8 +76,8 @@ def generate_statements(models) end end - def build_delete_sql(model) - "DELETE FROM #{model.quoted_table_name}" + def build_delete_sql(connection, table_name) + "DELETE FROM #{connection.quote_table_name(table_name)}" end def build_insert_sql(table_name, columns, rows, connection) @@ -81,19 +87,15 @@ def build_insert_sql(table_name, columns, rows, connection) "INSERT INTO #{quoted_table} (#{quoted_columns.join(", ")}) VALUES #{rows.join(", ")}" end - def statements_by_connection(records) - deleted_tables = Set.new + def models_by_pool(data) + seen = Set.new - records.each_with_object({}) do |(model, sql), grouped| - connection = model.connection - grouped[connection] ||= [] - - table_key = [connection, model.table_name] - if deleted_tables.add?(table_key) - grouped[connection] << build_delete_sql(model) - end + data.each_with_object({}) do |(model, _), grouped| + pool = model.connection_pool + next unless seen.add?([pool, model.table_name]) - grouped[connection] << sql if sql + grouped[pool] ||= [] + grouped[pool] << model end end end