[Mlir-commits] [mlir] 26f93d9 - [mlir] OperationFolder: fix crash in creation of single-result-ops with in-place folds
Alex Zinenko
llvmlistbot at llvm.org
Wed May 6 11:40:43 PDT 2020
Author: Alex Zinenko
Date: 2020-05-06T20:40:32+02:00
New Revision: 26f93d9f373a1e638b621391ef7ba9bdf7b79044
URL: https://github.com/llvm/llvm-project/commit/26f93d9f373a1e638b621391ef7ba9bdf7b79044
DIFF: https://github.com/llvm/llvm-project/commit/26f93d9f373a1e638b621391ef7ba9bdf7b79044.diff
LOG: [mlir] OperationFolder: fix crash in creation of single-result-ops with in-place folds
When the folding is performed in place, the `::fold` function does not populate
its `results` argument to indicate that. (In the folding hook for single-result
operations, the result of the original operation is expected to be returned,
but it is then ignored by the wrapper.) `OperationFolder::create` would
erronously rely on the _operation_ having zero results instead of on the
_folding_ producing zero new results to populate the list of results with those
of the original operation. This would lead to a crash for single-result ops
with in-place folds where the first result is accessed uncondtionally because
the list of results was not properly populated. Use the list of values produced
by the folding instead.
Differential Revision: https://reviews.llvm.org/D79497
Added:
Modified:
mlir/include/mlir/Transforms/FoldUtils.h
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index f8c678d11b6a..d427f0b2406d 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -77,14 +77,14 @@ class OperationFolder {
void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
Location location, Args &&... args) {
// The op needs to be inserted only if the fold (below) fails, or the number
- // of results of the op is zero (which is treated as an in-place
- // fold). Using create methods of the builder will insert the op, so not
- // using it here.
+ // of results produced by the successful folding is zero (which is treated
+ // as an in-place fold). Using create methods of the builder will insert the
+ // op, so not using it here.
OperationState state(location, OpTy::getOperationName());
OpTy::build(builder, state, std::forward<Args>(args)...);
Operation *op = Operation::create(state);
- if (failed(tryToFold(builder, op, results)) || op->getNumResults() == 0) {
+ if (failed(tryToFold(builder, op, results)) || results.empty()) {
builder.insert(op);
results.assign(op->result_begin(), op->result_end());
return;
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 1a40f9989eae..fb7acde61b98 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -323,6 +323,15 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
return success();
}
+OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 1);
+ if (operands.front()) {
+ setAttr("attr", operands.front());
+ return getResult();
+ }
+ return {};
+}
+
LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f9140f2e9bdc..3e49a1d9a1f9 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -734,6 +734,17 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
let results = (outs I32);
}
+def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> {
+ let arguments = (ins I32);
+ let results = (outs I32);
+}
+
+def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
+ let arguments = (ins I32:$op, I32Attr:$attr);
+ let results = (outs I32);
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// Test Patterns (Symbol Binding)
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index deb1cf5bb075..7c91d5f6c682 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -8,9 +8,11 @@
#include "TestDialect.h"
#include "mlir/Conversion/StandardToStandard/StandardToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
@@ -39,13 +41,36 @@ namespace {
//===----------------------------------------------------------------------===//
namespace {
+struct FoldingPattern : public RewritePattern {
+public:
+ FoldingPattern(MLIRContext *context)
+ : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
+ /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Exercice OperationFolder API for a single-result operation that is folded
+ // upon construction. The operation being created through the folder has an
+ // in-place folder, and it should be still present in the output.
+ // Furthermore, the folder should not crash when attempting to recover the
+ // (unchanged) opeation result.
+ OperationFolder folder(op->getContext());
+ Value result = folder.create<TestOpInPlaceFold>(
+ rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
+ rewriter.getI32IntegerAttr(0));
+ assert(result);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
void runOnFunction() override {
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
// Verify named pattern is generated with expected name.
- patterns.insert<TestNamedPatternRule>(&getContext());
+ patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
More information about the Mlir-commits
mailing list