[Mlir-commits] [mlir] 38d2306 - [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (#140083)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 19 01:45:45 PDT 2025
Author: Momchil Velikov
Date: 2025-05-19T09:45:41+01:00
New Revision: 38d2306b62d6b0b7cc0a1bc0d73a2f9c8323bd87
URL: https://github.com/llvm/llvm-project/commit/38d2306b62d6b0b7cc0a1bc0d73a2f9c8323bd87
DIFF: https://github.com/llvm/llvm-project/commit/38d2306b62d6b0b7cc0a1bc0d73a2f9c8323bd87.diff
LOG: [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (#140083)
This patch contains two minor changes, which I believe were the original
author's intent.
* when folding `transpose(broadcast(x))` emit `broadcast(x)` instead of
`broadcast(broadcast(x))`. The latter causes transient verifier
failures with `mlir-opt --debug` , e.g.
```
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
%0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
%1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
"func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
```
* when checking permutation groups the variable `low` was set just once
to zero, thus checking was quadratic. It looks the intent was for `low`
to track the beginning of each dimension groups. (Nevertheless the check
was correct).
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..7ae43b64a5deb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
bool inputIsScalar = !inputType;
if (inputIsScalar) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- transpose.getVector());
+ broadcast.getSource());
return success();
}
@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
transpose, "permutation not local to group");
}
}
+ low = high;
}
}
@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
"not broadcastable directly to transpose output");
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- transpose.getVector());
+ broadcast.getSource());
return success();
}
More information about the Mlir-commits
mailing list