diff options
author | Samuel Williams <samuel.williams@oriontransfer.co.nz> | 2023-03-15 16:37:21 +1300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-15 16:37:21 +1300 |
commit | d25feddcbe634d95ec693bfbd710167a11c74069 (patch) | |
tree | 13a5a95e37d553586187a62a2ae8cfeae0a9026d | |
parent | 95d2f64a84a6010d5b1a5179a647b42cb53356d1 (diff) | |
download | rack-d25feddcbe634d95ec693bfbd710167a11c74069.tar.gz |
Fix handling of cached values in `Rack::Request`. (#2054)
* Per-class cache keys for cached query/body parameters.
* Use the query parser class as the default cache key.
-rw-r--r-- | lib/rack/request.rb | 154 | ||||
-rw-r--r-- | test/spec_request.rb | 66 |
2 files changed, 171 insertions, 49 deletions
diff --git a/lib/rack/request.rb b/lib/rack/request.rb index e6969645..86e133cd 100644 --- a/lib/rack/request.rb +++ b/lib/rack/request.rb @@ -480,25 +480,114 @@ module Rack PARSEABLE_DATA_MEDIA_TYPES.include?(media_type) end - # Returns the data received in the query string. - def GET - if get_header(RACK_REQUEST_QUERY_STRING) == query_string - if query_hash = get_header(RACK_REQUEST_QUERY_HASH) - return query_hash + # Given a current input value, and a validity key, check if the cache + # is valid, and if so, return the cached value. If not, yield the + # current value to the block, and set the cache to the result. + # + # This method does not use cache_key, so it is shared between all + # instance of Rack::Request and it's sub-classes. + private def cache_for(key, validity_key, current_value) + # Get the current value of the validity key and compare it with the input value: + if get_header(validity_key).equal?(current_value) + # If the values are the same, then the cache is valid, so return the cached value. + if has_header?(key) + value = get_header(key) + # If the cached value is an exception, then re-raise it. + if value.is_a?(Exception) + raise value.class, value.message, cause: value.cause + else + # Otherwise, return the cached value. + return value + end end end - set_header(RACK_REQUEST_QUERY_HASH, expand_params(query_param_list)) + # If the cache is not valid, then yield the current value to the block: + value = yield(current_value) + + # Set the validity key to the current value so that we can detect changes: + set_header(validity_key, current_value) + + # Set the cache to the result of the block, and return the result: + set_header(key, value) + rescue => error + # If an exception is raised, then set the cache to the exception, and re-raise it: + set_header(validity_key, current_value) + set_header(key, error) + raise + end + + # This cache key is used by cached values generated by class_cache_for, + # specfically GET and POST. This is to ensure that the cache is not + # shared between instances of different classes which have different + # behaviour. This includes sub-classes that override query_parser or + # expand_params. + def cache_key + query_parser.class + end + + # Given a current input value, and a validity key, check if the cache + # is valid, and if so, return the cached value. If not, yield the + # current value to the block, and set the cache to the result. + # + # This method uses cache_key to ensure that the cache is not shared + # between instances of different classes which have different + # behaviour of the cached operations. + private def class_cache_for(key, validity_key, current_value) + # The cache is organised in the env as: + # env[key][cache_key] = value + # and is valid as long as env[validity_key].equal?(current_value) + + cache_key = self.cache_key + + # Get the current value of the validity key and compare it with the input value: + if get_header(validity_key).equal?(current_value) + # Lookup the cache for the current cache key: + if cache = get_header(key) + if cache.key?(cache_key) + # If the cache is valid, then return the cached value. + value = cache[cache_key] + if value.is_a?(Exception) + # If the cached value is an exception, then re-raise it. + raise value.class, value.message, cause: value.cause + else + # Otherwise, return the cached value. + return value + end + end + end + end + + # If the cache was not defined for this cache key, then create a new cache: + unless cache + set_header(key, cache = {}) + end + + begin + # Yield the current value to the block to generate an updated value: + value = yield(current_value) + + # Only set this after generating the value, so that if an error or other cache depending on the same key, it will be invalidated correctly: + set_header(validity_key, current_value) + return cache[cache_key] = value + rescue => error + set_header(validity_key, current_value) + cache[cache_key] = error + raise + end + end + + # Returns the data received in the query string. + def GET + class_cache_for(RACK_REQUEST_QUERY_HASH, RACK_REQUEST_QUERY_STRING, query_string) do + expand_params(query_param_list) + end end def query_param_list - if get_header(RACK_REQUEST_QUERY_STRING) == query_string - get_header(RACK_REQUEST_QUERY_PAIRS) - else - query_pairs = split_query(query_string, '&') - set_header RACK_REQUEST_QUERY_STRING, query_string - set_header RACK_REQUEST_QUERY_HASH, nil - set_header(RACK_REQUEST_QUERY_PAIRS, query_pairs) + cache_for(RACK_REQUEST_QUERY_PAIRS, RACK_REQUEST_QUERY_STRING, query_string) do + set_header(RACK_REQUEST_QUERY_HASH, nil) + split_query(query_string, '&') end end @@ -507,33 +596,13 @@ module Rack # This method support both application/x-www-form-urlencoded and # multipart/form-data. def POST - if get_header(RACK_REQUEST_FORM_INPUT).equal?(get_header(RACK_INPUT)) - if form_hash = get_header(RACK_REQUEST_FORM_HASH) - return form_hash - end + class_cache_for(RACK_REQUEST_FORM_HASH, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do + expand_params(body_param_list) end - - set_header(RACK_REQUEST_FORM_HASH, expand_params(body_param_list)) end def body_param_list - if error = get_header(RACK_REQUEST_FORM_ERROR) - raise error.class, error.message, cause: error.cause - end - - begin - rack_input = get_header(RACK_INPUT) - - form_pairs = nil - - # If the form data has already been memoized from the same - # input: - if get_header(RACK_REQUEST_FORM_INPUT).equal?(rack_input) - if form_pairs = get_header(RACK_REQUEST_FORM_PAIRS) - return form_pairs - end - end - + cache_for(RACK_REQUEST_FORM_PAIRS, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do |rack_input| if rack_input.nil? form_pairs = [] elsif form_data? || parseable_data? @@ -544,19 +613,16 @@ module Rack # form_vars.sub!(/\0\z/, '') # performance replacement: form_vars.slice!(-1) if form_vars.end_with?("\0") - set_header RACK_REQUEST_FORM_VARS, form_vars + # Removing this line breaks Rail test "test_filters_rack_request_form_vars"! + set_header(RACK_REQUEST_FORM_VARS, form_vars) + form_pairs = split_query(form_vars, '&') end else form_pairs = [] end - - set_header RACK_REQUEST_FORM_INPUT, rack_input - set_header RACK_REQUEST_FORM_HASH, nil - set_header(RACK_REQUEST_FORM_PAIRS, form_pairs) - rescue => error - set_header(RACK_REQUEST_FORM_ERROR, error) - raise + + form_pairs end end diff --git a/test/spec_request.rb b/test/spec_request.rb index 2a3f792a..e525621a 100644 --- a/test/spec_request.rb +++ b/test/spec_request.rb @@ -1554,12 +1554,19 @@ EOF rack_input.write(input) rack_input.rewind - req = make_request Rack::MockRequest.env_for("/", - "rack.request.form_hash" => { 'foo' => 'bar' }, - "rack.request.form_input" => rack_input, - :input => rack_input) + form_hash_cache = {} + + req = make_request Rack::MockRequest.env_for( + "/", + "rack.request.form_hash" => form_hash_cache, + "rack.request.form_input" => rack_input, + :input => rack_input + ) - req.POST.must_equal req.env['rack.request.form_hash'] + form_hash = {'foo' => 'bar'}.freeze + form_hash_cache[req.cache_key] = form_hash + + req.POST.must_equal form_hash end it "conform to the Rack spec" do @@ -1957,4 +1964,53 @@ EOF DelegateRequest.new super(env) end end + + class UpperRequest < Rack::Request + def expand_params(parameters) + parameters.map do |(key, value)| + [key.upcase, value] + end.to_h + end + + # If this is not specified, the behaviour becomes order dependent. + def cache_key + :my_request + end + end + + it "correctly expands parameters" do + env = {"QUERY_STRING" => "foo=bar"} + + request = Rack::Request.new(env) + request.query_param_list.must_equal [["foo", "bar"]] + request.GET.must_equal "foo" => "bar" + + upper_request = UpperRequest.new(env) + upper_request.query_param_list.must_equal [["foo", "bar"]] + upper_request.GET.must_equal "FOO" => "bar" + + env['QUERY_STRING'] = "foo=bar&bar=baz" + + request.GET.must_equal "foo" => "bar", "bar" => "baz" + upper_request.GET.must_equal "FOO" => "bar", "BAR" => "baz" + end + + class BrokenRequest < Rack::Request + def expand_params(parameters) + raise "boom" + end + end + + it "raises an error if expand_params raises an error" do + env = {"QUERY_STRING" => "foo=bar"} + + request = Rack::Request.new(env) + request.GET.must_equal "foo" => "bar" + + broken_request = BrokenRequest.new(env) + lambda { broken_request.GET }.must_raise RuntimeError + + # Subsequnt calls also raise an error: + lambda { broken_request.GET }.must_raise RuntimeError + end end |