summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Burke <tim.burke@gmail.com>2022-12-06 11:21:37 -0800
committerTim Burke <tim.burke@gmail.com>2023-01-06 10:13:01 -0800
commita86270bed72cdbdf2f42eec562d9fae2d273fa64 (patch)
treecb404a85d716b3dc46684e26d25c7dd267479057
parent637b261cc60dbd6eca10feb399da92341268d550 (diff)
downloadswift-a86270bed72cdbdf2f42eec562d9fae2d273fa64.tar.gz
Inline parse_request from cpython
Applied deltas: - Fix http.client references - Inline HTTPStatus codes - Address request line splitting (https://bugs.python.org/issue33973) - Special-case py2 header-parsing - Address multiple leading slashes in request path (https://github.com/python/cpython/issues/99220) Closes-Bug: #1999278 Change-Id: Iae28097668213aa0734837ff21aef83251167d19 (cherry picked from commit 884f5538f8fb187b6ff18316249f1bd4b97b0952)
-rw-r--r--swift/common/http_protocol.py142
-rw-r--r--test/unit/common/test_http_protocol.py148
2 files changed, 212 insertions, 78 deletions
diff --git a/swift/common/http_protocol.py b/swift/common/http_protocol.py
index 59d7767de..4c416155d 100644
--- a/swift/common/http_protocol.py
+++ b/swift/common/http_protocol.py
@@ -16,8 +16,11 @@
from eventlet import wsgi, websocket
import six
-from swift.common.swob import wsgi_quote, wsgi_unquote, \
- wsgi_quote_plus, wsgi_unquote_plus, wsgi_to_bytes, bytes_to_wsgi
+
+if six.PY2:
+ from eventlet.green import httplib as http_client
+else:
+ from eventlet.green.http import client as http_client
class SwiftHttpProtocol(wsgi.HttpProtocol):
@@ -62,44 +65,115 @@ class SwiftHttpProtocol(wsgi.HttpProtocol):
return ''
def parse_request(self):
- # Need to track the bytes-on-the-wire for S3 signatures -- eventlet
- # would do it for us, but since we rewrite the path on py3, we need to
- # fix it ourselves later.
- self.__raw_path_info = None
+ """Parse a request (inlined from cpython@7e293984).
+
+ The request should be stored in self.raw_requestline; the results
+ are in self.command, self.path, self.request_version and
+ self.headers.
+
+ Return True for success, False for failure; on failure, any relevant
+ error response has already been sent back.
+ """
+ self.command = None # set in case of error on the first line
+ self.request_version = version = self.default_request_version
+ self.close_connection = True
+ requestline = self.raw_requestline
if not six.PY2:
- # request lines *should* be ascii per the RFC, but historically
- # we've allowed (and even have func tests that use) arbitrary
- # bytes. This breaks on py3 (see https://bugs.python.org/issue33973
- # ) but the work-around is simple: munge the request line to be
- # properly quoted.
- if self.raw_requestline.count(b' ') >= 2:
- parts = self.raw_requestline.split(b' ', 2)
- path, q, query = parts[1].partition(b'?')
- self.__raw_path_info = path
- # unquote first, so we don't over-quote something
- # that was *correctly* quoted
- path = wsgi_to_bytes(wsgi_quote(wsgi_unquote(
- bytes_to_wsgi(path))))
- query = b'&'.join(
- sep.join([
- wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus(
- bytes_to_wsgi(key)))),
- wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus(
- bytes_to_wsgi(val))))
- ])
- for part in query.split(b'&')
- for key, sep, val in (part.partition(b'='), ))
- parts[1] = path + q + query
- self.raw_requestline = b' '.join(parts)
- # else, mangled protocol, most likely; let base class deal with it
- return wsgi.HttpProtocol.parse_request(self)
+ requestline = requestline.decode('iso-8859-1')
+ requestline = requestline.rstrip('\r\n')
+ self.requestline = requestline
+ # Split off \x20 explicitly (see https://bugs.python.org/issue33973)
+ words = requestline.split(' ')
+ if len(words) == 0:
+ return False
+
+ if len(words) >= 3: # Enough to determine protocol version
+ version = words[-1]
+ try:
+ if not version.startswith('HTTP/'):
+ raise ValueError
+ base_version_number = version.split('/', 1)[1]
+ version_number = base_version_number.split(".")
+ # RFC 2145 section 3.1 says there can be only one "." and
+ # - major and minor numbers MUST be treated as
+ # separate integers;
+ # - HTTP/2.4 is a lower version than HTTP/2.13, which in
+ # turn is lower than HTTP/12.3;
+ # - Leading zeros MUST be ignored by recipients.
+ if len(version_number) != 2:
+ raise ValueError
+ version_number = int(version_number[0]), int(version_number[1])
+ except (ValueError, IndexError):
+ self.send_error(
+ 400,
+ "Bad request version (%r)" % version)
+ return False
+ if version_number >= (1, 1) and \
+ self.protocol_version >= "HTTP/1.1":
+ self.close_connection = False
+ if version_number >= (2, 0):
+ self.send_error(
+ 505,
+ "Invalid HTTP version (%s)" % base_version_number)
+ return False
+ self.request_version = version
+
+ if not 2 <= len(words) <= 3:
+ self.send_error(
+ 400,
+ "Bad request syntax (%r)" % requestline)
+ return False
+ command, path = words[:2]
+ if len(words) == 2:
+ self.close_connection = True
+ if command != 'GET':
+ self.send_error(
+ 400,
+ "Bad HTTP/0.9 request type (%r)" % command)
+ return False
+ self.command, self.path = command, path
+
+ # Examine the headers and look for a Connection directive.
+ if six.PY2:
+ self.headers = self.MessageClass(self.rfile, 0)
+ else:
+ try:
+ self.headers = http_client.parse_headers(
+ self.rfile,
+ _class=self.MessageClass)
+ except http_client.LineTooLong as err:
+ self.send_error(
+ 431,
+ "Line too long",
+ str(err))
+ return False
+ except http_client.HTTPException as err:
+ self.send_error(
+ 431,
+ "Too many headers",
+ str(err)
+ )
+ return False
+
+ conntype = self.headers.get('Connection', "")
+ if conntype.lower() == 'close':
+ self.close_connection = True
+ elif (conntype.lower() == 'keep-alive' and
+ self.protocol_version >= "HTTP/1.1"):
+ self.close_connection = False
+ # Examine the headers and look for an Expect directive
+ expect = self.headers.get('Expect', "")
+ if (expect.lower() == "100-continue" and
+ self.protocol_version >= "HTTP/1.1" and
+ self.request_version >= "HTTP/1.1"):
+ if not self.handle_expect_100():
+ return False
+ return True
if not six.PY2:
def get_environ(self, *args, **kwargs):
environ = wsgi.HttpProtocol.get_environ(self, *args, **kwargs)
- environ['RAW_PATH_INFO'] = bytes_to_wsgi(
- self.__raw_path_info)
header_payload = self.headers.get_payload()
if isinstance(header_payload, list) and len(header_payload) == 1:
header_payload = header_payload[0].get_payload()
diff --git a/test/unit/common/test_http_protocol.py b/test/unit/common/test_http_protocol.py
index 24e5225b2..b9962c7ff 100644
--- a/test/unit/common/test_http_protocol.py
+++ b/test/unit/common/test_http_protocol.py
@@ -15,6 +15,7 @@
from argparse import Namespace
from io import BytesIO
+import json
import mock
import types
import unittest
@@ -81,36 +82,14 @@ class TestSwiftHttpProtocol(unittest.TestCase):
], proto_obj.send_error.mock_calls)
self.assertEqual(('a', '123'), proto_obj.client_address)
- def test_request_line_cleanup(self):
- def do_test(line_from_socket, expected_line=None):
- if expected_line is None:
- expected_line = line_from_socket
-
- proto_obj = self._proto_obj()
- proto_obj.raw_requestline = line_from_socket
- with mock.patch('swift.common.http_protocol.wsgi.HttpProtocol') \
- as mock_super:
- proto_obj.parse_request()
-
- self.assertEqual([mock.call.parse_request(proto_obj)],
- mock_super.mock_calls)
- self.assertEqual(proto_obj.raw_requestline, expected_line)
-
- do_test(b'GET / HTTP/1.1')
- do_test(b'GET /%FF HTTP/1.1')
-
- if not six.PY2:
- do_test(b'GET /\xff HTTP/1.1', b'GET /%FF HTTP/1.1')
- do_test(b'PUT /Here%20Is%20A%20SnowMan:\xe2\x98\x83 HTTP/1.0',
- b'PUT /Here%20Is%20A%20SnowMan%3A%E2%98%83 HTTP/1.0')
- do_test(
- b'POST /?and%20it=fixes+params&'
- b'PALMTREE=\xf0%9f\x8c%b4 HTTP/1.1',
- b'POST /?and+it=fixes+params&PALMTREE=%F0%9F%8C%B4 HTTP/1.1')
+ def test_bad_request_line(self):
+ proto_obj = self._proto_obj()
+ proto_obj.raw_requestline = b'None //'
+ self.assertEqual(False, proto_obj.parse_request())
class ProtocolTest(unittest.TestCase):
- def _run_bytes_through_protocol(self, bytes_from_client):
+ def _run_bytes_through_protocol(self, bytes_from_client, app=None):
rfile = BytesIO(bytes_from_client)
wfile = BytesIO()
@@ -153,7 +132,7 @@ class ProtocolTest(unittest.TestCase):
with mock.patch.object(wfile, 'close', lambda: None), \
mock.patch.object(rfile, 'close', lambda: None):
eventlet.wsgi.server(
- fake_listen_socket, self.app,
+ fake_listen_socket, app or self.app,
protocol=self.protocol_class,
custom_pool=FakePool(),
log_output=False, # quiet the test run
@@ -170,37 +149,118 @@ class TestSwiftHttpProtocolSomeMore(ProtocolTest):
return [swob.wsgi_to_bytes(env['RAW_PATH_INFO'])]
def test_simple(self):
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
- ))
+ )
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
self.assertEqual(lines[-1], b'/someurl')
def test_quoted(self):
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
b"GET /some%fFpath%D8%AA HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
- ))
+ )
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
self.assertEqual(lines[-1], b'/some%fFpath%D8%AA')
def test_messy(self):
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
b"GET /oh\xffboy%what$now%E2%80%bd HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
- ))
+ )
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[-1], b'/oh\xffboy%what$now%E2%80%bd')
+ def test_bad_request(self):
+ bytes_out = self._run_bytes_through_protocol((
+ b"ONLY-METHOD\r\n"
+ b"Server: example.com\r\n"
+ b"\r\n"
+ ))
+ lines = [l for l in bytes_out.split(b"\r\n") if l]
+ self.assertEqual(
+ lines[0], b"HTTP/1.1 400 Bad request syntax ('ONLY-METHOD')")
+ self.assertIn(b"Bad request syntax or unsupported method.", lines[-1])
+
+ def test_leading_slashes(self):
+ bytes_out = self._run_bytes_through_protocol((
+ b"GET ///some-leading-slashes HTTP/1.0\r\n"
+ b"User-Agent: blah blah blah\r\n"
+ b"\r\n"
+ ))
+ lines = [l for l in bytes_out.split(b"\r\n") if l]
+ self.assertEqual(lines[-1], b'///some-leading-slashes')
+
+ def test_request_lines(self):
+ def app(env, start_response):
+ start_response("200 OK", [])
+ if six.PY2:
+ return [json.dumps({
+ 'RAW_PATH_INFO': env['RAW_PATH_INFO'].decode('latin1'),
+ 'QUERY_STRING': (None if 'QUERY_STRING' not in env else
+ env['QUERY_STRING'].decode('latin1')),
+ }).encode('ascii')]
+ return [json.dumps({
+ 'RAW_PATH_INFO': env['RAW_PATH_INFO'],
+ 'QUERY_STRING': env.get('QUERY_STRING'),
+ }).encode('ascii')]
+
+ def do_test(request_line, expected):
+ bytes_out = self._run_bytes_through_protocol(
+ request_line + b'\r\n\r\n',
+ app,
+ )
+ print(bytes_out)
+ resp_body = bytes_out.partition(b'\r\n\r\n')[2]
+ self.assertEqual(json.loads(resp_body), expected)
+
+ do_test(b'GET / HTTP/1.1', {
+ 'RAW_PATH_INFO': u'/',
+ 'QUERY_STRING': None,
+ })
+ do_test(b'GET /%FF HTTP/1.1', {
+ 'RAW_PATH_INFO': u'/%FF',
+ 'QUERY_STRING': None,
+ })
+
+ do_test(b'GET /\xff HTTP/1.1', {
+ 'RAW_PATH_INFO': u'/\xff',
+ 'QUERY_STRING': None,
+ })
+ do_test(b'PUT /Here%20Is%20A%20SnowMan:\xe2\x98\x83 HTTP/1.0', {
+ 'RAW_PATH_INFO': u'/Here%20Is%20A%20SnowMan:\xe2\x98\x83',
+ 'QUERY_STRING': None,
+ })
+ do_test(
+ b'POST /?and%20it=does+nothing+to+params&'
+ b'PALMTREE=\xf0%9f\x8c%b4 HTTP/1.1', {
+ 'RAW_PATH_INFO': u'/',
+ 'QUERY_STRING': (u'and%20it=does+nothing+to+params'
+ u'&PALMTREE=\xf0%9f\x8c%b4'),
+ }
+ )
+ do_test(b'GET // HTTP/1.1', {
+ 'RAW_PATH_INFO': u'//',
+ 'QUERY_STRING': None,
+ })
+ do_test(b'GET //bar HTTP/1.1', {
+ 'RAW_PATH_INFO': u'//bar',
+ 'QUERY_STRING': None,
+ })
+ do_test(b'GET //////baz HTTP/1.1', {
+ 'RAW_PATH_INFO': u'//////baz',
+ 'QUERY_STRING': None,
+ })
+
class TestProxyProtocol(ProtocolTest):
protocol_class = http_protocol.SwiftHttpProxiedProtocol
@@ -222,12 +282,12 @@ class TestProxyProtocol(ProtocolTest):
return [body.encode("utf-8")]
def test_request_with_proxy(self):
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 4433\r\n"
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
- ))
+ )
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
@@ -238,12 +298,12 @@ class TestProxyProtocol(ProtocolTest):
])
def test_request_with_proxy_https(self):
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
- ))
+ )
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
@@ -254,7 +314,7 @@ class TestProxyProtocol(ProtocolTest):
])
def test_multiple_requests_with_proxy(self):
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
b"GET /someurl HTTP/1.1\r\n"
b"User-Agent: something or other\r\n"
@@ -263,7 +323,7 @@ class TestProxyProtocol(ProtocolTest):
b"User-Agent: something or other\r\n"
b"Connection: close\r\n"
b"\r\n"
- ))
+ )
lines = bytes_out.split(b"\r\n")
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
@@ -277,12 +337,12 @@ class TestProxyProtocol(ProtocolTest):
self.assertEqual(addr_lines, [b"https is on (scheme https)"] * 2)
def test_missing_proxy_line(self):
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
# whoops, no PROXY line here
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
- ))
+ )
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertIn(b"400 Invalid PROXY line", lines[0])
@@ -303,12 +363,12 @@ class TestProxyProtocol(ProtocolTest):
for unknown_line in [b'PROXY UNKNOWN', # mimimal valid unknown
b'PROXY UNKNOWNblahblah', # also valid
b'PROXY UNKNOWN a b c d']:
- bytes_out = self._run_bytes_through_protocol((
+ bytes_out = self._run_bytes_through_protocol(
unknown_line + (b"\r\n"
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n")
- ))
+ )
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertIn(b"200 OK", lines[0])