[Mlir-commits] [mlir] [MLIR][Linalg] Fix decompose generic op pattern (PR #117650)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 25 21:06:38 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Ian Wood (IanWood1)

<details>
<summary>Changes</summary>

`DecomposeGenericByUnfoldingPermutation` can return `success()` even if the IR was not modified. This patch prevents indefinite looping by early returning with `failure()` if `isChanged == false`. 

---
Full diff: https://github.com/llvm/llvm-project/pull/117650.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp (+16-14) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..6828aae7fe44c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -223,21 +223,23 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
   }
 
-  if (isChanged) {
-    SmallVector<Value> operands = op->getOperands();
-    ValueRange operandsRef(operands);
-
-    auto newOp = rewriter.create<linalg::GenericOp>(
-        /*location=*/op.getLoc(),
-        /*resultTensorTypes=*/op->getResultTypes(),
-        /*inputs=*/newInitValues,
-        /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
-        /*indexingMaps=*/newMap,
-        /*iteratorTypes=*/op.getIteratorTypesArray());
-
-    newOp.getRegion().takeBody(op->getRegion(0));
-    rewriter.replaceOp(op, newOp->getResults());
+  if (!isChanged) {
+    return failure();
   }
+
+  SmallVector<Value> operands = op->getOperands();
+  ValueRange operandsRef(operands);
+
+  auto newOp = rewriter.create<linalg::GenericOp>(
+      /*location=*/op.getLoc(),
+      /*resultTensorTypes=*/op->getResultTypes(),
+      /*inputs=*/newInitValues,
+      /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
+      /*indexingMaps=*/newMap,
+      /*iteratorTypes=*/op.getIteratorTypesArray());
+
+  newOp.getRegion().takeBody(op->getRegion(0));
+  rewriter.replaceOp(op, newOp->getResults());
   return success();
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/117650


More information about the Mlir-commits mailing list