[Mlir-commits] [mlir] 409def2 - [mlir][vector] shape_cast(broadcast) -> broadcast canonicalization (#134939)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 11 07:15:06 PDT 2025


Author: James Newling
Date: 2025-04-11T15:15:03+01:00
New Revision: 409def2867ffe7dffceeff380d7b4057bac5eac6

URL: https://github.com/llvm/llvm-project/commit/409def2867ffe7dffceeff380d7b4057bac5eac6
DIFF: https://github.com/llvm/llvm-project/commit/409def2867ffe7dffceeff380d7b4057bac5eac6.diff

LOG: [mlir][vector] shape_cast(broadcast) -> broadcast canonicalization (#134939)

Add additional cases of this canonicalization, by checking the 'source
of truth' function `isBroadcastableTo` to check when it is possible to
broadcast directly to the shape resulting from the shape_cast.

---------

Signed-off-by: James Newling <james.newling at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 59f3b788cebed..754dab21ee1f3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5768,11 +5768,10 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
-/// This only applies when the shape of the broadcast source
-/// 1. is a suffix of the shape of the result (i.e. when broadcast without
-///    reshape is expressive enough to capture the result in a single op), or
-/// 2. has the same element count as the shape cast result.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
+///   i) Y = ShapeCast(X), or
+///  ii) Y = Broadcast(X)
+/// If both (i) and (ii) are possible, (i) is chosen.
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -5784,26 +5783,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     if (!broadcastOp)
       return failure();
 
-    ArrayRef<int64_t> broadcastSourceShape;
-    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
-      broadcastSourceShape = srcType.getShape();
-    ArrayRef<int64_t> shapeCastTargetShape =
-        shapeCastOp.getResultVectorType().getShape();
-
-    // If `broadcastSourceShape` is a suffix of the result, we can just replace
-    // with a broadcast to the final shape.
-    if (broadcastSourceShape ==
-        shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          shapeCastOp, shapeCastOp.getResultVectorType(),
-          broadcastOp.getSource());
-      return success();
-    }
-
-    // Otherwise, if the final result has the same element count, we can replace
-    // with a shape cast.
-    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
-      if (srcType.getNumElements() ==
+    auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    bool srcIsScalar = !srcVectorType;
+
+    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
+    // Example:
+    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
+    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
+    // to
+    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
+    if (srcVectorType) {
+      if (srcVectorType.getNumElements() ==
           shapeCastOp.getResultVectorType().getNumElements()) {
         rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
             shapeCastOp, shapeCastOp.getResultVectorType(),
@@ -5812,6 +5802,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
       }
     }
 
+    // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
+    // Example
+    // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
+    // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
+    // to
+    // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
+    VectorType dstVectorType = shapeCastOp.getResultVectorType();
+    if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
+                           BroadcastableToResult::Success) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          shapeCastOp, dstVectorType, broadcastOp.getSource());
+      return success();
+    }
     return failure();
   }
 };
@@ -6072,7 +6075,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
-// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
+// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
 struct FoldTransposedScalarBroadcast final
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 72064fb42741a..a6d82b85777b0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1006,13 +1006,48 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -
 
 // -----
 
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_ones
+//       CHECK:   vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8>
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
+  %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
+  return %1 : vector<1x1x6x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_scalar
+//       CHECK:   vector.broadcast {{.*}} f32 to vector<3x4x1xf32>
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<12xf32>
+  %1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
+  return %1 : vector<3x4x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
+  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
+  return %1 : vector<1x2x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
 //   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
-    %0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
-    return %1 : vector<1x12xf32>
+//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
+    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+    return %1 : vector<1x1xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list