summaryrefslogtreecommitdiff
path: root/chromium/third_party/safe_browsing/testing/safebrowsing_test_server.py
blob: 65d0f4afda6ef773aa4b4526dad6fb45b59e8347 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
#!/usr/bin/python
#
# Copyright 2009 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test server for Safebrowsing protocol v2.

To test an implementation of the safebrowsing protocol, this server should
be run on the same machine as the client implementation.  The client should
connect to this server at localhost:port where port is specified as a command
line flag (--port) and perform updates normally, except that each request
should have an additional CGI param "test_step" that specifies which update
request this is for the client.  That is, it should be incremented after the
complete parsing of a downloads request so a downloads request and its
associated redirects should all have the same test_step. The client should
also make a newkey request and then a getlists requests before making the
first update request and should use test_step=1 for these requests (test_step
is 1 indexed). When the client believes that it is done with testing (because
it recieves a response from an update request with no new data), it should
make a "/test_complete" request. This will return either "yes" or "no" if the
test is complete or not.
"""

__author__ = 'gcasto@google.com (Garrett Casto)'

import BaseHTTPServer
import binascii
import base64
import cgi
import hmac
from optparse import OptionParser
import re
import sha
import sys
from threading import Timer
import time
import urlparse

import external_test_pb2

DEFAULT_PORT = 40101
DEFAULT_DATAFILE_LOCATION = "testing_input.dat"
POST_DATA_KEY = "post_data"
GETHASH_PATH = "/safebrowsing/gethash"
RESET_PATH="/reset"
DOWNLOADS_PATH = "/safebrowsing/downloads"
TEST_COMPLETE_PATH = "/test_complete"
DATABASE_VALIDATION_PATH = "/safebrowsing/verify_database"

# Dict of step -> List of (request_path, param key, response)
response_data_by_step = {}
# Dict of step -> Dict of hash_prefix ->
# (full length hashes responses, num times requested)
hash_data_by_step = {}
client_key = None
enforce_caching = False
validate_database = True
datafile_location = ''

def EndServer():
  sys.exit(0)

def CGIParamsToListOfTuples(cgi_params):
  return [(param.Name, param.Value) for param in cgi_params]

def SortedTupleFromParamsAndPostData(params,
                                     post_data):
  """ Make a sorted tuple from the request such that it can be inserted as
  a key in a map. params is a list of (name, value) tuples and post_data is
  a string (or None).
  """
  if post_data:
    params.append((POST_DATA_KEY, tuple(sorted(post_data.split('\n')))))
  return tuple(sorted(params))

def LoadData():
  """ Load data from datafile_location to be used by the testing server.
  """
  global response_data_by_step
  global hash_data_by_step
  global client_key
  data_file = open(datafile_location, 'rb')
  str_data = data_file.read()
  test_data = external_test_pb2.TestData()
  test_data.ParseFromString(str_data)
  print "Data Loaded"
  if test_data.HasField('ClientKey'):
    client_key = test_data.ClientKey
  else:
    client_key = None
  step = 0
  response_data_by_step = {}
  hash_data_by_step = {}
  for step_data in test_data.Steps:
    step += 1
    step_list = []
    for request_data in step_data.Requests:
      params_tuple = SortedTupleFromParamsAndPostData(
          CGIParamsToListOfTuples(request_data.Params),
          request_data.PostData)
      step_list.append((request_data.RequestPath,
                        params_tuple,
                        request_data.ServerResponse))
    response_data_by_step[step] = step_list

    hash_step_dict = {}
    for hash_request in step_data.Hashes:
      hash_step_dict[hash_request.HashPrefix] = (hash_request.ServerResponse,
                                                 hash_request.Expression,
                                                 0)
    hash_data_by_step[step] = hash_step_dict
  print "Data Parsed"

def VerifyTestComplete():
  """ Returns true if all the necessary requests have been made by the client.
  """
  global response_data_by_step
  global hash_data_by_step
  global enforce_caching

  complete = True
  for (step, step_list) in response_data_by_step.iteritems():
    if len(step_list):
      print ("Step %s has %d request(s) that were not made %s" %
             (step, len(step_list), step_list))
      complete = False

  for (step, hash_step_dict) in hash_data_by_step.iteritems():
    for (prefix,
         (response, expression, num_requests)) in hash_step_dict.iteritems():
      if ((enforce_caching and num_requests != 1) or
          num_requests == 0):
        print ("Hash prefix %s not requested the correct number of times"
               "(%d requests). Requests originated because of expression"
               " %s. Prefix is located in the following locations" %
               (binascii.hexlify(prefix),
                num_requests,
                expression))
        # This information is slightly redundant with what will be printed below
        # but it is occasionally worth seeing.
        print "Response %s" % response
        cur_index = 0
        while cur_index < len(response):
          end_header_index = response.find('\n', cur_index + 1)
          header = response[cur_index:end_header_index]
          (listname, chunk_num, hashdatalen) = header.split(":")
          print "   List '%s' in add chunk num %s" % (listname, chunk_num)
          cur_index = end_header_index + int(hashdatalen) + 1

        complete = False

  # TODO(gcasto): Have a check here that verifies that the client doesn't
  # make too many hash requests during the test run.

  return complete

class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
  def ParamDictToListOfTuples(self, params):
    # params is a list cgi params to list of specified values.  Since we never
    # expect a parameter to be specified multiple times, we just take the first
    # one.
    return [(name, value[0]) for (name, value) in params.iteritems()]

  def MakeParamKey(self, params, post_data=None):
    """ Make a lookup key from the request.
    """
    return SortedTupleFromParamsAndPostData(
        self.ParamDictToListOfTuples(params),
        post_data)

  def MACResponse(self, response, is_downloads_request):
    """ Returns the response wrapped with a MAC. Formatting will change
    if this is a downloads_request or hashserver_request.  If no client_key
    is set, returns the response as-is.
    """
    if client_key is None:
      return response
    unescaped_mac = hmac.new(client_key, response, sha).digest()
    return "%s%s\n%s" % (is_downloads_request and "m:" or "",
                       base64.urlsafe_b64encode(unescaped_mac),
                       response)

  def VerifyRequest(self, is_post_request):
    """ Verify that the request matches one loaded from the datafile and
    give the corresponding response. If there is no match, try and give a
    descriptive error message in the response.
    """
    parsed_url = urlparse.urlparse(self.path)
    path = parsed_url[2]
    params = cgi.parse_qs(parsed_url[4])

    step = params.get("test_step")
    if step is None or len(step) != 1:
      self.send_response(400)
      self.end_headers()
      print "No test step present."
      return
    step = int(step[0])

    if path == TEST_COMPLETE_PATH:
      self.send_response(200)
      self.end_headers()
      if VerifyTestComplete():
        self.wfile.write('yes')
      else:
        self.wfile.write('no')
    elif path == GETHASH_PATH:
      self.SynthesizeGethashResponse(step)
    elif path == RESET_PATH:
      LoadData()
      self.send_response(200)
      self.end_headers()
      self.wfile.write('done')
    else:
      self.GetCannedResponse(path, params, step, is_post_request)

  def SynthesizeGethashResponse(self, step):
    """ Create a gethash response. This will possibly combine an arbitrary
    number of hash requests from the protocol buffer.
    """
    global hash_data_by_step

    hashes_for_step = hash_data_by_step.get(step, {})
    if not hashes_for_step:
      self.send_response(400)
      self.end_headers()
      print "No response for step %d" % step
      return

    post_data = self.rfile.read(int(self.headers['Content-Length']))
    match = re.match(
        r'(?P<prefixsize>\d+):(?P<totalsize>\d+)\n(?P<prefixes>.+)',
        post_data,
        re.MULTILINE | re.IGNORECASE | re.DOTALL)
    if not match:
      self.send_response(400)
      self.end_headers()
      print "Gethash request is malformed %s" % post_data
      return

    prefixsize = int(match.group('prefixsize'))
    total_length = int(match.group('totalsize'))
    if total_length % prefixsize != 0:
      self.send_response(400)
      self.end_headers()
      print ("Gethash request is malformed, length should be a multiple of the "
             " prefix size%s" % post_data)
      return

    response = ""
    for n in range(total_length/prefixsize):
      prefix = match.group('prefixes')[n*prefixsize:n*prefixsize + prefixsize]
      hash_data = hashes_for_step.get(prefix)
      if hash_data is not None:
        # Reply with the correct response
        response += hash_data[0]
        # Remember that this hash has now been requested.
        hashes_for_step[prefix] = (hash_data[0], hash_data[1], hash_data[2] + 1)

    if not response:
      self.send_response(204)
      self.end_headers()
      return

    # Need to perform MACing before sending response out.
    self.send_response(200)
    self.end_headers()
    self.wfile.write(self.MACResponse(response, False))

  def GetCannedResponse(self, path, params, step, is_post_request):
    """ Given the parameters of a request, see if a matching response is
    found. If one is found, respond with with it, else respond with a 400.
    """
    responses_for_step = response_data_by_step.get(step)
    if not responses_for_step:
      self.send_response(400)
      self.end_headers()
      print "No responses for step %d" % step
      return

    # Delete unnecessary params
    del params["test_step"]
    if "client" in params:
      del params["client"]
    if "appver" in params:
      del params["appver"]
    if "key" in params:
      del params["key"]

    param_key = self.MakeParamKey(
        params,
        is_post_request and
        self.rfile.read(int(self.headers['Content-Length'])) or
        None)

    (expected_path, expected_params, server_response) = responses_for_step[0]
    if expected_path != path or param_key != expected_params:
      self.send_response(400)
      self.end_headers()
      print "Expected request with path %s and params %s." % (expected_path,
                                                              expected_params)
      print "Actual request path %s and params %s" % (path, param_key)
      return

    # Remove request that was just made
    responses_for_step.pop(0)

    # If the next request is not needed for this test run, remove it now.
    # We do this after processing instead of before for cases where the
    # data we are removing is the last requests in a step.
    if responses_for_step:
      (expected_path, _, _) = responses_for_step[0]
      if expected_path == DATABASE_VALIDATION_PATH and not validate_database:
        responses_for_step.pop(0)

    if path == DOWNLOADS_PATH:
      # Need to have the redirects point to the current port.
      server_response = re.sub(r'localhost:\d+',
                               '%s:%d' % (self.server.server_address[0],
                                          self.server.server_port),
                               server_response)
      # Remove the current MAC, because it's going to be wrong now.
      if server_response.startswith('m:'):
        server_response = server_response[server_response.find('\n')+1:]
      # Add a new correct MAC.
      server_response = self.MACResponse(server_response, True)

    self.send_response(200)
    self.end_headers()
    self.wfile.write(server_response)

  def do_GET(self):
    self.VerifyRequest(False)

  def do_POST(self):
    self.VerifyRequest(True)


def SetupServer(opt_datafile_location,
                host,
                port,
                opt_enforce_caching,
                opt_validate_database):
  """Sets up the safebrowsing test server.

  Arguments:
    datafile_location: The file to load testing data from.
    port: port that the server runs on.
    opt_enforce_caching: Whether to require the client to implement caching.
    opt_validate_database: Whether to require the client makes database
         verification requests.

  Returns:
    An HTTPServer object which the caller should call serve_forever() on.
  """
  global datafile_location
  datafile_location = opt_datafile_location
  LoadData()
  # TODO(gcasto):  Look into extending HTTPServer to remove global variables.
  global enforce_caching
  global validate_database
  enforce_caching = opt_enforce_caching
  validate_database = opt_validate_database
  return BaseHTTPServer.HTTPServer((host, port), RequestHandler)

if __name__ == '__main__':
  parser = OptionParser()
  parser.add_option("--datafile", dest="datafile_location",
                    default=DEFAULT_DATAFILE_LOCATION,
                    help="Location to load testing data from.")
  parser.add_option("--host", dest="host",
                    default='localhost', help="Host the server should bind.")
  parser.add_option("--port", dest="port", type="int",
                    default=DEFAULT_PORT, help="Port to run the server on.")
  parser.add_option("--enforce_caching", dest="enforce_caching",
                    action="store_true", default=False,
                    help="Whether to require that the client"
                    "has implemented caching or not.")
  parser.add_option("--ignore_database_validation", dest="validate_database",
                    action="store_false", default=True,
                    help="Whether to requires that the client makes verify "
                    "database requests or not.")
  parser.add_option("--server_timeout_sec", dest="server_timeout_sec",
                    type="int", default=600,
                    help="How long to let the server run before shutting it "
                    "down. If <=0, the server will never be down")
  (options, _) = parser.parse_args()

  server = SetupServer(options.datafile_location,
                       options.host,
                       options.port,
                       options.enforce_caching,
                       options.validate_database)

  if (options.server_timeout_sec > 0):
    tm = Timer(options.server_timeout_sec, EndServer)
    tm.start()

  try:
    server.serve_forever()
  except KeyboardInterrupt:
    pass
  server.server_close()
  print "Server stopped."