[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)

James Newling llvmlistbot at llvm.org
Fri Jul 18 10:11:46 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/146368

>From 09ba159afe75e1ff476ff82d51668471699d40ed Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 30 Jun 2025 09:15:48 -0700
Subject: [PATCH 1/6] extend to broadcastlike, code simplifications

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 126 ++++++++++-----------
 mlir/test/Dialect/Vector/canonicalize.mlir |  46 +++++++-
 2 files changed, 104 insertions(+), 68 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d615bfc12984..cfad95a7aee79 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1707,59 +1707,71 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
+/// 1s, are considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+  if (isa<BroadcastOp, SplatOp>(op))
+    return true;
+
+  auto shapeCast = dyn_cast<ShapeCastOp>(op);
+  if (!shapeCast)
+    return false;
+
+  // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
+  // Condition 1: dst has hight rank.
+  // Condition 2: src shape is a suffix of dst shape.
+  VectorType srcType = shapeCast.getSourceVectorType();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  uint64_t srcRank = srcType.getRank();
+  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+  return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  Operation *defOp = extractOp.getVector().getDefiningOp();
-  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+  Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+  if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
     return Value();
 
-  Value source = defOp->getOperand(0);
-  if (extractOp.getType() == source.getType())
-    return source;
-  auto getRank = [](Type type) {
-    return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
-                                       : 0;
-  };
+  Value src = broadcastLikeOp->getOperand(0);
+
+  // Replace extract(broadcast(X)) with X
+  if (extractOp.getType() == src.getType())
+    return src;
 
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
-    return source;
+  // Get required types and ranks in the chain
+  //    src -> broadcastDst -> dst
+  auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+  auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+  unsigned srcRank = srcType ? srcType.getRank() : 0;
+  unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+  unsigned dstRank = dstType ? dstType.getRank() : 0;
 
-  unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank > broadcastSrcRank)
+  // Cannot do without the broadcast if overall the rank increases.
+  if (dstRank > srcRank)
     return Value();
-  // Check that the dimension of the result haven't been broadcasted.
-  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
-  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
-  if (extractVecType && broadcastVecType &&
-      extractVecType.getShape() !=
-          broadcastVecType.getShape().take_back(extractResultRank))
+
+  assert(srcType && "src must be a vector type because of previous checks");
+
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
     return Value();
 
-  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
-  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
+  // Replace extract(broadcast(X)) with extract(X).
+  // First, determine the new extraction position.
+  unsigned deltaOverall = srcRank - dstRank;
+  unsigned deltaBroadcast = broadcastDstRank - srcRank;
 
-  // Detect all the positions that come from "dim-1" broadcasting.
-  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
-  // extract position to `0` when extracting from the source operand.
-  llvm::SetVector<int64_t> broadcastedUnitDims =
-      broadcastOp.computeBroadcastedUnitDims();
-  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
-  OpBuilder b(extractOp.getContext());
-  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
-  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
-    if (broadcastedUnitDims.contains(i))
-      extractPos[i] = b.getIndexAttr(0);
-  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
-  // matching extract position when extracting from the source operand.
-  int64_t rankDiff = broadcastSrcRank - extractResultRank;
-  extractPos.erase(extractPos.begin(),
-                   std::next(extractPos.begin(), extractPos.size() - rankDiff));
-  // OpBuilder is only used as a helper to build an I64ArrayAttr.
-  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+  SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
+  SmallVector<OpFoldResult> newPositions(deltaOverall);
+  IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
+  for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+    newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
+  }
+  auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
   extractOp->setOperands(
-      llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+      llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
   extractOp.setStaticPosition(staticPos);
   return extractOp.getResult();
 }
@@ -2204,32 +2216,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
 
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    Operation *defOp = extractOp.getVector().getDefiningOp();
-    if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
-      return failure();
 
-    Value source = defOp->getOperand(0);
-    if (extractOp.getType() == source.getType())
+    Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
       return failure();
-    auto getRank = [](Type type) {
-      return llvm::isa<VectorType>(type)
-                 ? llvm::cast<VectorType>(type).getRank()
-                 : 0;
-    };
-    unsigned broadcastSrcRank = getRank(source.getType());
-    unsigned extractResultRank = getRank(extractOp.getType());
-    // We only consider the case where the rank of the source is less than or
-    // equal to the rank of the extract dst. The other cases are handled in the
-    // folding patterns.
-    if (extractResultRank < broadcastSrcRank)
-      return failure();
-    // For scalar result, the input can only be a rank-0 vector, which will
-    // be handled by the folder.
-    if (extractResultRank == 0)
+
+    Value source = broadcastLikeOp->getOperand(0);
+    if (isBroadcastableTo(source.getType(), outType) !=
+        BroadcastableToResult::Success)
       return failure();
 
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        extractOp, extractOp.getType(), source);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index ea2343efd246e..6ed64cb8313c2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -823,10 +823,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
 
 // -----
 
-// CHECK-LABEL: fold_extract_splat
+// CHECK-LABEL: fold_extract_scalar_from_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
@@ -834,6 +834,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
 
 // -----
 
+// CHECK-LABEL: fold_extract_vector_from_splat
+//       CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
+func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.splat %a : vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<2x1xf32>
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -863,6 +873,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
 
 // -----
 
+// Test where the shape_cast is broadcast-like.
+// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
+//  CHECK-SAME:   %[[A:.*]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK:   return %[[B]] : vector<4xf32>
+func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+  %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
@@ -890,6 +915,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
+//  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
+//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   return %[[R]] : vector<1x1xf32>
+func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
+  -> vector<1x1xf32> {
+  %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
+  %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
+  return %r : vector<1x1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_extract_shuffle
 //  CHECK-SAME:   %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
 //   CHECK-NOT:   vector.shuffle
@@ -1623,7 +1661,7 @@ func.func @negative_store_to_load_tensor_memref(
     %arg0 : tensor<?x?xf32>,
     %arg1 : memref<?x?xf32>,
     %v0 : vector<4x2xf32>
-  ) -> vector<4x2xf32> 
+  ) -> vector<4x2xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -1680,7 +1718,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
 //       CHECK:   vector.transfer_read
 func.func @negative_store_to_load_tensor_broadcast_masked(
     %arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-  -> vector<4x2x6xf32> 
+  -> vector<4x2x6xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32

>From 1c46b4eab4b8d1cc6000e0e78de13a5fa7ec9153 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 30 Jun 2025 11:32:03 -0700
Subject: [PATCH 2/6] improve comments, add test

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 92 +++++++++++++++-------
 mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++
 2 files changed, 78 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cfad95a7aee79..3ea8d0eb784c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1707,8 +1707,8 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
-/// 1s, are considered 'broadcastlike'.
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
+/// 1s, are considered to be 'broadcastlike'.
 static bool isBroadcastLike(Operation *op) {
   if (isa<BroadcastOp, SplatOp>(op))
     return true;
@@ -1717,9 +1717,12 @@ static bool isBroadcastLike(Operation *op) {
   if (!shapeCast)
     return false;
 
-  // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
+  // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
   // Condition 1: dst has hight rank.
   // Condition 2: src shape is a suffix of dst shape.
+  //
+  // Note that checking that dst shape has a prefix of 1s is not sufficient,
+  // for example (2,3) -> (1,3,2) is not broadcast-like.
   VectorType srcType = shapeCast.getSourceVectorType();
   ArrayRef<int64_t> srcShape = srcType.getShape();
   uint64_t srcRank = srcType.getRank();
@@ -1727,51 +1730,84 @@ static bool isBroadcastLike(Operation *op) {
   return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
 }
 
-/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
+/// Fold extract(broadcast(X)) to either extract(X) or just X.
+///
+/// Example:
+///
+///        broadcast           extract
+/// (3, 4) --------> (2, 3, 4) ------> (4)
+///
+/// becomes
+///                  extract
+/// (3,4) ---------------------------> (4)
+///
+///
+/// The variable names used in this implementation use names which correspond to
+/// the above shapes as,
+///
+/// - (3, 4) is `input` shape.
+/// - (2, 3, 4) is `broadcast` shape.
+/// - (4) is `extract` shape.
+///
+/// This folding is possible when the suffix of `input` shape is the same as
+/// `extract` shape.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
 
-  Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
-  if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
+  Operation *defOp = extractOp.getVector().getDefiningOp();
+  if (!defOp || !isBroadcastLike(defOp))
     return Value();
 
-  Value src = broadcastLikeOp->getOperand(0);
+  Value input = defOp->getOperand(0);
 
   // Replace extract(broadcast(X)) with X
-  if (extractOp.getType() == src.getType())
-    return src;
+  if (extractOp.getType() == input.getType())
+    return input;
 
   // Get required types and ranks in the chain
-  //    src -> broadcastDst -> dst
-  auto srcType = llvm::dyn_cast<VectorType>(src.getType());
-  auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
-  unsigned srcRank = srcType ? srcType.getRank() : 0;
-  unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
-  unsigned dstRank = dstType ? dstType.getRank() : 0;
+  //    input -> broadcast -> extract
+  auto inputType = llvm::dyn_cast<VectorType>(input.getType());
+  auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
+  unsigned inputRank = inputType ? inputType.getRank() : 0;
+  unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
+  unsigned extractRank = extractType ? extractType.getRank() : 0;
 
   // Cannot do without the broadcast if overall the rank increases.
-  if (dstRank > srcRank)
+  if (extractRank > inputRank)
     return Value();
 
-  assert(srcType && "src must be a vector type because of previous checks");
-
-  ArrayRef<int64_t> srcShape = srcType.getShape();
-  if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
+  // Proof by contradiction that, at this point, input is a vector.
+  //     Suppose input is a scalar.
+  // ==> inputRank is 0.
+  // ==> extractRank is 0 (because extractRank <= inputRank).
+  // ==> extract is scalar (because rank-0 extraction is always scalar).
+  // ==> input and extract are scalar, so same type.
+  // ==> returned early (check same type).
+  //     Contradiction!
+  assert(inputType && "input must be a vector type because of previous checks");
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+
+  // In the case where there is a broadcast dimension in the suffix, it is not
+  // possible to replace extract(broadcast(X)) with extract(X). Example:
+  //
+  //     broadcast       extract
+  // (1) --------> (3,4) ------> (4)
+  if (extractType &&
+      extractType.getShape() != inputShape.take_back(extractRank))
     return Value();
 
   // Replace extract(broadcast(X)) with extract(X).
   // First, determine the new extraction position.
-  unsigned deltaOverall = srcRank - dstRank;
-  unsigned deltaBroadcast = broadcastDstRank - srcRank;
-
+  unsigned deltaOverall = inputRank - extractRank;
+  unsigned deltaBroadcast = broadcastRank - inputRank;
   SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
   SmallVector<OpFoldResult> newPositions(deltaOverall);
   IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
-  for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+  for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
     newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
   }
   auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
   extractOp->setOperands(
-      llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
+      llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
   extractOp.setStaticPosition(staticPos);
   return extractOp.getResult();
 }
@@ -2217,12 +2253,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
 
-    Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+    Operation *defOp = extractOp.getVector().getDefiningOp();
     VectorType outType = dyn_cast<VectorType>(extractOp.getType());
-    if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
+    if (!defOp || !isBroadcastLike(defOp) || !outType)
       return failure();
 
-    Value source = broadcastLikeOp->getOperand(0);
+    Value source = defOp->getOperand(0);
     if (isBroadcastableTo(source.getType(), outType) !=
         BroadcastableToResult::Success)
       return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6ed64cb8313c2..6809122974545 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -888,6 +888,20 @@ func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
 
 // -----
 
+// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
+// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
+//  CHECK-NEXT: vector.shape_cast
+//  CHECK-NEXT: vector.extract
+//  CHECK-NEXT: return
+func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+  %idx0 : index, %idx1 : index) -> vector<2xf32> {
+  %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
+  return %r : vector<2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>

>From 302cb34913dc99f668b98742799f172e6292bb80 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 09:34:11 -0700
Subject: [PATCH 3/6] comment improvements

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 35 ++++++++++++------------
 1 file changed, 17 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3ea8d0eb784c1..31dba8781745f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1718,11 +1718,9 @@ static bool isBroadcastLike(Operation *op) {
     return false;
 
   // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
-  // Condition 1: dst has hight rank.
-  // Condition 2: src shape is a suffix of dst shape.
-  //
   // Note that checking that dst shape has a prefix of 1s is not sufficient,
-  // for example (2,3) -> (1,3,2) is not broadcast-like.
+  // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
+  // is that the source shape is a suffix of the destination shape.
   VectorType srcType = shapeCast.getSourceVectorType();
   ArrayRef<int64_t> srcShape = srcType.getShape();
   uint64_t srcRank = srcType.getRank();
@@ -1734,16 +1732,16 @@ static bool isBroadcastLike(Operation *op) {
 ///
 /// Example:
 ///
-///        broadcast           extract
-/// (3, 4) --------> (2, 3, 4) ------> (4)
+///        broadcast             extract [1][2]
+/// (3, 4) --------> (2, 3, 4) ----------------> (4)
 ///
 /// becomes
-///                  extract
-/// (3,4) ---------------------------> (4)
+///                  extract [1]
+/// (3,4) -------------------------------------> (4)
 ///
 ///
-/// The variable names used in this implementation use names which correspond to
-/// the above shapes as,
+/// The variable names used in this implementation correspond to the above
+/// shapes as,
 ///
 /// - (3, 4) is `input` shape.
 /// - (2, 3, 4) is `broadcast` shape.
@@ -1775,14 +1773,15 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   if (extractRank > inputRank)
     return Value();
 
-  // Proof by contradiction that, at this point, input is a vector.
-  //     Suppose input is a scalar.
-  // ==> inputRank is 0.
-  // ==> extractRank is 0 (because extractRank <= inputRank).
-  // ==> extract is scalar (because rank-0 extraction is always scalar).
-  // ==> input and extract are scalar, so same type.
-  // ==> returned early (check same type).
-  //     Contradiction!
+  // The above condition guarantees that input is a vector:
+  //
+  // If input is a scalar:
+  // 1) inputRank is 0, so
+  // 2) extractRank is 0 (because extractRank <= inputRank), so
+  // 3) extract is scalar (because rank-0 extraction is always scalar), s0
+  // 4) input and extract are scalar, so same type.
+  // But then we should have returned earlier when the types were compared for
+  // equivalence. So input is not a scalar at this point.
   assert(inputType && "input must be a vector type because of previous checks");
   ArrayRef<int64_t> inputShape = inputType.getShape();
 

>From 8c85bc7a0959c9cb67819e6251e0edb230ef2c05 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 09:40:07 -0700
Subject: [PATCH 4/6] remove lengthy explanation

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++----------
 1 file changed, 2 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 31dba8781745f..7723665926295 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1718,7 +1718,7 @@ static bool isBroadcastLike(Operation *op) {
     return false;
 
   // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
-  // Note that checking that dst shape has a prefix of 1s is not sufficient,
+  // Checking that the destination shape has a prefix of 1s is not sufficient,
   // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
   // is that the source shape is a suffix of the destination shape.
   VectorType srcType = shapeCast.getSourceVectorType();
@@ -1773,15 +1773,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   if (extractRank > inputRank)
     return Value();
 
-  // The above condition guarantees that input is a vector:
-  //
-  // If input is a scalar:
-  // 1) inputRank is 0, so
-  // 2) extractRank is 0 (because extractRank <= inputRank), so
-  // 3) extract is scalar (because rank-0 extraction is always scalar), s0
-  // 4) input and extract are scalar, so same type.
-  // But then we should have returned earlier when the types were compared for
-  // equivalence. So input is not a scalar at this point.
+  // The above condition guarantees that input is a vector.
   assert(inputType && "input must be a vector type because of previous checks");
   ArrayRef<int64_t> inputShape = inputType.getShape();
 

>From cb306132fa2e1f2ae46978a269358627d64966f3 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 09:42:18 -0700
Subject: [PATCH 5/6] broadcastlike vs broadcast-like

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7723665926295..01eedceafb275 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1719,7 +1719,7 @@ static bool isBroadcastLike(Operation *op) {
 
   // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
   // Checking that the destination shape has a prefix of 1s is not sufficient,
-  // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
+  // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
   // is that the source shape is a suffix of the destination shape.
   VectorType srcType = shapeCast.getSourceVectorType();
   ArrayRef<int64_t> srcShape = srcType.getShape();

>From 15f78d32c09644d28b453e4d1e3dab52fba4ddd9 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 10:12:45 -0700
Subject: [PATCH 6/6] test simplification

---
 mlir/test/Conversion/VectorToSCF/funk.mlir    | 755 ++++++++++++++++++
 .../Conversion/VectorToSCF/vector-to-scf.mlir |   3 +-
 2 files changed, 756 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Conversion/VectorToSCF/funk.mlir

diff --git a/mlir/test/Conversion/VectorToSCF/funk.mlir b/mlir/test/Conversion/VectorToSCF/funk.mlir
new file mode 100644
index 0000000000000..556814cd04792
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSCF/funk.mlir
@@ -0,0 +1,755 @@
+module {
+  func.func @vector_transfer_ops_0d(%arg0: memref<f32>) {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = vector.transfer_read %arg0[], %cst : memref<f32>, vector<f32>
+    vector.transfer_write %0, %arg0[] : vector<f32>, memref<f32>
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+#map1 = affine_map<(d0) -> (d0 + 1)>
+#map2 = affine_map<(d0) -> (d0 + 2)>
+#map3 = affine_map<(d0) -> (d0 + 3)>
+module {
+  func.func @materialize_read_1d() {
+    %c7 = arith.constant 7 : index
+    %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c0 = arith.constant 0 : index
+    %alloc = memref.alloc() : memref<7x42xf32>
+    affine.for %arg0 = 0 to 7 step 4 {
+      affine.for %arg1 = 0 to 42 step 4 {
+        %0 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+          %7 = affine.apply #map(%arg0, %arg2)
+          %8 = affine.apply #map(%arg0, %arg2)
+          %9 = arith.cmpi slt, %8, %c7 : index
+          %10 = scf.if %9 -> (vector<4xf32>) {
+            %11 = memref.load %alloc[%7, %arg1] : memref<7x42xf32>
+            %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+            scf.yield %12 : vector<4xf32>
+          } else {
+            scf.yield %arg3 : vector<4xf32>
+          }
+          scf.yield %10 : vector<4xf32>
+        }
+        %1 = affine.apply #map1(%arg1)
+        %2 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+          %7 = affine.apply #map(%arg0, %arg2)
+          %8 = affine.apply #map(%arg0, %arg2)
+          %9 = arith.cmpi slt, %8, %c7 : index
+          %10 = scf.if %9 -> (vector<4xf32>) {
+            %11 = memref.load %alloc[%7, %1] : memref<7x42xf32>
+            %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+            scf.yield %12 : vector<4xf32>
+          } else {
+            scf.yield %arg3 : vector<4xf32>
+          }
+          scf.yield %10 : vector<4xf32>
+        }
+        %3 = affine.apply #map2(%arg1)
+        %4 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+          %7 = affine.apply #map(%arg0, %arg2)
+          %8 = affine.apply #map(%arg0, %arg2)
+          %9 = arith.cmpi slt, %8, %c7 : index
+          %10 = scf.if %9 -> (vector<4xf32>) {
+            %11 = memref.load %alloc[%7, %3] : memref<7x42xf32>
+            %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+            scf.yield %12 : vector<4xf32>
+          } else {
+            scf.yield %arg3 : vector<4xf32>
+          }
+          scf.yield %10 : vector<4xf32>
+        }
+        %5 = affine.apply #map3(%arg1)
+        %6 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+          %7 = affine.apply #map(%arg0, %arg2)
+          %8 = affine.apply #map(%arg0, %arg2)
+          %9 = arith.cmpi slt, %8, %c7 : index
+          %10 = scf.if %9 -> (vector<4xf32>) {
+            %11 = memref.load %alloc[%7, %5] : memref<7x42xf32>
+            %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+            scf.yield %12 : vector<4xf32>
+          } else {
+            scf.yield %arg3 : vector<4xf32>
+          }
+          scf.yield %10 : vector<4xf32>
+        }
+        "dummy_use"(%0, %2, %4, %6) : (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) -> ()
+      }
+    }
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+#map1 = affine_map<(d0, d1) -> (d0 + d1 + 1)>
+module {
+  func.func @materialize_read_1d_partially_specialized(%arg0: index, %arg1: index, %arg2: index) {
+    %c42 = arith.constant 42 : index
+    %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c0 = arith.constant 0 : index
+    %alloc = memref.alloc(%arg0, %arg1, %arg2) : memref<7x?x?x42x?xf32>
+    affine.for %arg3 = 0 to 7 {
+      affine.for %arg4 = 0 to %arg0 {
+        affine.for %arg5 = 0 to %arg1 {
+          affine.for %arg6 = 0 to 42 step 2 {
+            affine.for %arg7 = 0 to %arg2 {
+              %0 = scf.for %arg8 = %c0 to %c4 step %c1 iter_args(%arg9 = %cst) -> (vector<4xf32>) {
+                %2 = affine.apply #map(%arg6, %arg8)
+                %3 = affine.apply #map(%arg6, %arg8)
+                %4 = arith.cmpi slt, %3, %c42 : index
+                %5 = scf.if %4 -> (vector<4xf32>) {
+                  %6 = memref.load %alloc[%arg3, %arg4, %arg5, %2, %arg7] : memref<7x?x?x42x?xf32>
+                  %7 = vector.insert %6, %arg9 [%arg8] : f32 into vector<4xf32>
+                  scf.yield %7 : vector<4xf32>
+                } else {
+                  scf.yield %arg9 : vector<4xf32>
+                }
+                scf.yield %5 : vector<4xf32>
+              }
+              %1 = scf.for %arg8 = %c0 to %c4 step %c1 iter_args(%arg9 = %cst) -> (vector<4xf32>) {
+                %2 = affine.apply #map1(%arg8, %arg6)
+                %3 = affine.apply #map1(%arg8, %arg6)
+                %4 = arith.cmpi slt, %3, %c42 : index
+                %5 = scf.if %4 -> (vector<4xf32>) {
+                  %6 = memref.load %alloc[%arg3, %arg4, %arg5, %2, %arg7] : memref<7x?x?x42x?xf32>
+                  %7 = vector.insert %6, %arg9 [%arg8] : f32 into vector<4xf32>
+                  scf.yield %7 : vector<4xf32>
+                } else {
+                  scf.yield %arg9 : vector<4xf32>
+                }
+                scf.yield %5 : vector<4xf32>
+              }
+              "dummy_use"(%0, %1) : (vector<4xf32>, vector<4xf32>) -> ()
+            }
+          }
+        }
+      }
+    }
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+module {
+  func.func @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
+    %cst = arith.constant dense<0.000000e+00> : vector<3xf32>
+    %c3 = arith.constant 3 : index
+    %cst_0 = arith.constant dense<0.000000e+00> : vector<4x3xf32>
+    %c4 = arith.constant 4 : index
+    %c1 = arith.constant 1 : index
+    %c5 = arith.constant 5 : index
+    %c0 = arith.constant 0 : index
+    %alloc = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
+    affine.for %arg4 = 0 to %arg0 step 3 {
+      affine.for %arg5 = 0 to %arg1 {
+        affine.for %arg6 = 0 to %arg2 {
+          affine.for %arg7 = 0 to %arg3 step 5 {
+            %alloca = memref.alloca() : memref<vector<5x4x3xf32>>
+            %0 = vector.type_cast %alloca : memref<vector<5x4x3xf32>> to memref<5xvector<4x3xf32>>
+            scf.for %arg8 = %c0 to %c5 step %c1 {
+              %2 = affine.apply #map(%arg7, %arg8)
+              %3 = arith.cmpi sgt, %arg3, %2 : index
+              scf.if %3 {
+                %4 = affine.apply #map(%arg7, %arg8)
+                %5 = vector.type_cast %0 : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
+                scf.for %arg9 = %c0 to %c4 step %c1 {
+                  %6 = scf.for %arg10 = %c0 to %c3 step %c1 iter_args(%arg11 = %cst) -> (vector<3xf32>) {
+                    %7 = affine.apply #map(%arg4, %arg10)
+                    %8 = affine.apply #map(%arg4, %arg10)
+                    %9 = arith.cmpi sgt, %arg0, %8 : index
+                    %10 = scf.if %9 -> (vector<3xf32>) {
+                      %11 = memref.load %alloc[%7, %arg5, %arg6, %4] : memref<?x?x?x?xf32>
+                      %12 = vector.insert %11, %arg11 [%arg10] : f32 into vector<3xf32>
+                      scf.yield %12 : vector<3xf32>
+                    } else {
+                      scf.yield %arg11 : vector<3xf32>
+                    }
+                    scf.yield %10 : vector<3xf32>
+                  }
+                  memref.store %6, %5[%arg8, %arg9] : memref<5x4xvector<3xf32>>
+                }
+              } else {
+                memref.store %cst_0, %0[%arg8] : memref<5xvector<4x3xf32>>
+              }
+            }
+            %1 = memref.load %alloca[] : memref<vector<5x4x3xf32>>
+            "dummy_use"(%1) : (vector<5x4x3xf32>) -> ()
+          }
+        }
+      }
+    }
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+module {
+  func.func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
+    %c4 = arith.constant 4 : index
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant dense<1.000000e+00> : vector<3x4x1x5xf32>
+    %alloc = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
+    affine.for %arg4 = 0 to %arg0 step 3 {
+      affine.for %arg5 = 0 to %arg1 step 4 {
+        affine.for %arg6 = 0 to %arg2 {
+          affine.for %arg7 = 0 to %arg3 step 5 {
+            %alloca = memref.alloca() : memref<vector<3x4x1x5xf32>>
+            memref.store %cst, %alloca[] : memref<vector<3x4x1x5xf32>>
+            %0 = vector.type_cast %alloca : memref<vector<3x4x1x5xf32>> to memref<3xvector<4x1x5xf32>>
+            scf.for %arg8 = %c0 to %c3 step %c1 {
+              %1 = affine.apply #map(%arg4, %arg8)
+              %2 = arith.cmpi sgt, %arg0, %1 : index
+              scf.if %2 {
+                %3 = affine.apply #map(%arg4, %arg8)
+                %4 = vector.type_cast %0 : memref<3xvector<4x1x5xf32>> to memref<3x4xvector<1x5xf32>>
+                scf.for %arg9 = %c0 to %c4 step %c1 {
+                  %5 = affine.apply #map(%arg5, %arg9)
+                  %6 = arith.cmpi sgt, %arg1, %5 : index
+                  scf.if %6 {
+                    %7 = affine.apply #map(%arg5, %arg9)
+                    %8 = vector.type_cast %4 : memref<3x4xvector<1x5xf32>> to memref<3x4x1xvector<5xf32>>
+                    scf.for %arg10 = %c0 to %c1 step %c1 {
+                      %9 = affine.apply #map(%arg6, %arg10)
+                      %10 = memref.load %8[%arg8, %arg9, %arg10] : memref<3x4x1xvector<5xf32>>
+                      vector.transfer_write %10, %alloc[%3, %7, %9, %arg7] : vector<5xf32>, memref<?x?x?x?xf32>
+                    }
+                  } else {
+                  }
+                }
+              } else {
+              }
+            }
+          }
+        }
+      }
+    }
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+  func.func @transfer_read_progressive(%arg0: memref<?x?xf32>, %arg1: index) -> vector<3x15xf32> {
+    %cst = arith.constant dense<7.000000e+00> : vector<15xf32>
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 7.000000e+00 : f32
+    %alloca = memref.alloca() : memref<vector<3x15xf32>>
+    %0 = vector.type_cast %alloca : memref<vector<3x15xf32>> to memref<3xvector<15xf32>>
+    scf.for %arg2 = %c0 to %c3 step %c1 {
+      %dim = memref.dim %arg0, %c0 : memref<?x?xf32>
+      %2 = affine.apply #map(%arg2)[%arg1]
+      %3 = arith.cmpi sgt, %dim, %2 : index
+      scf.if %3 {
+        %4 = affine.apply #map(%arg2)[%arg1]
+        %5 = vector.transfer_read %arg0[%4, %arg1], %cst_0 : memref<?x?xf32>, vector<15xf32>
+        memref.store %5, %0[%arg2] : memref<3xvector<15xf32>>
+      } else {
+        memref.store %cst, %0[%arg2] : memref<3xvector<15xf32>>
+      }
+    }
+    %1 = memref.load %alloca[] : memref<vector<3x15xf32>>
+    return %1 : vector<3x15xf32>
+  }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+  func.func @transfer_write_progressive(%arg0: memref<?x?xf32>, %arg1: index, %arg2: vector<3x15xf32>) {
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<3x15xf32>>
+    memref.store %arg2, %alloca[] : memref<vector<3x15xf32>>
+    %0 = vector.type_cast %alloca : memref<vector<3x15xf32>> to memref<3xvector<15xf32>>
+    scf.for %arg3 = %c0 to %c3 step %c1 {
+      %dim = memref.dim %arg0, %c0 : memref<?x?xf32>
+      %1 = affine.apply #map(%arg3)[%arg1]
+      %2 = arith.cmpi sgt, %dim, %1 : index
+      scf.if %2 {
+        %3 = affine.apply #map(%arg3)[%arg1]
+        %4 = memref.load %0[%arg3] : memref<3xvector<15xf32>>
+        vector.transfer_write %4, %arg0[%3, %arg1] : vector<15xf32>, memref<?x?xf32>
+      } else {
+      }
+    }
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+  func.func @transfer_write_progressive_inbounds(%arg0: memref<?x?xf32>, %arg1: index, %arg2: vector<3x15xf32>) {
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<3x15xf32>>
+    memref.store %arg2, %alloca[] : memref<vector<3x15xf32>>
+    %0 = vector.type_cast %alloca : memref<vector<3x15xf32>> to memref<3xvector<15xf32>>
+    scf.for %arg3 = %c0 to %c3 step %c1 {
+      %1 = affine.apply #map(%arg3)[%arg1]
+      %2 = memref.load %0[%arg3] : memref<3xvector<15xf32>>
+      vector.transfer_write %2, %arg0[%1, %arg1] {in_bounds = [true]} : vector<15xf32>, memref<?x?xf32>
+    }
+    return
+  }
+}
+
+// -----
+module {
+  func.func @transfer_read_simple(%arg0: memref<2x2xf32>) -> vector<2x2xf32> {
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<2x2xf32>>
+    %0 = vector.type_cast %alloca : memref<vector<2x2xf32>> to memref<2xvector<2xf32>>
+    scf.for %arg1 = %c0 to %c2 step %c1 {
+      %2 = vector.transfer_read %arg0[%arg1, %c0], %cst {in_bounds = [true]} : memref<2x2xf32>, vector<2xf32>
+      memref.store %2, %0[%arg1] : memref<2xvector<2xf32>>
+    }
+    %1 = memref.load %alloca[] : memref<vector<2x2xf32>>
+    return %1 : vector<2x2xf32>
+  }
+  func.func @transfer_read_minor_identity(%arg0: memref<?x?x?x?xf32>) -> vector<3x3xf32> {
+    %cst = arith.constant dense<0.000000e+00> : vector<3xf32>
+    %c2 = arith.constant 2 : index
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %cst_0 = arith.constant 0.000000e+00 : f32
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<3x3xf32>>
+    %0 = vector.type_cast %alloca : memref<vector<3x3xf32>> to memref<3xvector<3xf32>>
+    scf.for %arg1 = %c0 to %c3 step %c1 {
+      %dim = memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
+      %2 = arith.cmpi sgt, %dim, %arg1 : index
+      scf.if %2 {
+        %3 = vector.transfer_read %arg0[%c0, %c0, %arg1, %c0], %cst_0 : memref<?x?x?x?xf32>, vector<3xf32>
+        memref.store %3, %0[%arg1] : memref<3xvector<3xf32>>
+      } else {
+        memref.store %cst, %0[%arg1] : memref<3xvector<3xf32>>
+      }
+    }
+    %1 = memref.load %alloca[] : memref<vector<3x3xf32>>
+    return %1 : vector<3x3xf32>
+  }
+  func.func @transfer_write_minor_identity(%arg0: vector<3x3xf32>, %arg1: memref<?x?x?x?xf32>) {
+    %c2 = arith.constant 2 : index
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<3x3xf32>>
+    memref.store %arg0, %alloca[] : memref<vector<3x3xf32>>
+    %0 = vector.type_cast %alloca : memref<vector<3x3xf32>> to memref<3xvector<3xf32>>
+    scf.for %arg2 = %c0 to %c3 step %c1 {
+      %dim = memref.dim %arg1, %c2 : memref<?x?x?x?xf32>
+      %1 = arith.cmpi sgt, %dim, %arg2 : index
+      scf.if %1 {
+        %2 = memref.load %0[%arg2] : memref<3xvector<3xf32>>
+        vector.transfer_write %2, %arg1[%c0, %c0, %arg2, %c0] : vector<3xf32>, memref<?x?x?x?xf32>
+      } else {
+      }
+    }
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1 * 8)>
+module {
+  func.func @transfer_read_strided(%arg0: memref<8x4xf32, #map>) -> vector<4xf32> {
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %c1 = arith.constant 1 : index
+    %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+    %0 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %cst) -> (vector<4xf32>) {
+      %1 = memref.load %arg0[%c0, %arg1] : memref<8x4xf32, #map>
+      %2 = vector.insert %1, %arg2 [%arg1] : f32 into vector<4xf32>
+      scf.yield %2 : vector<4xf32>
+    }
+    return %0 : vector<4xf32>
+  }
+  func.func @transfer_write_strided(%arg0: vector<4xf32>, %arg1: memref<8x4xf32, #map>) {
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %c1 = arith.constant 1 : index
+    scf.for %arg2 = %c0 to %c4 step %c1 {
+      %0 = vector.extract %arg0[%arg2] : f32 from vector<4xf32>
+      memref.store %0, %arg1[%c0, %arg2] : memref<8x4xf32, #map>
+    }
+    return
+  }
+}
+
+// -----
+module {
+  func.func private @fake_side_effecting_fun(vector<2x2xf32>)
+  func.func @transfer_read_within_async_execute(%arg0: memref<2x2xf32>) -> !async.token {
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %token = async.execute {
+      %alloca = memref.alloca() : memref<vector<2x2xf32>>
+      %0 = vector.type_cast %alloca : memref<vector<2x2xf32>> to memref<2xvector<2xf32>>
+      scf.for %arg1 = %c0 to %c2 step %c1 {
+        %2 = vector.transfer_read %arg0[%arg1, %c0], %cst {in_bounds = [true]} : memref<2x2xf32>, vector<2xf32>
+        memref.store %2, %0[%arg1] : memref<2xvector<2xf32>>
+      }
+      %1 = memref.load %alloca[] : memref<vector<2x2xf32>>
+      func.call @fake_side_effecting_fun(%1) : (vector<2x2xf32>) -> ()
+      async.yield
+    }
+    return %token : !async.token
+  }
+}
+
+// -----
+module {
+  func.func @transfer_read_with_tensor(%arg0: tensor<f32>) -> vector<1xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = vector.transfer_read %arg0[], %cst : tensor<f32>, vector<f32>
+    %1 = vector.broadcast %0 : vector<f32> to vector<1xf32>
+    return %1 : vector<1xf32>
+  }
+}
+
+// -----
+module {
+  func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>, %arg1: f32) {
+    %c1 = arith.constant 1 : index
+    %c16 = arith.constant 16 : index
+    %0 = llvm.mlir.undef : vector<[16]xf32>
+    %1 = llvm.mlir.undef : vector<[16]xi32>
+    %2 = llvm.mlir.constant(0 : i32) : i32
+    %c0 = arith.constant 0 : index
+    %dim = memref.dim %arg0, %c0 : memref<?xf32, strided<[?], offset: ?>>
+    %3 = llvm.intr.stepvector : vector<[16]xi32>
+    %4 = arith.index_cast %dim : index to i32
+    %5 = llvm.insertelement %4, %1[%2 : i32] : vector<[16]xi32>
+    %6 = llvm.shufflevector %5, %1 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<[16]xi32> 
+    %7 = arith.cmpi slt, %3, %6 : vector<[16]xi32>
+    %8 = llvm.insertelement %arg1, %0[%2 : i32] : vector<[16]xf32>
+    %9 = llvm.shufflevector %8, %0 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<[16]xf32> 
+    %vscale = vector.vscale
+    %c16_vscale = arith.muli %vscale, %c16 : index
+    scf.for %arg2 = %c0 to %c16_vscale step %c1 {
+      %10 = vector.extract %7[%arg2] : i1 from vector<[16]xi1>
+      scf.if %10 {
+        %11 = vector.extract %9[%arg2] : f32 from vector<[16]xf32>
+        memref.store %11, %arg0[%arg2] : memref<?xf32, strided<[?], offset: ?>>
+      } else {
+      }
+    }
+    return
+  }
+}
+
+// -----
+module {
+  func.func @vector_print_vector_0d(%arg0: vector<f32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.print punctuation <open>
+    scf.for %arg1 = %c0 to %c1 step %c1 {
+      %0 = vector.extract %arg0[] : f32 from vector<f32>
+      vector.print %0 : f32 punctuation <no_punctuation>
+      %1 = arith.cmpi ult, %arg1, %c0 : index
+      scf.if %1 {
+        vector.print punctuation <comma>
+      }
+    }
+    vector.print punctuation <close>
+    vector.print
+    return
+  }
+}
+
+// -----
+module {
+  func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c0 = arith.constant 0 : index
+    %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+    vector.print punctuation <open>
+    scf.for %arg1 = %c0 to %c2 step %c1 {
+      vector.print punctuation <open>
+      scf.for %arg2 = %c0 to %c2 step %c1 {
+        %2 = arith.muli %arg1, %c2 : index
+        %3 = arith.addi %arg2, %2 : index
+        %4 = vector.extract %0[%3] : f32 from vector<4xf32>
+        vector.print %4 : f32 punctuation <no_punctuation>
+        %5 = arith.cmpi ult, %arg2, %c1 : index
+        scf.if %5 {
+          vector.print punctuation <comma>
+        }
+      }
+      vector.print punctuation <close>
+      %1 = arith.cmpi ult, %arg1, %c1 : index
+      scf.if %1 {
+        vector.print punctuation <comma>
+      }
+    }
+    vector.print punctuation <close>
+    vector.print
+    return
+  }
+}
+
+// -----
+module {
+  func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %c1 = arith.constant 1 : index
+    %vscale = vector.vscale
+    %c4_vscale = arith.muli %vscale, %c4 : index
+    %0 = arith.subi %c4_vscale, %c1 : index
+    vector.print punctuation <open>
+    scf.for %arg1 = %c0 to %c4_vscale step %c1 {
+      %1 = vector.extract %arg0[%arg1] : i32 from vector<[4]xi32>
+      vector.print %1 : i32 punctuation <no_punctuation>
+      %2 = arith.cmpi ult, %arg1, %0 : index
+      scf.if %2 {
+        vector.print punctuation <comma>
+      }
+    }
+    vector.print punctuation <close>
+    vector.print
+    return
+  }
+}
+
+// -----
+module {
+  func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[4]xf32> {
+    %c3 = arith.constant 3 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<3x[4]xf32>>
+    %alloca_0 = memref.alloca() : memref<vector<3x[4]xi1>>
+    %dim = memref.dim %arg0, %c1 : memref<3x?xf32>
+    %0 = vector.create_mask %c1, %dim : vector<3x[4]xi1>
+    memref.store %0, %alloca_0[] : memref<vector<3x[4]xi1>>
+    %1 = vector.type_cast %alloca : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
+    %2 = vector.type_cast %alloca_0 : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
+    scf.for %arg1 = %c0 to %c3 step %c1 {
+      %4 = memref.load %2[%arg1] : memref<3xvector<[4]xi1>>
+      %5 = vector.transfer_read %arg0[%arg1, %c0], %cst, %4 {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
+      memref.store %5, %1[%arg1] : memref<3xvector<[4]xf32>>
+    }
+    %3 = memref.load %alloca[] : memref<vector<3x[4]xf32>>
+    return %3 : vector<3x[4]xf32>
+  }
+}
+
+// -----
+module {
+  func.func @transfer_write_array_of_scalable(%arg0: vector<3x[4]xf32>, %arg1: memref<3x?xf32>) {
+    %c3 = arith.constant 3 : index
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<3x[4]xf32>>
+    %alloca_0 = memref.alloca() : memref<vector<3x[4]xi1>>
+    %dim = memref.dim %arg1, %c1 : memref<3x?xf32>
+    %0 = vector.create_mask %c1, %dim : vector<3x[4]xi1>
+    memref.store %0, %alloca_0[] : memref<vector<3x[4]xi1>>
+    memref.store %arg0, %alloca[] : memref<vector<3x[4]xf32>>
+    %1 = vector.type_cast %alloca : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
+    %2 = vector.type_cast %alloca_0 : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
+    scf.for %arg2 = %c0 to %c3 step %c1 {
+      %3 = memref.load %1[%arg2] : memref<3xvector<[4]xf32>>
+      %4 = memref.load %2[%arg2] : memref<3xvector<[4]xi1>>
+      vector.transfer_write %3, %arg1[%arg2, %c0], %4 {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
+    }
+    return
+  }
+}
+
+// -----
+module {
+  func.func @cannot_lower_transfer_write_with_leading_scalable(%arg0: vector<[4]x4xf32>, %arg1: memref<?x4xf32>) {
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %dim = memref.dim %arg1, %c0 : memref<?x4xf32>
+    %0 = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
+    vector.transfer_write %arg0, %arg1[%c0, %c0], %0 {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
+    return
+  }
+}
+
+// -----
+module {
+  func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf32>) -> vector<[4]x4xf32> {
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dim = memref.dim %arg0, %c0 : memref<?x4xf32>
+    %0 = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
+    %1 = vector.transfer_read %arg0[%c0, %c0], %cst, %0 {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
+    return %1 : vector<[4]x4xf32>
+  }
+  func.func @does_not_crash_on_unpack_one_dim(%arg0: memref<1x1x1x1xi32>, %arg1: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+    %c1 = arith.constant 1 : index
+    %c0_i32 = arith.constant 0 : i32
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<1x1x1x1xi32>>
+    %alloca_0 = memref.alloca() : memref<vector<1x1xi1>>
+    memref.store %arg1, %alloca_0[] : memref<vector<1x1xi1>>
+    %0 = vector.type_cast %alloca : memref<vector<1x1x1x1xi32>> to memref<1xvector<1x1x1xi32>>
+    %1 = vector.type_cast %alloca_0 : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+    scf.for %arg2 = %c0 to %c1 step %c1 {
+      %3 = vector.type_cast %0 : memref<1xvector<1x1x1xi32>> to memref<1x1xvector<1x1xi32>>
+      scf.for %arg3 = %c0 to %c1 step %c1 {
+        %4 = vector.type_cast %3 : memref<1x1xvector<1x1xi32>> to memref<1x1x1xvector<1xi32>>
+        scf.for %arg4 = %c0 to %c1 step %c1 {
+          %5 = memref.load %1[%arg2] : memref<1xvector<1xi1>>
+          %6 = vector.transfer_read %arg0[%arg2, %c0, %c0, %c0], %c0_i32, %5 {in_bounds = [true]} : memref<1x1x1x1xi32>, vector<1xi32>
+          memref.store %6, %4[%arg2, %arg3, %arg4] : memref<1x1x1xvector<1xi32>>
+        }
+      }
+    }
+    %2 = memref.load %alloca[] : memref<vector<1x1x1x1xi32>>
+    return %2 : vector<1x1x1x1xi32>
+  }
+  func.func @add_arrays_of_scalable_vectors(%arg0: memref<1x2x?xf32>, %arg1: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+    %c1 = arith.constant 1 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %c2 = arith.constant 2 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<1x2x[4]xf32>>
+    %alloca_0 = memref.alloca() : memref<vector<1x2x[4]xi1>>
+    %dim = memref.dim %arg0, %c2 : memref<1x2x?xf32>
+    %0 = vector.create_mask %c2, %c2, %dim : vector<1x2x[4]xi1>
+    memref.store %0, %alloca_0[] : memref<vector<1x2x[4]xi1>>
+    %1 = vector.type_cast %alloca : memref<vector<1x2x[4]xf32>> to memref<1xvector<2x[4]xf32>>
+    %2 = vector.type_cast %alloca_0 : memref<vector<1x2x[4]xi1>> to memref<1xvector<2x[4]xi1>>
+    scf.for %arg2 = %c0 to %c1 step %c1 {
+      %4 = vector.type_cast %1 : memref<1xvector<2x[4]xf32>> to memref<1x2xvector<[4]xf32>>
+      %5 = vector.type_cast %2 : memref<1xvector<2x[4]xi1>> to memref<1x2xvector<[4]xi1>>
+      scf.for %arg3 = %c0 to %c2 step %c1 {
+        %6 = memref.load %5[%arg2, %arg3] : memref<1x2xvector<[4]xi1>>
+        %7 = vector.transfer_read %arg0[%arg2, %arg3, %c0], %cst, %6 {in_bounds = [true]} : memref<1x2x?xf32>, vector<[4]xf32>
+        memref.store %7, %4[%arg2, %arg3] : memref<1x2xvector<[4]xf32>>
+      }
+    }
+    %3 = memref.load %alloca[] : memref<vector<1x2x[4]xf32>>
+    return %3 : vector<1x2x[4]xf32>
+  }
+  func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%arg0: vector<[4]x[4]xf32>, %arg1: memref<?x?xf32>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %arg0, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+    return
+  }
+  func.func @unroll_transfer_write_target_rank_zero(%arg0: vector<2xi32>) {
+    %c0 = arith.constant 0 : index
+    %alloc = memref.alloc() : memref<4xi32>
+    vector.transfer_write %arg0, %alloc[%c0] {in_bounds = [true]} : vector<2xi32>, memref<4xi32>
+    return
+  }
+}
+
+// -----
+module {
+  func.func @scalable_transpose_store_unmasked(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+    %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+    vector.transfer_write %0, %arg1[%arg2, %arg3] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+    return
+  }
+}
+
+// -----
+module {
+  func.func @scalable_transpose_store_dynamic_mask(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
+    %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+    %1 = vector.create_mask %arg4, %arg5 : vector<[4]x4xi1>
+    vector.transfer_write %0, %arg1[%arg2, %arg3], %1 {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+    return
+  }
+}
+
+// -----
+module {
+  func.func @scalable_transpose_store_constant_mask(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+    %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+    %1 = vector.constant_mask [4, 3] : vector<[4]x4xi1>
+    vector.transfer_write %0, %arg1[%arg2, %arg3], %1 {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+  func.func @negative_scalable_transpose_store_0(%arg0: vector<[4]x4xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<4x[4]xf32>>
+    %0 = vector.transpose %arg0, [1, 0] : vector<[4]x4xf32> to vector<4x[4]xf32>
+    memref.store %0, %alloca[] : memref<vector<4x[4]xf32>>
+    %1 = vector.type_cast %alloca : memref<vector<4x[4]xf32>> to memref<4xvector<[4]xf32>>
+    scf.for %arg4 = %c0 to %c4 step %c1 {
+      %2 = affine.apply #map(%arg4)[%arg2]
+      %3 = memref.load %1[%arg4] : memref<4xvector<[4]xf32>>
+      vector.transfer_write %3, %arg1[%2, %arg3] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+    }
+    return
+  }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+  func.func @negative_scalable_transpose_store_1(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c0 = arith.constant 0 : index
+    %alloca = memref.alloca() : memref<vector<4x[4]xf32>>
+    %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+    %1 = vector.transpose %0, [1, 0] : vector<[4]x4xf32> to vector<4x[4]xf32>
+    memref.store %1, %alloca[] : memref<vector<4x[4]xf32>>
+    %2 = vector.type_cast %alloca : memref<vector<4x[4]xf32>> to memref<4xvector<[4]xf32>>
+    scf.for %arg4 = %c0 to %c4 step %c1 {
+      %3 = affine.apply #map(%arg4)[%arg2]
+      %4 = memref.load %2[%arg4] : memref<4xvector<[4]xf32>>
+      vector.transfer_write %4, %arg1[%3, %arg3] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+    }
+    return
+  }
+}
+
+// -----
+module {
+  func.func @negative_scalable_transpose_store_2(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+    %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+    vector.transfer_write %0, %arg1[%arg2, %arg3] {in_bounds = [false, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+    return
+  }
+}
+
+// -----
+module {
+  func.func @negative_scalable_transpose_store_3(%arg0: vector<[4]x4xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+    vector.transfer_write %arg0, %arg1[%arg2, %arg3] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+    return
+  }
+}
+
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 33177736eb5fe..1ed82954398f0 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -558,10 +558,9 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
 // CHECK-SAME:                                      %[[VEC:.*]]: vector<f32>) {
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK:             %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
+// CHECK:             %[[EL:.*]] = vector.extract %[[VEC]][] : f32 from vector<f32>
 // CHECK:             vector.print %[[EL]] : f32 punctuation <no_punctuation>
 // CHECK:             %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
 // CHECK:             scf.if %[[IS_NOT_LAST]] {



More information about the Mlir-commits mailing list