[Mlir-commits] [mlir] [MLIR][Linalg] Fix decompose generic op pattern (PR #117650)

Ian Wood llvmlistbot at llvm.org
Mon Nov 25 16:32:58 PST 2024


https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/117650

`DecomposeGenericByUnfoldingPermutation` can return `success()` even if the IR was not modified. This patch prevents indefinite looping by early returning with `failure()` if `isChanged == false`. 

>From 642d92768e1c03ef32bbf3d3a47a6dfc58344d40 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 26 Nov 2024 03:44:04 -0800
Subject: [PATCH] Return failure on no change

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 ...DecomposeGenericByUnfoldingPermutation.cpp | 30 ++++++++++---------
 1 file changed, 16 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..6828aae7fe44c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -223,21 +223,23 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
   }
 
-  if (isChanged) {
-    SmallVector<Value> operands = op->getOperands();
-    ValueRange operandsRef(operands);
-
-    auto newOp = rewriter.create<linalg::GenericOp>(
-        /*location=*/op.getLoc(),
-        /*resultTensorTypes=*/op->getResultTypes(),
-        /*inputs=*/newInitValues,
-        /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
-        /*indexingMaps=*/newMap,
-        /*iteratorTypes=*/op.getIteratorTypesArray());
-
-    newOp.getRegion().takeBody(op->getRegion(0));
-    rewriter.replaceOp(op, newOp->getResults());
+  if (!isChanged) {
+    return failure();
   }
+
+  SmallVector<Value> operands = op->getOperands();
+  ValueRange operandsRef(operands);
+
+  auto newOp = rewriter.create<linalg::GenericOp>(
+      /*location=*/op.getLoc(),
+      /*resultTensorTypes=*/op->getResultTypes(),
+      /*inputs=*/newInitValues,
+      /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
+      /*indexingMaps=*/newMap,
+      /*iteratorTypes=*/op.getIteratorTypesArray());
+
+  newOp.getRegion().takeBody(op->getRegion(0));
+  rewriter.replaceOp(op, newOp->getResults());
   return success();
 }
 



More information about the Mlir-commits mailing list