summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Williams <samuel.williams@oriontransfer.co.nz>2023-03-15 16:37:21 +1300
committerGitHub <noreply@github.com>2023-03-15 16:37:21 +1300
commitd25feddcbe634d95ec693bfbd710167a11c74069 (patch)
tree13a5a95e37d553586187a62a2ae8cfeae0a9026d
parent95d2f64a84a6010d5b1a5179a647b42cb53356d1 (diff)
downloadrack-d25feddcbe634d95ec693bfbd710167a11c74069.tar.gz
Fix handling of cached values in `Rack::Request`. (#2054)
* Per-class cache keys for cached query/body parameters. * Use the query parser class as the default cache key.
-rw-r--r--lib/rack/request.rb154
-rw-r--r--test/spec_request.rb66
2 files changed, 171 insertions, 49 deletions
diff --git a/lib/rack/request.rb b/lib/rack/request.rb
index e6969645..86e133cd 100644
--- a/lib/rack/request.rb
+++ b/lib/rack/request.rb
@@ -480,25 +480,114 @@ module Rack
PARSEABLE_DATA_MEDIA_TYPES.include?(media_type)
end
- # Returns the data received in the query string.
- def GET
- if get_header(RACK_REQUEST_QUERY_STRING) == query_string
- if query_hash = get_header(RACK_REQUEST_QUERY_HASH)
- return query_hash
+ # Given a current input value, and a validity key, check if the cache
+ # is valid, and if so, return the cached value. If not, yield the
+ # current value to the block, and set the cache to the result.
+ #
+ # This method does not use cache_key, so it is shared between all
+ # instance of Rack::Request and it's sub-classes.
+ private def cache_for(key, validity_key, current_value)
+ # Get the current value of the validity key and compare it with the input value:
+ if get_header(validity_key).equal?(current_value)
+ # If the values are the same, then the cache is valid, so return the cached value.
+ if has_header?(key)
+ value = get_header(key)
+ # If the cached value is an exception, then re-raise it.
+ if value.is_a?(Exception)
+ raise value.class, value.message, cause: value.cause
+ else
+ # Otherwise, return the cached value.
+ return value
+ end
end
end
- set_header(RACK_REQUEST_QUERY_HASH, expand_params(query_param_list))
+ # If the cache is not valid, then yield the current value to the block:
+ value = yield(current_value)
+
+ # Set the validity key to the current value so that we can detect changes:
+ set_header(validity_key, current_value)
+
+ # Set the cache to the result of the block, and return the result:
+ set_header(key, value)
+ rescue => error
+ # If an exception is raised, then set the cache to the exception, and re-raise it:
+ set_header(validity_key, current_value)
+ set_header(key, error)
+ raise
+ end
+
+ # This cache key is used by cached values generated by class_cache_for,
+ # specfically GET and POST. This is to ensure that the cache is not
+ # shared between instances of different classes which have different
+ # behaviour. This includes sub-classes that override query_parser or
+ # expand_params.
+ def cache_key
+ query_parser.class
+ end
+
+ # Given a current input value, and a validity key, check if the cache
+ # is valid, and if so, return the cached value. If not, yield the
+ # current value to the block, and set the cache to the result.
+ #
+ # This method uses cache_key to ensure that the cache is not shared
+ # between instances of different classes which have different
+ # behaviour of the cached operations.
+ private def class_cache_for(key, validity_key, current_value)
+ # The cache is organised in the env as:
+ # env[key][cache_key] = value
+ # and is valid as long as env[validity_key].equal?(current_value)
+
+ cache_key = self.cache_key
+
+ # Get the current value of the validity key and compare it with the input value:
+ if get_header(validity_key).equal?(current_value)
+ # Lookup the cache for the current cache key:
+ if cache = get_header(key)
+ if cache.key?(cache_key)
+ # If the cache is valid, then return the cached value.
+ value = cache[cache_key]
+ if value.is_a?(Exception)
+ # If the cached value is an exception, then re-raise it.
+ raise value.class, value.message, cause: value.cause
+ else
+ # Otherwise, return the cached value.
+ return value
+ end
+ end
+ end
+ end
+
+ # If the cache was not defined for this cache key, then create a new cache:
+ unless cache
+ set_header(key, cache = {})
+ end
+
+ begin
+ # Yield the current value to the block to generate an updated value:
+ value = yield(current_value)
+
+ # Only set this after generating the value, so that if an error or other cache depending on the same key, it will be invalidated correctly:
+ set_header(validity_key, current_value)
+ return cache[cache_key] = value
+ rescue => error
+ set_header(validity_key, current_value)
+ cache[cache_key] = error
+ raise
+ end
+ end
+
+ # Returns the data received in the query string.
+ def GET
+ class_cache_for(RACK_REQUEST_QUERY_HASH, RACK_REQUEST_QUERY_STRING, query_string) do
+ expand_params(query_param_list)
+ end
end
def query_param_list
- if get_header(RACK_REQUEST_QUERY_STRING) == query_string
- get_header(RACK_REQUEST_QUERY_PAIRS)
- else
- query_pairs = split_query(query_string, '&')
- set_header RACK_REQUEST_QUERY_STRING, query_string
- set_header RACK_REQUEST_QUERY_HASH, nil
- set_header(RACK_REQUEST_QUERY_PAIRS, query_pairs)
+ cache_for(RACK_REQUEST_QUERY_PAIRS, RACK_REQUEST_QUERY_STRING, query_string) do
+ set_header(RACK_REQUEST_QUERY_HASH, nil)
+ split_query(query_string, '&')
end
end
@@ -507,33 +596,13 @@ module Rack
# This method support both application/x-www-form-urlencoded and
# multipart/form-data.
def POST
- if get_header(RACK_REQUEST_FORM_INPUT).equal?(get_header(RACK_INPUT))
- if form_hash = get_header(RACK_REQUEST_FORM_HASH)
- return form_hash
- end
+ class_cache_for(RACK_REQUEST_FORM_HASH, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do
+ expand_params(body_param_list)
end
-
- set_header(RACK_REQUEST_FORM_HASH, expand_params(body_param_list))
end
def body_param_list
- if error = get_header(RACK_REQUEST_FORM_ERROR)
- raise error.class, error.message, cause: error.cause
- end
-
- begin
- rack_input = get_header(RACK_INPUT)
-
- form_pairs = nil
-
- # If the form data has already been memoized from the same
- # input:
- if get_header(RACK_REQUEST_FORM_INPUT).equal?(rack_input)
- if form_pairs = get_header(RACK_REQUEST_FORM_PAIRS)
- return form_pairs
- end
- end
-
+ cache_for(RACK_REQUEST_FORM_PAIRS, RACK_REQUEST_FORM_INPUT, get_header(RACK_INPUT)) do |rack_input|
if rack_input.nil?
form_pairs = []
elsif form_data? || parseable_data?
@@ -544,19 +613,16 @@ module Rack
# form_vars.sub!(/\0\z/, '') # performance replacement:
form_vars.slice!(-1) if form_vars.end_with?("\0")
- set_header RACK_REQUEST_FORM_VARS, form_vars
+ # Removing this line breaks Rail test "test_filters_rack_request_form_vars"!
+ set_header(RACK_REQUEST_FORM_VARS, form_vars)
+
form_pairs = split_query(form_vars, '&')
end
else
form_pairs = []
end
-
- set_header RACK_REQUEST_FORM_INPUT, rack_input
- set_header RACK_REQUEST_FORM_HASH, nil
- set_header(RACK_REQUEST_FORM_PAIRS, form_pairs)
- rescue => error
- set_header(RACK_REQUEST_FORM_ERROR, error)
- raise
+
+ form_pairs
end
end
diff --git a/test/spec_request.rb b/test/spec_request.rb
index 2a3f792a..e525621a 100644
--- a/test/spec_request.rb
+++ b/test/spec_request.rb
@@ -1554,12 +1554,19 @@ EOF
rack_input.write(input)
rack_input.rewind
- req = make_request Rack::MockRequest.env_for("/",
- "rack.request.form_hash" => { 'foo' => 'bar' },
- "rack.request.form_input" => rack_input,
- :input => rack_input)
+ form_hash_cache = {}
+
+ req = make_request Rack::MockRequest.env_for(
+ "/",
+ "rack.request.form_hash" => form_hash_cache,
+ "rack.request.form_input" => rack_input,
+ :input => rack_input
+ )
- req.POST.must_equal req.env['rack.request.form_hash']
+ form_hash = {'foo' => 'bar'}.freeze
+ form_hash_cache[req.cache_key] = form_hash
+
+ req.POST.must_equal form_hash
end
it "conform to the Rack spec" do
@@ -1957,4 +1964,53 @@ EOF
DelegateRequest.new super(env)
end
end
+
+ class UpperRequest < Rack::Request
+ def expand_params(parameters)
+ parameters.map do |(key, value)|
+ [key.upcase, value]
+ end.to_h
+ end
+
+ # If this is not specified, the behaviour becomes order dependent.
+ def cache_key
+ :my_request
+ end
+ end
+
+ it "correctly expands parameters" do
+ env = {"QUERY_STRING" => "foo=bar"}
+
+ request = Rack::Request.new(env)
+ request.query_param_list.must_equal [["foo", "bar"]]
+ request.GET.must_equal "foo" => "bar"
+
+ upper_request = UpperRequest.new(env)
+ upper_request.query_param_list.must_equal [["foo", "bar"]]
+ upper_request.GET.must_equal "FOO" => "bar"
+
+ env['QUERY_STRING'] = "foo=bar&bar=baz"
+
+ request.GET.must_equal "foo" => "bar", "bar" => "baz"
+ upper_request.GET.must_equal "FOO" => "bar", "BAR" => "baz"
+ end
+
+ class BrokenRequest < Rack::Request
+ def expand_params(parameters)
+ raise "boom"
+ end
+ end
+
+ it "raises an error if expand_params raises an error" do
+ env = {"QUERY_STRING" => "foo=bar"}
+
+ request = Rack::Request.new(env)
+ request.GET.must_equal "foo" => "bar"
+
+ broken_request = BrokenRequest.new(env)
+ lambda { broken_request.GET }.must_raise RuntimeError
+
+ # Subsequnt calls also raise an error:
+ lambda { broken_request.GET }.must_raise RuntimeError
+ end
end