diff options
Diffstat (limited to 'lib/gitlab/database/each_database.rb')
-rw-r--r-- | lib/gitlab/database/each_database.rb | 37 |
1 files changed, 29 insertions, 8 deletions
diff --git a/lib/gitlab/database/each_database.rb b/lib/gitlab/database/each_database.rb index c3eea0515d4..cccd4b48723 100644 --- a/lib/gitlab/database/each_database.rb +++ b/lib/gitlab/database/each_database.rb @@ -4,8 +4,11 @@ module Gitlab module Database module EachDatabase class << self - def each_database_connection - Gitlab::Database.database_base_models.each_pair do |connection_name, model| + def each_database_connection(only: nil) + selected_names = Array.wrap(only) + base_models = select_base_models(selected_names) + + base_models.each_pair do |connection_name, model| connection = model.connection with_shared_connection(connection, connection_name) do @@ -14,34 +17,52 @@ module Gitlab end end - def each_model_connection(models, &blk) + def each_model_connection(models, only_on: nil, &blk) + selected_databases = Array.wrap(only_on).map(&:to_sym) + models.each do |model| # If model is shared, iterate all available base connections # Example: `LooseForeignKeys::DeletedRecord` if model < ::Gitlab::Database::SharedModel - with_shared_model_connections(model, &blk) + with_shared_model_connections(model, selected_databases, &blk) else - with_model_connection(model, &blk) + with_model_connection(model, selected_databases, &blk) end end end private - def with_shared_model_connections(shared_model, &blk) + def select_base_models(names) + base_models = Gitlab::Database.database_base_models + + return base_models if names.empty? + + names.each_with_object(HashWithIndifferentAccess.new) do |name, hash| + raise ArgumentError, "#{name} is not a valid database name" unless base_models.key?(name) + + hash[name] = base_models[name] + end + end + + def with_shared_model_connections(shared_model, selected_databases, &blk) Gitlab::Database.database_base_models.each_pair do |connection_name, connection_model| if shared_model.limit_connection_names next unless shared_model.limit_connection_names.include?(connection_name.to_sym) end + next if selected_databases.present? && !selected_databases.include?(connection_name.to_sym) + with_shared_connection(connection_model.connection, connection_name) do yield shared_model, connection_name end end end - def with_model_connection(model, &blk) - connection_name = model.connection.pool.db_config.name + def with_model_connection(model, selected_databases, &blk) + connection_name = model.connection_db_config.name + + return if selected_databases.present? && !selected_databases.include?(connection_name.to_sym) with_shared_connection(model.connection, connection_name) do yield model, connection_name |