summaryrefslogtreecommitdiff
path: root/choose_rounds.py
blob: 294995023aa13d6c71e62624e26beab752a64056 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""cli helper for selecting appropriate <rounds> value for a given hash"""
#=============================================================================
# imports
#=============================================================================
# core
import math
import logging; log = logging.getLogger(__name__)
import sys
# site
# pkg
from passlib.registry import get_crypt_handler
from passlib.utils import tick
# local
__all__ = [
    "main",
]

#=============================================================================
# main
#=============================================================================
_usage = "usage: python choose_rounds.py <hash_name> [<target_milliseconds>] [<backend>]\n"

def main(*args):
    #---------------------------------------------------------------
    # parse args
    #---------------------------------------------------------------
    args = list(args)
    def print_error(msg):
        print("error: %s\n" % msg)

    # parse hasher
    if args:
        name = args.pop(0)
        if name == "-h" or name == "--help":
            print(_usage)
            return 1
        try:
            hasher = get_crypt_handler(name)
        except KeyError:
            print_error("unknown hash %r" % name)
            return 1
        if 'rounds' not in hasher.setting_kwds:
            print_error("%s does not support variable rounds" % name)
            return 1
    else:
        print_error("hash name not specified")
        print(_usage)
        return 1

    # parse target time
    if args:
        try:
            target = int(args.pop(0))*.001
            if target <= 0:
                raise ValueError
        except ValueError:
            print_error("target time must be integer milliseconds > 0")
            return 1
    else:
        target = .350

    # parse backend
    if args:
        backend = args.pop(0)
        if hasattr(hasher, "set_backend"):
            hasher.set_backend(backend)
        else:
            print_error("%s does not support multiple backends")
            return 1

    #---------------------------------------------------------------
    # setup some helper functions
    #---------------------------------------------------------------
    if hasher.rounds_cost == "log2":
        # time cost varies logarithmically with rounds parameter,
        # so speed = (2**rounds) / elapsed
        def rounds_to_cost(rounds):
            return 2 ** rounds
        def cost_to_rounds(cost):
            return math.log(cost, 2)
    else:
        # time cost varies linearly with rounds parameter,
        # so speed = rounds / elapsed
        assert hasher.rounds_cost == "linear"
        rounds_to_cost = cost_to_rounds = lambda value: value

    def clamp_rounds(rounds):
        """convert float rounds to int value, clamped to hasher's limits"""
        if hasher.max_rounds and rounds > hasher.max_rounds:
            rounds = hasher.max_rounds
        rounds = int(rounds)
        if getattr(hasher, "_avoid_even_rounds", False):
            rounds |= 1
        return max(hasher.min_rounds, rounds)

    def average(seq):
        if not hasattr(seq, "__length__"):
            seq = tuple(seq)
        return sum(seq) / len(seq)

    def estimate_speed(rounds):
        """estimate speed using specified # of rounds"""
        # time a single verify() call
        secret = "S0m3-S3Kr1T"
        hash = hasher.using(rounds=rounds).hash(secret)
        def helper():
            start = tick()
            hasher.verify(secret, hash)
            return tick() - start
        # try to get average time over a few samples
        # XXX: way too much variability between sampling runs,
        #      would like to improve this bit
        elapsed = min(average(helper() for _ in range(4)) for _ in range(4))
        return rounds_to_cost(rounds) / elapsed

    #---------------------------------------------------------------
    # get rough estimate of speed using fraction of default_rounds
    # (so we don't take crazy long amounts of time on slow systems)
    #---------------------------------------------------------------
    rounds = clamp_rounds(cost_to_rounds(.5 * rounds_to_cost(hasher.default_rounds)))
    speed = estimate_speed(rounds)

    #---------------------------------------------------------------
    # re-do estimate using previous result,
    # to get more accurate sample using a larger number of rounds.
    #---------------------------------------------------------------
    for _ in range(2):
        rounds = clamp_rounds(cost_to_rounds(speed * target))
        speed = estimate_speed(rounds)

    #---------------------------------------------------------------
    # using final estimate, calc desired number of rounds for target time
    #---------------------------------------------------------------
    if hasattr(hasher, "backends"):
        name = "%s (using %s backend)" % (name, hasher.get_backend())
    print("hash............: %s" % name)
    if speed < 1000:
        speedstr = "%.2f" % speed
    else:
        speedstr = int(speed)
    print("speed...........: %s iterations/second" % speedstr)
    print("target time.....: %d ms" % (target*1000,))
    rounds = cost_to_rounds(speed * target)
    if hasher.rounds_cost == "log2":
        # for log2 rounds parameter, target time will usually fall
        # somewhere between two integer values, which will have large gulf
        # between them. if target is within <tolerance> percent of
        # one of two ends, report it, otherwise list both and let user decide.
        tolerance = .05
        lower = clamp_rounds(rounds)
        upper = clamp_rounds(math.ceil(rounds))
        lower_elapsed = rounds_to_cost(lower) / speed
        upper_elapsed = rounds_to_cost(upper) / speed
        if (target-lower_elapsed)/target < tolerance:
            print("target rounds...: %d" % lower)
        elif (upper_elapsed-target)/target < tolerance:
            print("target rounds...: %d" % upper)
        else:
            faster = (target - lower_elapsed)
            prin("target rounds...: %d (%dms -- %dms/%d%% faster than requested)" % \
                  (lower, lower_elapsed*1000, faster * 1000, round(100 * faster / target)))
            slower = (upper_elapsed - target)
            print("target rounds...: %d (%dms -- %dms/%d%% slower than requested)" % \
                  (upper, upper_elapsed*1000, slower * 1000, round(100 * slower / target)))
    else:
        # for linear rounds parameter, just use nearest integer value
        rounds = clamp_rounds(round(rounds))
        print("target rounds...: %d" % (rounds,))
    print()

if __name__ == "__main__":
    sys.exit(main(*sys.argv[1:]))

#=============================================================================
# eof
#=============================================================================