[Mlir-commits] [mlir] [MLIR] Harmonize the behavior of the folding API functions (PR #88508)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 12 05:56:06 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-llvm

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit changes `OpBuilder::tryFold` to behave more similarly to `Operation::fold`. Concretely, this ensures that even an in-place fold returns `success`.
This is necessary to fix a bug in the dialect conversion that occurred when an in-place folding made an operation legal. The dialect conversion infrastructure did not check if the result of an in-place folding legalized the operation and just went ahead and tried to apply pattern anyways.

The added test contains a simplified version of a breakage we observed downstream.

---
Full diff: https://github.com/llvm/llvm-project/pull/88508.diff


8 Files Affected:

- (modified) mlir/include/mlir/IR/Builders.h (+12-4) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+2-1) 
- (modified) mlir/lib/IR/Builders.cpp (+11-9) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+4) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+10) 
- (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+9) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+6) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+6) 


``````````diff
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3beade017d1ab9..e74505e5dbfdf4 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -517,7 +517,7 @@ class OpBuilder : public Builder {
 
   /// Create an operation of specific op type at the current insertion point,
   /// and immediately try to fold it. This functions populates 'results' with
-  /// the results after folding the operation.
+  /// the results of the operation.
   template <typename OpTy, typename... Args>
   void createOrFold(SmallVectorImpl<Value> &results, Location location,
                     Args &&...args) {
@@ -530,10 +530,17 @@ class OpBuilder : public Builder {
     if (block)
       block->getOperations().insert(insertPoint, op);
 
-    // Fold the operation. If successful erase it, otherwise notify.
-    if (succeeded(tryFold(op, results)))
+    // Attempt to fold the operation.
+    if (succeeded(tryFold(op, results)) && !results.empty()) {
+      // Erase the operation, if the fold removed the need for this operation.
+      // Note: The fold already populated the results in this case.
       op->erase();
-    else if (block && listener)
+      return;
+    }
+
+    ResultRange opResults = op->getResults();
+    results.assign(opResults.begin(), opResults.end());
+    if (block && listener)
       listener->notifyOperationInserted(op, /*previous=*/{});
   }
 
@@ -561,6 +568,7 @@ class OpBuilder : public Builder {
 
   /// Attempts to fold the given operation and places new results within
   /// 'results'. Returns success if the operation was folded, failure otherwise.
+  /// If the fold was in-place, `results` will not be filled.
   /// Note: This function does not erase the operation on a successful fold.
   LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f90240a67dcc5f..0fff06df39c1e7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2763,7 +2763,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
 
 /// Folds a cast op that can be chained.
 template <typename T>
-static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+static OpFoldResult foldChainableCast(T castOp,
+                                      typename T::FoldAdaptor adaptor) {
   // cast(x : T0, T0) -> x
   if (castOp.getArg().getType() == castOp.getType())
     return castOp.getArg();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 18ca3c332e0204..36e17609eab609 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
   return create(state);
 }
 
-/// Attempts to fold the given operation and places new results within
-/// 'results'. Returns success if the operation was folded, failure otherwise.
-/// Note: This function does not erase the operation on a successful fold.
 LogicalResult OpBuilder::tryFold(Operation *op,
                                  SmallVectorImpl<Value> &results) {
+  assert(results.empty());
   ResultRange opResults = op->getResults();
 
   results.reserve(opResults.size());
   auto cleanupFailure = [&] {
-    results.assign(opResults.begin(), opResults.end());
+    results.clear();
     return failure();
   };
 
@@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
   // Try to fold the operation.
   SmallVector<OpFoldResult, 4> foldResults;
-  if (failed(op->fold(foldResults)) || foldResults.empty())
+  if (failed(op->fold(foldResults)))
     return cleanupFailure();
 
+  // An in-place fold does not require generation of any constants.
+  if (foldResults.empty())
+    return success();
+
   // A temporary builder used for creating constants during folding.
   OpBuilder cstBuilder(context);
   SmallVector<Operation *, 1> generatedConstants;
 
   // Populate the results with the folded results.
   Dialect *dialect = op->getDialect();
-  for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
-    Type expectedType = std::get<1>(it);
+  for (auto [foldResult, expectedType] :
+       llvm::zip_equal(foldResults, opResults.getTypes())) {
 
     // Normal values get pushed back directly.
-    if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
+    if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
       results.push_back(value);
       continue;
     }
@@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
       return cleanupFailure();
 
     // Ask the dialect to materialize a constant operation for this value.
-    Attribute attr = std::get<0>(it).get<Attribute>();
+    Attribute attr = foldResult.get<Attribute>();
     auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
                                                  op->getLoc());
     if (!constOp) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8671c1008902a0..18d6f7daa4bea7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
     return failure();
   }
+  // An empty list of replacement values indicates that the fold was in-place.
+  // As the operation changed, a new legalization needs to be attempted.
+  if (replacementValues.empty())
+    return legalize(op, rewriter);
 
   // Insert a replacement for 'op' with the folded replacement values.
   rewriter.replaceOp(op, replacementValues);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d552f0346644b3..7530b300d57b8b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
   }) : (i64) -> (i64)
   "test.invalid"(%0) : (i64) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+  // CHECK: op_in_place_self_fold
+  // CHECK-SAME: folded = true
+  %1 = "test.op_in_place_self_fold"() : () -> (i32)
+  "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 380c74a47e509a..becd0d68bf1d3e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -568,6 +568,15 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
+  if (!getProperties().folded) {
+    // The folder adds the "folded" if not present.
+    getProperties().folded = BoolAttr::get(getContext(), true);
+    return getResult();
+  }
+  return {};
+}
+
 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
   int64_t sum = 0;
   if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e6c3601d08dad0..663064d51f1bbe 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
   let hasFolder = 1;
 }
 
+def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
+  let arguments = (ins OptionalAttr<BoolAttr>:$folded);
+  let results = (outs I32);
+  let hasFolder = 1;
+}
+
 // Test op that simply returns success.
 def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
   let results = (outs Variadic<I1>);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 76dc825fe44515..285e39dd9016e1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1167,6 +1167,12 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
 
+    // Create a dynamically legal rule that can only be legalized by folding it.
+    target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
+        [](TestOpInPlaceSelfFold op) {
+          return op.getProperties().folded != nullptr;
+        });
+
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;

``````````

</details>


https://github.com/llvm/llvm-project/pull/88508


More information about the Mlir-commits mailing list