summaryrefslogtreecommitdiff
path: root/lib/gitlab/database/each_database.rb
blob: cccd4b48723b8f6fe12f240544c9e7d50d917f88 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# frozen_string_literal: true

module Gitlab
  module Database
    module EachDatabase
      class << self
        def each_database_connection(only: nil)
          selected_names = Array.wrap(only)
          base_models = select_base_models(selected_names)

          base_models.each_pair do |connection_name, model|
            connection = model.connection

            with_shared_connection(connection, connection_name) do
              yield connection, connection_name
            end
          end
        end

        def each_model_connection(models, only_on: nil, &blk)
          selected_databases = Array.wrap(only_on).map(&:to_sym)

          models.each do |model|
            # If model is shared, iterate all available base connections
            # Example: `LooseForeignKeys::DeletedRecord`
            if model < ::Gitlab::Database::SharedModel
              with_shared_model_connections(model, selected_databases, &blk)
            else
              with_model_connection(model, selected_databases, &blk)
            end
          end
        end

        private

        def select_base_models(names)
          base_models = Gitlab::Database.database_base_models

          return base_models if names.empty?

          names.each_with_object(HashWithIndifferentAccess.new) do |name, hash|
            raise ArgumentError, "#{name} is not a valid database name" unless base_models.key?(name)

            hash[name] = base_models[name]
          end
        end

        def with_shared_model_connections(shared_model, selected_databases, &blk)
          Gitlab::Database.database_base_models.each_pair do |connection_name, connection_model|
            if shared_model.limit_connection_names
              next unless shared_model.limit_connection_names.include?(connection_name.to_sym)
            end

            next if selected_databases.present? && !selected_databases.include?(connection_name.to_sym)

            with_shared_connection(connection_model.connection, connection_name) do
              yield shared_model, connection_name
            end
          end
        end

        def with_model_connection(model, selected_databases, &blk)
          connection_name = model.connection_db_config.name

          return if selected_databases.present? && !selected_databases.include?(connection_name.to_sym)

          with_shared_connection(model.connection, connection_name) do
            yield model, connection_name
          end
        end

        def with_shared_connection(connection, connection_name)
          Gitlab::Database::SharedModel.using_connection(connection) do
            Gitlab::AppLogger.debug(message: 'Switched database connection', connection_name: connection_name)

            yield
          end
        end
      end
    end
  end
end