summaryrefslogtreecommitdiff
path: root/app/services/ml/experiment_tracking/candidate_repository.rb
blob: f1fd93d78162c0e2b0274b5e144df62fb8d32f50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# frozen_string_literal: true

module Ml
  module ExperimentTracking
    class CandidateRepository
      attr_accessor :project, :user, :experiment, :candidate

      def initialize(project, user)
        @project = project
        @user = user
      end

      def by_iid(iid)
        ::Ml::Candidate.with_project_id_and_iid(project.id, iid)
      end

      def create!(experiment, start_time, tags = nil, name = nil)
        candidate = experiment.candidates.create!(
          user: user,
          name: candidate_name(name, tags),
          start_time: start_time || 0
        )

        add_tags(candidate, tags)

        candidate
      end

      def update(candidate, status, end_time)
        candidate.status = status.downcase if status
        candidate.end_time = end_time if end_time

        candidate.save
      end

      def add_metric!(candidate, name, value, tracked_at, step)
        candidate.metrics.create!(
          name: name,
          value: value,
          tracked_at: tracked_at,
          step: step
        )
      end

      def add_param!(candidate, name, value)
        candidate.params.create!(name: name, value: value)
      end

      def add_tag!(candidate, name, value)
        candidate.metadata.create!(name: name, value: value)
      end

      def add_metrics(candidate, metric_definitions)
        extra_keys = { tracked_at: :timestamp, step: :step }
        insert_many(candidate, metric_definitions, ::Ml::CandidateMetric, extra_keys)
      end

      def add_params(candidate, param_definitions)
        insert_many(candidate, param_definitions, ::Ml::CandidateParam)
      end

      def add_tags(candidate, tag_definitions)
        insert_many(candidate, tag_definitions, ::Ml::CandidateMetadata)
      end

      private

      def timestamps
        current_time = Time.zone.now

        { created_at: current_time, updated_at: current_time }
      end

      def insert_many(candidate, definitions, entity_class, extra_keys = {})
        return unless candidate.present? && definitions.present?

        entities = definitions.map do |d|
          {
            candidate_id: candidate.id,
            name: d[:key],
            value: d[:value],
            **extra_keys.transform_values { |old_key| d[old_key] },
            **timestamps
          }
        end

        entity_class.insert_all(entities, returning: false) unless entities.empty?
      end

      def candidate_name(name, tags)
        return name if name.present?
        return unless tags.present?

        tags.detect { |t| t[:key] == 'mlflow.runName' }&.dig(:value)
      end
    end
  end
end