diff options
Diffstat (limited to 'mlir/lib/Interfaces')
-rw-r--r-- | mlir/lib/Interfaces/ViewLikeInterface.cpp | 55 |
1 files changed, 49 insertions, 6 deletions
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp index 49331b516468..dfeda72b3811 100644 --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -182,29 +182,72 @@ bool mlir::detail::sameOffsetsSizesAndStrides( SmallVector<OpFoldResult, 4> mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op, ArrayAttr staticOffsets, ValueRange offsets) { - return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator()); + SmallVector<OpFoldResult, 4> res; + unsigned numDynamic = 0; + unsigned count = static_cast<unsigned>(staticOffsets.size()); + for (unsigned idx = 0; idx < count; ++idx) { + if (op.isDynamicOffset(idx)) + res.push_back(offsets[numDynamic++]); + else + res.push_back(staticOffsets[idx]); + } + return res; } SmallVector<OpFoldResult, 4> mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, ValueRange sizes) { - return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator()); + SmallVector<OpFoldResult, 4> res; + unsigned numDynamic = 0; + unsigned count = static_cast<unsigned>(staticSizes.size()); + for (unsigned idx = 0; idx < count; ++idx) { + if (op.isDynamicSize(idx)) + res.push_back(sizes[numDynamic++]); + else + res.push_back(staticSizes[idx]); + } + return res; } SmallVector<OpFoldResult, 4> mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op, ArrayAttr staticStrides, ValueRange strides) { - return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator()); + SmallVector<OpFoldResult, 4> res; + unsigned numDynamic = 0; + unsigned count = static_cast<unsigned>(staticStrides.size()); + for (unsigned idx = 0; idx < count; ++idx) { + if (op.isDynamicStride(idx)) + res.push_back(strides[numDynamic++]); + else + res.push_back(staticStrides[idx]); + } + return res; +} + +static std::pair<ArrayAttr, SmallVector<Value>> +decomposeMixedImpl(OpBuilder &b, + const SmallVectorImpl<OpFoldResult> &mixedValues, + const int64_t dynamicValuePlaceholder) { + SmallVector<int64_t> staticValues; + SmallVector<Value> dynamicValues; + for (const auto &it : mixedValues) { + if (it.is<Attribute>()) { + staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt()); + } else { + staticValues.push_back(ShapedType::kDynamicStrideOrOffset); + dynamicValues.push_back(it.get<Value>()); + } + } + return {b.getI64ArrayAttr(staticValues), dynamicValues}; } std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets( OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) { - return decomposeMixedValues(b, mixedValues, - ShapedType::kDynamicStrideOrOffset); + return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset); } std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedSizes(OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) { - return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize); + return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize); } |