diff options
author | Robert Bradshaw <robertwb@gmail.com> | 2016-12-06 03:06:27 -0800 |
---|---|---|
committer | Robert Bradshaw <robertwb@gmail.com> | 2016-12-06 03:06:27 -0800 |
commit | a531a31c64f76cf5031f04b79b64f2cae66a116b (patch) | |
tree | 2bfdf9604503dad5195d8e77485b2e53da0ac10c | |
parent | 2a3603379307c489e743854caacf81b20124d854 (diff) | |
download | cython-a531a31c64f76cf5031f04b79b64f2cae66a116b.tar.gz |
Walk cpp class hierarchy for class template deduction.
-rw-r--r-- | Cython/Compiler/PyrexTypes.py | 28 | ||||
-rw-r--r-- | tests/run/cpp_template_functions.pyx | 9 | ||||
-rw-r--r-- | tests/run/cpp_template_functions_helper.h | 9 |
3 files changed, 33 insertions, 13 deletions
diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 29302110f..95b08435b 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -3498,18 +3498,24 @@ class CppClassType(CType): def deduce_template_params(self, actual): if self == actual: return {} - elif not hasattr(actual, 'template_type'): - # Untemplated type? - return None - # TODO(robertwb): Actual type equality. - elif (self.template_type or self).empty_declaration_code() == actual.template_type.empty_declaration_code(): - return reduce( - merge_template_deductions, - [formal_param.deduce_template_params(actual_param) - for (formal_param, actual_param) in zip(self.templates, actual.templates)], - {}) + elif actual.is_cpp_class: + self_template_type = self.template_type or self + def all_bases(cls): + yield cls + for parent in cls.base_classes: + for base in all_bases(parent): + yield base + for actual_base in all_bases(actual): + if (actual_base.template_type + and self_template_type.empty_declaration_code() + == actual_base.template_type.empty_declaration_code()): + return reduce( + merge_template_deductions, + [formal_param.deduce_template_params(actual_param) + for (formal_param, actual_param) in zip(self.templates, actual_base.templates)], + {}) else: - return None + return {} def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0, diff --git a/tests/run/cpp_template_functions.pyx b/tests/run/cpp_template_functions.pyx index 3ade34615..d66b3e013 100644 --- a/tests/run/cpp_template_functions.pyx +++ b/tests/run/cpp_template_functions.pyx @@ -9,9 +9,12 @@ cdef extern from "cpp_template_functions_helper.h": cdef cppclass A[T]: pair[T, U] method[U](T, U) U part_method[U](pair[T, U]) + U part_method_ref[U](pair[T, U]&) cdef T nested_deduction[T](const T*) pair[T, U] pair_arg[T, U](pair[T, U] a) cdef T* pointer_param[T](T*) + cdef cppclass double_pair(pair[double, double]): + double_pair(double, double) def test_no_arg(): """ @@ -48,13 +51,15 @@ def test_method(int x, int y): def test_part_method(int x, int y): """ >>> test_part_method(5, 10) - (10.0, 10) + (10.0, 10, 10.0) """ cdef A[int] a_int cdef pair[int, double] p_int = (x, y) cdef A[double] a_double cdef pair[double, int] p_double = (x, y) - return a_int.part_method(p_int), a_double.part_method(p_double) + return (a_int.part_method(p_int), + a_double.part_method(p_double), + a_double.part_method_ref(double_pair(x, y))) def test_simple_deduction(int x, double y): """ diff --git a/tests/run/cpp_template_functions_helper.h b/tests/run/cpp_template_functions_helper.h index e93e1ebfe..4fd9dbb36 100644 --- a/tests/run/cpp_template_functions_helper.h +++ b/tests/run/cpp_template_functions_helper.h @@ -24,6 +24,10 @@ class A { U part_method(std::pair<T, U> p) { return p.second; } + template <typename U> + U part_method_ref(const std::pair<T, U>& p) { + return p.second; + } }; template <typename T> @@ -40,3 +44,8 @@ template <typename T> T* pointer_param(T* param) { return param; } + +class double_pair : public std::pair<double, double> { + public: + double_pair(double x, double y) : std::pair<double, double>(x, y) { }; +}; |