[Mlir-commits] [mlir] [mlir][Linalg] Bugfix in decompose generic by unfolding permutation (PR #126737)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 13 04:13:52 PST 2025


https://github.com/gdehame updated https://github.com/llvm/llvm-project/pull/126737

>From 6bb4f94ecc5c6ac3849197c8650fada7ad61c537 Mon Sep 17 00:00:00 2001
From: gdehame <gabrieldehame at gmail.com>
Date: Tue, 11 Feb 2025 15:24:47 +0100
Subject: [PATCH 1/2] [mlir][Linalg] Bugfix in decompose generic by unfolding
 permutation

The pattern was returning success() by default which made the greedy pattern application act as if the IR was modified and even though nothing was changed and thus it can prevent it from converging for no legitimate reason.

The patch makes the rewrite pattern return failure() by default and success() if and only if the IR changed
---
 .../Transforms/DecomposeGenericByUnfoldingPermutation.cpp      | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf1097..281a248681792 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -237,8 +237,9 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
 
     newOp.getRegion().takeBody(op->getRegion(0));
     rewriter.replaceOp(op, newOp->getResults());
+    return success();
   }
-  return success();
+  return failure();
 }
 
 } // namespace

>From 57d7bbfcf0410730fecf97d6419e9b13b974a182 Mon Sep 17 00:00:00 2001
From: gdehame <gabrieldehame at gmail.com>
Date: Thu, 13 Feb 2025 13:12:47 +0100
Subject: [PATCH 2/2] Changed the patch to an early exit

---
 ...DecomposeGenericByUnfoldingPermutation.cpp | 33 +++++++++----------
 1 file changed, 16 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 281a248681792..96b581b24064a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -223,23 +223,22 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
   }
 
-  if (isChanged) {
-    SmallVector<Value> operands = op->getOperands();
-    ValueRange operandsRef(operands);
-
-    auto newOp = rewriter.create<linalg::GenericOp>(
-        /*location=*/op.getLoc(),
-        /*resultTensorTypes=*/op->getResultTypes(),
-        /*inputs=*/newInitValues,
-        /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
-        /*indexingMaps=*/newMap,
-        /*iteratorTypes=*/op.getIteratorTypesArray());
-
-    newOp.getRegion().takeBody(op->getRegion(0));
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
-  }
-  return failure();
+  if (!isChanged)
+    return failure();
+
+  SmallVector<Value> operands = op->getOperands();
+  ValueRange operandsRef(operands);
+  
+  auto newOp = rewriter.create<linalg::GenericOp>(
+      /*location=*/op.getLoc(),
+      /*resultTensorTypes=*/op->getResultTypes(),
+      /*inputs=*/newInitValues,
+      /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
+      /*indexingMaps=*/newMap,
+      /*iteratorTypes=*/op.getIteratorTypesArray());
+  newOp.getRegion().takeBody(op->getRegion(0));
+  rewriter.replaceOp(op, newOp->getResults());
+  return success();
 }
 
 } // namespace



More information about the Mlir-commits mailing list