summaryrefslogtreecommitdiff
path: root/numpy/f2py/tests
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-09-26 09:34:52 -0600
committerGitHub <noreply@github.com>2021-09-26 09:34:52 -0600
commit468142495461319f06a8debf57a3c0aa2703f9df (patch)
tree19cef9e83e86436b28a278d702b0433f0ec26161 /numpy/f2py/tests
parent26e656e7989e776a9292b9eba88cef7ec9fec5cb (diff)
parent88f988ae6226ffb5e3c9ae41a2d0cf9fd982109a (diff)
downloadnumpy-468142495461319f06a8debf57a3c0aa2703f9df.tar.gz
Merge pull request #19805 from pearu/gh-8062-dimdecs
ENH: Symbolic solver for dimension specifications.
Diffstat (limited to 'numpy/f2py/tests')
-rw-r--r--numpy/f2py/tests/test_crackfortran.py99
-rw-r--r--numpy/f2py/tests/test_symbolic.py462
2 files changed, 561 insertions, 0 deletions
diff --git a/numpy/f2py/tests/test_crackfortran.py b/numpy/f2py/tests/test_crackfortran.py
index 140f42cbc..da7974d1a 100644
--- a/numpy/f2py/tests/test_crackfortran.py
+++ b/numpy/f2py/tests/test_crackfortran.py
@@ -1,3 +1,4 @@
+import pytest
import numpy as np
from numpy.testing import assert_array_equal, assert_equal
from numpy.f2py.crackfortran import markinnerspaces
@@ -39,6 +40,7 @@ class TestNoSpace(util.F2PyTest):
class TestPublicPrivate():
+
def test_defaultPrivate(self, tmp_path):
f_path = tmp_path / "mod.f90"
with f_path.open('w') as ff:
@@ -165,3 +167,100 @@ class TestMarkinnerspaces():
def test_multiple_relevant_spaces(self):
assert_equal(markinnerspaces("a 'b c' 'd e'"), "a 'b@_@c' 'd@_@e'")
assert_equal(markinnerspaces(r'a "b c" "d e"'), r'a "b@_@c" "d@_@e"')
+
+
+class TestDimSpec(util.F2PyTest):
+ """This test site tests various expressions that are used as dimension
+ specifications.
+
+ There exists two usage cases where analyzing dimensions
+ specifications are important.
+
+ In the first case, the size of output arrays must be defined based
+ on the inputs to a Fortran function. Because Fortran supports
+ arbitrary bases for indexing, for instance, `arr(lower:upper)`,
+ f2py has to evaluate an expression `upper - lower + 1` where
+ `lower` and `upper` are arbitrary expressions of input parameters.
+ The evaluation is performed in C, so f2py has to translate Fortran
+ expressions to valid C expressions (an alternative approach is
+ that a developer specifies the corresponing C expressions in a
+ .pyf file).
+
+ In the second case, when user provides an input array with a given
+ size but some hidden parameters used in dimensions specifications
+ need to be determined based on the input array size. This is a
+ harder problem because f2py has to solve the inverse problem: find
+ a parameter `p` such that `upper(p) - lower(p) + 1` equals to the
+ size of input array. In the case when this equation cannot be
+ solved (e.g. because the input array size is wrong), raise an
+ error before calling the Fortran function (that otherwise would
+ likely crash Python process when the size of input arrays is
+ wrong). f2py currently supports this case only when the equation
+ is linear with respect to unknown parameter.
+
+ """
+
+ suffix = '.f90'
+
+ code_template = textwrap.dedent("""
+ function get_arr_size_{count}(a, n) result (length)
+ integer, intent(in) :: n
+ integer, dimension({dimspec}), intent(out) :: a
+ integer length
+ length = size(a)
+ end function
+
+ subroutine get_inv_arr_size_{count}(a, n)
+ integer :: n
+ ! the value of n is computed in f2py wrapper
+ !f2py intent(out) n
+ integer, dimension({dimspec}), intent(in) :: a
+ if (a({first}).gt.0) then
+ print*, "a=", a
+ endif
+ end subroutine
+ """)
+
+ linear_dimspecs = ['n', '2*n', '2:n', 'n/2', '5 - n/2', '3*n:20',
+ 'n*(n+1):n*(n+5)']
+ nonlinear_dimspecs = ['2*n:3*n*n+2*n']
+ all_dimspecs = linear_dimspecs + nonlinear_dimspecs
+
+ code = ''
+ for count, dimspec in enumerate(all_dimspecs):
+ code += code_template.format(
+ count=count, dimspec=dimspec,
+ first=dimspec.split(':')[0] if ':' in dimspec else '1')
+
+ @pytest.mark.parametrize('dimspec', all_dimspecs)
+ def test_array_size(self, dimspec):
+
+ count = self.all_dimspecs.index(dimspec)
+ get_arr_size = getattr(self.module, f'get_arr_size_{count}')
+
+ for n in [1, 2, 3, 4, 5]:
+ sz, a = get_arr_size(n)
+ assert len(a) == sz
+
+ @pytest.mark.parametrize('dimspec', all_dimspecs)
+ def test_inv_array_size(self, dimspec):
+
+ count = self.all_dimspecs.index(dimspec)
+ get_arr_size = getattr(self.module, f'get_arr_size_{count}')
+ get_inv_arr_size = getattr(self.module, f'get_inv_arr_size_{count}')
+
+ for n in [1, 2, 3, 4, 5]:
+ sz, a = get_arr_size(n)
+ if dimspec in self.nonlinear_dimspecs:
+ # one must specify n as input, the call we'll ensure
+ # that a and n are compatible:
+ n1 = get_inv_arr_size(a, n)
+ else:
+ # in case of linear dependence, n can be determined
+ # from the shape of a:
+ n1 = get_inv_arr_size(a)
+ # n1 may be different from n (for instance, when `a` size
+ # is a function of some `n` fraction) but it must produce
+ # the same sized array
+ sz1, _ = get_arr_size(n1)
+ assert sz == sz1, (n, n1, sz, sz1)
diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py
new file mode 100644
index 000000000..52cabac53
--- /dev/null
+++ b/numpy/f2py/tests/test_symbolic.py
@@ -0,0 +1,462 @@
+from numpy.testing import assert_raises
+from numpy.f2py.symbolic import (
+ Expr, Op, ArithOp, Language,
+ as_symbol, as_number, as_string, as_array, as_complex,
+ as_terms, as_factors, eliminate_quotes, insert_quotes,
+ fromstring, as_expr, as_apply,
+ as_numer_denom, as_ternary, as_ref, as_deref,
+ normalize, as_eq, as_ne, as_lt, as_gt, as_le, as_ge
+ )
+from . import util
+
+
+class TestSymbolic(util.F2PyTest):
+
+ def test_eliminate_quotes(self):
+ def worker(s):
+ r, d = eliminate_quotes(s)
+ s1 = insert_quotes(r, d)
+ assert s1 == s
+
+ for kind in ['', 'mykind_']:
+ worker(kind + '"1234" // "ABCD"')
+ worker(kind + '"1234" // ' + kind + '"ABCD"')
+ worker(kind + '"1234" // \'ABCD\'')
+ worker(kind + '"1234" // ' + kind + '\'ABCD\'')
+ worker(kind + '"1\\"2\'AB\'34"')
+ worker('a = ' + kind + "'1\\'2\"AB\"34'")
+
+ def test_sanity(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+
+ assert x.op == Op.SYMBOL
+ assert repr(x) == "Expr(Op.SYMBOL, 'x')"
+ assert x == x
+ assert x != y
+ assert hash(x) is not None
+
+ n = as_number(123)
+ m = as_number(456)
+ assert n.op == Op.INTEGER
+ assert repr(n) == "Expr(Op.INTEGER, (123, 4))"
+ assert n == n
+ assert n != m
+ assert hash(n) is not None
+
+ fn = as_number(12.3)
+ fm = as_number(45.6)
+ assert fn.op == Op.REAL
+ assert repr(fn) == "Expr(Op.REAL, (12.3, 4))"
+ assert fn == fn
+ assert fn != fm
+ assert hash(fn) is not None
+
+ c = as_complex(1, 2)
+ c2 = as_complex(3, 4)
+ assert c.op == Op.COMPLEX
+ assert repr(c) == ("Expr(Op.COMPLEX, (Expr(Op.INTEGER, (1, 4)),"
+ " Expr(Op.INTEGER, (2, 4))))")
+ assert c == c
+ assert c != c2
+ assert hash(c) is not None
+
+ s = as_string("'123'")
+ s2 = as_string('"ABC"')
+ assert s.op == Op.STRING
+ assert repr(s) == "Expr(Op.STRING, (\"'123'\", 1))", repr(s)
+ assert s == s
+ assert s != s2
+
+ a = as_array((n, m))
+ b = as_array((n,))
+ assert a.op == Op.ARRAY
+ assert repr(a) == ("Expr(Op.ARRAY, (Expr(Op.INTEGER, (123, 4)),"
+ " Expr(Op.INTEGER, (456, 4))))")
+ assert a == a
+ assert a != b
+
+ t = as_terms(x)
+ u = as_terms(y)
+ assert t.op == Op.TERMS
+ assert repr(t) == "Expr(Op.TERMS, {Expr(Op.SYMBOL, 'x'): 1})"
+ assert t == t
+ assert t != u
+ assert hash(t) is not None
+
+ v = as_factors(x)
+ w = as_factors(y)
+ assert v.op == Op.FACTORS
+ assert repr(v) == "Expr(Op.FACTORS, {Expr(Op.SYMBOL, 'x'): 1})"
+ assert v == v
+ assert w != v
+ assert hash(v) is not None
+
+ t = as_ternary(x, y, z)
+ u = as_ternary(x, z, y)
+ assert t.op == Op.TERNARY
+ assert t == t
+ assert t != u
+ assert hash(t) is not None
+
+ e = as_eq(x, y)
+ f = as_lt(x, y)
+ assert e.op == Op.RELATIONAL
+ assert e == e
+ assert e != f
+ assert hash(e) is not None
+
+ def test_tostring_fortran(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+ n = as_number(123)
+ m = as_number(456)
+ a = as_array((n, m))
+ c = as_complex(n, m)
+
+ assert str(x) == 'x'
+ assert str(n) == '123'
+ assert str(a) == '[123, 456]'
+ assert str(c) == '(123, 456)'
+
+ assert str(Expr(Op.TERMS, {x: 1})) == 'x'
+ assert str(Expr(Op.TERMS, {x: 2})) == '2 * x'
+ assert str(Expr(Op.TERMS, {x: -1})) == '-x'
+ assert str(Expr(Op.TERMS, {x: -2})) == '-2 * x'
+ assert str(Expr(Op.TERMS, {x: 1, y: 1})) == 'x + y'
+ assert str(Expr(Op.TERMS, {x: -1, y: -1})) == '-x - y'
+ assert str(Expr(Op.TERMS, {x: 2, y: 3})) == '2 * x + 3 * y'
+ assert str(Expr(Op.TERMS, {x: -2, y: 3})) == '-2 * x + 3 * y'
+ assert str(Expr(Op.TERMS, {x: 2, y: -3})) == '2 * x - 3 * y'
+
+ assert str(Expr(Op.FACTORS, {x: 1})) == 'x'
+ assert str(Expr(Op.FACTORS, {x: 2})) == 'x ** 2'
+ assert str(Expr(Op.FACTORS, {x: -1})) == 'x ** -1'
+ assert str(Expr(Op.FACTORS, {x: -2})) == 'x ** -2'
+ assert str(Expr(Op.FACTORS, {x: 1, y: 1})) == 'x * y'
+ assert str(Expr(Op.FACTORS, {x: 2, y: 3})) == 'x ** 2 * y ** 3'
+
+ v = Expr(Op.FACTORS, {x: 2, Expr(Op.TERMS, {x: 1, y: 1}): 3})
+ assert str(v) == 'x ** 2 * (x + y) ** 3', str(v)
+ v = Expr(Op.FACTORS, {x: 2, Expr(Op.FACTORS, {x: 1, y: 1}): 3})
+ assert str(v) == 'x ** 2 * (x * y) ** 3', str(v)
+
+ assert str(Expr(Op.APPLY, ('f', (), {}))) == 'f()'
+ assert str(Expr(Op.APPLY, ('f', (x,), {}))) == 'f(x)'
+ assert str(Expr(Op.APPLY, ('f', (x, y), {}))) == 'f(x, y)'
+ assert str(Expr(Op.INDEXING, ('f', x))) == 'f[x]'
+
+ assert str(as_ternary(x, y, z)) == 'merge(y, z, x)'
+ assert str(as_eq(x, y)) == 'x .eq. y'
+ assert str(as_ne(x, y)) == 'x .ne. y'
+ assert str(as_lt(x, y)) == 'x .lt. y'
+ assert str(as_le(x, y)) == 'x .le. y'
+ assert str(as_gt(x, y)) == 'x .gt. y'
+ assert str(as_ge(x, y)) == 'x .ge. y'
+
+ def test_tostring_c(self):
+ language = Language.C
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+ n = as_number(123)
+
+ assert Expr(Op.FACTORS, {x: 2}).tostring(language=language) == 'x * x'
+ assert Expr(Op.FACTORS, {x + y: 2}).tostring(
+ language=language) == '(x + y) * (x + y)'
+ assert Expr(Op.FACTORS, {x: 12}).tostring(
+ language=language) == 'pow(x, 12)'
+
+ assert as_apply(ArithOp.DIV, x, y).tostring(
+ language=language) == 'x / y'
+ assert as_apply(ArithOp.DIV, x, x + y).tostring(
+ language=language) == 'x / (x + y)'
+ assert as_apply(ArithOp.DIV, x - y, x + y).tostring(
+ language=language) == '(x - y) / (x + y)'
+ assert (x + (x - y) / (x + y) + n).tostring(
+ language=language) == '123 + x + (x - y) / (x + y)'
+
+ assert as_ternary(x, y, z).tostring(language=language) == '(x ? y : z)'
+ assert as_eq(x, y).tostring(language=language) == 'x == y'
+ assert as_ne(x, y).tostring(language=language) == 'x != y'
+ assert as_lt(x, y).tostring(language=language) == 'x < y'
+ assert as_le(x, y).tostring(language=language) == 'x <= y'
+ assert as_gt(x, y).tostring(language=language) == 'x > y'
+ assert as_ge(x, y).tostring(language=language) == 'x >= y'
+
+ def test_operations(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+
+ assert x + x == Expr(Op.TERMS, {x: 2})
+ assert x - x == Expr(Op.INTEGER, (0, 4))
+ assert x + y == Expr(Op.TERMS, {x: 1, y: 1})
+ assert x - y == Expr(Op.TERMS, {x: 1, y: -1})
+ assert x * x == Expr(Op.FACTORS, {x: 2})
+ assert x * y == Expr(Op.FACTORS, {x: 1, y: 1})
+
+ assert +x == x
+ assert -x == Expr(Op.TERMS, {x: -1}), repr(-x)
+ assert 2 * x == Expr(Op.TERMS, {x: 2})
+ assert 2 + x == Expr(Op.TERMS, {x: 1, as_number(1): 2})
+ assert 2 * x + 3 * y == Expr(Op.TERMS, {x: 2, y: 3})
+ assert (x + y) * 2 == Expr(Op.TERMS, {x: 2, y: 2})
+
+ assert x ** 2 == Expr(Op.FACTORS, {x: 2})
+ assert (x + y) ** 2 == Expr(Op.TERMS,
+ {Expr(Op.FACTORS, {x: 2}): 1,
+ Expr(Op.FACTORS, {y: 2}): 1,
+ Expr(Op.FACTORS, {x: 1, y: 1}): 2})
+ assert (x + y) * x == x ** 2 + x * y
+ assert (x + y) ** 2 == x ** 2 + 2 * x * y + y ** 2
+ assert (x + y) ** 2 + (x - y) ** 2 == 2 * x ** 2 + 2 * y ** 2
+ assert (x + y) * z == x * z + y * z
+ assert z * (x + y) == x * z + y * z
+
+ assert (x / 2) == as_apply(ArithOp.DIV, x, as_number(2))
+ assert (2 * x / 2) == x
+ assert (3 * x / 2) == as_apply(ArithOp.DIV, 3*x, as_number(2))
+ assert (4 * x / 2) == 2 * x
+ assert (5 * x / 2) == as_apply(ArithOp.DIV, 5*x, as_number(2))
+ assert (6 * x / 2) == 3 * x
+ assert ((3*5) * x / 6) == as_apply(ArithOp.DIV, 5*x, as_number(2))
+ assert (30*x**2*y**4 / (24*x**3*y**3)) == as_apply(ArithOp.DIV,
+ 5*y, 4*x)
+ assert ((15 * x / 6) / 5) == as_apply(
+ ArithOp.DIV, x, as_number(2)), ((15 * x / 6) / 5)
+ assert (x / (5 / x)) == as_apply(ArithOp.DIV, x**2, as_number(5))
+
+ assert (x / 2.0) == Expr(Op.TERMS, {x: 0.5})
+
+ s = as_string('"ABC"')
+ t = as_string('"123"')
+
+ assert s // t == Expr(Op.STRING, ('"ABC123"', 1))
+ assert s // x == Expr(Op.CONCAT, (s, x))
+ assert x // s == Expr(Op.CONCAT, (x, s))
+
+ c = as_complex(1., 2.)
+ assert -c == as_complex(-1., -2.)
+ assert c + c == as_expr((1+2j)*2)
+ assert c * c == as_expr((1+2j)**2)
+
+ def test_substitute(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+ a = as_array((x, y))
+
+ assert x.substitute({x: y}) == y
+ assert (x + y).substitute({x: z}) == y + z
+ assert (x * y).substitute({x: z}) == y * z
+ assert (x ** 4).substitute({x: z}) == z ** 4
+ assert (x / y).substitute({x: z}) == z / y
+ assert x.substitute({x: y + z}) == y + z
+ assert a.substitute({x: y + z}) == as_array((y + z, y))
+
+ assert as_ternary(x, y, z).substitute(
+ {x: y + z}) == as_ternary(y + z, y, z)
+ assert as_eq(x, y).substitute(
+ {x: y + z}) == as_eq(y + z, y)
+
+ def test_fromstring(self):
+
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+ f = as_symbol('f')
+ s = as_string('"ABC"')
+ t = as_string('"123"')
+ a = as_array((x, y))
+
+ assert fromstring('x') == x
+ assert fromstring('+ x') == x
+ assert fromstring('- x') == -x
+ assert fromstring('x + y') == x + y
+ assert fromstring('x + 1') == x + 1
+ assert fromstring('x * y') == x * y
+ assert fromstring('x * 2') == x * 2
+ assert fromstring('x / y') == x / y
+ assert fromstring('x ** 2',
+ language=Language.Python) == x ** 2
+ assert fromstring('x ** 2 ** 3',
+ language=Language.Python) == x ** 2 ** 3
+ assert fromstring('(x + y) * z') == (x + y) * z
+
+ assert fromstring('f(x)') == f(x)
+ assert fromstring('f(x,y)') == f(x, y)
+ assert fromstring('f[x]') == f[x]
+ assert fromstring('f[x][y]') == f[x][y]
+
+ assert fromstring('"ABC"') == s
+ assert normalize(fromstring('"ABC" // "123" ',
+ language=Language.Fortran)) == s // t
+ assert fromstring('f("ABC")') == f(s)
+ assert fromstring('MYSTRKIND_"ABC"') == as_string('"ABC"', 'MYSTRKIND')
+
+ assert fromstring('(/x, y/)') == a, fromstring('(/x, y/)')
+ assert fromstring('f((/x, y/))') == f(a)
+ assert fromstring('(/(x+y)*z/)') == as_array(((x+y)*z,))
+
+ assert fromstring('123') == as_number(123)
+ assert fromstring('123_2') == as_number(123, 2)
+ assert fromstring('123_myintkind') == as_number(123, 'myintkind')
+
+ assert fromstring('123.0') == as_number(123.0, 4)
+ assert fromstring('123.0_4') == as_number(123.0, 4)
+ assert fromstring('123.0_8') == as_number(123.0, 8)
+ assert fromstring('123.0e0') == as_number(123.0, 4)
+ assert fromstring('123.0d0') == as_number(123.0, 8)
+ assert fromstring('123d0') == as_number(123.0, 8)
+ assert fromstring('123e-0') == as_number(123.0, 4)
+ assert fromstring('123d+0') == as_number(123.0, 8)
+ assert fromstring('123.0_myrealkind') == as_number(123.0, 'myrealkind')
+ assert fromstring('3E4') == as_number(30000.0, 4)
+
+ assert fromstring('(1, 2)') == as_complex(1, 2)
+ assert fromstring('(1e2, PI)') == as_complex(
+ as_number(100.0), as_symbol('PI'))
+
+ assert fromstring('[1, 2]') == as_array((as_number(1), as_number(2)))
+
+ assert fromstring('POINT(x, y=1)') == as_apply(
+ as_symbol('POINT'), x, y=as_number(1))
+ assert (fromstring('PERSON(name="John", age=50, shape=(/34, 23/))')
+ == as_apply(as_symbol('PERSON'),
+ name=as_string('"John"'),
+ age=as_number(50),
+ shape=as_array((as_number(34), as_number(23)))))
+
+ assert fromstring('x?y:z') == as_ternary(x, y, z)
+
+ assert fromstring('*x') == as_deref(x)
+ assert fromstring('**x') == as_deref(as_deref(x))
+ assert fromstring('&x') == as_ref(x)
+ assert fromstring('(*x) * (*y)') == as_deref(x) * as_deref(y)
+ assert fromstring('(*x) * *y') == as_deref(x) * as_deref(y)
+ assert fromstring('*x * *y') == as_deref(x) * as_deref(y)
+ assert fromstring('*x**y') == as_deref(x) * as_deref(y)
+
+ assert fromstring('x == y') == as_eq(x, y)
+ assert fromstring('x != y') == as_ne(x, y)
+ assert fromstring('x < y') == as_lt(x, y)
+ assert fromstring('x > y') == as_gt(x, y)
+ assert fromstring('x <= y') == as_le(x, y)
+ assert fromstring('x >= y') == as_ge(x, y)
+
+ assert fromstring('x .eq. y', language=Language.Fortran) == as_eq(x, y)
+ assert fromstring('x .ne. y', language=Language.Fortran) == as_ne(x, y)
+ assert fromstring('x .lt. y', language=Language.Fortran) == as_lt(x, y)
+ assert fromstring('x .gt. y', language=Language.Fortran) == as_gt(x, y)
+ assert fromstring('x .le. y', language=Language.Fortran) == as_le(x, y)
+ assert fromstring('x .ge. y', language=Language.Fortran) == as_ge(x, y)
+
+ def test_traverse(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+ f = as_symbol('f')
+
+ # Use traverse to substitute a symbol
+ def replace_visit(s, r=z):
+ if s == x:
+ return r
+
+ assert x.traverse(replace_visit) == z
+ assert y.traverse(replace_visit) == y
+ assert z.traverse(replace_visit) == z
+ assert (f(y)).traverse(replace_visit) == f(y)
+ assert (f(x)).traverse(replace_visit) == f(z)
+ assert (f[y]).traverse(replace_visit) == f[y]
+ assert (f[z]).traverse(replace_visit) == f[z]
+ assert (x + y + z).traverse(replace_visit) == (2 * z + y)
+ assert (x + f(y, x - z)).traverse(
+ replace_visit) == (z + f(y, as_number(0)))
+ assert as_eq(x, y).traverse(replace_visit) == as_eq(z, y)
+
+ # Use traverse to collect symbols, method 1
+ function_symbols = set()
+ symbols = set()
+
+ def collect_symbols(s):
+ if s.op is Op.APPLY:
+ oper = s.data[0]
+ function_symbols.add(oper)
+ if oper in symbols:
+ symbols.remove(oper)
+ elif s.op is Op.SYMBOL and s not in function_symbols:
+ symbols.add(s)
+
+ (x + f(y, x - z)).traverse(collect_symbols)
+ assert function_symbols == {f}
+ assert symbols == {x, y, z}
+
+ # Use traverse to collect symbols, method 2
+ def collect_symbols2(expr, symbols):
+ if expr.op is Op.SYMBOL:
+ symbols.add(expr)
+
+ symbols = set()
+ (x + f(y, x - z)).traverse(collect_symbols2, symbols)
+ assert symbols == {x, y, z, f}
+
+ # Use traverse to partially collect symbols
+ def collect_symbols3(expr, symbols):
+ if expr.op is Op.APPLY:
+ # skip traversing function calls
+ return expr
+ if expr.op is Op.SYMBOL:
+ symbols.add(expr)
+
+ symbols = set()
+ (x + f(y, x - z)).traverse(collect_symbols3, symbols)
+ assert symbols == {x}
+
+ def test_linear_solve(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ z = as_symbol('z')
+
+ assert x.linear_solve(x) == (as_number(1), as_number(0))
+ assert (x+1).linear_solve(x) == (as_number(1), as_number(1))
+ assert (2*x).linear_solve(x) == (as_number(2), as_number(0))
+ assert (2*x+3).linear_solve(x) == (as_number(2), as_number(3))
+ assert as_number(3).linear_solve(x) == (as_number(0), as_number(3))
+ assert y.linear_solve(x) == (as_number(0), y)
+ assert (y*z).linear_solve(x) == (as_number(0), y * z)
+
+ assert (x+y).linear_solve(x) == (as_number(1), y)
+ assert (z*x+y).linear_solve(x) == (z, y)
+ assert ((z+y)*x+y).linear_solve(x) == (z + y, y)
+ assert (z*y*x+y).linear_solve(x) == (z * y, y)
+
+ assert_raises(RuntimeError, lambda: (x*x).linear_solve(x))
+
+ def test_as_numer_denom(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ n = as_number(123)
+
+ assert as_numer_denom(x) == (x, as_number(1))
+ assert as_numer_denom(x / n) == (x, n)
+ assert as_numer_denom(n / x) == (n, x)
+ assert as_numer_denom(x / y) == (x, y)
+ assert as_numer_denom(x * y) == (x * y, as_number(1))
+ assert as_numer_denom(n + x / y) == (x + n * y, y)
+ assert as_numer_denom(n + x / (y - x / n)) == (y * n ** 2, y * n - x)
+
+ def test_polynomial_atoms(self):
+ x = as_symbol('x')
+ y = as_symbol('y')
+ n = as_number(123)
+
+ assert x.polynomial_atoms() == {x}
+ assert n.polynomial_atoms() == set()
+ assert (y[x]).polynomial_atoms() == {y[x]}
+ assert (y(x)).polynomial_atoms() == {y(x)}
+ assert (y(x) + x).polynomial_atoms() == {y(x), x}
+ assert (y(x) * x[y]).polynomial_atoms() == {y(x), x[y]}
+ assert (y(x) ** x).polynomial_atoms() == {y(x)}