[Mlir-commits] [mlir] [mlir][vector] transpose(broadcast) -> broadcast canonicalization (PR #135096)

James Newling llvmlistbot at llvm.org
Mon Apr 14 08:12:14 PDT 2025


================
@@ -6151,12 +6129,103 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+///                                                 to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+///
+/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
+/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
+/// broadcasting from 1x1x4x1x1x7.
+///                   ^^^ ^ ^^^ ^
+///          groups:   0  1  2  3
+/// Order preserving permutations for this example are ones that only permute
+/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    vector::BroadcastOp broadcast =
+        transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast) {
+      return rewriter.notifyMatchFailure(transpose,
+                                         "not preceded by a broadcast");
+    }
+
+    auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+    bool inputIsScalar = !inputType;
+    if (inputIsScalar) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          transpose, transpose.getResultVectorType(), transpose.getVector());
+      return success();
+    }
+
+    ArrayRef<int64_t> permutation = transpose.getPermutation();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t inputRank = inputType.getRank();
+    int64_t outputRank = transpose.getType().getRank();
+    int64_t deltaRank = outputRank - inputRank;
+
+    int low = 0;
+    for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+      bool notOne = inputShape[inputIndex] != 1;
+      bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+      bool groupEndFound = notOne || prevNotOne;
+      if (groupEndFound) {
+        int high = inputIndex + deltaRank;
+        // Return failure if not all permutation destinations for indices in
+        // [low, high) are in [low, high), i.e. the permutation is not local to
+        // the group.
+        for (int i = low; i < high; ++i) {
+          if (permutation[i] < low || permutation[i] >= high) {
+            return rewriter.notifyMatchFailure(
+                transpose, "permutation not local to group");
+          }
+        }
+      }
+    }
+
+    // We don't need to check the final group [low, outputRank) because if it is
+    // not locally bound, there must be a preceding group that already failed
+    // the check (impossible to have just 1 non-locally bound group).
+
+    // The preceding logic also ensures that at this point, the output of the
+    // transpose is definitely broadcastable from the input shape, so we don't
+    // need to check vector::isBroadcastableTo now.
----------------
newling wrote:

Added it basic, apologies for the flip-flopping :)

https://github.com/llvm/llvm-project/pull/135096


More information about the Mlir-commits mailing list