summaryrefslogtreecommitdiff
path: root/mlir/lib/Interfaces
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Interfaces')
-rw-r--r--mlir/lib/Interfaces/ViewLikeInterface.cpp55
1 files changed, 49 insertions, 6 deletions
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 49331b516468..dfeda72b3811 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -182,29 +182,72 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
SmallVector<OpFoldResult, 4>
mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticOffsets, ValueRange offsets) {
- return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator());
+ SmallVector<OpFoldResult, 4> res;
+ unsigned numDynamic = 0;
+ unsigned count = static_cast<unsigned>(staticOffsets.size());
+ for (unsigned idx = 0; idx < count; ++idx) {
+ if (op.isDynamicOffset(idx))
+ res.push_back(offsets[numDynamic++]);
+ else
+ res.push_back(staticOffsets[idx]);
+ }
+ return res;
}
SmallVector<OpFoldResult, 4>
mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
ValueRange sizes) {
- return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator());
+ SmallVector<OpFoldResult, 4> res;
+ unsigned numDynamic = 0;
+ unsigned count = static_cast<unsigned>(staticSizes.size());
+ for (unsigned idx = 0; idx < count; ++idx) {
+ if (op.isDynamicSize(idx))
+ res.push_back(sizes[numDynamic++]);
+ else
+ res.push_back(staticSizes[idx]);
+ }
+ return res;
}
SmallVector<OpFoldResult, 4>
mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticStrides, ValueRange strides) {
- return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator());
+ SmallVector<OpFoldResult, 4> res;
+ unsigned numDynamic = 0;
+ unsigned count = static_cast<unsigned>(staticStrides.size());
+ for (unsigned idx = 0; idx < count; ++idx) {
+ if (op.isDynamicStride(idx))
+ res.push_back(strides[numDynamic++]);
+ else
+ res.push_back(staticStrides[idx]);
+ }
+ return res;
+}
+
+static std::pair<ArrayAttr, SmallVector<Value>>
+decomposeMixedImpl(OpBuilder &b,
+ const SmallVectorImpl<OpFoldResult> &mixedValues,
+ const int64_t dynamicValuePlaceholder) {
+ SmallVector<int64_t> staticValues;
+ SmallVector<Value> dynamicValues;
+ for (const auto &it : mixedValues) {
+ if (it.is<Attribute>()) {
+ staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
+ } else {
+ staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
+ dynamicValues.push_back(it.get<Value>());
+ }
+ }
+ return {b.getI64ArrayAttr(staticValues), dynamicValues};
}
std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
- return decomposeMixedValues(b, mixedValues,
- ShapedType::kDynamicStrideOrOffset);
+ return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
}
std::pair<ArrayAttr, SmallVector<Value>>
mlir::decomposeMixedSizes(OpBuilder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues) {
- return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize);
+ return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
}