[Mlir-commits] [mlir] [MLIR] Harmonize the behavior of the folding API functions (PR #88508)
Christian Ulmann
llvmlistbot at llvm.org
Mon Apr 22 22:47:16 PDT 2024
https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/88508
>From 98747419d51c4f397bbde5146895c2e9f70d93c5 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 12 Apr 2024 12:35:29 +0000
Subject: [PATCH 1/2] [MLIR] Harmonize the behavior of the folding API
functions
This commit changes `OpBuilder::tryFold` to behave more similarly to
`Operation::fold`. Concretely, this ensures that even an in-place fold
returns `success`. This is necessary to fix a bug in the dialect
conversion that occurred when an in-place folding made an operation
legal. The dialect conversion infrastructure did not check if the result
of an in-place folding legalized the operation and just went ahead and
tried to apply pattern anyways.
---
mlir/include/mlir/IR/Builders.h | 16 +++++++++++----
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 ++-
mlir/lib/IR/Builders.cpp | 20 ++++++++++---------
.../Transforms/Utils/DialectConversion.cpp | 4 ++++
mlir/test/Transforms/test-legalizer.mlir | 10 ++++++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 13 ++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 6 ++++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 6 ++++++
8 files changed, 64 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3beade017d1ab9..e74505e5dbfdf4 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -517,7 +517,7 @@ class OpBuilder : public Builder {
/// Create an operation of specific op type at the current insertion point,
/// and immediately try to fold it. This functions populates 'results' with
- /// the results after folding the operation.
+ /// the results of the operation.
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value> &results, Location location,
Args &&...args) {
@@ -530,10 +530,17 @@ class OpBuilder : public Builder {
if (block)
block->getOperations().insert(insertPoint, op);
- // Fold the operation. If successful erase it, otherwise notify.
- if (succeeded(tryFold(op, results)))
+ // Attempt to fold the operation.
+ if (succeeded(tryFold(op, results)) && !results.empty()) {
+ // Erase the operation, if the fold removed the need for this operation.
+ // Note: The fold already populated the results in this case.
op->erase();
- else if (block && listener)
+ return;
+ }
+
+ ResultRange opResults = op->getResults();
+ results.assign(opResults.begin(), opResults.end());
+ if (block && listener)
listener->notifyOperationInserted(op, /*previous=*/{});
}
@@ -561,6 +568,7 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
+ /// If the fold was in-place, `results` will not be filled.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 78ff24dae68b4c..4e06b9c127e76a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2831,7 +2831,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
/// Folds a cast op that can be chained.
template <typename T>
-static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+static OpFoldResult foldChainableCast(T castOp,
+ typename T::FoldAdaptor adaptor) {
// cast(x : T0, T0) -> x
if (castOp.getArg().getType() == castOp.getType())
return castOp.getArg();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 18ca3c332e0204..36e17609eab609 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-/// Attempts to fold the given operation and places new results within
-/// 'results'. Returns success if the operation was folded, failure otherwise.
-/// Note: This function does not erase the operation on a successful fold.
LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
+ assert(results.empty());
ResultRange opResults = op->getResults();
results.reserve(opResults.size());
auto cleanupFailure = [&] {
- results.assign(opResults.begin(), opResults.end());
+ results.clear();
return failure();
};
@@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Try to fold the operation.
SmallVector<OpFoldResult, 4> foldResults;
- if (failed(op->fold(foldResults)) || foldResults.empty())
+ if (failed(op->fold(foldResults)))
return cleanupFailure();
+ // An in-place fold does not require generation of any constants.
+ if (foldResults.empty())
+ return success();
+
// A temporary builder used for creating constants during folding.
OpBuilder cstBuilder(context);
SmallVector<Operation *, 1> generatedConstants;
// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
- for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
- Type expectedType = std::get<1>(it);
+ for (auto [foldResult, expectedType] :
+ llvm::zip_equal(foldResults, opResults.getTypes())) {
// Normal values get pushed back directly.
- if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
+ if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
results.push_back(value);
continue;
}
@@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
return cleanupFailure();
// Ask the dialect to materialize a constant operation for this value.
- Attribute attr = std::get<0>(it).get<Attribute>();
+ Attribute attr = foldResult.get<Attribute>();
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
op->getLoc());
if (!constOp) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d85938847c776c..d407d60334c70d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+ // An empty list of replacement values indicates that the fold was in-place.
+ // As the operation changed, a new legalization needs to be attempted.
+ if (replacementValues.empty())
+ return legalize(op, rewriter);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d552f0346644b3..7530b300d57b8b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
}) : (i64) -> (i64)
"test.invalid"(%0) : (i64) -> ()
}
+
+// -----
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK: op_in_place_self_fold
+ // CHECK-SAME: folded = true
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 7263774ca158eb..08df2e5e12286d 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -825,6 +825,19 @@ LogicalResult CompareOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TestOpInPlaceSelfFold
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
+ if (!getFolded()) {
+ // The folder adds the "folded" if not present.
+ setFolded(true);
+ return getResult();
+ }
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// TestOpFoldWithFoldAdaptor
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b641b3da719c78..ef5fd9e7e520b0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
let hasFolder = 1;
}
+def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
+ let arguments = (ins OptionalAttr<BoolAttr>:$folded);
+ let results = (outs I32);
+ let hasFolder = 1;
+}
+
// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
let results = (outs Variadic<I1>);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 0c1731ba5f07c8..0c09136d63715a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1168,6 +1168,12 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
[](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
+ // Create a dynamically legal rule that can only be legalized by folding it.
+ target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
+ [](TestOpInPlaceSelfFold op) {
+ return op.getProperties().folded != nullptr;
+ });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
>From c4e1f796cb032bc209e710b6793495a78c52336a Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 22 Apr 2024 20:13:14 +0000
Subject: [PATCH 2/2] address review comments
---
mlir/include/mlir/IR/Builders.h | 2 +-
mlir/lib/IR/Builders.cpp | 2 +-
mlir/test/Transforms/test-legalizer.mlir | 2 +-
mlir/test/lib/Dialect/Test/TestOps.td | 2 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 4 +---
5 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index e74505e5dbfdf4..0d5fa719d0dee2 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -567,7 +567,7 @@ class OpBuilder : public Builder {
}
/// Attempts to fold the given operation and places new results within
- /// 'results'. Returns success if the operation was folded, failure otherwise.
+ /// `results`. Returns success if the operation was folded, failure otherwise.
/// If the fold was in-place, `results` will not be filled.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 36e17609eab609..d49f69a7b7ae6b 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -478,7 +478,7 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
- assert(results.empty());
+ assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
results.reserve(opResults.size());
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 7530b300d57b8b..65c947198e06e0 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -433,7 +433,7 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
// CHECK-LABEL: @fold_legalization
func.func @fold_legalization() -> i32 {
// CHECK: op_in_place_self_fold
- // CHECK-SAME: folded = true
+ // CHECK-SAME: folded
%1 = "test.op_in_place_self_fold"() : () -> (i32)
"test.return"(%1) : (i32) -> ()
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ef5fd9e7e520b0..5352d574ac3943 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1352,7 +1352,7 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
}
def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
- let arguments = (ins OptionalAttr<BoolAttr>:$folded);
+ let arguments = (ins UnitAttr:$folded);
let results = (outs I32);
let hasFolder = 1;
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 0c09136d63715a..f9f7d4eacf948a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1170,9 +1170,7 @@ struct TestLegalizePatternDriver
// Create a dynamically legal rule that can only be legalized by folding it.
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
- [](TestOpInPlaceSelfFold op) {
- return op.getProperties().folded != nullptr;
- });
+ [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
More information about the Mlir-commits
mailing list