[Mlir-commits] [mlir] [MLIR][Linalg] Fix decompose generic op pattern (PR #117650)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 25 21:06:38 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Ian Wood (IanWood1)
<details>
<summary>Changes</summary>
`DecomposeGenericByUnfoldingPermutation` can return `success()` even if the IR was not modified. This patch prevents indefinite looping by early returning with `failure()` if `isChanged == false`.
---
Full diff: https://github.com/llvm/llvm-project/pull/117650.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp (+16-14)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf10976..6828aae7fe44c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -223,21 +223,23 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
}
- if (isChanged) {
- SmallVector<Value> operands = op->getOperands();
- ValueRange operandsRef(operands);
-
- auto newOp = rewriter.create<linalg::GenericOp>(
- /*location=*/op.getLoc(),
- /*resultTensorTypes=*/op->getResultTypes(),
- /*inputs=*/newInitValues,
- /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
- /*indexingMaps=*/newMap,
- /*iteratorTypes=*/op.getIteratorTypesArray());
-
- newOp.getRegion().takeBody(op->getRegion(0));
- rewriter.replaceOp(op, newOp->getResults());
+ if (!isChanged) {
+ return failure();
}
+
+ SmallVector<Value> operands = op->getOperands();
+ ValueRange operandsRef(operands);
+
+ auto newOp = rewriter.create<linalg::GenericOp>(
+ /*location=*/op.getLoc(),
+ /*resultTensorTypes=*/op->getResultTypes(),
+ /*inputs=*/newInitValues,
+ /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
+ /*indexingMaps=*/newMap,
+ /*iteratorTypes=*/op.getIteratorTypesArray());
+
+ newOp.getRegion().takeBody(op->getRegion(0));
+ rewriter.replaceOp(op, newOp->getResults());
return success();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/117650
More information about the Mlir-commits
mailing list