1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
|
//===- LowerMemorySpaceAttributes.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// Implementation of a pass that rewrites the IR so that uses of
/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced
/// with caller-specified numeric values.
///
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::gpu;
//===----------------------------------------------------------------------===//
// Conversion Target
//===----------------------------------------------------------------------===//
/// Returns true if the given `type` is considered as legal during memory space
/// attribute lowering.
static bool isLegalType(Type type) {
if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
return !memRefType.getMemorySpace()
.isa_and_nonnull<gpu::AddressSpaceAttr>();
}
return true;
}
/// Returns true if the given `attr` is considered legal during memory space
/// attribute lowering.
static bool isLegalAttr(Attribute attr) {
if (auto typeAttr = attr.dyn_cast<TypeAttr>())
return isLegalType(typeAttr.getValue());
return true;
}
/// Returns true if the given `op` is legal during memory space attribute
/// lowering.
static bool isLegalOp(Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
isLegalType);
}
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
return attr.getValue();
});
return llvm::all_of(op->getOperandTypes(), isLegalType) &&
llvm::all_of(op->getResultTypes(), isLegalType) &&
llvm::all_of(attrs, isLegalAttr);
}
void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) {
target.markUnknownOpDynamicallyLegal(isLegalOp);
}
//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//
IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addConversion([mapping](Type type) -> Optional<Type> {
auto subElementType = type.dyn_cast_or_null<SubElementTypeInterface>();
if (!subElementType)
return type;
Type newType = subElementType.replaceSubElements(
[mapping](Attribute attr) -> std::optional<Attribute> {
auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
if (!memorySpaceAttr)
return std::nullopt;
auto newValue = wrapNumericMemorySpace(
attr.getContext(), mapping(memorySpaceAttr.getValue()));
return newValue;
});
return newType;
});
}
namespace {
/// Converts any op that has operands/results/attributes with numeric MemRef
/// memory spaces.
struct LowerMemRefAddressSpacePattern final : public ConversionPattern {
LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<NamedAttribute> newAttrs;
newAttrs.reserve(op->getAttrs().size());
for (auto attr : op->getAttrs()) {
if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
} else {
newAttrs.push_back(attr);
}
}
SmallVector<Type> newResults;
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttrs, op->getSuccessors());
for (Region ®ion : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
} // namespace
void mlir::gpu::populateMemorySpaceLoweringPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<LowerMemRefAddressSpacePattern>(patterns.getContext(),
typeConverter);
}
namespace {
class LowerMemorySpaceAttributesPass
: public mlir::impl::GPULowerMemorySpaceAttributesPassBase<
LowerMemorySpaceAttributesPass> {
public:
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
ConversionTarget target(getContext());
populateLowerMemorySpaceOpLegality(target);
TypeConverter typeConverter;
typeConverter.addConversion([](Type t) { return t; });
populateMemorySpaceAttributeTypeConversions(
typeConverter, [this](AddressSpace space) -> unsigned {
switch (space) {
case AddressSpace::Global:
return globalAddrSpace;
case AddressSpace::Workgroup:
return workgroupAddrSpace;
case AddressSpace::Private:
return privateAddrSpace;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
RewritePatternSet patterns(context);
populateMemorySpaceLoweringPatterns(typeConverter, patterns);
if (failed(applyFullConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
|