summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorDiego Caballero <diegocaballero@google.com>2022-11-24 02:16:46 +0000
committerDiego Caballero <diegocaballero@google.com>2022-12-13 01:33:06 +0000
commit72fd36448d7c97c9dab094d6deda401d97baf0ef (patch)
tree809c57441d260dab1e09864b55a0d1897b82b902 /mlir
parent6893b151906ae9f4112197a979ed94bbab08a32e (diff)
downloadllvm-72fd36448d7c97c9dab094d6deda401d97baf0ef.tar.gz
[mlir][Vector] Initial masking support in Linalg vectorizer
This patch introduces the initial bits to support vector masking using the `vector.mask` operation. Vectorization changes should be NFC for non-masked cases. We can't test masked cases directly until we extend the Transform dialect to support masking. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D137690
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td30
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td41
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h16
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td10
-rw-r--r--mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td18
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/Passes.h6
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp83
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp551
-rw-r--r--mlir/lib/Dialect/Vector/IR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp32
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp18
-rw-r--r--mlir/test/Dialect/Linalg/vectorization.mlir148
12 files changed, 820 insertions, 135 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index e5006ae8c4ae..a04c48f5fa4b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -591,6 +591,36 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return result;
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given a dimension of the iteration space of a Linalg operation, finds an
+ operand in the operation that is defined on such dimension. Returns
+ whether such operand was found or not. If found, also returns the
+ operand value and the dimension position within the operand.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"mapIterationSpaceDimToOperandDim",
+ /*args=*/(ins "unsigned":$dimPos,
+ "::mlir::Value &":$operand,
+ "unsigned &":$operandDimPos),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ // Retrieve the operand and its dimension position from the first
+ // operand with a permutation map that is defined on such dimension.
+ for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) {
+ if (idxMap.isProjectedPermutation()) {
+ if (auto mayOperandDim = idxMap.getResultPosition(
+ getAffineDimExpr(dimPos, idxMap.getContext()))) {
+ operand = $_op->getOperand(i);
+ operandDimPos = *mayOperandDim;
+ return success();
+ }
+ }
+ }
+
+ return failure();
+ }]
+ >,
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9fe6536f23f6..9c80f56332ce 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1115,4 +1115,45 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];
}
+def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface]> {
+ let description = [{
+ Vectorize the target ops, which must be Linalg ops, with masked vectors
+ of the specified size.
+
+ The vector sizes can be either static or dynamic (SSA values). In case of
+ SSA values, the handle must be mapped to exactly one payload op with
+ exactly one index-typed result.
+
+ #### Return modes:
+
+ This operation produces a definite failure if the dynamic vector sizes (SSA
+ values) do not satify the constraints mentioned above. It produces a
+ silenceable failure if at least one target op is not a Linalg op or fails to
+ vectorize.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ Variadic<PDL_Operation>:$vector_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+ $static_vector_sizes);
+ let results = (outs);
+ let assemblyFormat = [{
+ $target
+ `vector_sizes` custom<DynamicIndexList>($vector_sizes,
+ $static_vector_sizes)
+ attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ // TODO: applyToOne.
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedVectorSizes();
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7d6a58431979..7b3e0726effc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -344,8 +344,14 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
const LinalgPromotionOptions &options);
-/// Emit a suitable vector form for a Linalg op with fully static shape.
-LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp,
+/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
+/// are used to vectorize this operation. `inputVectorSizes` must match the rank
+/// of the iteration space of the operation and the sizes must be smaller or
+/// equal than their counterpart interation space sizes, if static.
+/// `inputVectorShapes` also allows the vectorization of operations with dynamic
+/// shapes.
+LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+ ArrayRef<int64_t> inputVectorSizes = {},
bool vectorizeNDExtract = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
@@ -372,8 +378,10 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
/// Return success if the operation can be vectorized.
-LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
- bool vectorizeNDExtract = false);
+LogicalResult
+vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+ ArrayRef<int64_t> inputVectorSizes = {},
+ bool vectorizeNDExtract = false);
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f035ab3dbc04..176d70942a08 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -450,13 +450,13 @@ def Vector_BroadcastOp :
/// source tensor and thus correspond to "dim-1" broadcasting.
llvm::SetVector<int64_t> computeBroadcastedUnitDims();
- /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
+ /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
- /// This requires (and asserts) that the broadcast is free of dim-1
+ /// This requires (and asserts) that the broadcast is free of dim-1
/// broadcasting.
/// Since vector.broadcast only allows expanding leading dimensions, an extra
/// vector.transpose may be inserted to make the broadcast possible.
- /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
+ /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
/// the helper will assert. This means:
/// 1. `dstShape` must not be empty.
/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
@@ -1179,6 +1179,8 @@ def Vector_ExtractStridedSliceOp :
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
}
+// TODO: Tighten semantics so that masks and inbounds can't be used
+// simultaneously within the same transfer op.
def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
@@ -1394,6 +1396,8 @@ def Vector_TransferReadOp :
let hasVerifier = 1;
}
+// TODO: Tighten semantics so that masks and inbounds can't be used
+// simultaneously within the same transfer op.
def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td
index bbde7bc33bf0..184ca68d24b6 100644
--- a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td
@@ -31,7 +31,9 @@ def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return mlir::isa<mlir::vector::MaskingOpInterface>($_op->getParentOp());
+ mlir::Operation *parentOp = $_op->getParentOp();
+ return parentOp &&
+ mlir::isa<mlir::vector::MaskingOpInterface>(parentOp);
}]>,
InterfaceMethod<
/*desc=*/"Returns the MaskingOpInterface masking this operation.",
@@ -54,18 +56,14 @@ def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
return false;
}]>,
InterfaceMethod<
- /*desc=*/"Returns the mask type expected by this operation. It requires "
- "the operation to be vectorized.",
- /*retTy=*/"mlir::VectorType",
+ /*desc=*/"Returns the mask type expected by this operation. Mostly used"
+ " for verification purposes. It requires the operation to be "
+ "vectorized.",
+ /*retTy=*/"mlir::Type",
/*methodName=*/"getExpectedMaskType",
/*args=*/(ins),
/*methodBody=*/"",
- /*defaultImplementation=*/[{
- // Default implementation is only aimed for operations that implement the
- // `getVectorType()` method.
- return $_op.getVectorType().cloneWith(/*shape=*/std::nullopt,
- IntegerType::get($_op.getContext(), /*width=*/1));
- }]>,
+ /*defaultImplementation=*/"">,
];
}
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index bf89b01e2b60..d0c06f69930d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -22,6 +22,12 @@ std::unique_ptr<Pass> createVectorBufferizePass();
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();
+/// Populates instances of `MaskOpRewritePattern` to lower masked operations
+/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
+/// not its nested `MaskableOpInterface`.
+void populateVectorMaskLoweringPatternsForSideEffectingOps(
+ RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dc6baddadf43..79738d1737a1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1825,7 +1826,8 @@ struct VectorizationPattern : public RewritePattern {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
- return vectorize(rewriter, linalgOp, vectorizeNDExtract);
+ return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
+ vectorizeNDExtract);
}
private:
@@ -1874,6 +1876,85 @@ transform::VectorizeOp::applyToOne(Operation *target,
}
//===----------------------------------------------------------------------===//
+// MaskedVectorizeOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
+ mlir::transform::TransformResults &transformResults,
+ mlir::transform::TransformState &state) {
+ IRRewriter rewriter(getContext());
+ ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+ if (targets.empty())
+ return DiagnosedSilenceableFailure::success();
+
+ SmallVector<int64_t> vectorSizes;
+ for (OpFoldResult sz : getMixedVectorSizes()) {
+ if (sz.is<Attribute>()) {
+ auto attr = sz.get<Attribute>();
+ vectorSizes.push_back(attr.cast<IntegerAttr>().getInt());
+ continue;
+ }
+
+ ArrayRef<Operation *> szPayloads = state.getPayloadOps(sz.get<Value>());
+ if (szPayloads.size() != 1) {
+ auto diag = this->emitOpError(
+ "requires vector size handle that is mapped to 1 payload op");
+ diag.attachNote(sz.get<Value>().getLoc())
+ << "mapped to " << szPayloads.size() << " payload ops";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ Operation *szPayloadOp = szPayloads[0];
+ if (szPayloadOp->getNumResults() != 1 ||
+ !szPayloadOp->getResult(0).getType().isIndex()) {
+ auto diag = this->emitOpError(
+ "requires vector size payload op with 1 index result");
+ diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ IntegerAttr attr;
+ if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
+ auto diag = this->emitOpError("requires constant vector size");
+ diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ vectorSizes.push_back(attr.getInt());
+ }
+
+ // TODO: Check that the correct number of vectorSizes was provided.
+
+ for (Operation *target : targets) {
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+ if (!linalgOp) {
+ Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+ diag << "cannot vectorize non-Linalg op";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+
+ if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes))) {
+ Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+ diag << "failed to vectorize op";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MaskedVectorizeOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ onlyReadsHandle(getVectorSizes(), effects);
+}
+
+SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
+ OpBuilder b(getContext());
+ return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
+}
+
+//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a7c3c0094889..89140d42dd2f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -11,25 +11,20 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -65,6 +60,266 @@ static OpType getSingleOpOfType(Block &block) {
return res;
}
+/// Contains the vectorization state and related methods used across the
+/// vectorization process of a given operation.
+struct VectorizationState {
+ VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
+
+ /// Initializes the vectorization state, including the computation of the
+ /// canonical vector shape for vectorization.
+ LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
+ ArrayRef<int64_t> inputVectorSizes);
+
+ /// Returns the canonical vector shape used to vectorize the iteration space.
+ ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
+
+ /// Masks an operation with the canonical vector mask if the operation needs
+ /// masking. Returns the masked operation or the original operation if masking
+ /// is not needed. If provided, the canonical mask for this operation is
+ /// permuted using `maybeMaskingMap`.
+ Operation *maskOperation(RewriterBase &rewriter, Operation *opToMask,
+ LinalgOp linalgOp,
+ Optional<AffineMap> maybeMaskingMap = std::nullopt);
+
+private:
+ /// Initializes the iteration space static sizes using the Linalg op
+ /// information. This may become more complicated in the future.
+ void initIterSpaceStaticSizes(LinalgOp linalgOp) {
+ iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
+ }
+
+ /// Generates 'tensor.dim' operations for all the dynamic dimensions of the
+ /// iteration space to be vectorized and store them in
+ /// `iterSpaceDynamicSizes`.
+ LogicalResult precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
+ LinalgOp linalgOp);
+
+ /// Create or retrieve an existing mask value to mask `opToMask` in the
+ /// canonical vector iteration space. If `maybeMaskingMap` the mask is
+ /// permuted using that permutation map. If a new mask is created, it will be
+ /// cached for future users.
+ Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
+ LinalgOp linalgOp,
+ Optional<AffineMap> maybeMaskingMap);
+
+ // Holds the compile-time static sizes of the iteration space to vectorize.
+ // Dynamic dimensions are represented using ShapedType::kDynamicSize.
+ SmallVector<int64_t> iterSpaceStaticSizes;
+
+ /// Holds the runtime sizes of the iteration spaces to vectorize. Static
+ /// dimensions are represented with a empty value.
+ SmallVector<Value> iterSpaceDynamicSizes;
+
+ /// Holds the canonical vector shape used to vectorize the iteration space.
+ SmallVector<int64_t> canonicalVecShape;
+
+ /// Holds the active masks for permutations of the canonical vector iteration
+ /// space.
+ DenseMap<AffineMap, Value> activeMaskCache;
+
+ /// Global vectorization guard for the incoming rewriter. It's initialized
+ /// when the vectorization state is initialized.
+ OpBuilder::InsertionGuard rewriterGuard;
+};
+
+/// Generates 'tensor.dim' operations for all the dynamic dimensions of the
+/// iteration space to be vectorized and store them in
+/// `iterSpaceDynamicSizes`.
+LogicalResult
+VectorizationState::precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
+ LinalgOp linalgOp) {
+ // TODO: Support 0-d vectors.
+ for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
+ if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
+ // Add a empty value for static dimensions.
+ iterSpaceDynamicSizes.push_back(Value());
+ continue;
+ }
+
+ // Find an operand defined on this dimension of the iteration space to
+ // extract the runtime dimension size.
+ Value operand;
+ unsigned operandDimPos;
+ if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
+ operandDimPos)))
+ return failure();
+
+ Value dynamicDim = linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::DimOp>(
+ linalgOp.getLoc(), operand, operandDimPos)
+ : (Value)rewriter.create<memref::DimOp>(
+ linalgOp.getLoc(), operand, operandDimPos);
+ iterSpaceDynamicSizes.push_back(dynamicDim);
+ }
+
+ return success();
+}
+
+/// Initializes the vectorization state, including the computation of the
+/// canonical vector shape for vectorization.
+// TODO: Move this to the constructor when we can remove the failure cases.
+LogicalResult
+VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ // Initialize the insertion point.
+ rewriter.setInsertionPoint(linalgOp);
+
+ if (!inputVectorSizes.empty()) {
+ // Get the canonical vector shape from the input vector sizes provided. This
+ // path should be taken to vectorize code with dynamic shapes and when using
+ // vector sizes greater than the iteration space sizes.
+ canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
+ } else {
+ // Compute the canonical vector shape from the operation shape. If there are
+ // dynamic shapes, the operation won't be vectorized.
+ canonicalVecShape = linalgOp.getStaticLoopRanges();
+ }
+
+ LDBG("Canonical vector shape: ");
+ LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+
+ // Initialize iteration space static sizes.
+ initIterSpaceStaticSizes(linalgOp);
+
+ // Extract and register the runtime value of any potential dynamic shape
+ // needed to compute a mask during vectorization.
+ if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp)))
+ return failure();
+
+ if (ShapedType::isDynamicShape(canonicalVecShape))
+ return failure();
+ return success();
+}
+
+/// Create or retrieve an existing mask value to mask `opToMask` in the
+/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
+/// using that permutation map. If a new mask is created, it will be cached for
+/// future users.
+Value VectorizationState::getOrCreateMaskFor(
+ RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
+ Optional<AffineMap> maybeMaskingMap) {
+ // No mask is needed if the operation is not maskable.
+ auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
+ if (!maskableOp)
+ return Value();
+
+ assert(!maskableOp.isMasked() &&
+ "Masking an operation that is already masked");
+
+ // If no masking map was provided, use an identity map with the loop dims.
+ assert((!maybeMaskingMap || *maybeMaskingMap) &&
+ "Unexpected null mask permutation map");
+ AffineMap maskingMap =
+ maybeMaskingMap ? *maybeMaskingMap
+ : AffineMap::getMultiDimIdentityMap(
+ linalgOp.getNumLoops(), rewriter.getContext());
+
+ LDBG("Masking map: " << maskingMap << "\n");
+
+ // Return the active mask for the masking map of this operation if it was
+ // already created.
+ auto activeMaskIt = activeMaskCache.find(maskingMap);
+ if (activeMaskIt != activeMaskCache.end()) {
+ Value mask = activeMaskIt->second;
+ LDBG("Reusing mask: " << mask << "\n");
+ return mask;
+ }
+
+ // Compute permuted projection of the iteration space to be masked and the
+ // corresponding mask shape. If the resulting iteration space dimensions are
+ // static and identical to the mask shape, masking is not needed for this
+ // operation.
+ // TODO: Improve this check. Only projected permutation indexing maps are
+ // supported.
+ SmallVector<int64_t> permutedStaticSizes =
+ applyPermutationMap(maskingMap, ArrayRef<int64_t>(iterSpaceStaticSizes));
+ SmallVector<int64_t> maskShape =
+ applyPermutationMap(maskingMap, ArrayRef<int64_t>(canonicalVecShape));
+ LDBG("Mask shape: ");
+ LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+
+ if (permutedStaticSizes == maskShape) {
+ LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+ activeMaskCache[maskingMap] = Value();
+ return Value();
+ }
+
+ // Compute the mask upper bound values by combining the permuted iteration
+ // space static sizes and the dynamic values.
+ SmallVector<Value> permutedDynamicSizes =
+ applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceDynamicSizes));
+ SmallVector<Value> upperBounds;
+ for (auto [staticBound, dynBound] :
+ llvm::zip(permutedStaticSizes, permutedDynamicSizes))
+ upperBounds.push_back(ShapedType::isDynamic(staticBound)
+ ? dynBound
+ : rewriter.create<arith::ConstantIndexOp>(
+ linalgOp.getLoc(), staticBound));
+
+ assert(!maskShape.empty() && !upperBounds.empty() &&
+ "Masked 0-d vectors are not supported yet");
+
+ // Create the mask based on the dimension size values.
+ auto maskType = VectorType::get(maskShape, rewriter.getI1Type());
+ Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
+ maskType, upperBounds);
+ LDBG("Creating new mask: " << mask << "\n");
+ activeMaskCache[maskingMap] = mask;
+ return mask;
+}
+
+/// Masks an operation with the canonical vector mask if the operation needs
+/// masking. Returns the masked operation or the original operation if masking
+/// is not needed. If provided, the canonical mask for this operation is
+/// permuted using `maybeMaskingMap`.
+Operation *
+VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
+ LinalgOp linalgOp,
+ Optional<AffineMap> maybeMaskingMap) {
+ LDBG("Trying to mask: " << *opToMask << "\n");
+
+ // Create or retrieve mask for this operation.
+ Value mask =
+ getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
+
+ if (!mask) {
+ LDBG("No mask required\n");
+ return opToMask;
+ }
+
+ // Wrap the operation with a new `vector.mask` and update D-U chain.
+ assert(opToMask && "Expected a valid operation to mask");
+ auto opResults = opToMask->getResultTypes();
+ auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) {
+ Block *insBlock = builder.getInsertionBlock();
+ // Create a block, put an op in that block. Look for a utility.
+ // Maybe in conversion pattern rewriter. Way to avoid splice.
+ // Set insertion point.
+ insBlock->getOperations().splice(
+ insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask);
+ builder.create<vector::YieldOp>(loc, opToMask->getResults());
+ };
+ // TODO: Allow multiple results in vector.mask.
+ auto maskOp =
+ opResults.empty()
+ ? rewriter.create<vector::MaskOp>(opToMask->getLoc(), mask,
+ createRegionMask)
+ : rewriter.create<vector::MaskOp>(opToMask->getLoc(),
+ opToMask->getResultTypes().front(),
+ mask, createRegionMask);
+
+ Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
+
+ for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
+ rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
+ maskOpTerminator);
+
+ LDBG("Masked operation: " << *maskOp << "\n");
+ return maskOp;
+}
+
/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
/// projectedPermutation, compress the unused dimensions to serve as a
/// permutation_map for a vector transfer operation.
@@ -204,35 +459,44 @@ static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
/// Return the produced value or null if no value is produced.
// Note: this is a true builder that notifies the OpBuilder listener.
// TODO: Consider moving as a static helper on the ReduceOp.
-static Value buildVectorWrite(OpBuilder &b, Value value,
- OpOperand *outputOperand) {
- Operation *write;
+static Value buildVectorWrite(RewriterBase &rewriter, Value value,
+ OpOperand *outputOperand,
+ VectorizationState &state) {
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
- ArrayRef<int64_t> shape = linalgOp.getShape(outputOperand);
- auto vectorType = VectorType::get(
- shape, getElementTypeOrSelf(outputOperand->get().getType()));
+ AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
+ auto vectorType =
+ VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()),
+ getElementTypeOrSelf(outputOperand->get().getType()));
+
+ Operation *write;
if (vectorType.getRank() > 0) {
- // 0-d case is still special: do not invert the reindexing map.
- AffineMap map =
- reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand));
- SmallVector<int64_t> transposeShape =
- applyPermutationMap(inversePermutation(map), vectorType.getShape());
- assert(!transposeShape.empty() && "unexpected empty transpose shape");
- vectorType = VectorType::get(transposeShape, vectorType.getElementType());
+ AffineMap writeMap = reindexIndexingMap(opOperandMap);
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
- b.create<arith::ConstantIndexOp>(loc, 0));
- value = broadcastIfNeeded(b, value, vectorType.getShape());
- write = b.create<vector::TransferWriteOp>(
- loc, value, outputOperand->get(), indices, map);
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ value = broadcastIfNeeded(rewriter, value, vectorType.getShape());
+ write = rewriter.create<vector::TransferWriteOp>(
+ loc, value, outputOperand->get(), indices, writeMap);
} else {
+ // 0-d case is still special: do not invert the reindexing writeMap.
if (!value.getType().isa<VectorType>())
- value = b.create<vector::BroadcastOp>(loc, vectorType, value);
+ value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
assert(value.getType() == vectorType && "incorrect type");
- write = b.create<vector::TransferWriteOp>(
+ write = rewriter.create<vector::TransferWriteOp>(
loc, value, outputOperand->get(), ValueRange{});
}
- LDBG("vectorized op: " << *write);
+
+ write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+
+ // If masked, set in-bounds to true. Masking guarantees that the access will
+ // be in-bounds.
+ if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
+ auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
+ SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
+ maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+ }
+
+ LDBG("vectorized op: " << *write << "\n");
if (!write->getResults().empty())
return write->getResult(0);
return Value();
@@ -259,20 +523,22 @@ using CustomVectorizationHook = std::function<VectorizationResult(
/// CustomVectorizationHook.
static VectorizationResult
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
- const BlockAndValueMapping &bvm, LinalgOp linalgOp,
- SmallVectorImpl<Value> &newResults) {
+ const BlockAndValueMapping &bvm, VectorizationState &state,
+ LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
- for (const auto &outputs : llvm::enumerate(yieldOp.getValues())) {
+ for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
- Value vectorValue = bvm.lookup(outputs.value());
- Value newResult = buildVectorWrite(
- rewriter, vectorValue, linalgOp.getDpsInitOperand(outputs.index()));
+ Value vectorValue = bvm.lookup(output.value());
+ Value newResult =
+ buildVectorWrite(rewriter, vectorValue,
+ linalgOp.getDpsInitOperand(output.index()), state);
if (newResult)
newResults.push_back(newResult);
}
+
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
}
@@ -464,7 +730,7 @@ static VectorizationResult
vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
const BlockAndValueMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
- LDBG("vectorize op " << *op);
+ LDBG("vectorize op " << *op << "\n");
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
@@ -561,8 +827,10 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
/// This is not deemed a problem as we expect canonicalizations and foldings to
/// aggressively clean up the useless work.
static LogicalResult
-vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
+vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp,
SmallVectorImpl<Value> &newResults) {
+ LDBG("Vectorizing operation as linalg generic\n");
Block *block = linalgOp.getBlock();
// 2. Values defined above the region can only be broadcast for now. Make them
@@ -575,11 +843,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
if (linalgOp.getNumDpsInits() == 0)
return failure();
- // TODO: the common vector shape is equal to the static loop sizes only when
- // all indexing maps are projected permutations. For convs and stencils the
- // logic will need to evolve.
- SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
-
// 3. Turn all BBArgs into vector.transfer_read / load.
Location loc = linalgOp.getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
@@ -589,35 +852,60 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
bvm.map(bbarg, opOperand->get());
continue;
}
- VectorType readType;
- AffineMap map;
- // TODO: can we keep this simplification?
- // if (linalgOp.getShape(&opOperand).empty()) {
- // readType = VectorType::get({}, bbarg.getType());
- // } else {
- if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) {
- map = inverseAndBroadcastProjectedPermutation(
- linalgOp.getMatchingIndexingMap(opOperand));
- readType = VectorType::get(commonVectorShape,
- getElementTypeOrSelf(opOperand->get()));
+
+ // 3.a. Convert the indexing map for this input/output to a transfer read
+ // permutation map and masking map.
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+
+ // Remove zeros from indexing map to use it as masking map.
+ SmallVector<int64_t> zeroPos;
+ auto results = indexingMap.getResults();
+ for (auto result : llvm::enumerate(results)) {
+ if (result.value().isa<AffineConstantExpr>()) {
+ zeroPos.push_back(result.index());
+ }
+ }
+ AffineMap maskingMap = indexingMap.dropResults(zeroPos);
+
+ AffineMap readMap;
+ SmallVector<int64_t> readVecShape;
+ if (linalgOp.isDpsInput(opOperand)) {
+ // 3.a.i. For input reads we use the canonical vector shape.
+ readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
+ readVecShape = llvm::to_vector(state.getCanonicalVecShape());
} else {
- map = inversePermutation(
- reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
- readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
- getElementTypeOrSelf(opOperand->get()));
+ // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
+ // reductions), the vector shape is computed by mapping the canonical
+ // vector shape to the output domain and back to the canonical domain.
+ readMap = inversePermutation(reindexIndexingMap(indexingMap));
+ readVecShape =
+ readMap.compose(indexingMap.compose(state.getCanonicalVecShape()));
}
- // }
- auto shape = linalgOp.getShape(opOperand);
- SmallVector<Value> indices(shape.size(), zero);
- Value readValue = rewriter.create<vector::TransferReadOp>(
- loc, readType, opOperand->get(), indices, map);
- // Not all ops support 0-d vectors, extract the scalar for now.
+ auto readType =
+ VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
+ SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
+
+ Operation *read = rewriter.create<vector::TransferReadOp>(
+ loc, readType, opOperand->get(), indices, readMap);
+ read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
+ Value readValue = read->getResult(0);
+
+ // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
+ // will be in-bounds.
+ if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
+ SmallVector<bool> inBounds(readType.getRank(), true);
+ cast<vector::TransferReadOp>(maskOp.getMaskableOp())
+ .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+ }
+
+ // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readValue.getType().cast<VectorType>().getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
- LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
+ LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
+ << "\n");
bvm.map(bbarg, readValue);
bvm.map(opOperand->get(), readValue);
}
@@ -627,7 +915,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
CustomVectorizationHook vectorizeYield =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
- return vectorizeLinalgYield(rewriter, op, bvm, linalgOp, newResults);
+ return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
};
hooks.push_back(vectorizeYield);
@@ -652,12 +940,14 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
VectorizationResult result =
vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
- LDBG("failed to vectorize: " << op);
+ LDBG("failed to vectorize: " << op << "\n");
return failure();
}
if (result.status == VectorizationStatus::NewOp) {
- LDBG("new vector op: " << *result.newOp;);
- bvm.map(op.getResults(), result.newOp->getResults());
+ Operation *maybeMaskedOp =
+ state.maskOperation(rewriter, result.newOp, linalgOp);
+ LDBG("New vector op: " << *maybeMaskedOp << "\n");
+ bvm.map(op.getResults(), maybeMaskedOp->getResults());
}
}
@@ -668,7 +958,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
- LDBG("reduction precondition failed: no reduction iterator");
+ LDBG("reduction precondition failed: no reduction iterator\n");
return failure();
}
for (OpOperand *opOperand : op.getDpsInitOperands()) {
@@ -678,20 +968,69 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
Operation *reduceOp = matchLinalgReduction(opOperand);
if (!reduceOp || !getCombinerOpKind(reduceOp)) {
- LDBG("reduction precondition failed: reduction detection failed");
+ LDBG("reduction precondition failed: reduction detection failed\n");
return failure();
}
}
return success();
}
-static LogicalResult vectorizeStaticLinalgOpPrecondition(
- linalg::LinalgOp op,
- ArrayRef<CustomVectorizationPrecondition> customPreconditions,
- bool vectorizeNDExtract) {
+static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+ // TODO: Masking only supports dynamic generic ops without reductions for now.
+ if (!isElementwise(op) &&
+ llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) {
+ return itType != utils::IteratorType::parallel;
+ }))
+ return failure();
+
+ // TODO: 0-d vectors are not supported yet.
+ if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) {
+ return map.isEmpty() || map.getResults().empty();
+ }))
+ return failure();
+
+ LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
+ return success();
+}
+
+LogicalResult
+mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ bool vectorizeNDExtract) {
+ // Check API contract for input vector sizes.
+ if (!inputVectorSizes.empty()) {
+ assert(inputVectorSizes.size() == linalgOp.getNumLoops() &&
+ "Input vector sizes don't match the number of loops");
+ assert(!ShapedType::isDynamicShape(inputVectorSizes) &&
+ "Input vector sizes can't have dynamic dimensions");
+ assert(llvm::all_of(
+ llvm::zip(linalgOp.getStaticLoopRanges(), inputVectorSizes),
+ [](std::tuple<int64_t, int64_t> sizePair) {
+ int64_t staticSize = std::get<0>(sizePair);
+ int64_t inputSize = std::get<1>(sizePair);
+ return ShapedType::isDynamic(staticSize) ||
+ staticSize <= inputSize;
+ }) &&
+ "Input vector sizes must be smaller or equal than iteration space "
+ "static sizes");
+ }
+
+ // TODO: Masking is only supported for dynamic shapes so input vector sizes
+ // must be empty if the op is not dynamic.
+ if (!linalgOp.hasDynamicShape() && !inputVectorSizes.empty())
+ return failure();
+
+ if (linalgOp.hasDynamicShape() &&
+ failed(vectorizeDynamicLinalgOpPrecondition(linalgOp)))
+ return failure();
+
+ SmallVector<CustomVectorizationPrecondition> customPreconditions;
+
+ // Register CustomVectorizationPrecondition for extractOp.
+ customPreconditions.push_back(tensorExtractVectorizationPrecondition);
// All types in the body should be a supported element type for VectorType.
- for (Operation &innerOp : op->getRegion(0).front()) {
+ for (Operation &innerOp : linalgOp->getRegion(0).front()) {
// Check if any custom hook can vectorize the inner op.
if (llvm::any_of(
customPreconditions,
@@ -712,50 +1051,52 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(
return failure();
}
}
- if (isElementwise(op))
+ if (isElementwise(linalgOp))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.
// But we will still need stride/dilation attributes that will be annoying to
// reverse-engineer...
- if (isa<ConvolutionOpInterface>(op.getOperation()))
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
return success();
// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
- if (!allIndexingsAreProjectedPermutation(op)) {
- LDBG("precondition failed: not projected permutations");
+ if (!allIndexingsAreProjectedPermutation(linalgOp)) {
+ LDBG("precondition failed: not projected permutations\n");
return failure();
}
- if (failed(reductionPreconditions(op))) {
- LDBG("precondition failed: reduction preconditions");
+ if (failed(reductionPreconditions(linalgOp))) {
+ LDBG("precondition failed: reduction preconditions\n");
return failure();
}
return success();
}
-LogicalResult
-mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
- bool vectorizeNDExtract) {
- // All types must be static shape to go to vector.
- if (linalgOp.hasDynamicShape()) {
- LDBG("precondition failed: dynamic shape");
- return failure();
- }
-
- SmallVector<CustomVectorizationPrecondition> customPreconditions;
-
- // Register CustomVectorizationPrecondition for extractOp.
- customPreconditions.push_back(tensorExtractVectorizationPrecondition);
-
- return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions,
- vectorizeNDExtract);
-}
-
+/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
+/// are used to vectorize this operation. `inputVectorSizes` must match the rank
+/// of the iteration space of the operation and the sizes must be smaller or
+/// equal than their counterpart interation space sizes, if static.
+/// `inputVectorShapes` also allows the vectorization of operations with dynamic
+/// shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+ ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract) {
- if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
+ LDBG("Attempting to vectorize:\n" << linalgOp << "\n");
+ LDBG("Input vector sizes: ");
+ LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+ vectorizeNDExtract)))
return failure();
+ // Initialize vectorization state.
+ VectorizationState state(rewriter);
+ if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) {
+ LDBG("Vectorization state couldn't be initialized\n");
+ return failure();
+ }
+
SmallVector<Value> results;
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
@@ -763,10 +1104,16 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
} else {
- if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+ vectorizeNDExtract)))
return failure();
- LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
- if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
+ LDBG("Vectorize generic by broadcasting to the canonical vector shape\n");
+ // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to
+ // 'OpBuilder' when it is passed over to some methods like
+ // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op
+ // within these methods, the actual rewriter won't be notified and we will
+ // end up with read-after-free issues!
+ if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results)))
return failure();
}
@@ -1262,7 +1609,7 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
if (firstOp->getBlock() != secondOp->getBlock() ||
!firstOp->isBeforeInBlock(secondOp)) {
LDBG("interleavedUses precondition failed, firstOp: "
- << *firstOp << ", second op: " << *secondOp);
+ << *firstOp << ", second op: " << *secondOp << "\n");
return true;
}
for (auto v : values) {
@@ -1275,7 +1622,7 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
continue;
LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
- << ", second op: " << *secondOp);
+ << ", second op: " << *secondOp << "\n");
return true;
}
}
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 8dfa96fbd68a..596f6422807c 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -5,6 +5,8 @@ add_mlir_dialect_library(MLIRVectorDialect
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR
DEPENDS
+ MLIRMaskableOpInterfaceIncGen
+ MLIRMaskingOpInterfaceIncGen
MLIRVectorOpsIncGen
MLIRVectorOpsEnumsIncGen
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 18dae281821e..4c772c2fb11d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -447,6 +447,15 @@ void ReductionOp::print(OpAsmPrinter &p) {
p << " : " << getVector().getType() << " into " << getDest().getType();
}
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation.
+Type ReductionOp::getExpectedMaskType() {
+ auto vecType = getVectorType();
+ return vecType.cloneWith(std::nullopt,
+ IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
OpBuilder &builder, Location loc,
Value vector) {
@@ -3461,6 +3470,14 @@ LogicalResult TransferReadOp::verify() {
[&](Twine t) { return emitOpError(t); });
}
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type TransferReadOp::getExpectedMaskType() {
+ return inferTransferReadMaskType(getVectorType(), getPermutationMap());
+}
+
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
@@ -3903,6 +3920,14 @@ LogicalResult TransferWriteOp::verify() {
[&](Twine t) { return emitOpError(t); });
}
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes.
+Type TransferWriteOp::getExpectedMaskType() {
+ return inferTransferWriteMaskType(getVectorType(), getPermutationMap());
+}
+
/// Fold:
/// ```
/// %t1 = ...
@@ -5377,9 +5402,10 @@ LogicalResult MaskOp::verify() {
"expects result type to match maskable operation result type");
// Mask checks.
- if (getMask().getType() != maskableOp.getExpectedMaskType())
- return emitOpError("expects a ") << maskableOp.getExpectedMaskType()
- << " mask for the maskable operation";
+ Type expectedMaskType = maskableOp.getExpectedMaskType();
+ if (getMask().getType() != expectedMaskType)
+ return emitOpError("expects a ")
+ << expectedMaskType << " mask for the maskable operation";
// Passthru checks.
Value passthru = getPassthru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index aa79e54b9b30..b225662e58c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -109,15 +109,6 @@ public:
}
};
-/// Populates instances of `MaskOpRewritePattern` to lower masked operations
-/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
-/// not its nested `MaskableOpInterface`.
-void populateVectorMaskLoweringPatternsForSideEffectingOps(
- RewritePatternSet &patterns) {
- patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
- patterns.getContext());
-}
-
struct LowerVectorMaskPass
: public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
using Base::Base;
@@ -141,6 +132,15 @@ struct LowerVectorMaskPass
} // namespace
+/// Populates instances of `MaskOpRewritePattern` to lower masked operations
+/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
+/// not its nested `MaskableOpInterface`.
+void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
+ RewritePatternSet &patterns) {
+ patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
+ patterns.getContext());
+}
+
std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
return std::make_unique<LowerVectorMaskPass>();
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a9a536ae596b..96c81d1593a3 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1,7 +1,5 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
-// -----
-
// CHECK-LABEL: contraction_dot
func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
@@ -130,7 +128,7 @@ transform.sequence failures(propagate) {
// CHECK-LABEL: func @generic_output_transpose
func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
- %C: memref<32x8xf32>) {
+ %C: memref<32x8xf32>) {
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
// CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
@@ -1608,3 +1606,147 @@ transform.sequence failures(propagate) {
%1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
%2 = transform.structured.vectorize %1
}
+
+// -----
+
+func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
+ %arg1: tensor<?xf32>,
+ %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"] }
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @vectorize_dynamic_identity
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
+// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1>
+// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<4xf32>
+// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.masked_vectorize %0 vector_sizes [4]
+}
+
+// -----
+
+func.func @vectorize_dynamic_1d_broadcast(%arg0: tensor<?xf32>,
+ %arg1: tensor<?xf32>,
+ %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"] }
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: @vectorize_dynamic_1d_broadcast
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
+// CHECK: %[[VAL_7:.*]] = vector.transfer_read %{{.*}} {permutation_map = #{{.*}}} : tensor<?xf32>, vector<4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1>
+// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_7]], %[[VAL_10]] : vector<4xf32>
+// CHECK: %[[VAL_14:.*]] = vector.mask %{{.*}} { vector.transfer_write %[[VAL_13]], {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.masked_vectorize %0 vector_sizes [4]
+}
+
+// -----
+
+func.func @vectorize_dynamic_2d_transpose(%arg0: tensor<?x?xf32>,
+ %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @vectorize_dynamic_2d_transpose
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?x?xf32>
+// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_6:.*]] = tensor.dim %{{.*}}, %[[VAL_5]] : tensor<?x?xf32>
+// CHECK: %[[VAL_9:.*]] = vector.create_mask %[[VAL_6]], %[[VAL_4]] : vector<8x4xi1>
+// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<4x8xf32> } : vector<8x4xi1> -> vector<4x8xf32>
+// CHECK: %[[VAL_12:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<4x8xi1>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK: %[[VAL_14:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_10]], %[[VAL_13]] : vector<4x8xf32>
+// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_12]] { vector.transfer_write %[[VAL_16]], %{{.*}} {in_bounds = [true, true]} : vector<4x8xf32>, tensor<?x?xf32> } : vector<4x8xi1> -> tensor<?x?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.masked_vectorize %0 vector_sizes [4, 8]
+}
+
+// -----
+
+func.func @vectorize_dynamic_generic_2d_broadcast(%arg0: tensor<?x?xf32>,
+ %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb(%in0: f32, %in1: f32, %out: f32) :
+ %0 = arith.addf %in0, %in1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @vectorize_dynamic_generic_2d_broadcast
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?x?xf32>
+// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_6:.*]] = tensor.dim %{{.*}}, %[[VAL_5]] : tensor<?x?xf32>
+// CHECK: %[[VAL_9:.*]] = vector.create_mask %[[VAL_6]] : vector<8xi1>
+// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<4x8xf32> } : vector<8xi1> -> vector<4x8xf32>
+// CHECK: %[[VAL_12:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<4x8xi1>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_10]], %[[VAL_13]] : vector<4x8xf32>
+// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_12]] { vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<4x8xf32>, tensor<?x?xf32> } : vector<4x8xi1> -> tensor<?x?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.masked_vectorize %0 vector_sizes [4, 8]
+}
+