1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
|
//===- AsmParserState.cpp -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Parser/AsmParserState.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// AsmParserState::Impl
//===----------------------------------------------------------------------===//
struct AsmParserState::Impl {
/// A map from a SymbolRefAttr to a range of uses.
using SymbolUseMap =
DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
struct PartialOpDef {
explicit PartialOpDef(const OperationName &opName) {
if (opName.hasTrait<OpTrait::SymbolTable>())
symbolTable = std::make_unique<SymbolUseMap>();
}
/// Return if this operation is a symbol table.
bool isSymbolTable() const { return symbolTable.get(); }
/// If this operation is a symbol table, the following contains symbol uses
/// within this operation.
std::unique_ptr<SymbolUseMap> symbolTable;
};
/// Resolve any symbol table uses in the IR.
void resolveSymbolUses();
/// A mapping from operations in the input source file to their parser state.
SmallVector<std::unique_ptr<OperationDefinition>> operations;
DenseMap<Operation *, unsigned> operationToIdx;
/// A mapping from blocks in the input source file to their parser state.
SmallVector<std::unique_ptr<BlockDefinition>> blocks;
DenseMap<Block *, unsigned> blocksToIdx;
/// A set of value definitions that are placeholders for forward references.
/// This map should be empty if the parser finishes successfully.
DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
/// The symbol table operations within the IR.
SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
symbolTableOperations;
/// A stack of partial operation definitions that have been started but not
/// yet finalized.
SmallVector<PartialOpDef> partialOperations;
/// A stack of symbol use scopes. This is used when collecting symbol table
/// uses during parsing.
SmallVector<SymbolUseMap *> symbolUseScopes;
/// A symbol table containing all of the symbol table operations in the IR.
SymbolTableCollection symbolTable;
};
void AsmParserState::Impl::resolveSymbolUses() {
SmallVector<Operation *> symbolOps;
for (auto &opAndUseMapIt : symbolTableOperations) {
for (auto &it : *opAndUseMapIt.second) {
symbolOps.clear();
if (failed(symbolTable.lookupSymbolIn(
opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
continue;
for (ArrayRef<SMRange> useRange : it.second) {
for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
auto opIt = operationToIdx.find(std::get<0>(symIt));
if (opIt != operationToIdx.end())
operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
}
}
}
}
}
//===----------------------------------------------------------------------===//
// AsmParserState
//===----------------------------------------------------------------------===//
AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
AsmParserState::~AsmParserState() = default;
AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
impl = std::move(other.impl);
return *this;
}
//===----------------------------------------------------------------------===//
// Access State
auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks));
}
auto AsmParserState::getBlockDef(Block *block) const
-> const BlockDefinition * {
auto it = impl->blocksToIdx.find(block);
return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
}
auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
}
auto AsmParserState::getOpDef(Operation *op) const
-> const OperationDefinition * {
auto it = impl->operationToIdx.find(op);
return it == impl->operationToIdx.end() ? nullptr
: &*impl->operations[it->second];
}
/// Lex a string token whose contents start at the given `curPtr`. Returns the
/// position at the end of the string, after a terminal or invalid character
/// (e.g. `"` or `\0`).
static const char *lexLocStringTok(const char *curPtr) {
while (char c = *curPtr++) {
// Check for various terminal characters.
if (StringRef("\"\n\v\f").contains(c))
return curPtr;
// Check for escape sequences.
if (c == '\\') {
// Check a few known escapes and \xx hex digits.
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
++curPtr;
else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
curPtr += 2;
else
return curPtr;
}
}
// If we hit this point, we've reached the end of the buffer. Update the end
// pointer to not point past the buffer.
return curPtr - 1;
}
SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
if (!loc.isValid())
return SMRange();
const char *curPtr = loc.getPointer();
// Check if this is a string token.
if (*curPtr == '"') {
curPtr = lexLocStringTok(curPtr + 1);
// Otherwise, default to handling an identifier.
} else {
// Return if the given character is a valid identifier character.
auto isIdentifierChar = [](char c) {
return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
};
while (*curPtr && isIdentifierChar(*(++curPtr)))
continue;
}
return SMRange(loc, SMLoc::getFromPointer(curPtr));
}
//===----------------------------------------------------------------------===//
// Populate State
void AsmParserState::initialize(Operation *topLevelOp) {
startOperationDefinition(topLevelOp->getName());
// If the top-level operation is a symbol table, push a new symbol scope.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
}
void AsmParserState::finalize(Operation *topLevelOp) {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
// If this operation is a symbol table, resolve any symbol uses.
if (partialOpDef.isSymbolTable()) {
impl->symbolTableOperations.emplace_back(
topLevelOp, std::move(partialOpDef.symbolTable));
}
impl->resolveSymbolUses();
}
void AsmParserState::startOperationDefinition(const OperationName &opName) {
impl->partialOperations.emplace_back(opName);
}
void AsmParserState::finalizeOperationDefinition(
Operation *op, SMRange nameLoc, SMLoc endLoc,
ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
// Build the full operation definition.
std::unique_ptr<OperationDefinition> def =
std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
for (auto &resultGroup : resultGroups)
def->resultGroups.emplace_back(resultGroup.first,
convertIdLocToRange(resultGroup.second));
impl->operationToIdx.try_emplace(op, impl->operations.size());
impl->operations.emplace_back(std::move(def));
// If this operation is a symbol table, resolve any symbol uses.
if (partialOpDef.isSymbolTable()) {
impl->symbolTableOperations.emplace_back(
op, std::move(partialOpDef.symbolTable));
}
}
void AsmParserState::startRegionDefinition() {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
// If the parent operation of this region is a symbol table, we also push a
// new symbol scope.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
}
void AsmParserState::finalizeRegionDefinition() {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
// If the parent operation of this region is a symbol table, pop the symbol
// scope for this region.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.pop_back();
}
void AsmParserState::addDefinition(Block *block, SMLoc location) {
auto it = impl->blocksToIdx.find(block);
if (it == impl->blocksToIdx.end()) {
impl->blocksToIdx.try_emplace(block, impl->blocks.size());
impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
block, convertIdLocToRange(location)));
return;
}
// If an entry already exists, this was a forward declaration that now has a
// proper definition.
impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
}
void AsmParserState::addDefinition(BlockArgument blockArg,
SMLoc location) {
auto it = impl->blocksToIdx.find(blockArg.getOwner());
assert(it != impl->blocksToIdx.end() &&
"expected owner block to have an entry");
BlockDefinition &def = *impl->blocks[it->second];
unsigned argIdx = blockArg.getArgNumber();
if (def.arguments.size() <= argIdx)
def.arguments.resize(argIdx + 1);
def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
}
void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
// Handle the case where the value is an operation result.
if (OpResult result = value.dyn_cast<OpResult>()) {
// Check to see if a definition for the parent operation has been recorded.
// If one hasn't, we treat the provided value as a placeholder value that
// will be refined further later.
Operation *parentOp = result.getOwner();
auto existingIt = impl->operationToIdx.find(parentOp);
if (existingIt == impl->operationToIdx.end()) {
impl->placeholderValueUses[value].append(locations.begin(),
locations.end());
return;
}
// If a definition does exist, locate the value's result group and add the
// use. The result groups are ordered by increasing start index, so we just
// need to find the last group that has a smaller/equal start index.
unsigned resultNo = result.getResultNumber();
OperationDefinition &def = *impl->operations[existingIt->second];
for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
if (resultNo >= resultGroup.startIndex) {
for (SMLoc loc : locations)
resultGroup.definition.uses.push_back(convertIdLocToRange(loc));
return;
}
}
llvm_unreachable("expected valid result group for value use");
}
// Otherwise, this is a block argument.
BlockArgument arg = value.cast<BlockArgument>();
auto existingIt = impl->blocksToIdx.find(arg.getOwner());
assert(existingIt != impl->blocksToIdx.end() &&
"expected valid block definition for block argument");
BlockDefinition &blockDef = *impl->blocks[existingIt->second];
SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
for (SMLoc loc : locations)
argDef.uses.emplace_back(convertIdLocToRange(loc));
}
void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
auto it = impl->blocksToIdx.find(block);
if (it == impl->blocksToIdx.end()) {
it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
}
BlockDefinition &def = *impl->blocks[it->second];
for (SMLoc loc : locations)
def.definition.uses.push_back(convertIdLocToRange(loc));
}
void AsmParserState::addUses(SymbolRefAttr refAttr,
ArrayRef<SMRange> locations) {
// Ignore this symbol if no scopes are active.
if (impl->symbolUseScopes.empty())
return;
assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
"expected the same number of references as provided locations");
(*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
locations.end());
}
void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
auto it = impl->placeholderValueUses.find(oldValue);
assert(it != impl->placeholderValueUses.end() &&
"expected `oldValue` to be a placeholder");
addUses(newValue, it->second);
impl->placeholderValueUses.erase(oldValue);
}
|