[Mlir-commits] [mlir] [mlir][vector] Constrain broadcast->shape_cast folding (PR #190230)

Benoit Jacob llvmlistbot at llvm.org
Tue Apr 7 07:46:24 PDT 2026


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/190230

>From febcc4f74e887eea6014740fcb873bd3b126e517 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <benoit.jacob at amd.com>
Date: Thu, 2 Apr 2026 17:53:41 +0000
Subject: [PATCH 1/2] fix-ShapeCastBroadcastFolder

Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 48 +++++++++++--
 .../Linalg/transform-op-mmt4d-to-fma.mlir     | 69 -------------------
 .../canonicalize/vector-to-shape-cast.mlir    | 16 +++++
 3 files changed, 58 insertions(+), 75 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b1..cf2da8ff54a1d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6703,6 +6703,22 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
+// vector.broadcast has two distinct semantic modes: duplication across leading
+// dimensions, and stretching across inner dimensions. This helper returns the
+// product of the inner-dimension stretching factors.
+int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
+                                     ArrayRef<int64_t> dstShape) {
+  int stretchingFactor = 1;
+  int numLeadingDims = dstShape.size() - srcShape.size();
+  for (int i = 0, e = srcShape.size(); i < e; i++) {
+    int64_t dstDim = dstShape[numLeadingDims + i];
+    if (srcShape[i] == 1 && dstDim != 1) {
+      stretchingFactor *= dstDim;
+    }
+  }
+  return stretchingFactor;
+}
+
 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
@@ -6725,13 +6741,33 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     // to
     // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
     VectorType dstVectorType = shapeCastOp.getResultVectorType();
-    if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
-                           BroadcastableToResult::Success) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          shapeCastOp, dstVectorType, broadcastOp.getSource());
-      return success();
+    ArrayRef<int64_t> dstShape = dstVectorType.getShape();
+    ArrayRef<int64_t> srcShape =
+        srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
+    ArrayRef<int64_t> broadcastShape =
+        broadcastOp.getResultVectorType().getShape();
+
+    if (!srcIsScalar) {
+      if (isBroadcastableTo(srcVectorType, dstVectorType) !=
+          BroadcastableToResult::Success) {
+        return failure();
+      }
+      // Avoid folding if this would result in switching between the two
+      // distinct semantic modes of vector.broadcast (duplication vs
+      // stretching). See https://github.com/llvm/llvm-project/issues/190614.
+      // This is detected by a change in the stretching factor. However if the
+      // source has a single element, there is no ambiguity.
+      if (srcVectorType.getNumElements() != 1) {
+        if (getBroadcastStretchingFactor(srcShape, dstShape) !=
+            getBroadcastStretchingFactor(srcShape, broadcastShape)) {
+          return failure();
+        }
+      }
     }
-    return failure();
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
+                                                     broadcastOp.getSource());
+    return success();
   }
 };
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
deleted file mode 100644
index b5c6e610f58f9..0000000000000
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
-
-func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
-  %res = linalg.mmt4d
-                   ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
-                   outs(%C_in: tensor<16x16x8x8xf32>)
-                   -> tensor<16x16x8x8xf32>
-  return %res : tensor<16x16x8x8xf32>
-}
-
-
-// CHECK-LABEL:     @mmt4d_to_fma
-// CHECK-COUNT-8:         vector.fma
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
-
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func : (!transform.op<"func.func">) -> !transform.any_op
-
-    // Step 1: Tile
-    // Tile parallel dims
-    %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, 8, 0]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    // Tile reduction dims
-    %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p tile_sizes [0, 0, 1, 0, 0, 1]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
-    // Step 2: Vectorize
-    transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
-
-    // Step 3: Simplify
-    // vector.multi_reduction --> vector.contract
-    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
-    // and with the following split into parallel and reduction dims:
-    //    * parallel, parallel, reduction, parallel, parallel, reduction
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.reduction_to_contract
-      // Reduce the rank of xfer ops. This transforms vector.contract to be
-      // more matmul-like and to enable the lowering to outer product Ops.
-      transform.apply_patterns.vector.transfer_permutation_patterns
-    } : !transform.op<"func.func">
-
-    // Hoisting and LICM - not strictly required
-    %func_h = transform.structured.hoist_redundant_vector_transfers %func
-      : (!transform.op<"func.func">) -> !transform.op<"func.func">
-    %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
-      : (!transform.op<"func.func">) -> !transform.any_op
-    transform.apply_licm to %all_loops : !transform.any_op
-    transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
-
-    // Simplify the 6-dim vector.contract into a 3-dim matmul-like
-    // vector.contract with the following split into parallel and reduction
-    // dims:
-    //    * parallel, parallel, reduction
-    transform.apply_patterns to %func_h {
-      transform.apply_patterns.vector.reduction_to_contract
-      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
-      transform.apply_patterns.canonicalization
-    } : !transform.op<"func.func">
-
-    // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
-    transform.apply_patterns to %func_h {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-      transform.apply_patterns.vector.lower_outerproduct
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
index 1d082179207dc..7ae190d0b3c56 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -128,3 +128,19 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   return %1: vector<1xf32>
 }
 
+// -----
+
+// https://github.com/llvm/llvm-project/issues/190614 — ShapeCastBroadcastFolder
+// must not rewrite shape_cast(broadcast) into a lower-rank broadcast when that
+// changes duplication (new leading dims) vs stretching (source unit dim).
+// For %arg0 = [[1], [2]]: broadcast to 2x2x1 then shape_cast to 2x2 yields
+// [[1, 2], [1, 2]]; folding to broadcast 2x1 -> 2x2 would incorrectly yield
+// [[1, 1], [2, 2]].
+// CHECK-LABEL: @no_fold_bcast_mode_switch
+// CHECK:         vector.broadcast %{{.*}} : vector<2x1xf32> to vector<2x2x1xf32>
+// CHECK-NEXT:    vector.shape_cast %{{.*}} : vector<2x2x1xf32> to vector<2x2xf32>
+func.func @no_fold_bcast_mode_switch(%arg0: vector<2x1xf32>) -> vector<2x2xf32> {
+  %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x2x1xf32>
+  %1 = vector.shape_cast %0 : vector<2x2x1xf32> to vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}

>From 0802e64c9e300a7f5aeb3d45853814210c87dcc4 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <benoit.jacob at amd.com>
Date: Tue, 7 Apr 2026 14:43:09 +0000
Subject: [PATCH 2/2] add vector.multireduction to vector.fma test

Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>
---
 .../Vector/vector-multi-reduction-to-fma.mlir | 51 +++++++++++++++++++
 1 file changed, 51 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir

diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir
new file mode 100644
index 0000000000000..75ac060bcc666
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// `arith.mulf` + `vector.multi_reduction` matches what `linalg.matmul` produces
+// after `transform.structured.vectorize` (with matmul-style transfer_read layout).
+// This checks the follow-on stack: multi_reduction → contract → (transfer layout
+// cleanup) → outer-product-style contraction lowering → `vector.fma`.
+
+#map_lhs = affine_map<(d0, d1) -> (d0, 0, d1)>
+#map_rhs = affine_map<(d0, d1) -> (0, d1, d0)>
+
+// CHECK-LABEL: func @multi_reduction_to_fma
+// CHECK-SAME: memref<3x4xf32>
+// CHECK-SAME: memref<4x3xf32>
+// CHECK-SAME: memref<3x3xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<3x4xf32>, vector<3x4xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<4x3xf32>, vector<4x3xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<3x3xf32>, vector<3x3xf32>
+// One dot-product row uses three fused multiply-adds along K; 3 output rows × 4 K steps.
+// CHECK-COUNT-12: vector.fma
+// CHECK-NOT: vector.multi_reduction
+// CHECK-NOT: vector.contract
+// CHECK-NOT: vector.outerproduct
+// CHECK: vector.transfer_write {{.*}} : vector<3x3xf32>, memref<3x3xf32>
+func.func @multi_reduction_to_fma(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+  %c0 = arith.constant 0 : index
+  %p = ub.poison : f32
+  %va = vector.transfer_read %A[%c0, %c0], %p {permutation_map = #map_lhs} : memref<3x4xf32>, vector<3x3x4xf32>
+  %vb = vector.transfer_read %B[%c0, %c0], %p {permutation_map = #map_rhs} : memref<4x3xf32>, vector<3x3x4xf32>
+  %vc = vector.transfer_read %C[%c0, %c0], %p : memref<3x3xf32>, vector<3x3xf32>
+  %mul = arith.mulf %va, %vb : vector<3x3x4xf32>
+  %acc = vector.multi_reduction <add>, %mul, %vc [2] : vector<3x3x4xf32> to vector<3x3xf32>
+  vector.transfer_write %acc, %C[%c0, %c0] : vector<3x3xf32>, memref<3x3xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+    } : !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list