summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Bradshaw <robertwb@gmail.com>2016-12-06 03:06:27 -0800
committerRobert Bradshaw <robertwb@gmail.com>2016-12-06 03:06:27 -0800
commita531a31c64f76cf5031f04b79b64f2cae66a116b (patch)
tree2bfdf9604503dad5195d8e77485b2e53da0ac10c
parent2a3603379307c489e743854caacf81b20124d854 (diff)
downloadcython-a531a31c64f76cf5031f04b79b64f2cae66a116b.tar.gz
Walk cpp class hierarchy for class template deduction.
-rw-r--r--Cython/Compiler/PyrexTypes.py28
-rw-r--r--tests/run/cpp_template_functions.pyx9
-rw-r--r--tests/run/cpp_template_functions_helper.h9
3 files changed, 33 insertions, 13 deletions
diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py
index 29302110f..95b08435b 100644
--- a/Cython/Compiler/PyrexTypes.py
+++ b/Cython/Compiler/PyrexTypes.py
@@ -3498,18 +3498,24 @@ class CppClassType(CType):
def deduce_template_params(self, actual):
if self == actual:
return {}
- elif not hasattr(actual, 'template_type'):
- # Untemplated type?
- return None
- # TODO(robertwb): Actual type equality.
- elif (self.template_type or self).empty_declaration_code() == actual.template_type.empty_declaration_code():
- return reduce(
- merge_template_deductions,
- [formal_param.deduce_template_params(actual_param)
- for (formal_param, actual_param) in zip(self.templates, actual.templates)],
- {})
+ elif actual.is_cpp_class:
+ self_template_type = self.template_type or self
+ def all_bases(cls):
+ yield cls
+ for parent in cls.base_classes:
+ for base in all_bases(parent):
+ yield base
+ for actual_base in all_bases(actual):
+ if (actual_base.template_type
+ and self_template_type.empty_declaration_code()
+ == actual_base.template_type.empty_declaration_code()):
+ return reduce(
+ merge_template_deductions,
+ [formal_param.deduce_template_params(actual_param)
+ for (formal_param, actual_param) in zip(self.templates, actual_base.templates)],
+ {})
else:
- return None
+ return {}
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0,
diff --git a/tests/run/cpp_template_functions.pyx b/tests/run/cpp_template_functions.pyx
index 3ade34615..d66b3e013 100644
--- a/tests/run/cpp_template_functions.pyx
+++ b/tests/run/cpp_template_functions.pyx
@@ -9,9 +9,12 @@ cdef extern from "cpp_template_functions_helper.h":
cdef cppclass A[T]:
pair[T, U] method[U](T, U)
U part_method[U](pair[T, U])
+ U part_method_ref[U](pair[T, U]&)
cdef T nested_deduction[T](const T*)
pair[T, U] pair_arg[T, U](pair[T, U] a)
cdef T* pointer_param[T](T*)
+ cdef cppclass double_pair(pair[double, double]):
+ double_pair(double, double)
def test_no_arg():
"""
@@ -48,13 +51,15 @@ def test_method(int x, int y):
def test_part_method(int x, int y):
"""
>>> test_part_method(5, 10)
- (10.0, 10)
+ (10.0, 10, 10.0)
"""
cdef A[int] a_int
cdef pair[int, double] p_int = (x, y)
cdef A[double] a_double
cdef pair[double, int] p_double = (x, y)
- return a_int.part_method(p_int), a_double.part_method(p_double)
+ return (a_int.part_method(p_int),
+ a_double.part_method(p_double),
+ a_double.part_method_ref(double_pair(x, y)))
def test_simple_deduction(int x, double y):
"""
diff --git a/tests/run/cpp_template_functions_helper.h b/tests/run/cpp_template_functions_helper.h
index e93e1ebfe..4fd9dbb36 100644
--- a/tests/run/cpp_template_functions_helper.h
+++ b/tests/run/cpp_template_functions_helper.h
@@ -24,6 +24,10 @@ class A {
U part_method(std::pair<T, U> p) {
return p.second;
}
+ template <typename U>
+ U part_method_ref(const std::pair<T, U>& p) {
+ return p.second;
+ }
};
template <typename T>
@@ -40,3 +44,8 @@ template <typename T>
T* pointer_param(T* param) {
return param;
}
+
+class double_pair : public std::pair<double, double> {
+ public:
+ double_pair(double x, double y) : std::pair<double, double>(x, y) { };
+};