[Mlir-commits] [mlir] [MLIR][Linalg] Fix decompose generic op pattern (PR #117650)
Ian Wood
llvmlistbot at llvm.org
Mon Nov 25 16:32:58 PST 2024
https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/117650
`DecomposeGenericByUnfoldingPermutation` can return `success()` even if the IR was not modified. This patch prevents indefinite looping by early returning with `failure()` if `isChanged == false`.
>From 642d92768e1c03ef32bbf3d3a47a6dfc58344d40 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 26 Nov 2024 03:44:04 -0800
Subject: [PATCH] Return failure on no change
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
...DecomposeGenericByUnfoldingPermutation.cpp | 30 ++++++++++---------
1 file changed, 16 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..6828aae7fe44c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -223,21 +223,23 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
}
- if (isChanged) {
- SmallVector<Value> operands = op->getOperands();
- ValueRange operandsRef(operands);
-
- auto newOp = rewriter.create<linalg::GenericOp>(
- /*location=*/op.getLoc(),
- /*resultTensorTypes=*/op->getResultTypes(),
- /*inputs=*/newInitValues,
- /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
- /*indexingMaps=*/newMap,
- /*iteratorTypes=*/op.getIteratorTypesArray());
-
- newOp.getRegion().takeBody(op->getRegion(0));
- rewriter.replaceOp(op, newOp->getResults());
+ if (!isChanged) {
+ return failure();
}
+
+ SmallVector<Value> operands = op->getOperands();
+ ValueRange operandsRef(operands);
+
+ auto newOp = rewriter.create<linalg::GenericOp>(
+ /*location=*/op.getLoc(),
+ /*resultTensorTypes=*/op->getResultTypes(),
+ /*inputs=*/newInitValues,
+ /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
+ /*indexingMaps=*/newMap,
+ /*iteratorTypes=*/op.getIteratorTypesArray());
+
+ newOp.getRegion().takeBody(op->getRegion(0));
+ rewriter.replaceOp(op, newOp->getResults());
return success();
}
More information about the Mlir-commits
mailing list