summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorArash Taheri-Dezfouli <ataheridezfouli@groq.com>2023-05-11 14:29:16 -0500
committermax <maksim.levental@gmail.com>2023-05-11 16:20:47 -0500
commitf22008ed89eac028cd70f91de3adf41a481f6d22 (patch)
treef277becf263a740d6c3058f0d802ba196dceb291 /mlir
parent62c4c614eea8078918d04cb33ce54ef8f9987766 (diff)
downloadllvm-f22008ed89eac028cd70f91de3adf41a481f6d22.tar.gz
[MLIR] Add InferShapedTypeOpInterface bindings
Add C and python bindings for InferShapedTypeOpInterface and ShapedTypeComponents. This allows users to invoke InferShapedTypeOpInterface for ops that implement it. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D149494
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir-c/Interfaces.h27
-rw-r--r--mlir/lib/Bindings/Python/IRInterfaces.cpp305
-rw-r--r--mlir/lib/CAPI/Interfaces/Interfaces.cpp127
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/ir.pyi24
-rw-r--r--mlir/test/python/dialects/python_test.py53
-rw-r--r--mlir/test/python/python_test_ops.td27
6 files changed, 483 insertions, 80 deletions
diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
index 405e2bb7173e..a5a3473eaef5 100644
--- a/mlir/include/mlir-c/Interfaces.h
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -60,6 +60,33 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
void *properties, intptr_t nRegions, MlirRegion *regions,
MlirTypesCallback callback, void *userData);
+//===----------------------------------------------------------------------===//
+// InferShapedTypeOpInterface.
+//===----------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the InferShapedTypeOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
+
+/// These callbacks are used to return multiple shaped type components from
+/// functions while transferring ownership to the caller. The first argument is
+/// the has rank boolean followed by the the rank and a pointer to the shape
+/// (if applicable). The next argument is the element type, then the attribute.
+/// The last argument is an opaque pointer forwarded to the callback by the
+/// caller. This callback will be called potentially multiple times for each
+/// shaped type components.
+typedef void (*MlirShapedTypeComponentsCallback)(bool, intptr_t,
+ const int64_t *, MlirType,
+ MlirAttribute, void *);
+
+/// Infers the return shaped type components of the operation. Calls `callback`
+/// with the types of inferred arguments on success. Returns failure otherwise.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirInferShapedTypeOpInterfaceInferReturnTypes(
+ MlirStringRef opName, MlirContext context, MlirLocation location,
+ intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
+ void *properties, intptr_t nRegions, MlirRegion *regions,
+ MlirShapedTypeComponentsCallback callback, void *userData);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 766d6f3e4793..0a7a25c0005f 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include <utility>
#include <optional>
+#include <utility>
#include "IRModule.h"
#include "mlir-c/BuiltinAttributes.h"
@@ -35,6 +35,83 @@ constexpr static const char *inferReturnTypesDoc =
R"(Given the arguments required to build an operation, attempts to infer
its return types. Raises ValueError on failure.)";
+constexpr static const char *inferReturnTypeComponentsDoc =
+ R"(Given the arguments required to build an operation, attempts to infer
+its return shaped type components. Raises ValueError on failure.)";
+
+namespace {
+
+/// Takes in an optional ist of operands and converts them into a SmallVector
+/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
+llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
+ llvm::SmallVector<MlirValue> mlirOperands;
+
+ if (!operandList || operandList->empty()) {
+ return mlirOperands;
+ }
+
+ // Note: as the list may contain other lists this may not be final size.
+ mlirOperands.reserve(operandList->size());
+ for (const auto &&it : llvm::enumerate(*operandList)) {
+ PyValue *val;
+ try {
+ val = py::cast<PyValue *>(it.value());
+ if (!val)
+ throw py::cast_error();
+ mlirOperands.push_back(val->get());
+ continue;
+ } catch (py::cast_error &err) {
+ // Intentionally unhandled to try sequence below first.
+ (void)err;
+ }
+
+ try {
+ auto vals = py::cast<py::sequence>(it.value());
+ for (py::object v : vals) {
+ try {
+ val = py::cast<PyValue *>(v);
+ if (!val)
+ throw py::cast_error();
+ mlirOperands.push_back(val->get());
+ } catch (py::cast_error &err) {
+ throw py::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " must be a Value or Sequence of Values (" + err.what() + ")")
+ .str());
+ }
+ }
+ continue;
+ } catch (py::cast_error &err) {
+ throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " must be a Value or Sequence of Values (" +
+ err.what() + ")")
+ .str());
+ }
+
+ throw py::cast_error();
+ }
+
+ return mlirOperands;
+}
+
+/// Takes in an optional vector of PyRegions and returns a SmallVector of
+/// MlirRegion. Returns an empty SmallVector if the list is empty.
+llvm::SmallVector<MlirRegion>
+wrapRegions(std::optional<std::vector<PyRegion>> regions) {
+ llvm::SmallVector<MlirRegion> mlirRegions;
+
+ if (regions) {
+ mlirRegions.reserve(regions->size());
+ for (PyRegion &region : *regions) {
+ mlirRegions.push_back(region);
+ }
+ }
+
+ return mlirRegions;
+}
+
+} // namespace
+
/// CRTP base class for Python classes representing MLIR Op interfaces.
/// Interface hierarchies are flat so no base class is expected here. The
/// derived class is expected to define the following static fields:
@@ -104,7 +181,7 @@ public:
/// Creates the Python bindings for this class in the given module.
static void bind(py::module &m) {
- py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
+ py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
py::module_local());
cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
py::arg("context") = py::none(), constructorDoc)
@@ -155,7 +232,7 @@ private:
py::object obj;
};
-/// Python wrapper for InterTypeOpInterface. This interface has only static
+/// Python wrapper for InferTypeOpInterface. This interface has only static
/// methods.
class PyInferTypeOpInterface
: public PyConcreteOpInterface<PyInferTypeOpInterface> {
@@ -191,59 +268,8 @@ public:
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
DefaultingPyLocation location) {
- llvm::SmallVector<MlirValue> mlirOperands;
- llvm::SmallVector<MlirRegion> mlirRegions;
-
- if (operandList && !operandList->empty()) {
- // Note: as the list may contain other lists this may not be final size.
- mlirOperands.reserve(operandList->size());
- for (const auto& it : llvm::enumerate(*operandList)) {
- PyValue* val;
- try {
- val = py::cast<PyValue *>(it.value());
- if (!val)
- throw py::cast_error();
- mlirOperands.push_back(val->get());
- continue;
- } catch (py::cast_error &err) {
- // Intentionally unhandled to try sequence below first.
- (void)err;
- }
-
- try {
- auto vals = py::cast<py::sequence>(it.value());
- for (py::object v : vals) {
- try {
- val = py::cast<PyValue *>(v);
- if (!val)
- throw py::cast_error();
- mlirOperands.push_back(val->get());
- } catch (py::cast_error &err) {
- throw py::value_error(
- (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
- " must be a Value or Sequence of Values (" + err.what() +
- ")")
- .str());
- }
- }
- continue;
- } catch (py::cast_error &err) {
- throw py::value_error(
- (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
- " must be a Value or Sequence of Values (" + err.what() + ")")
- .str());
- }
-
- throw py::cast_error();
- }
- }
-
- if (regions) {
- mlirRegions.reserve(regions->size());
- for (PyRegion &region : *regions) {
- mlirRegions.push_back(region);
- }
- }
+ llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
+ llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
std::vector<PyType> inferredTypes;
PyMlirContext &pyContext = context.resolve();
@@ -275,7 +301,172 @@ public:
}
};
-void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
+/// Wrapper around an shaped type components.
+class PyShapedTypeComponents {
+public:
+ PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
+ PyShapedTypeComponents(py::list shape, MlirType elementType)
+ : shape(shape), elementType(elementType), ranked(true) {}
+ PyShapedTypeComponents(py::list shape, MlirType elementType,
+ MlirAttribute attribute)
+ : shape(shape), elementType(elementType), attribute(attribute),
+ ranked(true) {}
+ PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
+ PyShapedTypeComponents(PyShapedTypeComponents &&other)
+ : shape(other.shape), elementType(other.elementType),
+ attribute(other.attribute), ranked(other.ranked) {}
+
+ static void bind(py::module &m) {
+ py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
+ py::module_local())
+ .def_property_readonly(
+ "element_type",
+ [](PyShapedTypeComponents &self) {
+ return PyType(PyMlirContext::forContext(
+ mlirTypeGetContext(self.elementType)),
+ self.elementType);
+ },
+ "Returns the element type of the shaped type components.")
+ .def_static(
+ "get",
+ [](PyType &elementType) {
+ return PyShapedTypeComponents(elementType);
+ },
+ py::arg("element_type"),
+ "Create an shaped type components object with only the element "
+ "type.")
+ .def_static(
+ "get",
+ [](py::list shape, PyType &elementType) {
+ return PyShapedTypeComponents(shape, elementType);
+ },
+ py::arg("shape"), py::arg("element_type"),
+ "Create a ranked shaped type components object.")
+ .def_static(
+ "get",
+ [](py::list shape, PyType &elementType, PyAttribute &attribute) {
+ return PyShapedTypeComponents(shape, elementType, attribute);
+ },
+ py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
+ "Create a ranked shaped type components object with attribute.")
+ .def_property_readonly(
+ "has_rank",
+ [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
+ "Returns whether the given shaped type component is ranked.")
+ .def_property_readonly(
+ "rank",
+ [](PyShapedTypeComponents &self) -> py::object {
+ if (!self.ranked) {
+ return py::none();
+ }
+ return py::int_(self.shape.size());
+ },
+ "Returns the rank of the given ranked shaped type components. If "
+ "the shaped type components does not have a rank, None is "
+ "returned.")
+ .def_property_readonly(
+ "shape",
+ [](PyShapedTypeComponents &self) -> py::object {
+ if (!self.ranked) {
+ return py::none();
+ }
+ return py::list(self.shape);
+ },
+ "Returns the shape of the ranked shaped type components as a list "
+ "of integers. Returns none if the shaped type component does not "
+ "have a rank.");
+ }
+
+ pybind11::object getCapsule();
+ static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
+
+private:
+ py::list shape;
+ MlirType elementType;
+ MlirAttribute attribute;
+ bool ranked{false};
+};
+
+/// Python wrapper for InferShapedTypeOpInterface. This interface has only
+/// static methods.
+class PyInferShapedTypeOpInterface
+ : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
+public:
+ using PyConcreteOpInterface<
+ PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirInferShapedTypeOpInterfaceTypeID;
+
+ /// C-style user-data structure for type appending callback.
+ struct AppendResultsCallbackData {
+ std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
+ };
+
+ /// Appends the shaped type components provided as unpacked shape, element
+ /// type, attribute to the user-data.
+ static void appendResultsCallback(bool hasRank, intptr_t rank,
+ const int64_t *shape, MlirType elementType,
+ MlirAttribute attribute, void *userData) {
+ auto *data = static_cast<AppendResultsCallbackData *>(userData);
+ if (!hasRank) {
+ data->inferredShapedTypeComponents.emplace_back(elementType);
+ } else {
+ py::list shapeList;
+ for (intptr_t i = 0; i < rank; ++i) {
+ shapeList.append(shape[i]);
+ }
+ data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
+ attribute);
+ }
+ }
+
+ /// Given the arguments required to build an operation, attempts to infer the
+ /// shaped type components. Throws value_error on failure.
+ std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
+ std::optional<py::list> operandList,
+ std::optional<PyAttribute> attributes, void *properties,
+ std::optional<std::vector<PyRegion>> regions,
+ DefaultingPyMlirContext context, DefaultingPyLocation location) {
+ llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
+ llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
+
+ std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
+ PyMlirContext &pyContext = context.resolve();
+ AppendResultsCallbackData data{inferredShapedTypeComponents};
+ MlirStringRef opNameRef =
+ mlirStringRefCreate(getOpName().data(), getOpName().length());
+ MlirAttribute attributeDict =
+ attributes ? attributes->get() : mlirAttributeGetNull();
+
+ MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
+ opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
+ mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
+ mlirRegions.data(), &appendResultsCallback, &data);
+
+ if (mlirLogicalResultIsFailure(result)) {
+ throw py::value_error("Failed to infer result shape type components");
+ }
+
+ return inferredShapedTypeComponents;
+ }
+
+ static void bindDerived(ClassTy &cls) {
+ cls.def("inferReturnTypeComponents",
+ &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
+ py::arg("operands") = py::none(),
+ py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
+ py::arg("properties") = py::none(), py::arg("context") = py::none(),
+ py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
+ }
+};
+
+void populateIRInterfaces(py::module &m) {
+ PyInferTypeOpInterface::bind(m);
+ PyShapedTypeComponents::bind(m);
+ PyInferShapedTypeOpInterface::bind(m);
+}
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
index 029feed3a359..e597a7bcb4f2 100644
--- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -11,14 +11,65 @@
#include "mlir-c/Interfaces.h"
#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Interfaces.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/ScopeExit.h"
#include <optional>
using namespace mlir;
+namespace {
+
+std::optional<RegisteredOperationName>
+getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
+ StringRef name(opName.data, opName.length);
+ std::optional<RegisteredOperationName> info =
+ RegisteredOperationName::lookup(name, unwrap(context));
+ return info;
+}
+
+std::optional<Location> maybeGetLocation(MlirLocation location) {
+ std::optional<Location> maybeLocation;
+ if (!mlirLocationIsNull(location))
+ maybeLocation = unwrap(location);
+ return maybeLocation;
+}
+
+SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
+ SmallVector<Value> unwrappedOperands;
+ (void)unwrapList(nOperands, operands, unwrappedOperands);
+ return unwrappedOperands;
+}
+
+DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
+ DictionaryAttr attributeDict;
+ if (!mlirAttributeIsNull(attributes))
+ attributeDict = unwrap(attributes).cast<DictionaryAttr>();
+ return attributeDict;
+}
+
+SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
+ MlirRegion *regions) {
+ // Create a vector of unique pointers to regions and make sure they are not
+ // deleted when exiting the scope. This is a hack caused by C++ API expecting
+ // an list of unique pointers to regions (without ownership transfer
+ // semantics) and C API making ownership transfer explicit.
+ SmallVector<std::unique_ptr<Region>> unwrappedRegions;
+ unwrappedRegions.reserve(nRegions);
+ for (intptr_t i = 0; i < nRegions; ++i)
+ unwrappedRegions.emplace_back(unwrap(*(regions + i)));
+ auto cleaner = llvm::make_scope_exit([&]() {
+ for (auto &region : unwrappedRegions)
+ region.release();
+ });
+ return unwrappedRegions;
+}
+
+} // namespace
+
bool mlirOperationImplementsInterface(MlirOperation operation,
MlirTypeID interfaceTypeID) {
std::optional<RegisteredOperationName> info =
@@ -45,31 +96,15 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
MlirTypesCallback callback, void *userData) {
StringRef name(opName.data, opName.length);
std::optional<RegisteredOperationName> info =
- RegisteredOperationName::lookup(name, unwrap(context));
+ getRegisteredOperationName(context, opName);
if (!info)
return mlirLogicalResultFailure();
- std::optional<Location> maybeLocation;
- if (!mlirLocationIsNull(location))
- maybeLocation = unwrap(location);
- SmallVector<Value> unwrappedOperands;
- (void)unwrapList(nOperands, operands, unwrappedOperands);
- DictionaryAttr attributeDict;
- if (!mlirAttributeIsNull(attributes))
- attributeDict = unwrap(attributes).cast<DictionaryAttr>();
-
- // Create a vector of unique pointers to regions and make sure they are not
- // deleted when exiting the scope. This is a hack caused by C++ API expecting
- // an list of unique pointers to regions (without ownership transfer
- // semantics) and C API making ownership transfer explicit.
- SmallVector<std::unique_ptr<Region>> unwrappedRegions;
- unwrappedRegions.reserve(nRegions);
- for (intptr_t i = 0; i < nRegions; ++i)
- unwrappedRegions.emplace_back(unwrap(*(regions + i)));
- auto cleaner = llvm::make_scope_exit([&]() {
- for (auto &region : unwrappedRegions)
- region.release();
- });
+ std::optional<Location> maybeLocation = maybeGetLocation(location);
+ SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
+ DictionaryAttr attributeDict = unwrapAttributes(attributes);
+ SmallVector<std::unique_ptr<Region>> unwrappedRegions =
+ unwrapRegions(nRegions, regions);
SmallVector<Type> inferredTypes;
if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
@@ -84,3 +119,51 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
return mlirLogicalResultSuccess();
}
+
+MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
+ return wrap(InferShapedTypeOpInterface::getInterfaceID());
+}
+
+MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
+ MlirStringRef opName, MlirContext context, MlirLocation location,
+ intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
+ void *properties, intptr_t nRegions, MlirRegion *regions,
+ MlirShapedTypeComponentsCallback callback, void *userData) {
+ std::optional<RegisteredOperationName> info =
+ getRegisteredOperationName(context, opName);
+ if (!info)
+ return mlirLogicalResultFailure();
+
+ std::optional<Location> maybeLocation = maybeGetLocation(location);
+ SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
+ DictionaryAttr attributeDict = unwrapAttributes(attributes);
+ SmallVector<std::unique_ptr<Region>> unwrappedRegions =
+ unwrapRegions(nRegions, regions);
+
+ SmallVector<ShapedTypeComponents> inferredTypeComponents;
+ if (failed(info->getInterface<InferShapedTypeOpInterface>()
+ ->inferReturnTypeComponents(
+ unwrap(context), maybeLocation,
+ mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
+ attributeDict, properties, unwrappedRegions,
+ inferredTypeComponents)))
+ return mlirLogicalResultFailure();
+
+ bool hasRank;
+ intptr_t rank;
+ const int64_t *shapeData;
+ for (ShapedTypeComponents t : inferredTypeComponents) {
+ if (t.hasRank()) {
+ hasRank = true;
+ rank = t.getDims().size();
+ shapeData = t.getDims().data();
+ } else {
+ hasRank = false;
+ rank = 0;
+ shapeData = nullptr;
+ }
+ callback(hasRank, rank, shapeData, wrap(t.getElementType()),
+ wrap(t.getAttribute()), userData);
+ }
+ return mlirLogicalResultSuccess();
+}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 75b25bd8c1c9..714935fe12e2 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -62,6 +62,7 @@ __all__ = [
"FloatAttr",
"FunctionType",
"IndexType",
+ "InferShapedTypeOpInterface",
"InferTypeOpInterface",
"InsertionPoint",
"IntegerAttr",
@@ -88,6 +89,7 @@ __all__ = [
"RegionIterator",
"RegionSequence",
"ShapedType",
+ "ShapedTypeComponents",
"StringAttr",
"SymbolTable",
"TupleType",
@@ -689,9 +691,17 @@ class IndexType(Type):
@staticmethod
def isinstance(arg: Any) -> bool: ...
+class InferShapedTypeOpInterface:
+ def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
+ def inferReturnTypeComponents(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[ShapedTypeComponents]: ...
+ @property
+ def operation(self) -> Operation: ...
+ @property
+ def opview(self) -> OpView: ...
+
class InferTypeOpInterface:
def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
- def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
+ def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
@property
def operation(self) -> Operation: ...
@property
@@ -1016,6 +1026,18 @@ class ShapedType(Type):
@property
def shape(self) -> List[int]: ...
+class ShapedTypeComponents:
+ @property
+ def element_type(self) -> Type: ...
+ @staticmethod
+ def get(*args, **kwargs) -> ShapedTypeComponents: ...
+ @property
+ def has_rank(self) -> bool: ...
+ @property
+ def rank(self) -> int: ...
+ @property
+ def shape(self) -> List[int]: ...
+
# TODO: Auto-generated. Audit and fix.
class StringAttr(Attribute):
def __init__(self, cast_from_attr: Attribute) -> None: ...
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index d826540bec1d..8280e5ec73a7 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,6 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
+import mlir.dialects.func as func
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
@@ -330,3 +331,55 @@ def testTensorValue():
# CHECK: False
print(tt.is_null())
+
+
+# CHECK-LABEL: TEST: inferReturnTypeComponents
+@run
+def inferReturnTypeComponents():
+ with Context() as ctx, Location.unknown(ctx):
+ test.register_python_test_dialect(ctx)
+ module = Module.create()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+ resultType = UnrankedTensorType.get(i32)
+ operandTypes = [
+ RankedTensorType.get([1, 3, 10, 10], i32),
+ UnrankedTensorType.get(i32),
+ ]
+ f = func.FuncOp(
+ "test_inferReturnTypeComponents", (operandTypes, [resultType])
+ )
+ entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
+ with InsertionPoint(entry_block):
+ ranked_op = test.InferShapedTypeComponentsOp(
+ resultType, entry_block.arguments[0]
+ )
+ unranked_op = test.InferShapedTypeComponentsOp(
+ resultType, entry_block.arguments[1]
+ )
+
+ # CHECK: has rank: True
+ # CHECK: rank: 4
+ # CHECK: element type: i32
+ # CHECK: shape: [1, 3, 10, 10]
+ iface = InferShapedTypeOpInterface(ranked_op)
+ shaped_type_components = iface.inferReturnTypeComponents(
+ operands=[ranked_op.operand]
+ )[0]
+ print("has rank:", shaped_type_components.has_rank)
+ print("rank:", shaped_type_components.rank)
+ print("element type:", shaped_type_components.element_type)
+ print("shape:", shaped_type_components.shape)
+
+ # CHECK: has rank: False
+ # CHECK: rank: None
+ # CHECK: element type: i32
+ # CHECK: shape: None
+ iface = InferShapedTypeOpInterface(unranked_op)
+ shaped_type_components = iface.inferReturnTypeComponents(
+ operands=[unranked_op.operand]
+ )[0]
+ print("has rank:", shaped_type_components.has_rank)
+ print("rank:", shaped_type_components.rank)
+ print("element type:", shaped_type_components.element_type)
+ print("shape:", shaped_type_components.shape)
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 692fcb938961..e1a03c6ee217 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -90,6 +90,33 @@ def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {
let results = (outs I32:$integer, F64:$flt, Index:$index);
}
+def InferShapedTypeComponentsOp : TestOp<"infer_shaped_type_components_op",
+ [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>]> {
+ let arguments = (ins AnyTensor:$operand);
+ let results = (outs AnyTensor:$result);
+
+ let extraClassDefinition = [{
+ ::mlir::LogicalResult $cppClass::inferReturnTypeComponents(
+ ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
+ ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<
+ ::mlir::ShapedTypeComponents>& inferredShapedTypeComponents) {
+ $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
+ auto operandType =
+ adaptor.getOperand().getType().cast<::mlir::ShapedType>();
+ if (operandType.hasRank()) {
+ inferredShapedTypeComponents.emplace_back(operandType.getShape(),
+ operandType.getElementType());
+ } else {
+ inferredShapedTypeComponents.emplace_back(operandType.getElementType());
+ }
+ return ::mlir::success();
+ }
+ }];
+}
+
def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op",
[SameOperandsAndResultType]> {
let arguments = (ins Variadic<AnyType>);