[Mlir-commits] [mlir] [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (NFC) (PR #140083)
Momchil Velikov
llvmlistbot at llvm.org
Thu May 15 08:32:33 PDT 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/140083
This patch contains to minor changes, which I believe were the original author's intent.
* when folding `transpose(broadcast(x))` emit `broadcast(x)` instead of `broadcast(broadcast(x))`. The latter causes intermittent verifier failures, e.g.
```
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
%0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
%1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
"func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
```
* when checking permutation groups the variable `low` was set just once to zero, thus checking was quadratic. It looks the intent was for `low` to track the beginning of each dimension groups. (Nevertheless the check was correct).
>From db8b5d9062b26efc40d417cfeae8828c8317da1d Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 15 May 2025 15:14:58 +0000
Subject: [PATCH] [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (NFC)
This patch contains to minor changes, which I believe were
the original author's intent.
* when folding `transpose(broadcast(x))` emit `broadcast(x)`
instead of `broadcast(broadca(x))`. The later causes intermittent
verifier failures, e.g.
```
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
%0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
%1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
"func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
```
* when checking permutation groups the variable `low` was set just
once to zero, thus checking was quadratic. It looks the intent was
for `low` to track the beginning of each dimension groups.
(Nevertheless the check was correct).
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..7ae43b64a5deb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
bool inputIsScalar = !inputType;
if (inputIsScalar) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- transpose.getVector());
+ broadcast.getSource());
return success();
}
@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
transpose, "permutation not local to group");
}
}
+ low = high;
}
}
@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
"not broadcastable directly to transpose output");
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- transpose.getVector());
+ broadcast.getSource());
return success();
}
More information about the Mlir-commits
mailing list