[Mlir-commits] [mlir] [mlir][vector] shape_cast(broadcast) -> broadcast canonicalization (PR #134939)

James Newling llvmlistbot at llvm.org
Tue Apr 8 15:46:24 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/134939

Add additional cases of this canonicalization, by checking the 'source of truth' function `isBroadcastableTo` to check when it is possible to broadcast directly to the shape resulting from the shape_cast. 

>From d5d59c27560aab0cf55543e9bab56bcbbb072963 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 8 Apr 2025 15:45:48 -0700
Subject: [PATCH] cover additional cases of shape_cast(broadcast) -> broadcast
 canonicalization

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 33 +++++++++-------------
 mlir/test/Dialect/Vector/canonicalize.mlir | 25 ++++++++++++++++
 2 files changed, 39 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..c6d8ec1e1cf69 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5778,8 +5778,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
 
 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
 /// This only applies when the shape of the broadcast source
-/// 1. is a suffix of the shape of the result (i.e. when broadcast without
-///    reshape is expressive enough to capture the result in a single op), or
+/// 1. can be broadcast directly to the final shape, or
 /// 2. has the same element count as the shape cast result.
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
@@ -5792,24 +5791,20 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     if (!broadcastOp)
       return failure();
 
-    ArrayRef<int64_t> broadcastSourceShape;
-    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
-      broadcastSourceShape = srcType.getShape();
-    ArrayRef<int64_t> shapeCastTargetShape =
-        shapeCastOp.getResultVectorType().getShape();
-
-    // If `broadcastSourceShape` is a suffix of the result, we can just replace
-    // with a broadcast to the final shape.
-    if (broadcastSourceShape ==
-        shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          shapeCastOp, shapeCastOp.getResultVectorType(),
-          broadcastOp.getSource());
-      return success();
+    {
+      VectorType dstType = shapeCastOp.getResultVectorType();
+      auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+      bool isScalar = !srcType;
+      if (isScalar || isBroadcastableTo(srcType, dstType) ==
+                          BroadcastableToResult::Success) {
+        rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+            shapeCastOp, dstType, broadcastOp.getSource());
+        return success();
+      }
     }
 
-    // Otherwise, if the final result has the same element count, we can replace
-    // with a shape cast.
+    // If the final result has the same element count, we can replace with a
+    // shape cast.
     if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
       if (srcType.getNumElements() ==
           shapeCastOp.getResultVectorType().getNumElements()) {
@@ -6079,7 +6074,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
-// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
+// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
 struct FoldTransposedScalarBroadcast final
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..d7617d79b5cbf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1017,6 +1017,31 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 
 // -----
 
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
+//       CHECK:   vector.broadcast
+//  CHECK-SAME:   f32 to vector<3x4x1xf32>
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<12xf32>
+  %1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
+  return %1 : vector<3x4x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_ones
+//       CHECK:   vector.broadcast
+//  CHECK-SAME:   vector<1x1xi8> to vector<1x1x6x1x4xi8>
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
+  %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
+  return %1 : vector<1x1x6x1x4xi8>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index



More information about the Mlir-commits mailing list