diff options
Diffstat (limited to 'lib/gitlab/database/unidirectional_copy_trigger.rb')
-rw-r--r-- | lib/gitlab/database/unidirectional_copy_trigger.rb | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/lib/gitlab/database/unidirectional_copy_trigger.rb b/lib/gitlab/database/unidirectional_copy_trigger.rb new file mode 100644 index 00000000000..029c894a5ff --- /dev/null +++ b/lib/gitlab/database/unidirectional_copy_trigger.rb @@ -0,0 +1,97 @@ +# frozen_string_literal: true + +module Gitlab + module Database + class UnidirectionalCopyTrigger + def self.on_table(table_name, connection: ActiveRecord::Base.connection) + new(table_name, connection) + end + + def name(from_column_names, to_column_names) + from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names) + + unchecked_name(from_column_names, to_column_names) + end + + def create(from_column_names, to_column_names, trigger_name: nil) + from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names) + trigger_name ||= unchecked_name(from_column_names, to_column_names) + + assignment_clauses = assignment_clauses_for_columns(from_column_names, to_column_names) + + connection.execute(<<~SQL) + CREATE OR REPLACE FUNCTION #{trigger_name}() + RETURNS trigger AS + $BODY$ + BEGIN + #{assignment_clauses}; + RETURN NEW; + END; + $BODY$ + LANGUAGE 'plpgsql' + VOLATILE + SQL + + connection.execute(<<~SQL) + DROP TRIGGER IF EXISTS #{trigger_name} + ON #{quoted_table_name} + SQL + + connection.execute(<<~SQL) + CREATE TRIGGER #{trigger_name} + BEFORE INSERT OR UPDATE + ON #{quoted_table_name} + FOR EACH ROW + EXECUTE FUNCTION #{trigger_name}() + SQL + end + + def drop(trigger_name) + connection.execute("DROP TRIGGER IF EXISTS #{trigger_name} ON #{quoted_table_name}") + connection.execute("DROP FUNCTION IF EXISTS #{trigger_name}()") + end + + private + + attr_reader :table_name, :connection + + def initialize(table_name, connection) + @table_name = table_name + @connection = connection + end + + def quoted_table_name + @quoted_table_name ||= connection.quote_table_name(table_name) + end + + def check_column_names!(from_column_names, to_column_names) + from_column_names = Array.wrap(from_column_names) + to_column_names = Array.wrap(to_column_names) + + unless from_column_names.size == to_column_names.size + raise ArgumentError, 'number of source and destination columns must match' + end + + [from_column_names, to_column_names] + end + + def unchecked_name(from_column_names, to_column_names) + joined_column_names = from_column_names.zip(to_column_names).flatten.join('_') + 'trigger_' + Digest::SHA256.hexdigest("#{table_name}_#{joined_column_names}").first(12) + end + + def assignment_clauses_for_columns(from_column_names, to_column_names) + combined_column_names = to_column_names.zip(from_column_names) + + assignment_clauses = combined_column_names.map do |(new_name, old_name)| + new_name = connection.quote_column_name(new_name) + old_name = connection.quote_column_name(old_name) + + "NEW.#{new_name} := NEW.#{old_name}" + end + + assignment_clauses.join(";\n ") + end + end + end +end |