[Mlir-commits] [mlir] 38d2306 - [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (#140083)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 19 01:45:45 PDT 2025


Author: Momchil Velikov
Date: 2025-05-19T09:45:41+01:00
New Revision: 38d2306b62d6b0b7cc0a1bc0d73a2f9c8323bd87

URL: https://github.com/llvm/llvm-project/commit/38d2306b62d6b0b7cc0a1bc0d73a2f9c8323bd87
DIFF: https://github.com/llvm/llvm-project/commit/38d2306b62d6b0b7cc0a1bc0d73a2f9c8323bd87.diff

LOG: [MLIR] Minor fixes to FoldTransposeBroadcast rewrite (#140083)

This patch contains two minor changes, which I believe were the original
author's intent.

* when folding `transpose(broadcast(x))` emit `broadcast(x)` instead of
`broadcast(broadcast(x))`. The latter causes transient verifier
failures with `mlir-opt --debug` , e.g.
```
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
  %0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
  %1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
  "func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
```

* when checking permutation groups the variable `low` was set just once
to zero, thus checking was quadratic. It looks the intent was for `low`
to track the beginning of each dimension groups. (Nevertheless the check
was correct).

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..7ae43b64a5deb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
     bool inputIsScalar = !inputType;
     if (inputIsScalar) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                       transpose.getVector());
+                                                       broadcast.getSource());
       return success();
     }
 
@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
                 transpose, "permutation not local to group");
           }
         }
+        low = high;
       }
     }
 
@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
            "not broadcastable directly to transpose output");
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                     transpose.getVector());
+                                                     broadcast.getSource());
 
     return success();
   }


        


More information about the Mlir-commits mailing list