[Mlir-commits] [mlir] [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (NFC) (PR #140083)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 15 08:33:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

This patch contains two minor changes, which I believe were the original author's intent.

* when folding `transpose(broadcast(x))` emit `broadcast(x)` instead of `broadcast(broadcast(x))`. The latter causes intermittent verifier failures, e.g.
```
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
  %0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
  %1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
  "func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
```

* when checking permutation groups the variable `low` was set just once to zero, thus checking was quadratic. It looks the intent was for `low` to track the beginning of each dimension groups. (Nevertheless the check was correct).

---
Full diff: https://github.com/llvm/llvm-project/pull/140083.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+3-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..7ae43b64a5deb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
     bool inputIsScalar = !inputType;
     if (inputIsScalar) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                       transpose.getVector());
+                                                       broadcast.getSource());
       return success();
     }
 
@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
                 transpose, "permutation not local to group");
           }
         }
+        low = high;
       }
     }
 
@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
            "not broadcastable directly to transpose output");
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                     transpose.getVector());
+                                                     broadcast.getSource());
 
     return success();
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/140083


More information about the Mlir-commits mailing list