[Mlir-commits] [mlir] b4e8b8e - [mlir][vector] Canonicalize broadcast of shape_cast (#150523)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 8 09:25:35 PDT 2025


Author: Min-Yih Hsu
Date: 2025-08-08T09:25:32-07:00
New Revision: b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b

URL: https://github.com/llvm/llvm-project/commit/b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b
DIFF: https://github.com/llvm/llvm-project/commit/b4e8b8ee914d2de2c6c33c656fe8c06f5c11a01b.diff

LOG: [mlir][vector] Canonicalize broadcast of shape_cast (#150523)

Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type and the shape_cast only adds or removes ones in the leading dimensions.

---------

Co-authored-by: Andrzej WarzyƄski <andrzej.warzynski at gmail.com>
Co-authored-by: James Newling <james.newling at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a450056a3041a..cb4783d26a114 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() {
   llvm_unreachable("unexpected vector.broadcast op error");
 }
 
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+  if (!srcShapeCast)
+    return failure();
+
+  VectorType srcType = srcShapeCast.getSourceVectorType();
+  VectorType destType = broadcastOp.getResultVectorType();
+  // Check type compatibility.
+  if (vector::isBroadcastableTo(srcType, destType) !=
+      BroadcastableToResult::Success)
+    return failure();
+
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  ArrayRef<int64_t> shapecastShape =
+      srcShapeCast.getResultVectorType().getShape();
+  // Trailing dimensions should be the same if shape_cast only alters the
+  // leading dimensions.
+  unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+  if (!llvm::equal(srcShape.take_back(numTrailingDims),
+                   shapecastShape.take_back(numTrailingDims)))
+    return failure();
+
+  assert(all_of(srcShape.drop_back(numTrailingDims),
+                [](int64_t E) { return E == 1; }) &&
+         all_of(shapecastShape.drop_back(numTrailingDims),
+                [](int64_t E) { return E == 1; }) &&
+         "ill-formed shape_cast");
+
+  broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
+  return success();
+}
+
 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   if (getSourceType() == getResultVectorType())
     return getSource();
+  if (succeeded(foldBroadcastOfShapeCast(*this)))
+    return getResult();
+
   if (!adaptor.getSource())
     return {};
   auto vectorType = getResultVectorType();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f86fb387be5b8..4a7176e1f8d7d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,106 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+  return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
+// CHECK:           return %[[VAL_0]] : vector<32x2x1xf32>
+// CHECK:         }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+  %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32>
+  return %1 : vector<32x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
+// CHECK:           return %[[VAL_0]] : vector<32x2x4xf32>
+// CHECK:         }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+  %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32>
+  return %1 : vector<32x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
+// CHECK:           return %[[VAL_0]] : vector<32x2xf32>
+// CHECK:         }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32>
+  %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+  return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
+//       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
+//       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+  %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+  return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
+//       CHECK:   vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
+//       CHECK:   vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
+  %1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x4xf32>
+// CHECK:         }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
+  %1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32>
+  return %1 : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
+// CHECK:           return %[[VAL_1]] : vector<32x2xf32>
+// CHECK:         }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32>
+  %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+  return %1 : vector<32x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index


        


More information about the Mlir-commits mailing list