summaryrefslogtreecommitdiff
path: root/numpy/core/defchararray.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-10-11 18:07:41 -0700
committerStephan Hoyer <shoyer@google.com>2018-10-11 18:07:41 -0700
commit39fb169f62b97c1ef1c162b1999e41bcc9ed0b4e (patch)
tree7fe37740d21585ff67a823952180c18b847e64a4 /numpy/core/defchararray.py
parent2ed08ba2ec29a8ef5fb60fa1a17a3d9366ae6c5d (diff)
downloadnumpy-39fb169f62b97c1ef1c162b1999e41bcc9ed0b4e.tar.gz
ENH: __array_function__ for np.core.defchararray
Diffstat (limited to 'numpy/core/defchararray.py')
-rw-r--r--numpy/core/defchararray.py149
1 files changed, 149 insertions, 0 deletions
diff --git a/numpy/core/defchararray.py b/numpy/core/defchararray.py
index 6d0a0add5..f58abfbd5 100644
--- a/numpy/core/defchararray.py
+++ b/numpy/core/defchararray.py
@@ -22,6 +22,7 @@ from .numerictypes import string_, unicode_, integer, object_, bool_, character
from .numeric import ndarray, compare_chararrays
from .numeric import array as narray
from numpy.core.multiarray import _vec_string
+from numpy.core.overrides import array_function_dispatch
from numpy.compat import asbytes, long
import numpy
@@ -95,6 +96,11 @@ def _get_num_chars(a):
return a.itemsize
+def _binary_op_dispatcher(x1, x2):
+ return (x1, x2)
+
+
+@array_function_dispatch(_binary_op_dispatcher)
def equal(x1, x2):
"""
Return (x1 == x2) element-wise.
@@ -119,6 +125,8 @@ def equal(x1, x2):
"""
return compare_chararrays(x1, x2, '==', True)
+
+@array_function_dispatch(_binary_op_dispatcher)
def not_equal(x1, x2):
"""
Return (x1 != x2) element-wise.
@@ -143,6 +151,8 @@ def not_equal(x1, x2):
"""
return compare_chararrays(x1, x2, '!=', True)
+
+@array_function_dispatch(_binary_op_dispatcher)
def greater_equal(x1, x2):
"""
Return (x1 >= x2) element-wise.
@@ -168,6 +178,8 @@ def greater_equal(x1, x2):
"""
return compare_chararrays(x1, x2, '>=', True)
+
+@array_function_dispatch(_binary_op_dispatcher)
def less_equal(x1, x2):
"""
Return (x1 <= x2) element-wise.
@@ -192,6 +204,8 @@ def less_equal(x1, x2):
"""
return compare_chararrays(x1, x2, '<=', True)
+
+@array_function_dispatch(_binary_op_dispatcher)
def greater(x1, x2):
"""
Return (x1 > x2) element-wise.
@@ -216,6 +230,8 @@ def greater(x1, x2):
"""
return compare_chararrays(x1, x2, '>', True)
+
+@array_function_dispatch(_binary_op_dispatcher)
def less(x1, x2):
"""
Return (x1 < x2) element-wise.
@@ -240,6 +256,12 @@ def less(x1, x2):
"""
return compare_chararrays(x1, x2, '<', True)
+
+def _unary_op_dispatcher(a):
+ return (a,)
+
+
+@array_function_dispatch(_unary_op_dispatcher)
def str_len(a):
"""
Return len(a) element-wise.
@@ -259,6 +281,8 @@ def str_len(a):
"""
return _vec_string(a, integer, '__len__')
+
+@array_function_dispatch(_binary_op_dispatcher)
def add(x1, x2):
"""
Return element-wise string concatenation for two arrays of str or unicode.
@@ -285,6 +309,12 @@ def add(x1, x2):
dtype = _use_unicode(arr1, arr2)
return _vec_string(arr1, (dtype, out_size), '__add__', (arr2,))
+
+def _multiply_dispathcer(a, i):
+ return (a,)
+
+
+@array_function_dispatch(_multiply_dispathcer)
def multiply(a, i):
"""
Return (a * i), that is string multiple concatenation,
@@ -313,6 +343,12 @@ def multiply(a, i):
return _vec_string(
a_arr, (a_arr.dtype.type, out_size), '__mul__', (i_arr,))
+
+def _mod_dispatcher(a, values):
+ return (a, values)
+
+
+@array_function_dispatch(_mod_dispatcher)
def mod(a, values):
"""
Return (a % i), that is pre-Python 2.6 string formatting
@@ -339,6 +375,8 @@ def mod(a, values):
return _to_string_or_unicode_array(
_vec_string(a, object_, '__mod__', (values,)))
+
+@array_function_dispatch(_unary_op_dispatcher)
def capitalize(a):
"""
Return a copy of `a` with only the first character of each element
@@ -377,6 +415,11 @@ def capitalize(a):
return _vec_string(a_arr, a_arr.dtype, 'capitalize')
+def _center_dispatcher(a, width, fillchar=None):
+ return (a,)
+
+
+@array_function_dispatch(_center_dispatcher)
def center(a, width, fillchar=' '):
"""
Return a copy of `a` with its elements centered in a string of
@@ -413,6 +456,11 @@ def center(a, width, fillchar=' '):
a_arr, (a_arr.dtype.type, size), 'center', (width_arr, fillchar))
+def _count_dispatcher(a, sub, start=None, end=None):
+ return (a,)
+
+
+@array_function_dispatch(_count_dispatcher)
def count(a, sub, start=0, end=None):
"""
Returns an array with the number of non-overlapping occurrences of
@@ -459,6 +507,11 @@ def count(a, sub, start=0, end=None):
return _vec_string(a, integer, 'count', [sub, start] + _clean_args(end))
+def _code_dispatcher(a, encoding=None, errors=None):
+ return (a,)
+
+
+@array_function_dispatch(_code_dispatcher)
def decode(a, encoding=None, errors=None):
"""
Calls `str.decode` element-wise.
@@ -505,6 +558,7 @@ def decode(a, encoding=None, errors=None):
_vec_string(a, object_, 'decode', _clean_args(encoding, errors)))
+@array_function_dispatch(_code_dispatcher)
def encode(a, encoding=None, errors=None):
"""
Calls `str.encode` element-wise.
@@ -540,6 +594,11 @@ def encode(a, encoding=None, errors=None):
_vec_string(a, object_, 'encode', _clean_args(encoding, errors)))
+def _endswith_dispatcher(a, suffix, start=None, end=None):
+ return (a,)
+
+
+@array_function_dispatch(_endswith_dispatcher)
def endswith(a, suffix, start=0, end=None):
"""
Returns a boolean array which is `True` where the string element
@@ -584,6 +643,11 @@ def endswith(a, suffix, start=0, end=None):
a, bool_, 'endswith', [suffix, start] + _clean_args(end))
+def _expandtabs_dispatcher(a, tabsize=None):
+ return (a,)
+
+
+@array_function_dispatch(_expandtabs_dispatcher)
def expandtabs(a, tabsize=8):
"""
Return a copy of each string element where all tab characters are
@@ -619,6 +683,7 @@ def expandtabs(a, tabsize=8):
_vec_string(a, object_, 'expandtabs', (tabsize,)))
+@array_function_dispatch(_count_dispatcher)
def find(a, sub, start=0, end=None):
"""
For each element, return the lowest index in the string where
@@ -654,6 +719,7 @@ def find(a, sub, start=0, end=None):
a, integer, 'find', [sub, start] + _clean_args(end))
+@array_function_dispatch(_count_dispatcher)
def index(a, sub, start=0, end=None):
"""
Like `find`, but raises `ValueError` when the substring is not found.
@@ -681,6 +747,8 @@ def index(a, sub, start=0, end=None):
return _vec_string(
a, integer, 'index', [sub, start] + _clean_args(end))
+
+@array_function_dispatch(_unary_op_dispatcher)
def isalnum(a):
"""
Returns true for each element if all characters in the string are
@@ -705,6 +773,8 @@ def isalnum(a):
"""
return _vec_string(a, bool_, 'isalnum')
+
+@array_function_dispatch(_unary_op_dispatcher)
def isalpha(a):
"""
Returns true for each element if all characters in the string are
@@ -729,6 +799,8 @@ def isalpha(a):
"""
return _vec_string(a, bool_, 'isalpha')
+
+@array_function_dispatch(_unary_op_dispatcher)
def isdigit(a):
"""
Returns true for each element if all characters in the string are
@@ -753,6 +825,8 @@ def isdigit(a):
"""
return _vec_string(a, bool_, 'isdigit')
+
+@array_function_dispatch(_unary_op_dispatcher)
def islower(a):
"""
Returns true for each element if all cased characters in the
@@ -778,6 +852,8 @@ def islower(a):
"""
return _vec_string(a, bool_, 'islower')
+
+@array_function_dispatch(_unary_op_dispatcher)
def isspace(a):
"""
Returns true for each element if there are only whitespace
@@ -803,6 +879,8 @@ def isspace(a):
"""
return _vec_string(a, bool_, 'isspace')
+
+@array_function_dispatch(_unary_op_dispatcher)
def istitle(a):
"""
Returns true for each element if the element is a titlecased
@@ -827,6 +905,8 @@ def istitle(a):
"""
return _vec_string(a, bool_, 'istitle')
+
+@array_function_dispatch(_unary_op_dispatcher)
def isupper(a):
"""
Returns true for each element if all cased characters in the
@@ -852,6 +932,12 @@ def isupper(a):
"""
return _vec_string(a, bool_, 'isupper')
+
+def _join_dispatcher(sep, seq):
+ return (sep, seq)
+
+
+@array_function_dispatch(_join_dispatcher)
def join(sep, seq):
"""
Return a string which is the concatenation of the strings in the
@@ -877,6 +963,12 @@ def join(sep, seq):
_vec_string(sep, object_, 'join', (seq,)))
+
+def _just_dispatcher(a, width, fillchar=None):
+ return (a,)
+
+
+@array_function_dispatch(_just_dispatcher)
def ljust(a, width, fillchar=' '):
"""
Return an array with the elements of `a` left-justified in a
@@ -912,6 +1004,7 @@ def ljust(a, width, fillchar=' '):
a_arr, (a_arr.dtype.type, size), 'ljust', (width_arr, fillchar))
+@array_function_dispatch(_unary_op_dispatcher)
def lower(a):
"""
Return an array with the elements converted to lowercase.
@@ -948,6 +1041,11 @@ def lower(a):
return _vec_string(a_arr, a_arr.dtype, 'lower')
+def _strip_dispatcher(a, chars=None):
+ return (a,)
+
+
+@array_function_dispatch(_strip_dispatcher)
def lstrip(a, chars=None):
"""
For each element in `a`, return a copy with the leading characters
@@ -1005,6 +1103,11 @@ def lstrip(a, chars=None):
return _vec_string(a_arr, a_arr.dtype, 'lstrip', (chars,))
+def _partition_dispatcher(a, sep):
+ return (a,)
+
+
+@array_function_dispatch(_partition_dispatcher)
def partition(a, sep):
"""
Partition each element in `a` around `sep`.
@@ -1040,6 +1143,11 @@ def partition(a, sep):
_vec_string(a, object_, 'partition', (sep,)))
+def _replace_dispatcher(a, old, new, count=None):
+ return (a,)
+
+
+@array_function_dispatch(_replace_dispatcher)
def replace(a, old, new, count=None):
"""
For each element in `a`, return a copy of the string with all
@@ -1072,6 +1180,7 @@ def replace(a, old, new, count=None):
a, object_, 'replace', [old, new] + _clean_args(count)))
+@array_function_dispatch(_count_dispatcher)
def rfind(a, sub, start=0, end=None):
"""
For each element in `a`, return the highest index in the string
@@ -1104,6 +1213,7 @@ def rfind(a, sub, start=0, end=None):
a, integer, 'rfind', [sub, start] + _clean_args(end))
+@array_function_dispatch(_count_dispatcher)
def rindex(a, sub, start=0, end=None):
"""
Like `rfind`, but raises `ValueError` when the substring `sub` is
@@ -1133,6 +1243,7 @@ def rindex(a, sub, start=0, end=None):
a, integer, 'rindex', [sub, start] + _clean_args(end))
+@array_function_dispatch(_just_dispatcher)
def rjust(a, width, fillchar=' '):
"""
Return an array with the elements of `a` right-justified in a
@@ -1168,6 +1279,7 @@ def rjust(a, width, fillchar=' '):
a_arr, (a_arr.dtype.type, size), 'rjust', (width_arr, fillchar))
+@array_function_dispatch(_partition_dispatcher)
def rpartition(a, sep):
"""
Partition (split) each element around the right-most separator.
@@ -1203,6 +1315,11 @@ def rpartition(a, sep):
_vec_string(a, object_, 'rpartition', (sep,)))
+def _split_dispatcher(a, sep=None, maxsplit=None):
+ return (a,)
+
+
+@array_function_dispatch(_split_dispatcher)
def rsplit(a, sep=None, maxsplit=None):
"""
For each element in `a`, return a list of the words in the
@@ -1240,6 +1357,11 @@ def rsplit(a, sep=None, maxsplit=None):
a, object_, 'rsplit', [sep] + _clean_args(maxsplit))
+def _strip_dispatcher(a, chars=None):
+ return (a,)
+
+
+@array_function_dispatch(_strip_dispatcher)
def rstrip(a, chars=None):
"""
For each element in `a`, return a copy with the trailing
@@ -1284,6 +1406,7 @@ def rstrip(a, chars=None):
return _vec_string(a_arr, a_arr.dtype, 'rstrip', (chars,))
+@array_function_dispatch(_split_dispatcher)
def split(a, sep=None, maxsplit=None):
"""
For each element in `a`, return a list of the words in the
@@ -1318,6 +1441,11 @@ def split(a, sep=None, maxsplit=None):
a, object_, 'split', [sep] + _clean_args(maxsplit))
+def _splitlines_dispatcher(a, keepends=None):
+ return (a,)
+
+
+@array_function_dispatch(_splitlines_dispatcher)
def splitlines(a, keepends=None):
"""
For each element in `a`, return a list of the lines in the
@@ -1347,6 +1475,11 @@ def splitlines(a, keepends=None):
a, object_, 'splitlines', _clean_args(keepends))
+def _startswith_dispatcher(a, prefix, start=None, end=None):
+ return (a,)
+
+
+@array_function_dispatch(_startswith_dispatcher)
def startswith(a, prefix, start=0, end=None):
"""
Returns a boolean array which is `True` where the string element
@@ -1378,6 +1511,7 @@ def startswith(a, prefix, start=0, end=None):
a, bool_, 'startswith', [prefix, start] + _clean_args(end))
+@array_function_dispatch(_strip_dispatcher)
def strip(a, chars=None):
"""
For each element in `a`, return a copy with the leading and
@@ -1426,6 +1560,7 @@ def strip(a, chars=None):
return _vec_string(a_arr, a_arr.dtype, 'strip', _clean_args(chars))
+@array_function_dispatch(_unary_op_dispatcher)
def swapcase(a):
"""
Return element-wise a copy of the string with
@@ -1463,6 +1598,7 @@ def swapcase(a):
return _vec_string(a_arr, a_arr.dtype, 'swapcase')
+@array_function_dispatch(_unary_op_dispatcher)
def title(a):
"""
Return element-wise title cased version of string or unicode.
@@ -1502,6 +1638,11 @@ def title(a):
return _vec_string(a_arr, a_arr.dtype, 'title')
+def _translate_dispatcher(a, table, deletechars=None):
+ return (a,)
+
+
+@array_function_dispatch(_translate_dispatcher)
def translate(a, table, deletechars=None):
"""
For each element in `a`, return a copy of the string where all
@@ -1538,6 +1679,7 @@ def translate(a, table, deletechars=None):
a_arr, a_arr.dtype, 'translate', [table] + _clean_args(deletechars))
+@array_function_dispatch(_unary_op_dispatcher)
def upper(a):
"""
Return an array with the elements converted to uppercase.
@@ -1574,6 +1716,11 @@ def upper(a):
return _vec_string(a_arr, a_arr.dtype, 'upper')
+def _zfill_dispatcher(a, width):
+ return (a,)
+
+
+@array_function_dispatch(_zfill_dispatcher)
def zfill(a, width):
"""
Return the numeric string left-filled with zeros
@@ -1604,6 +1751,7 @@ def zfill(a, width):
a_arr, (a_arr.dtype.type, size), 'zfill', (width_arr,))
+@array_function_dispatch(_unary_op_dispatcher)
def isnumeric(a):
"""
For each element, return True if there are only numeric
@@ -1635,6 +1783,7 @@ def isnumeric(a):
return _vec_string(a, bool_, 'isnumeric')
+@array_function_dispatch(_unary_op_dispatcher)
def isdecimal(a):
"""
For each element, return True if there are only decimal