[Mlir-commits] [mlir] [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (NFC) (PR #140083)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 15 08:33:05 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Momchil Velikov (momchil-velikov)
<details>
<summary>Changes</summary>
This patch contains two minor changes, which I believe were the original author's intent.
* when folding `transpose(broadcast(x))` emit `broadcast(x)` instead of `broadcast(broadcast(x))`. The latter causes intermittent verifier failures, e.g.
```
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
%0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
%1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
"func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
```
* when checking permutation groups the variable `low` was set just once to zero, thus checking was quadratic. It looks the intent was for `low` to track the beginning of each dimension groups. (Nevertheless the check was correct).
---
Full diff: https://github.com/llvm/llvm-project/pull/140083.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+3-2)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..7ae43b64a5deb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
bool inputIsScalar = !inputType;
if (inputIsScalar) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- transpose.getVector());
+ broadcast.getSource());
return success();
}
@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
transpose, "permutation not local to group");
}
}
+ low = high;
}
}
@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
"not broadcastable directly to transpose output");
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- transpose.getVector());
+ broadcast.getSource());
return success();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/140083
More information about the Mlir-commits
mailing list