summaryrefslogtreecommitdiff
path: root/numpy/core/code_generators/genapi.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/code_generators/genapi.py')
-rw-r--r--numpy/core/code_generators/genapi.py90
1 files changed, 64 insertions, 26 deletions
diff --git a/numpy/core/code_generators/genapi.py b/numpy/core/code_generators/genapi.py
index 68ae30d5b..2cdaba52d 100644
--- a/numpy/core/code_generators/genapi.py
+++ b/numpy/core/code_generators/genapi.py
@@ -6,17 +6,35 @@ See ``find_function`` for how functions should be formatted, and
specified.
"""
-from numpy.distutils.conv_template import process_file as process_c_file
-
import hashlib
import io
import os
import re
import sys
+import importlib.util
import textwrap
from os.path import join
+
+def get_processor():
+ # Convoluted because we can't import from numpy.distutils
+ # (numpy is not yet built)
+ conv_template_path = os.path.join(
+ os.path.dirname(__file__),
+ '..', '..', 'distutils', 'conv_template.py'
+ )
+ spec = importlib.util.spec_from_file_location(
+ 'conv_template', conv_template_path
+ )
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ return mod.process_file
+
+
+process_c_file = get_processor()
+
+
__docformat__ = 'restructuredtext'
# The files under src/ that are scanned for API functions
@@ -81,6 +99,27 @@ def _repl(str):
return str.replace('Bool', 'npy_bool')
+class MinVersion:
+ def __init__(self, version):
+ """ Version should be the normal NumPy version, e.g. "1.25" """
+ major, minor = version.split(".")
+ self.version = f"NPY_{major}_{minor}_API_VERSION"
+
+ def __str__(self):
+ # Used by version hashing:
+ return self.version
+
+ def add_guard(self, name, normal_define):
+ """Wrap a definition behind a version guard"""
+ wrap = textwrap.dedent(f"""
+ #if NPY_FEATURE_VERSION >= {self.version}
+ {{define}}
+ #endif""")
+
+ # we only insert `define` later to avoid confusing dedent:
+ return wrap.format(define=normal_define)
+
+
class StealRef:
def __init__(self, arg):
self.arg = arg # counting from 1
@@ -113,21 +152,6 @@ class Function:
doccomment = ''
return '%s%s %s(%s)' % (doccomment, self.return_type, self.name, argstr)
- def to_ReST(self):
- lines = ['::', '', ' ' + self.return_type]
- argstr = ',\000'.join([self._format_arg(*a) for a in self.args])
- name = ' %s' % (self.name,)
- s = textwrap.wrap('(%s)' % (argstr,), width=72,
- initial_indent=name,
- subsequent_indent=' ' * (len(name)+1),
- break_long_words=False)
- for l in s:
- lines.append(l.replace('\000', ' ').rstrip())
- lines.append('')
- if self.doc:
- lines.append(textwrap.dedent(self.doc))
- return '\n'.join(lines)
-
def api_hash(self):
m = hashlib.md5()
m.update(remove_whitespace(self.return_type))
@@ -389,7 +413,20 @@ class FunctionApi:
def __init__(self, name, index, annotations, return_type, args, api_name):
self.name = name
self.index = index
- self.annotations = annotations
+
+ self.min_version = None
+ self.annotations = []
+ for annotation in annotations:
+ # String checks, because manual import breaks isinstance
+ if type(annotation).__name__ == "StealRef":
+ self.annotations.append(annotation)
+ elif type(annotation).__name__ == "MinVersion":
+ if self.min_version is not None:
+ raise ValueError("Two minimum versions specified!")
+ self.min_version = annotation
+ else:
+ raise ValueError(f"unknown annotation {annotation}")
+
self.return_type = return_type
self.args = args
self.api_name = api_name
@@ -401,13 +438,14 @@ class FunctionApi:
return argstr
def define_from_array_api_string(self):
- define = """\
-#define %s \\\n (*(%s (*)(%s)) \\
- %s[%d])""" % (self.name,
- self.return_type,
- self._argtypes_string(),
- self.api_name,
- self.index)
+ arguments = self._argtypes_string()
+ define = textwrap.dedent(f"""\
+ #define {self.name} \\
+ (*({self.return_type} (*)({arguments})) \\
+ {self.api_name}[{self.index}])""")
+
+ if self.min_version is not None:
+ define = self.min_version.add_guard(self.name, define)
return define
def array_api_define(self):
@@ -498,7 +536,7 @@ def get_versions_hash():
d = []
file = os.path.join(os.path.dirname(__file__), 'cversions.txt')
- with open(file, 'r') as fid:
+ with open(file) as fid:
for line in fid:
m = VERRE.match(line)
if m: