summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2021-12-14 21:21:15 +0900
committerMatthias Springer <springerm@google.com>2021-12-14 21:29:43 +0900
commit81eece7f2693fef4b85cc5a93c5894255b651e56 (patch)
tree4ef47a37b43eebf7a032fe5fa0ab3fd59ce825fc
parent6847379e895bfe4c6294f09f4e0d2c3bd949846e (diff)
downloadllvm-81eece7f2693fef4b85cc5a93c5894255b651e56.tar.gz
[mlir][linalg][bufferize] Debug output as IR attributes
Instead of printing analysis debug information to stderr, annotate the IR. This makes it easier to understand decisions made by the analysis, especially in larger input IR. Differential Revision: https://reviews.llvm.org/D115575
-rw-r--r--mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h4
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Passes.td3
-rw-r--r--mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp139
-rw-r--r--mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp1
5 files changed, 45 insertions, 113 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index df327aa5e243..36e539a42992 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -140,6 +140,10 @@ struct BufferizationOptions {
/// checking the results of the analysis) and post analysis steps.
bool testAnalysisOnly = false;
+ /// If set to `true`, the IR is annotated with details about RaW conflicts.
+ /// For debugging only. Should be used together with `testAnalysisOnly`.
+ bool printConflicts = false;
+
/// Registered post analysis steps.
PostAnalysisStepList postAnalysisSteps;
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index cde61ff98ab8..504bc562148f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -39,6 +39,9 @@ def LinalgComprehensiveModuleBufferize :
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Only runs inplaceability analysis (for testing purposes only)">,
+ Option<"printConflicts", "print-conflicts", "bool",
+ /*default=*/"false",
+ "Annotates IR with RaW conflicts. Requires test-analysis-only.">,
Option<"allowReturnMemref", "allow-return-memref", "bool",
/*default=*/"false",
"Allows the return of memrefs (for testing purposes only)">,
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6cfae7cc702e..d6c36f3b98a5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -117,25 +117,12 @@
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/FormatVariadic.h"
-
-#define DEBUG_TYPE "comprehensive-module-bufferize"
using namespace mlir;
using namespace linalg;
using namespace tensor;
using namespace comprehensive_bufferize;
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X)
-
-// Forward declarations.
-#ifndef NDEBUG
-static std::string printOperationInfo(Operation *, bool prefix = true);
-static std::string printValueInfo(Value, bool prefix = true);
-#endif
-
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
//===----------------------------------------------------------------------===//
@@ -164,65 +151,12 @@ static void setInPlaceOpResult(OpResult opResult, bool inPlace) {
attr ? SmallVector<StringRef>(
llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()))
: SmallVector<StringRef>(op->getNumResults(), "false");
- LDBG("->set inPlace=" << inPlace << " <- #" << opResult.getResultNumber()
- << ": " << printOperationInfo(op) << "\n");
inPlaceVector[opResult.getResultNumber()] = inPlace ? "true" : "false";
op->setAttr(kInPlaceResultsAttrName,
OpBuilder(op).getStrArrayAttr(inPlaceVector));
}
//===----------------------------------------------------------------------===//
-// Printing helpers.
-//===----------------------------------------------------------------------===//
-
-#ifndef NDEBUG
-/// Helper method printing the bufferization information of a buffer / tensor.
-static void printTensorOrBufferInfo(std::string prefix, Value value,
- AsmState &state, llvm::raw_ostream &os) {
- if (!value.getType().isa<ShapedType>())
- return;
- os << prefix;
- value.printAsOperand(os, state);
- os << " : " << value.getType();
-}
-
-/// Print the operation name and bufferization information.
-static std::string printOperationInfo(Operation *op, bool prefix) {
- std::string result;
- llvm::raw_string_ostream os(result);
- AsmState state(op->getParentOfType<mlir::FuncOp>());
- StringRef tab = prefix ? "\n[" DEBUG_TYPE "]\t" : "";
- os << tab << op->getName();
- SmallVector<Value> shapedOperands;
- for (OpOperand &opOperand : op->getOpOperands()) {
- std::string prefix =
- llvm::formatv("{0} -> #{1} ", tab, opOperand.getOperandNumber());
- printTensorOrBufferInfo(prefix, opOperand.get(), state, os);
- }
- for (OpResult opResult : op->getOpResults()) {
- std::string prefix =
- llvm::formatv("{0} <- #{1} ", tab, opResult.getResultNumber());
- printTensorOrBufferInfo(prefix, opResult, state, os);
- }
- return result;
-}
-
-/// Print the bufferization information for the defining op or block argument.
-static std::string printValueInfo(Value value, bool prefix) {
- auto *op = value.getDefiningOp();
- if (op)
- return printOperationInfo(op, prefix);
- // Print the block argument bufferization information.
- std::string result;
- llvm::raw_string_ostream os(result);
- AsmState state(value.getParentRegion()->getParentOfType<mlir::FuncOp>());
- os << value;
- printTensorOrBufferInfo("\n\t - ", value, state, os);
- return result;
-}
-#endif
-
-//===----------------------------------------------------------------------===//
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
@@ -251,7 +185,6 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
static bool aliasesNonWritableBuffer(Value value,
const BufferizationAliasInfo &aliasInfo,
BufferizationState &state) {
- LDBG("WRITABILITY ANALYSIS FOR " << printValueInfo(value) << "\n");
bool foundNonWritableBuffer = false;
aliasInfo.applyOnAliases(value, [&](Value v) {
// Query BufferizableOpInterface to see if the OpResult is writable.
@@ -270,11 +203,6 @@ static bool aliasesNonWritableBuffer(Value value,
foundNonWritableBuffer = true;
});
- if (foundNonWritableBuffer)
- LDBG("--> NON WRITABLE\n");
- else
- LDBG("--> WRITABLE\n");
-
return foundNonWritableBuffer;
}
@@ -282,23 +210,15 @@ static bool aliasesNonWritableBuffer(Value value,
/// to some buffer write.
static bool aliasesInPlaceWrite(Value value,
const BufferizationAliasInfo &aliasInfo) {
- LDBG("----Start aliasesInPlaceWrite\n");
- LDBG("-------for : " << printValueInfo(value) << '\n');
bool foundInplaceWrite = false;
aliasInfo.applyOnAliases(value, [&](Value v) {
for (auto &use : v.getUses()) {
if (isInplaceMemoryWrite(use, aliasInfo)) {
- LDBG("-----------wants to bufferize to inPlace write: "
- << printOperationInfo(use.getOwner()) << '\n');
foundInplaceWrite = true;
return;
}
}
});
-
- if (!foundInplaceWrite)
- LDBG("----------->does not alias an inplace write\n");
-
return foundInplaceWrite;
}
@@ -317,6 +237,39 @@ static bool happensBefore(Operation *a, Operation *b,
return false;
}
+/// Annotate IR with details about the detected RaW conflict.
+static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
+ Value lastWrite) {
+ static uint64_t counter = 0;
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ OpBuilder b(conflictingWritingOp->getContext());
+ std::string id = "C_" + std::to_string(counter++);
+
+ std::string conflictingWriteAttr =
+ id +
+ "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
+ "]";
+ conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
+
+ std::string readAttr =
+ id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
+ readingOp->setAttr(readAttr, b.getUnitAttr());
+
+ if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
+ std::string lastWriteAttr = id + "[LAST-WRITE: result " +
+ std::to_string(opResult.getResultNumber()) +
+ "]";
+ opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+ } else {
+ auto bbArg = lastWrite.cast<BlockArgument>();
+ std::string lastWriteAttr =
+ id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
+ bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+ }
+}
+
/// Given sets of uses and writes, return true if there is a RaW conflict under
/// the assumption that all given reads/writes alias the same buffer and that
/// all given writes bufferize inplace.
@@ -351,14 +304,6 @@ static bool hasReadAfterWriteInterference(
// met for uConflictingWrite to be an actual conflict.
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
- // Print some debug info.
- LDBG("Found potential conflict:\n");
- LDBG("READ = #" << uRead->getOperandNumber() << " of "
- << printOperationInfo(readingOp) << "\n");
- LDBG("CONFLICTING WRITE = #"
- << uConflictingWrite->getOperandNumber() << " of "
- << printOperationInfo(conflictingWritingOp) << "\n");
-
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
// write is not visible when reading.
if (happensBefore(readingOp, conflictingWritingOp, domInfo))
@@ -387,8 +332,6 @@ static bool hasReadAfterWriteInterference(
if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
continue;
- LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
-
// No conflict if the conflicting write happens before the last
// write.
if (Operation *writingOp = lastWrite.getDefiningOp()) {
@@ -413,12 +356,14 @@ static bool hasReadAfterWriteInterference(
continue;
// All requirements are met. Conflict found!
- LDBG("CONFLICT CONFIRMED!\n\n");
+
+ if (options.printConflicts)
+ annotateConflict(uRead, uConflictingWrite, lastWrite);
+
return true;
}
}
- LDBG("NOT A CONFLICT!\n\n");
return false;
}
@@ -530,7 +475,6 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
if (!hasWrite)
return false;
- LDBG("->the corresponding buffer is not writeable\n");
return true;
}
@@ -548,13 +492,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
"operand and result do not match");
#endif // NDEBUG
- int64_t resultNumber = result.getResultNumber();
- (void)resultNumber;
- LDBG('\n');
- LDBG("Inplace analysis for <- #" << resultNumber << " -> #"
- << operand.getOperandNumber() << " in "
- << printValueInfo(result) << '\n');
-
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) ||
wouldCreateReadAfterWriteInterference(operand, result, domInfo, state,
@@ -565,8 +502,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
else
aliasInfo.bufferizeInPlace(result, operand);
- LDBG("Done inplace analysis for result #" << resultNumber << '\n');
-
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 15fb9abd759e..cb0d75cb366f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -14,12 +14,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/FormatVariadic.h"
-
-#define DEBUG_TYPE "comprehensive-module-bufferize"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X)
using namespace mlir;
using namespace linalg;
@@ -181,7 +175,6 @@ static FunctionType getOrCreateBufferizedFunctionType(
auto it2 = bufferizedFunctionTypes.try_emplace(
funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
resultTypes));
- LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n");
return it2.first->second;
}
@@ -227,7 +220,6 @@ static void equivalenceAnalysis(FuncOp funcOp,
/// future.
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
BufferizationState &state) {
- LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
// If nothing to do then we are done.
@@ -261,7 +253,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
funcOp, funcOp.getType().getInputs(), TypeRange{},
moduleState.bufferizedFunctionTypes);
funcOp.setType(bufferizedFuncType);
- LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
return success();
}
@@ -341,8 +332,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
// 4. Rewrite the FuncOp type to buffer form.
funcOp.setType(bufferizedFuncType);
- LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
-
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 58f5845f0bf4..2ce198b86bce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -89,6 +89,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
options.allowUnknownOps = allowUnknownOps;
options.analysisFuzzerSeed = analysisFuzzerSeed;
options.testAnalysisOnly = testAnalysisOnly;
+ options.printConflicts = printConflicts;
// Enable InitTensorOp elimination.
options.addPostAnalysisStep<