[Mlir-commits] [mlir] 824fc35 - [mlir][vector] Constrain broadcast->shape_cast folding (#190230)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 7 11:26:06 PDT 2026


Author: Benoit Jacob
Date: 2026-04-07T14:26:01-04:00
New Revision: 824fc35bd9453e1776f098cc5d7123b1034ea510

URL: https://github.com/llvm/llvm-project/commit/824fc35bd9453e1776f098cc5d7123b1034ea510
DIFF: https://github.com/llvm/llvm-project/commit/824fc35bd9453e1776f098cc5d7123b1034ea510.diff

LOG: [mlir][vector] Constrain broadcast->shape_cast folding (#190230)

Fixes https://github.com/llvm/llvm-project/issues/190614.

Do not fold broadcast->shape_cast when that would result in switching
between the two distinct semantic modes of `vector.broadcast`, as
explained in https://github.com/llvm/llvm-project/issues/190614.

This fixes incorrect-result bugs in IREE:
https://github.com/iree-org/iree/issues/23952

---------

Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>

Added: 
    mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir

Removed: 
    mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b1..cf2da8ff54a1d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6703,6 +6703,22 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
+// vector.broadcast has two distinct semantic modes: duplication across leading
+// dimensions, and stretching across inner dimensions. This helper returns the
+// product of the inner-dimension stretching factors.
+int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
+                                     ArrayRef<int64_t> dstShape) {
+  int stretchingFactor = 1;
+  int numLeadingDims = dstShape.size() - srcShape.size();
+  for (int i = 0, e = srcShape.size(); i < e; i++) {
+    int64_t dstDim = dstShape[numLeadingDims + i];
+    if (srcShape[i] == 1 && dstDim != 1) {
+      stretchingFactor *= dstDim;
+    }
+  }
+  return stretchingFactor;
+}
+
 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
@@ -6725,13 +6741,33 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     // to
     // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
     VectorType dstVectorType = shapeCastOp.getResultVectorType();
-    if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
-                           BroadcastableToResult::Success) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          shapeCastOp, dstVectorType, broadcastOp.getSource());
-      return success();
+    ArrayRef<int64_t> dstShape = dstVectorType.getShape();
+    ArrayRef<int64_t> srcShape =
+        srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
+    ArrayRef<int64_t> broadcastShape =
+        broadcastOp.getResultVectorType().getShape();
+
+    if (!srcIsScalar) {
+      if (isBroadcastableTo(srcVectorType, dstVectorType) !=
+          BroadcastableToResult::Success) {
+        return failure();
+      }
+      // Avoid folding if this would result in switching between the two
+      // distinct semantic modes of vector.broadcast (duplication vs
+      // stretching). See https://github.com/llvm/llvm-project/issues/190614.
+      // This is detected by a change in the stretching factor. However if the
+      // source has a single element, there is no ambiguity.
+      if (srcVectorType.getNumElements() != 1) {
+        if (getBroadcastStretchingFactor(srcShape, dstShape) !=
+            getBroadcastStretchingFactor(srcShape, broadcastShape)) {
+          return failure();
+        }
+      }
     }
-    return failure();
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
+                                                     broadcastOp.getSource());
+    return success();
   }
 };
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
deleted file mode 100644
index b5c6e610f58f9..0000000000000
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
-
-func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
-  %res = linalg.mmt4d
-                   ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
-                   outs(%C_in: tensor<16x16x8x8xf32>)
-                   -> tensor<16x16x8x8xf32>
-  return %res : tensor<16x16x8x8xf32>
-}
-
-
-// CHECK-LABEL:     @mmt4d_to_fma
-// CHECK-COUNT-8:         vector.fma
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
-
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func : (!transform.op<"func.func">) -> !transform.any_op
-
-    // Step 1: Tile
-    // Tile parallel dims
-    %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, 8, 0]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    // Tile reduction dims
-    %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p tile_sizes [0, 0, 1, 0, 0, 1]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
-    // Step 2: Vectorize
-    transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
-
-    // Step 3: Simplify
-    // vector.multi_reduction --> vector.contract
-    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
-    // and with the following split into parallel and reduction dims:
-    //    * parallel, parallel, reduction, parallel, parallel, reduction
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.reduction_to_contract
-      // Reduce the rank of xfer ops. This transforms vector.contract to be
-      // more matmul-like and to enable the lowering to outer product Ops.
-      transform.apply_patterns.vector.transfer_permutation_patterns
-    } : !transform.op<"func.func">
-
-    // Hoisting and LICM - not strictly required
-    %func_h = transform.structured.hoist_redundant_vector_transfers %func
-      : (!transform.op<"func.func">) -> !transform.op<"func.func">
-    %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
-      : (!transform.op<"func.func">) -> !transform.any_op
-    transform.apply_licm to %all_loops : !transform.any_op
-    transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
-
-    // Simplify the 6-dim vector.contract into a 3-dim matmul-like
-    // vector.contract with the following split into parallel and reduction
-    // dims:
-    //    * parallel, parallel, reduction
-    transform.apply_patterns to %func_h {
-      transform.apply_patterns.vector.reduction_to_contract
-      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
-      transform.apply_patterns.canonicalization
-    } : !transform.op<"func.func">
-
-    // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
-    transform.apply_patterns to %func_h {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-      transform.apply_patterns.vector.lower_outerproduct
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
index 1d082179207dc..7ae190d0b3c56 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -128,3 +128,19 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   return %1: vector<1xf32>
 }
 
+// -----
+
+// https://github.com/llvm/llvm-project/issues/190614 — ShapeCastBroadcastFolder
+// must not rewrite shape_cast(broadcast) into a lower-rank broadcast when that
+// changes duplication (new leading dims) vs stretching (source unit dim).
+// For %arg0 = [[1], [2]]: broadcast to 2x2x1 then shape_cast to 2x2 yields
+// [[1, 2], [1, 2]]; folding to broadcast 2x1 -> 2x2 would incorrectly yield
+// [[1, 1], [2, 2]].
+// CHECK-LABEL: @no_fold_bcast_mode_switch
+// CHECK:         vector.broadcast %{{.*}} : vector<2x1xf32> to vector<2x2x1xf32>
+// CHECK-NEXT:    vector.shape_cast %{{.*}} : vector<2x2x1xf32> to vector<2x2xf32>
+func.func @no_fold_bcast_mode_switch(%arg0: vector<2x1xf32>) -> vector<2x2xf32> {
+  %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x2x1xf32>
+  %1 = vector.shape_cast %0 : vector<2x2x1xf32> to vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir
new file mode 100644
index 0000000000000..75ac060bcc666
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// `arith.mulf` + `vector.multi_reduction` matches what `linalg.matmul` produces
+// after `transform.structured.vectorize` (with matmul-style transfer_read layout).
+// This checks the follow-on stack: multi_reduction → contract → (transfer layout
+// cleanup) → outer-product-style contraction lowering → `vector.fma`.
+
+#map_lhs = affine_map<(d0, d1) -> (d0, 0, d1)>
+#map_rhs = affine_map<(d0, d1) -> (0, d1, d0)>
+
+// CHECK-LABEL: func @multi_reduction_to_fma
+// CHECK-SAME: memref<3x4xf32>
+// CHECK-SAME: memref<4x3xf32>
+// CHECK-SAME: memref<3x3xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<3x4xf32>, vector<3x4xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<4x3xf32>, vector<4x3xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<3x3xf32>, vector<3x3xf32>
+// One dot-product row uses three fused multiply-adds along K; 3 output rows × 4 K steps.
+// CHECK-COUNT-12: vector.fma
+// CHECK-NOT: vector.multi_reduction
+// CHECK-NOT: vector.contract
+// CHECK-NOT: vector.outerproduct
+// CHECK: vector.transfer_write {{.*}} : vector<3x3xf32>, memref<3x3xf32>
+func.func @multi_reduction_to_fma(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+  %c0 = arith.constant 0 : index
+  %p = ub.poison : f32
+  %va = vector.transfer_read %A[%c0, %c0], %p {permutation_map = #map_lhs} : memref<3x4xf32>, vector<3x3x4xf32>
+  %vb = vector.transfer_read %B[%c0, %c0], %p {permutation_map = #map_rhs} : memref<4x3xf32>, vector<3x3x4xf32>
+  %vc = vector.transfer_read %C[%c0, %c0], %p : memref<3x3xf32>, vector<3x3xf32>
+  %mul = arith.mulf %va, %vb : vector<3x3x4xf32>
+  %acc = vector.multi_reduction <add>, %mul, %vc [2] : vector<3x3x4xf32> to vector<3x3xf32>
+  vector.transfer_write %acc, %C[%c0, %c0] : vector<3x3xf32>, memref<3x3xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+    } : !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.any_op
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list