[Mlir-commits] [mlir] [mlir][vector] transpose(broadcast) -> broadcast canonicalization (PR #135096)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Apr 10 03:51:31 PDT 2025


================
@@ -6155,12 +6155,115 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+///                                                 to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+
+    vector::BroadcastOp broadcast =
+        transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast)
+      return false;
+
+    auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+    bool inputIsScalar = !inputType;
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t inputRank = inputType.getRank();
+    int64_t outputRank = transpose.getType().getRank();
+    int64_t deltaRank = outputRank - inputRank;
+
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+    if (inputIsScalar)
+      return true;
+
+    // Return true if all permutation destinations for indices in [low, high)
+    // are in [low, high), so the permutation is local to the group.
+    auto isGroupBound = [&](int low, int high) {
+      ArrayRef<int64_t> permutation = transpose.getPermutation();
----------------
banach-space wrote:

Wouldn't it be more efficient to obtain the permutation outside the lambda?

https://github.com/llvm/llvm-project/pull/135096


More information about the Mlir-commits mailing list