[Mlir-commits] [mlir] [mlir][vector] Canonicalize vector.extract and vector.broadcast to vector.shape_cast (PR #174452)

James Newling llvmlistbot at llvm.org
Mon Jan 5 09:54:39 PST 2026


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/174452

>From 6682fd0c588f61a8c38dadcd1dacb2ac2b2b60fa Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Jan 2026 09:38:50 -0800
Subject: [PATCH 1/2] Add canonicalizer patters and tests

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  96 +++++++++----
 mlir/test/Dialect/Vector/canonicalize.mlir    |  22 +--
 .../canonicalize/vector-from-elements.mlir    |   2 +-
 .../canonicalize/vector-to-shape-cast.mlir    | 130 ++++++++++++++++++
 .../vector-transfer-to-vector-load-store.mlir |   8 +-
 .../Vector/vector-warp-distribute.mlir        |   4 +-
 6 files changed, 219 insertions(+), 43 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..27dd4e2e034fc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2359,11 +2359,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// Replace `vector.extract` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return rewriter.notifyMatchFailure(
+          extractOp, "extract to vector with fewer elements");
+
+    // Negative values in `position` means that the extacted value is poison.
+    // There is a vector.extract folder for this.
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "leaving for extract poison folder");
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getSource());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -3076,13 +3113,43 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// Replace `vector.broadcast` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements()) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast to a greater number of elements");
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -6479,10 +6546,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using Base::Base;
@@ -6497,22 +6561,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e17b1cfbe5e0d..fdc8e487ff6bb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -814,7 +814,7 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // CHECK-LABEL: negative_fold_extract_broadcast
 //       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
   %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -931,7 +931,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
 
 // CHECK-LABEL: fold_extract_broadcastlike_shape_cast
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
-//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
 //       CHECK:   return %[[R]] : vector<1x1xf32>
 func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
   -> vector<1x1xf32> {
@@ -1561,7 +1561,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 
 // -----
 
-// Check the case where the same dimension is both broadcasted and sliced 
+// Check the case where the same dimension is both broadcasted and sliced
 // CHECK-LABEL: func @extract_strided_broadcast5
 //  CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
 //       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
@@ -1918,7 +1918,7 @@ func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
 //  CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
 // CHECK-NOT: vector.transfer_write
 // CHECK-NOT: vector.transfer_read
-// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
+// CHECK: %[[RET:.+]] = vector.shape_cast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
 // CHECK: return %[[RET]]
 func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
     %vec: vector<4x8xf32>,
@@ -2197,8 +2197,8 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
 
 // CHECK-LABEL: func @insert_extract_to_broadcast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
 func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
@@ -2565,7 +2565,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -3047,7 +3047,7 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
 
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
+  //  CHECK-NEXT:   %0 = vector.shape_cast {{.*}} : vector<1x1x1xf32> to vector<1xf32>
   //  CHECK-NEXT:   return %0 : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
@@ -3554,7 +3554,7 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
 // +---------------------------------------------------------------------------
 
 // CHECK-LABEL: transpose_from_elements_1d
-// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 
+// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
 func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
   %v = vector.from_elements %el_0, %el_1 : vector<2xi32>
   %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
@@ -3565,7 +3565,7 @@ func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
 }
 
 // CHECK-LABEL: transpose_from_elements_2d
-// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 
+// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
 func.func @transpose_from_elements_2d(
   %el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
   %el_1_0: i32, %el_1_1: i32, %el_1_2: i32
@@ -3579,7 +3579,7 @@ func.func @transpose_from_elements_2d(
 }
 
 // CHECK-LABEL: transpose_from_elements_3d
-// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 
+// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
 func.func @transpose_from_elements_3d(
   %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
   %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index f43328f621787..aa6539e466c95 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,7 +81,7 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
+//       CHECK:       %[[EXTRACT:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
 //       CHECK:       return %[[EXTRACT]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
new file mode 100644
index 0000000000000..86fc6a5c0d6fc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize |  FileCheck %s
+
+// This file contains tests where a vector.shape_cast is the result
+// of canonicalization.
+
+// **--------------------------------------------------------** //
+//   Tests of BroadcastToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+//  CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+  return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// broadcast can only be transformed to a shape_cast if the number of elements is
+// unchanged by the broadcast.
+// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+  return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
+// cannot be transformed to a shape_cast.
+// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
+  %0 = vector.broadcast %arg0 : i8 to vector<1xi8>
+  return %0 : vector<1xi8>
+}
+
+// -----
+
+// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
+  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
+  return %1 : vector<1x2x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
+    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+    return %1 : vector<1x1xf32>
+}
+
+// -----
+
+// **--------------------------------------------------------** //
+//   Tests of ExtractToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @extract_to_shape_cast
+//  CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+  %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison. We could
+// convert this to shape_cast (would be a legal transform with poison)
+// but we conservatively choose not to.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+  %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+//  CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
+//       CHECK:   %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
+//       CHECK:   return %[[R]]
+func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+  %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
+  %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
+  return %r : vector<12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shape_cast
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
+  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+  %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+  %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
+  %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
+  return %1: vector<1xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..8206c1a3ec865 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK:   %[[S:.*]] = vector.transfer_read %[[SRC]][]
-//  CHECK:   %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
+//  CHECK:   %[[V:.*]] = vector.shape_cast %[[S]] : vector<f32> to vector<1xf32>
     %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
       tensor<f32>, vector<1xf32>
 
@@ -369,8 +369,7 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
   %c0 = arith.constant 0 : index
 
   %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
+  // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
   // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
 
   return %res : tensor<?x?x?x?xf32>
@@ -385,8 +384,7 @@ func.func @transfer_write_broadcast_unit_dim_memref(
   %c0 = arith.constant 0 : index
 
   vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
+  // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32>
   // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
 
   return
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 135db02d543ef..2d0330043db06 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1534,7 +1534,7 @@ func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1x
 //   CHECK-PROP-DAG:   %[[THREADID:.*]] = gpu.thread_id  x
 //       CHECK-PROP:   %[[W:.*]] = gpu.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]]
 //       CHECK-PROP:     %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}]
-//       CHECK-PROP:     %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
+//       CHECK-PROP:     %[[EXTRACT:.*]] = vector.shape_cast %[[GATHER]] : vector<1x64xi32> to vector<64xi32>
 //       CHECK-PROP:     %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
 //       CHECK-PROP:     %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex>
 //       CHECK-PROP:     gpu.yield %[[EXTRACTELT]] : index
@@ -1571,7 +1571,7 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 :  memref<1
 // CHECK-PROP-LABEL: func @dont_fold_vector_broadcast(
 //       CHECK-PROP:   %[[r:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>)
 //       CHECK-PROP:     %[[some_def:.*]] = "some_def"
-//       CHECK-PROP:     %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
+//       CHECK-PROP:     %[[broadcast:.*]] = vector.shape_cast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
 //       CHECK-PROP:     gpu.yield %[[broadcast]] : vector<1x64xf32>
 //       CHECK-PROP:   vector.print %[[r]] : vector<1x2xf32>
 func.func @dont_fold_vector_broadcast(%laneid: index) {

>From 1f49b3dc3b29e5b25b928d2cc1219636d5071f42 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 Jan 2026 09:59:06 -0800
Subject: [PATCH 2/2] update linalg test

---
 .../Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir    | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
index 4bb40bef9fba2..4660cc75a1940 100644
--- a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir
@@ -24,7 +24,7 @@ func.func @conv(%arg0: tensor<1x1080x1962x48xi32>, %arg1: tensor<1x43x48xi32>) -
 // Loop over the Filter width dim
 // CHECK:                 scf.for %{{.*}} = %[[C0]] to %[[C_43]] step %[[C1]] {{.*}} -> (tensor<1x1x4x?xi32>) {
 // CHECK-NOT:               vector.mask
-// CHECK:                   vector.broadcast {{.*}} : vector<[4]xi32> to vector<1x4x[4]xi32>
+// CHECK:                   vector.broadcast {{.*}} : vector<1x[4]xi32> to vector<1x4x[4]xi32>
 // CHECK-NEXT:              arith.muli {{.*}} : vector<1x4x[4]xi32>
 // CHECK-NEXT:              arith.addi {{.*}} : vector<1x4x[4]xi32>
 // CHECK-NOT:               vector.mask



More information about the Mlir-commits mailing list