diff options
Diffstat (limited to 'numpy/core/code_generators/genapi.py')
-rw-r--r-- | numpy/core/code_generators/genapi.py | 90 |
1 files changed, 64 insertions, 26 deletions
diff --git a/numpy/core/code_generators/genapi.py b/numpy/core/code_generators/genapi.py index 68ae30d5b..2cdaba52d 100644 --- a/numpy/core/code_generators/genapi.py +++ b/numpy/core/code_generators/genapi.py @@ -6,17 +6,35 @@ See ``find_function`` for how functions should be formatted, and specified. """ -from numpy.distutils.conv_template import process_file as process_c_file - import hashlib import io import os import re import sys +import importlib.util import textwrap from os.path import join + +def get_processor(): + # Convoluted because we can't import from numpy.distutils + # (numpy is not yet built) + conv_template_path = os.path.join( + os.path.dirname(__file__), + '..', '..', 'distutils', 'conv_template.py' + ) + spec = importlib.util.spec_from_file_location( + 'conv_template', conv_template_path + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.process_file + + +process_c_file = get_processor() + + __docformat__ = 'restructuredtext' # The files under src/ that are scanned for API functions @@ -81,6 +99,27 @@ def _repl(str): return str.replace('Bool', 'npy_bool') +class MinVersion: + def __init__(self, version): + """ Version should be the normal NumPy version, e.g. "1.25" """ + major, minor = version.split(".") + self.version = f"NPY_{major}_{minor}_API_VERSION" + + def __str__(self): + # Used by version hashing: + return self.version + + def add_guard(self, name, normal_define): + """Wrap a definition behind a version guard""" + wrap = textwrap.dedent(f""" + #if NPY_FEATURE_VERSION >= {self.version} + {{define}} + #endif""") + + # we only insert `define` later to avoid confusing dedent: + return wrap.format(define=normal_define) + + class StealRef: def __init__(self, arg): self.arg = arg # counting from 1 @@ -113,21 +152,6 @@ class Function: doccomment = '' return '%s%s %s(%s)' % (doccomment, self.return_type, self.name, argstr) - def to_ReST(self): - lines = ['::', '', ' ' + self.return_type] - argstr = ',\000'.join([self._format_arg(*a) for a in self.args]) - name = ' %s' % (self.name,) - s = textwrap.wrap('(%s)' % (argstr,), width=72, - initial_indent=name, - subsequent_indent=' ' * (len(name)+1), - break_long_words=False) - for l in s: - lines.append(l.replace('\000', ' ').rstrip()) - lines.append('') - if self.doc: - lines.append(textwrap.dedent(self.doc)) - return '\n'.join(lines) - def api_hash(self): m = hashlib.md5() m.update(remove_whitespace(self.return_type)) @@ -389,7 +413,20 @@ class FunctionApi: def __init__(self, name, index, annotations, return_type, args, api_name): self.name = name self.index = index - self.annotations = annotations + + self.min_version = None + self.annotations = [] + for annotation in annotations: + # String checks, because manual import breaks isinstance + if type(annotation).__name__ == "StealRef": + self.annotations.append(annotation) + elif type(annotation).__name__ == "MinVersion": + if self.min_version is not None: + raise ValueError("Two minimum versions specified!") + self.min_version = annotation + else: + raise ValueError(f"unknown annotation {annotation}") + self.return_type = return_type self.args = args self.api_name = api_name @@ -401,13 +438,14 @@ class FunctionApi: return argstr def define_from_array_api_string(self): - define = """\ -#define %s \\\n (*(%s (*)(%s)) \\ - %s[%d])""" % (self.name, - self.return_type, - self._argtypes_string(), - self.api_name, - self.index) + arguments = self._argtypes_string() + define = textwrap.dedent(f"""\ + #define {self.name} \\ + (*({self.return_type} (*)({arguments})) \\ + {self.api_name}[{self.index}])""") + + if self.min_version is not None: + define = self.min_version.add_guard(self.name, define) return define def array_api_define(self): @@ -498,7 +536,7 @@ def get_versions_hash(): d = [] file = os.path.join(os.path.dirname(__file__), 'cversions.txt') - with open(file, 'r') as fid: + with open(file) as fid: for line in fid: m = VERRE.match(line) if m: |