//===- DataLayoutPropagation.cpp -----------------------------------------===/// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::linalg; #define DEBUG_TYPE "linalg-data-layout-propagation" namespace { // The struct contains the infomation about mapping packing information to // the iteration domain of Linalg ops. struct PackInfo { int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; // InnerDimsPos on iteration domain, which follows the order in pack ops. SmallVector tiledDimsPos; // The sizes of tiling data dimensions on iteration domain. llvm::DenseMap domainDimAndTileMapping; // The mapping from a dimension of iteration domain to the corresponding inner // tiling dimension on iteration domain. llvm::DenseMap tileToPointMapping; // The permutation of outer dims (on domain). SmallVector outerDimsOnDomainPerm; Optional paddingValue; }; static PackInfo getPackingInfoFromConsumer( AffineMap indexingMap, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm, Optional paddingValue = std::nullopt) { LLVM_DEBUG( { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; }); PackInfo packInfo; packInfo.paddingValue = paddingValue; int64_t origNumDims = indexingMap.getNumDims(); SmallVector exprs(indexingMap.getResults()); for (auto [index, innerDimPos, tileSize] : llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), innerDimsPos, innerTileSizes)) { int64_t domainDimPos = exprs[innerDimPos].cast().getPosition(); packInfo.tiledDimsPos.push_back(domainDimPos); packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; LLVM_DEBUG({ llvm::dbgs() << "map innerDimPos=" << innerDimPos << " to iteration dimension (d" << domainDimPos << ", d" << packInfo.tileToPointMapping[domainDimPos] << "), which has size=(" << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; }); } for (auto dim : outerDimsPerm) packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim)); if (!packInfo.outerDimsOnDomainPerm.empty()) { LLVM_DEBUG({ llvm::dbgs() << "map outer dimsDimsPerm to "; for (auto dim : packInfo.outerDimsOnDomainPerm) llvm::dbgs() << dim << " "; llvm::dbgs() << "\n"; }); } return packInfo; } /// Returns a tuple for packed operand and indexing_map with the assumptions: /// 1) The generic op is the producer of the pack op. /// 2) The generic op has only one result. /// If the operand is a scalar or packing dimensions are all irrelevant to the /// operand, the opreand and the updated indexing map will be returned. /// Otherwise, it returns the packed operand and the updated indexing map. E.g., /// /// #map0 = affine_map<(d0, d1) -> (d0, d1)> /// #map1 = affine_map<(d0, d1) -> (d0)> /// #map2 = affine_map<(d0, d1) -> (d1)> /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], /// iterator_types = ["parallel", "parallel"]} /// ins(%arg0, %arg1 : tensor, tensor) /// outs(%init : tensor) { /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): /// %4 = arith.addf %arg3, %arg4 : f32 /// linalg.yield %4 : f32 /// } -> tensor /// %1 = tensor.pack %0 /// inner_dims_pos = [0, 1] /// inner_tiles = [8, 2] /// into %dest : tensor -> tensor /// /// Taking the first input operand as an example, the inner tile size of d1 is /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> /// affine_map<(d1, d3)>` will be returned. /// /// %pack = tensor.pack %arg0 /// inner_dims_pos = [0] /// inner_tiles = [8] /// into %init : tensor -> tensor static std::tuple getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand) { int64_t numOrigLoops = genericOp.getNumLoops(); int64_t numInnerLoops = packInfo.getNumTiledLoops(); int64_t numLoops = numOrigLoops + numInnerLoops; AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); llvm::DenseMap domainDimToOperandDim; SmallVector exprs(origIndexingMap.getResults()); if (genericOp.isScalar(opOperand)) return std::make_tuple(opOperand->get(), AffineMap::get(numLoops, 0, exprs, b.getContext())); // Step 1. Construct the information of packing data dimensions; append inner // dimensions to the indexing maps for the operand. for (auto [index, expr] : llvm::enumerate(exprs)) { int64_t dimPos = expr.cast().getPosition(); domainDimToOperandDim[dimPos] = index; } SmallVector innerDimsPos; SmallVector innerTileSizes; for (auto dimPos : packInfo.tiledDimsPos) { if (!domainDimToOperandDim.count(dimPos)) continue; int64_t index = domainDimToOperandDim[dimPos]; innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); innerDimsPos.push_back(index); exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); } // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op. // TODO: should we propagate the permutation of outer dims to the pack op? SmallVector outerDimsPerm; if (!packInfo.outerDimsOnDomainPerm.empty()) { SmallVector inversedOuterPerm = invertPermutationVector(packInfo.outerDimsOnDomainPerm); for (auto i : llvm::seq(0, origIndexingMap.getNumResults())) { int64_t dimPos = exprs[i].cast().getPosition(); exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); } } auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); // The operand does not have dimensions that relates to pack op. if (innerDimsPos.empty()) return std::make_tuple(opOperand->get(), indexingMap); auto empty = tensor::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto packedOperand = b.create( loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, packInfo.paddingValue, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); } /// Bubbles up tensor.pack op through elementwise generic op. This /// swap pack(generic) to generic(pack). The new generic op works on packed /// domain; pack ops are created for input and output operands. E.g., /// /// #map0 = affine_map<(d0, d1) -> (d0, d1)> /// %0 = tensor.dim %arg0, %c0 : tensor /// %1 = tensor.dim %arg0, %c1 : tensor /// %2 = tensor.empty(%0, %1) : tensor /// %3 = linalg.generic {indexing_maps = [#map0, #map0], /// iterator_types = ["parallel", "parallel"]} /// ins(%arg0 : tensor) /// outs(%2 : tensor) { /// ^bb0(%arg3: f32, %arg4: f32): /// %4 = arith.addf %arg3, %arg3 : f32 /// linalg.yield %4 : f32 /// } -> tensor /// %4 = tensor.pack %3 /// inner_dims_pos = [0, 1] /// inner_tiles = [8, 2] /// into %dest : tensor -> tensor /// /// will be converted to /// /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> /// %dim = tensor.dim %arg0, %c0 : tensor /// %dim_0 = tensor.dim %arg0, %c1 : tensor /// %0 = affine.apply #map()[%dim] /// %1 = affine.apply #map1()[%dim_0] /// %2 = tensor.empty(%0, %1) : tensor /// %pack = tensor.pack %arg0 /// inner_dims_pos = [0, 1] /// inner_tiles = [8, 2] /// into %2 : tensor -> tensor /// %3 = linalg.generic {indexing_maps = [#map2, #map2], /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} /// ins(%pack : tensor) /// outs(%arg1 : tensor) { /// ^bb0(%in: f32, %out: f32): /// %4 = arith.addf %in, %in : f32 /// linalg.yield %4 : f32 /// } -> tensor static FailureOr bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, tensor::PackOp packOp) { auto genericOp = packOp.getSource().getDefiningOp(); if (!genericOp) return failure(); if (!isElementwise(genericOp)) return failure(); // TODO: Relax the restriction. We are able to bubble up the pack op through // multi-result generic op. It just needs more work. if (genericOp.getNumResults() != 1) return failure(); // TODO: Add an option for allowing padding values. It could introduce // undefined behavior if we unconditionally propagate pack op through all // the ops. E.g., if the padding value is zero and there are division ops in // a generic op. Some values of padding area could be NaN (0/0). if (packOp.getPaddingValue()) return failure(); OpOperand *opOperand = genericOp.getDpsInitOperand(0); auto packInfo = getPackingInfoFromConsumer( genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(), packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(), packOp.getPaddingValue()); Location loc = packOp.getLoc(); SmallVector inputOperands; SmallVector indexingMaps; for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, packInfo, genericOp, inputOperand); inputOperands.push_back(packedOperand); indexingMaps.push_back(packedIndexingMap); } int64_t numLoops = genericOp.getNumLoops(); int64_t numInnerLoops = packInfo.getNumTiledLoops(); int64_t newNumLoops = numLoops + numInnerLoops; SmallVector iterTypes = genericOp.getIteratorTypesArray(); iterTypes.append(numInnerLoops, utils::IteratorType::parallel); // Rebuild the indexing map for the corresponding init operand. auto [packedOutOperand, packedOutIndexingMap] = getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp, opOperand); SmallVector outExprs( packedOutIndexingMap.getResults().drop_back(numInnerLoops)); // Apply transpose to the indexing map, because we'll replace the init operand // with the destination of pack op. auto outerDimsPerm = packOp.getOuterDimsPerm(); if (!outerDimsPerm.empty()) { applyPermutationToVector(outExprs, outerDimsPerm); } for (int i = 0; i < numInnerLoops; ++i) outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i)); AffineMap outMap = AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()); indexingMaps.push_back(outMap); auto newGenericOp = rewriter.create( loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps, iterTypes, /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().begin()); return newGenericOp; } // Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method. struct BubbleUpPackOpThroughElemGenericOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override { auto genericOp = bubbleUpPackOpThroughElemGenericOp(rewriter, packOp); if (failed(genericOp)) return failure(); rewriter.replaceOp(packOp, genericOp.value().getResults()); return success(); } }; } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( RewritePatternSet &patterns) { patterns.insert( patterns.getContext()); }