summaryrefslogtreecommitdiff
path: root/spec/support/database
diff options
context:
space:
mode:
Diffstat (limited to 'spec/support/database')
-rw-r--r--spec/support/database/ci_tables.rb22
-rw-r--r--spec/support/database/prevent_cross_database_modification.rb109
-rw-r--r--spec/support/database/prevent_cross_joins.rb77
3 files changed, 208 insertions, 0 deletions
diff --git a/spec/support/database/ci_tables.rb b/spec/support/database/ci_tables.rb
new file mode 100644
index 00000000000..99fc7ac2501
--- /dev/null
+++ b/spec/support/database/ci_tables.rb
@@ -0,0 +1,22 @@
+# frozen_string_literal: true
+
+# This module stores the CI-related database tables which are
+# going to be moved to a separate database.
+module Database
+ module CiTables
+ def self.include?(name)
+ ci_tables.include?(name)
+ end
+
+ def self.ci_tables
+ @@ci_tables ||= Set.new.tap do |tables| # rubocop:disable Style/ClassVars
+ tables.merge(Ci::ApplicationRecord.descendants.map(&:table_name).compact)
+
+ # It was decided that taggings/tags are best placed with CI
+ # https://gitlab.com/gitlab-org/gitlab/-/issues/333413
+ tables.add('taggings')
+ tables.add('tags')
+ end
+ end
+ end
+end
diff --git a/spec/support/database/prevent_cross_database_modification.rb b/spec/support/database/prevent_cross_database_modification.rb
new file mode 100644
index 00000000000..460ee99391b
--- /dev/null
+++ b/spec/support/database/prevent_cross_database_modification.rb
@@ -0,0 +1,109 @@
+# frozen_string_literal: true
+
+module Database
+ module PreventCrossDatabaseModification
+ CrossDatabaseModificationAcrossUnsupportedTablesError = Class.new(StandardError)
+
+ module GitlabDatabaseMixin
+ def allow_cross_database_modification_within_transaction(url:)
+ cross_database_context = Database::PreventCrossDatabaseModification.cross_database_context
+ return yield unless cross_database_context && cross_database_context[:enabled]
+
+ transaction_tracker_enabled_was = cross_database_context[:enabled]
+ cross_database_context[:enabled] = false
+
+ yield
+ ensure
+ cross_database_context[:enabled] = transaction_tracker_enabled_was if cross_database_context
+ end
+ end
+
+ module SpecHelpers
+ def with_cross_database_modification_prevented
+ subscriber = ActiveSupport::Notifications.subscribe('sql.active_record') do |name, start, finish, id, payload|
+ PreventCrossDatabaseModification.prevent_cross_database_modification!(payload[:connection], payload[:sql])
+ end
+
+ PreventCrossDatabaseModification.reset_cross_database_context!
+ PreventCrossDatabaseModification.cross_database_context.merge!(enabled: true, subscriber: subscriber)
+
+ yield if block_given?
+ ensure
+ cleanup_with_cross_database_modification_prevented if block_given?
+ end
+
+ def cleanup_with_cross_database_modification_prevented
+ ActiveSupport::Notifications.unsubscribe(PreventCrossDatabaseModification.cross_database_context[:subscriber])
+ PreventCrossDatabaseModification.cross_database_context[:enabled] = false
+ end
+ end
+
+ def self.cross_database_context
+ Thread.current[:transaction_tracker]
+ end
+
+ def self.reset_cross_database_context!
+ Thread.current[:transaction_tracker] = initial_data
+ end
+
+ def self.initial_data
+ {
+ enabled: false,
+ transaction_depth_by_db: Hash.new { |h, k| h[k] = 0 },
+ modified_tables_by_db: Hash.new { |h, k| h[k] = Set.new }
+ }
+ end
+
+ def self.prevent_cross_database_modification!(connection, sql)
+ return unless cross_database_context[:enabled]
+
+ database = connection.pool.db_config.name
+
+ if sql.start_with?('SAVEPOINT')
+ cross_database_context[:transaction_depth_by_db][database] += 1
+
+ return
+ elsif sql.start_with?('RELEASE SAVEPOINT', 'ROLLBACK TO SAVEPOINT')
+ cross_database_context[:transaction_depth_by_db][database] -= 1
+ if cross_database_context[:transaction_depth_by_db][database] <= 0
+ cross_database_context[:modified_tables_by_db][database].clear
+ end
+
+ return
+ end
+
+ return if cross_database_context[:transaction_depth_by_db].values.all?(&:zero?)
+
+ tables = PgQuery.parse(sql).dml_tables
+
+ return if tables.empty?
+
+ cross_database_context[:modified_tables_by_db][database].merge(tables)
+
+ all_tables = cross_database_context[:modified_tables_by_db].values.map(&:to_a).flatten
+
+ unless PreventCrossJoins.only_ci_or_only_main?(all_tables)
+ raise Database::PreventCrossDatabaseModification::CrossDatabaseModificationAcrossUnsupportedTablesError,
+ "Cross-database data modification queries (CI and Main) were detected within " \
+ "a transaction '#{all_tables.join(", ")}' discovered"
+ end
+ end
+ end
+end
+
+Gitlab::Database.singleton_class.prepend(
+ Database::PreventCrossDatabaseModification::GitlabDatabaseMixin)
+
+RSpec.configure do |config|
+ config.include(::Database::PreventCrossDatabaseModification::SpecHelpers)
+
+ # Using before and after blocks because the around block causes problems with the let_it_be
+ # record creations. It makes an extra savepoint which breaks the transaction count logic.
+ config.before(:each, :prevent_cross_database_modification) do
+ with_cross_database_modification_prevented
+ end
+
+ config.after(:each, :prevent_cross_database_modification) do
+ cleanup_with_cross_database_modification_prevented
+ end
+end
diff --git a/spec/support/database/prevent_cross_joins.rb b/spec/support/database/prevent_cross_joins.rb
new file mode 100644
index 00000000000..789721ccd38
--- /dev/null
+++ b/spec/support/database/prevent_cross_joins.rb
@@ -0,0 +1,77 @@
+# frozen_string_literal: true
+
+# This module tries to discover and prevent cross-joins across tables
+# This will forbid usage of tables between CI and main database
+# on a same query unless explicitly allowed by. This will change execution
+# from a given point to allow cross-joins. The state will be cleared
+# on a next test run.
+#
+# This method should be used to mark METHOD introducing cross-join
+# not a test using the cross-join.
+#
+# class User
+# def ci_owned_runners
+# ::Gitlab::Database.allow_cross_joins_across_databases!(url: link-to-issue-url)
+#
+# ...
+# end
+# end
+
+module Database
+ module PreventCrossJoins
+ CrossJoinAcrossUnsupportedTablesError = Class.new(StandardError)
+
+ def self.validate_cross_joins!(sql)
+ return if Thread.current[:allow_cross_joins_across_databases]
+
+ # PgQuery might fail in some cases due to limited nesting:
+ # https://github.com/pganalyze/pg_query/issues/209
+ tables = PgQuery.parse(sql).tables
+
+ unless only_ci_or_only_main?(tables)
+ raise CrossJoinAcrossUnsupportedTablesError,
+ "Unsupported cross-join across '#{tables.join(", ")}' discovered " \
+ "when executing query '#{sql}'"
+ end
+ end
+
+ # Returns true if a set includes only CI tables, or includes only non-CI tables
+ def self.only_ci_or_only_main?(tables)
+ tables.all? { |table| CiTables.include?(table) } ||
+ tables.none? { |table| CiTables.include?(table) }
+ end
+
+ module SpecHelpers
+ def with_cross_joins_prevented
+ subscriber = ActiveSupport::Notifications.subscribe('sql.active_record') do |event|
+ ::Database::PreventCrossJoins.validate_cross_joins!(event.payload[:sql])
+ end
+
+ Thread.current[:allow_cross_joins_across_databases] = false
+
+ yield
+ ensure
+ ActiveSupport::Notifications.unsubscribe(subscriber) if subscriber
+ end
+ end
+
+ module GitlabDatabaseMixin
+ def allow_cross_joins_across_databases(url:)
+ Thread.current[:allow_cross_joins_across_databases] = true
+ super
+ end
+ end
+ end
+end
+
+Gitlab::Database.singleton_class.prepend(
+ Database::PreventCrossJoins::GitlabDatabaseMixin)
+
+RSpec.configure do |config|
+ config.include(::Database::PreventCrossJoins::SpecHelpers)
+
+ # TODO: remove `:prevent_cross_joins` to enable the check by default
+ config.around(:each, :prevent_cross_joins) do |example|
+ with_cross_joins_prevented { example.run }
+ end
+end