[Mlir-commits] [mlir] c1c4c8e - Revert "[mlir][vector] Migrate drop-lead-unit-dim to shape_cast #196206" (#199546)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 25 09:36:41 PDT 2026


Author: Omair Javaid
Date: 2026-05-25T21:36:36+05:00
New Revision: c1c4c8e23d099c199ea90b050742c3d6c5efcfaf

URL: https://github.com/llvm/llvm-project/commit/c1c4c8e23d099c199ea90b050742c3d6c5efcfaf
DIFF: https://github.com/llvm/llvm-project/commit/c1c4c8e23d099c199ea90b050742c3d6c5efcfaf.diff

LOG: Revert "[mlir][vector] Migrate drop-lead-unit-dim to shape_cast #196206" (#199546)

This reverts commit 24b8bb18f3417419cbd16fcd31f4e2842df952a1 from
#196206

This broke AArch64 SVE Linux buildbots, however it was not reported due
a glitch in the buildbot infrastructure. Following bots are failing:

https://lab.llvm.org/buildbot/#/builders/121
https://lab.llvm.org/buildbot/#/builders/41
https://lab.llvm.org/buildbot/#/builders/4
https://lab.llvm.org/buildbot/#/builders/199
https://lab.llvm.org/buildbot/#/builders/17
https://lab.llvm.org/buildbot/#/builders/198
https://lab.llvm.org/buildbot/#/builders/143

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
    mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index aad42039300e3..26a702ef0f512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -7,9 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <numeric>
-#include <utility>
 
-#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -17,7 +15,6 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "llvm/ADT/Repeated.h"
 #include "llvm/ADT/STLExtras.h"
 
 #define DEBUG_TYPE "vector-drop-unit-dim"
@@ -25,9 +22,9 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-// Trims leading unit dimensions from `oldType` and returns the result type.
-static VectorType trimLeadingUnitDims(VectorType oldType,
-                                      bool zeroDimsAllowed) {
+// Trims leading one dimensions from `oldType` and returns the result type.
+// Returns `vector<1xT>` if `oldType` only has one element.
+static VectorType trimLeadingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
   ArrayRef<int64_t> newShape = oldShape;
 
@@ -40,117 +37,22 @@ static VectorType trimLeadingUnitDims(VectorType oldType,
     newScalableDims = newScalableDims.drop_front(1);
   }
 
-  // Some vector ops forbid 0-D vectors.
-  if (!zeroDimsAllowed && newShape.empty()) {
+  // Make sure we have at least 1 dimension per vector type requirements.
+  if (newShape.empty()) {
     newShape = oldShape.take_back();
     newScalableDims = oldType.getScalableDims().take_back();
   }
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-static bool isNonScalableUnitDim(VectorType type, int64_t dim) {
-  assert(dim >= 0 && dim < type.getRank() &&
-         "expected a valid vector dimension");
-  return type.getShape()[dim] == 1 && !type.getScalableDims()[dim];
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+  return SmallVector<int64_t>(rank, 0);
 }
-
-/// Returns true if the first `k` dimensions of `type` are non-scalable unit
-/// dimensions.
-static bool areLeadingDimsUnit(VectorType type, int64_t k) {
-  assert(k >= 0 && k <= type.getRank() &&
-         "expected a valid leading dimension count");
-  return llvm::all_of(llvm::seq<int64_t>(0, k), [&](int64_t dim) {
-    return isNonScalableUnitDim(type, dim);
-  });
-}
-
-static bool areLeadingDimsUnitAfterPermutation(VectorType type,
-                                               ArrayRef<int64_t> permutation,
-                                               int64_t k) {
-  assert(k >= 0 && k <= static_cast<int64_t>(permutation.size()) &&
-         "expected a valid leading dimension count");
-  return llvm::all_of(permutation.take_front(k), [&](int64_t dim) {
-    return isNonScalableUnitDim(type, dim);
-  });
-}
-
-/// Shape-casts `operand` to the vector type obtained by dropping dimension
-/// `dim`, which must be non-scalable and unit-sized.
-static Value dropUnitDim(OpBuilder &b, Location loc, Value operand,
-                         int64_t dimToDrop, bool zeroDimsAllowed) {
-  auto oldType = cast<VectorType>(operand.getType());
-  assert(isNonScalableUnitDim(oldType, dimToDrop) &&
-         "expected a non-scalable unit dim to drop");
-  int64_t rank = oldType.getRank();
-  assert((zeroDimsAllowed || rank > 1) &&
-         "target op does not allow 0-D vectors");
-
-  SmallVector<int64_t> newShape;
-  SmallVector<bool> newScalableDims;
-  newShape.reserve(rank - 1);
-  newScalableDims.reserve(rank - 1);
-  for (auto [i, size, scalable] :
-       llvm::enumerate(oldType.getShape(), oldType.getScalableDims())) {
-    if (static_cast<int64_t>(i) == dimToDrop)
-      continue;
-    newShape.push_back(size);
-    newScalableDims.push_back(scalable);
-  }
-
-  return b.createOrFold<vector::ShapeCastOp>(
-      loc, VectorType::get(newShape, oldType.getElementType(), newScalableDims),
-      operand);
-}
-
-/// Shape-casts `operand` to the vector type obtained by dropping the first
-/// `k` non-scalable unit dimensions.
-static Value dropLeadingUnitDims(OpBuilder &b, Location loc, Value operand,
-                                 int64_t k, bool zeroDimsAllowed) {
-  auto oldType = cast<VectorType>(operand.getType());
-  assert(areLeadingDimsUnit(oldType, k) &&
-         "expected non-scalable leading unit dims to drop");
-  assert((zeroDimsAllowed || k < oldType.getRank()) &&
-         "target op does not allow 0-D vectors");
-  VectorType newType = VectorType::get(oldType.getShape().drop_front(k),
-                                       oldType.getElementType(),
-                                       oldType.getScalableDims().drop_front(k));
-  return b.createOrFold<vector::ShapeCastOp>(loc, newType, operand);
-}
-
-/// Returns the vector type obtained by applying `permutation` to `type`.
-static VectorType permuteVectorType(VectorType type,
-                                    ArrayRef<int64_t> permutation) {
-  assert(static_cast<int64_t>(permutation.size()) == type.getRank() &&
-         "expected a permutation matching the operand rank");
-  SmallVector<int64_t> permutedShape =
-      applyPermutation(type.getShape(), permutation);
-  SmallVector<bool> permutedScalableDims =
-      applyPermutation(type.getScalableDims(), permutation);
-  return VectorType::get(permutedShape, type.getElementType(),
-                         permutedScalableDims);
-}
-
-/// Like `dropLeadingUnitDims` except that if all dimensions would be dropped,
-/// the single element inside that vector is extracted and returned.
-static Value dropLeadingUnitDims0DIsScalar(OpBuilder &b, Location loc,
-                                           Value operand, int64_t k) {
-  auto oldType = cast<VectorType>(operand.getType());
-  assert(areLeadingDimsUnit(oldType, k) &&
-         "expected non-scalable leading unit dims to drop");
-
-  if (k == oldType.getRank()) {
-    SmallVector<int64_t> zeros(k, static_cast<int64_t>(0));
-    return vector::ExtractOp::create(b, loc, operand, zeros);
-  }
-
-  return dropLeadingUnitDims(b, loc, operand, k,
-                             /*zeroDimsAllowed=*/true);
-}
-
 namespace {
 
 // Casts away leading one dimensions in vector.extract_strided_slice's vector
-// input by inserting vector.shape_cast.
+// input by inserting vector.broadcast.
 struct CastAwayExtractStridedSliceLeadingOneDim
     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
   using Base::Base;
@@ -161,8 +63,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     // the same rank. Here we drop leading one dimensions from the input vector
     // type to make sure we don't cause mismatch.
     VectorType oldSrcType = extractOp.getSourceVectorType();
-    VectorType newSrcType =
-        trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false);
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
 
     if (newSrcType.getRank() == oldSrcType.getRank())
       return failure();
@@ -177,8 +78,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
 
     Location loc = extractOp.getLoc();
 
-    Value newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newSrcType, extractOp.getSource());
+    Value newSrcVector = vector::ExtractOp::create(
+        rewriter, loc, extractOp.getSource(), splatZero(dropCount));
 
     // The offsets/sizes/strides attribute can have a less number of elements
     // than the input vector's rank: it is meant for the leading dimensions.
@@ -193,7 +94,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
         rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
         newStrides);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
                                                      newExtractOp);
 
     return success();
@@ -201,7 +102,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
 };
 
 // Casts away leading one dimensions in vector.insert_strided_slice's vector
-// inputs by inserting vector.shape_cast.
+// inputs by inserting vector.broadcast.
 struct CastAwayInsertStridedSliceLeadingOneDim
     : public OpRewritePattern<vector::InsertStridedSliceOp> {
   using Base::Base;
@@ -209,11 +110,9 @@ struct CastAwayInsertStridedSliceLeadingOneDim
   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
                                 PatternRewriter &rewriter) const override {
     VectorType oldSrcType = insertOp.getSourceVectorType();
-    VectorType newSrcType =
-        trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false);
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
     VectorType oldDstType = insertOp.getDestVectorType();
-    VectorType newDstType =
-        trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/false);
+    VectorType newDstType = trimLeadingOneDims(oldDstType);
 
     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
@@ -223,10 +122,10 @@ struct CastAwayInsertStridedSliceLeadingOneDim
     // Trim leading one dimensions from both operands.
     Location loc = insertOp.getLoc();
 
-    Value newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newSrcType, insertOp.getValueToStore());
-    Value newDstVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newDstType, insertOp.getDest());
+    Value newSrcVector = vector::ExtractOp::create(
+        rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+    Value newDstVector = vector::ExtractOp::create(
+        rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
 
     auto newOffsets = rewriter.getArrayAttr(
         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
@@ -237,7 +136,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
         rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
         newStrides);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);
 
     return success();
@@ -245,7 +144,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
 };
 
 // Casts away leading one dimensions in vector.insert's vector inputs by
-// inserting vector.shape_cast.
+// inserting vector.broadcast.
 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
   using Base::Base;
 
@@ -255,14 +154,13 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     Type newSrcType = oldSrcType;
     int64_t oldSrcRank = 0, newSrcRank = 0;
     if (auto type = dyn_cast<VectorType>(oldSrcType)) {
-      newSrcType = trimLeadingUnitDims(type, /*zeroDimsAllowed=*/false);
+      newSrcType = trimLeadingOneDims(type);
       oldSrcRank = type.getRank();
       newSrcRank = cast<VectorType>(newSrcType).getRank();
     }
 
     VectorType oldDstType = insertOp.getDestVectorType();
-    VectorType newDstType =
-        trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/oldSrcRank == 0);
+    VectorType newDstType = trimLeadingOneDims(oldDstType);
 
     int64_t srcDropCount = oldSrcRank - newSrcRank;
     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
@@ -273,11 +171,12 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     Location loc = insertOp.getLoc();
 
     Value newSrcVector = insertOp.getValueToStore();
-    if (oldSrcRank != 0)
-      newSrcVector = rewriter.createOrFold<vector::ShapeCastOp>(
-          loc, cast<VectorType>(newSrcType), insertOp.getValueToStore());
-    Value newDstVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, newDstType, insertOp.getDest());
+    if (oldSrcRank != 0) {
+      newSrcVector = vector::ExtractOp::create(
+          rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+    }
+    Value newDstVector = vector::ExtractOp::create(
+        rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
 
     // New position rank needs to be computed in two steps: (1) if destination
     // type has leading unit dims, we also trim the position array accordingly,
@@ -294,7 +193,7 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
                                                 newDstVector, newPosition);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);
 
     return success();
@@ -302,10 +201,20 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
 };
 
 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
-                                  VectorType newType, AffineMap newMap) {
+                                  VectorType newType, AffineMap newMap,
+                                  VectorType oldMaskType) {
   // Infer the type of the new mask from the new map.
   VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
-  return b.createOrFold<vector::ShapeCastOp>(loc, newMaskType, mask);
+
+  // If the new mask is broadcastable to the old result type, we can safely
+  // use a `vector.extract` to get the new mask. Otherwise the best we can
+  // do is shape cast.
+  if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
+      BroadcastableToResult::Success) {
+    int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
+    return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim));
+  }
+  return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
 }
 
 // Turns vector.transfer_read on vector with leading 1 dimensions into
@@ -320,7 +229,7 @@ struct CastAwayTransferReadLeadingOneDim
     // TODO(#78787): Not supported masked op yet.
     if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
       return failure();
-    // Nothing to trim when the transfer itself has rank zero.
+    // TODO: support 0-d corner case.
     if (read.getTransferRank() == 0)
       return failure();
 
@@ -329,7 +238,7 @@ struct CastAwayTransferReadLeadingOneDim
       return failure();
 
     VectorType oldType = read.getVectorType();
-    VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true);
+    VectorType newType = trimLeadingOneDims(oldType);
 
     if (newType == oldType)
       return failure();
@@ -347,14 +256,16 @@ struct CastAwayTransferReadLeadingOneDim
           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
 
     Value mask = Value();
-    if (read.getMask())
+    if (read.getMask()) {
+      VectorType maskType = read.getMaskType();
       mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
-                                  newType, newMap);
+                                  newType, newMap, maskType);
+    }
 
     auto newRead = vector::TransferReadOp::create(
         rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
         AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
 
     return success();
   }
@@ -372,7 +283,7 @@ struct CastAwayTransferWriteLeadingOneDim
     // TODO(#78787): Not supported masked op yet.
     if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
       return failure();
-    // Nothing to trim when the transfer itself has rank zero.
+    // TODO: support 0-d corner case.
     if (write.getTransferRank() == 0)
       return failure();
 
@@ -381,9 +292,11 @@ struct CastAwayTransferWriteLeadingOneDim
       return failure();
 
     VectorType oldType = write.getVectorType();
-    VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true);
+    VectorType newType = trimLeadingOneDims(oldType);
     if (newType == oldType)
       return failure();
+    int64_t dropDim = oldType.getRank() - newType.getRank();
+
     AffineMap oldMap = write.getPermutationMap();
     ArrayRef<AffineExpr> newResults =
         oldMap.getResults().take_back(newType.getRank());
@@ -396,12 +309,13 @@ struct CastAwayTransferWriteLeadingOneDim
       inBoundsAttr = rewriter.getArrayAttr(
           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
 
-    auto newVector = rewriter.createOrFold<vector::ShapeCastOp>(
-        write.getLoc(), newType, write.getVector());
+    auto newVector = vector::ExtractOp::create(
+        rewriter, write.getLoc(), write.getVector(), splatZero(dropDim));
 
     if (write.getMask()) {
-      Value newMask = dropUnitDimsFromMask(rewriter, write.getLoc(),
-                                           write.getMask(), newType, newMap);
+      VectorType maskType = write.getMaskType();
+      Value newMask = dropUnitDimsFromMask(
+          rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
           write, newVector, write.getBase(), write.getIndices(),
           AffineMapAttr::get(newMap), newMask, inBoundsAttr);
@@ -417,15 +331,6 @@ struct CastAwayTransferWriteLeadingOneDim
 
 } // namespace
 
-namespace {
-struct VectorContractOperandCastPlan {
-  AffineMap map;
-  SmallVector<int64_t> permutation;
-  bool dropLeadingUnitDim = false;
-  bool permuteOperand = false;
-};
-} // namespace
-
 FailureOr<Value>
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                MaskingOpInterface maskingOp,
@@ -435,7 +340,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
     return failure();
   if (oldAccType.getRank() < 1)
     return failure();
-  if (!isNonScalableUnitDim(oldAccType, 0))
+  if (oldAccType.getShape()[0] != 1)
     return failure();
   // currently we support only dropping one dim but the pattern can be applied
   // greedily to drop more.
@@ -462,70 +367,74 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
 
   SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
                                  contractOp.getAcc()};
-  SmallVector<VectorContractOperandCastPlan> operandCastPlans;
   SmallVector<Value> newOperands;
   auto loc = contractOp.getLoc();
 
-  if (maskingOp) {
-    auto oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
-    if (oldMaskType.getRank() <= 1 || dimToDrop >= oldMaskType.getRank() ||
-        !isNonScalableUnitDim(oldMaskType, dimToDrop))
-      return failure();
-  }
-
   for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
     // Check if the dim to be dropped exists as a leading dim in the operand
-    // if it does then we use vector.shape_cast to drop it.
-    VectorContractOperandCastPlan plan;
+    // if it does then we use vector.extract to drop it.
+    bool validExtract = false;
     SmallVector<AffineExpr> results;
-    plan.map = it.value();
-    int64_t originalZeroDim = plan.map.getDimPosition(0);
-    if (originalZeroDim != dimToDrop) {
+    auto map = it.value();
+    int64_t orginalZeroDim = it.value().getDimPosition(0);
+    if (orginalZeroDim != dimToDrop) {
       // There are two reasons to be in this path, 1. We need to
-      // permute the operand type to make the dim to be dropped
+      // transpose the operand to make the dim to be dropped
       // leading. 2. The dim to be dropped does not exist and in
-      // that case we dont want to add a unit permutation but we must
+      // that case we dont want to add a unit transpose but we must
       // check all the indices to make sure this is the case.
-      SmallVector<AffineExpr> permutedResults;
+      bool transposeNeeded = false;
+      SmallVector<int64_t> perm;
+      SmallVector<AffineExpr> transposeResults;
 
-      for (int64_t i = 0, e = plan.map.getNumResults(); i < e; ++i) {
-        int64_t currDim = plan.map.getDimPosition(i);
+      for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+        int64_t currDim = map.getDimPosition(i);
         if (currDim == dimToDrop) {
-          plan.permuteOperand = true;
-          plan.permutation.insert(plan.permutation.begin(), i);
+          transposeNeeded = true;
+          perm.insert(perm.begin(), i);
           auto targetExpr = rewriter.getAffineDimExpr(currDim);
-          permutedResults.insert(permutedResults.begin(), targetExpr);
+          transposeResults.insert(transposeResults.begin(), targetExpr);
         } else {
-          plan.permutation.push_back(i);
+          perm.push_back(i);
           auto targetExpr = rewriter.getAffineDimExpr(currDim);
-          permutedResults.push_back(targetExpr);
+          transposeResults.push_back(targetExpr);
         }
       }
 
-      // Update the map now so that the later shape_cast drops the correct dim.
-      if (plan.permuteOperand) {
-        plan.map = AffineMap::get(plan.map.getNumDims(), 0, permutedResults,
-                                  contractOp.getContext());
-        if (plan.map.getDimPosition(0) == dimToDrop) {
-          auto operandType = cast<VectorType>(operands[it.index()].getType());
-          if (!areLeadingDimsUnitAfterPermutation(operandType, plan.permutation,
-                                                  dropDim))
-            return failure();
+      // Checks if only the outer, unit dimensions (of size 1) are permuted.
+      // Such transposes do not materially effect the underlying vector and can
+      // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
+      bool transposeNonOuterUnitDims = false;
+      auto operandShape = cast<ShapedType>(operands[it.index()].getType());
+      for (auto [index, dim] :
+           llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+        if (dim != static_cast<int64_t>(index) &&
+            operandShape.getDimSize(index) != 1) {
+          transposeNonOuterUnitDims = true;
+          break;
+        }
+      }
+
+      // Do the transpose now if needed so that we can drop the
+      // correct dim using extract later.
+      if (transposeNeeded) {
+        map = AffineMap::get(map.getNumDims(), 0, transposeResults,
+                             contractOp.getContext());
+        if (transposeNonOuterUnitDims) {
+          operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+              loc, operands[it.index()], perm);
         }
       }
     }
     // We have taken care to have the dim to be dropped be
     // the leading dim. If its still not leading that means it
-    // does not exist in this operand and hence we do not need a shape_cast.
-    if (plan.map.getDimPosition(0) == dimToDrop)
-      plan.dropLeadingUnitDim = true;
-    if (plan.dropLeadingUnitDim && originalZeroDim == dimToDrop &&
-        !areLeadingDimsUnit(cast<VectorType>(operands[it.index()].getType()),
-                            dropDim))
-      return failure();
+    // does not exist in this operand and hence we do not need
+    // an extract.
+    if (map.getDimPosition(0) == dimToDrop)
+      validExtract = true;
 
-    for (int64_t i = 0, e = plan.map.getNumResults(); i < e; ++i) {
-      int64_t currDim = plan.map.getDimPosition(i);
+    for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+      int64_t currDim = map.getDimPosition(i);
       if (currDim == dimToDrop)
         // This is the dim we are dropping.
         continue;
@@ -533,23 +442,15 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
           currDim < dimToDrop ? currDim : currDim - 1);
       results.push_back(targetExpr);
     }
-    newIndexingMaps.push_back(AffineMap::get(plan.map.getNumDims() - 1, 0,
-                                             results, contractOp.getContext()));
-    operandCastPlans.push_back(std::move(plan));
-  }
-
-  for (auto [plan, operand] : llvm::zip_equal(operandCastPlans, operands)) {
-    Value newOperand = operand;
-    if (plan.permuteOperand)
-      newOperand = rewriter.createOrFold<vector::ShapeCastOp>(
-          loc,
-          permuteVectorType(cast<VectorType>(newOperand.getType()),
-                            plan.permutation),
-          newOperand);
-    if (plan.dropLeadingUnitDim)
-      newOperand =
-          dropLeadingUnitDims0DIsScalar(rewriter, loc, newOperand, dropDim);
-    newOperands.push_back(newOperand);
+    newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
+                                             contractOp.getContext()));
+    // Extract if its a valid extraction, otherwise use the operand
+    // without extraction.
+    newOperands.push_back(validExtract
+                              ? vector::ExtractOp::create(rewriter, loc,
+                                                          operands[it.index()],
+                                                          splatZero(dropDim))
+                              : operands[it.index()]);
   }
 
   // Depending on whether this vector.contract is masked, the replacing Op
@@ -560,19 +461,13 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
       rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
 
   if (maskingOp) {
-    Value newMask = dropUnitDim(rewriter, loc, maskingOp.getMask(), dimToDrop,
-                                /*zeroDimsAllowed=*/false);
+    auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
+                                             splatZero(dropDim));
 
     newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
   }
 
-  if (!isa<VectorType>(newOp->getResults()[0].getType()))
-    return vector::BroadcastOp::create(rewriter, loc,
-                                       contractOp->getResultTypes()[0],
-                                       newOp->getResults()[0])
-        .getResult();
-
-  return vector::ShapeCastOp::create(rewriter, loc,
+  return vector::BroadcastOp::create(rewriter, loc,
                                      contractOp->getResultTypes()[0],
                                      newOp->getResults()[0])
       .getResult();
@@ -581,9 +476,9 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
 namespace {
 
 /// Turns vector.contract on vector with leading 1 dimensions into
-/// vector.shape_cast followed by vector.contract on vector without leading
-/// 1 dimensions. Non-leading unit dimensions are dropped via direct
-/// shape_casts.
+/// vector.extract followed by vector.contract on vector without leading
+/// 1 dimensions. Also performs transpose of lhs and rhs operands if required
+/// prior to extract.
 struct CastAwayContractionLeadingOneDim
     : public MaskableOpRewritePattern<vector::ContractionOp> {
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
@@ -598,15 +493,14 @@ struct CastAwayContractionLeadingOneDim
 
 /// Looks at elementwise operations on vectors with at least one leading
 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
-/// and casts away the leading one dimensions (_plural_) with shape_cast.
+/// and cast aways the leading one dimensions (_plural_) and then broadcasts
+/// the results.
 ///
 /// Example before:
 ///     %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
 /// Example after:
-///    %2 = vector.shape_cast %arg0 : vector<1x4x1xf32> to vector<4x1xf32>
-///    %3 = vector.shape_cast %arg1 : vector<1x4x1xf32> to vector<4x1xf32>
-///    %4 = arith.mulf %2, %3 : vector<4x1xf32>
-///    %5 = vector.shape_cast %4 : vector<4x1xf32> to vector<1x4x1xf32>
+///    %2 = arith.mulf %0, %1 : vector<4x1xf32>
+///    %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
 ///
 /// Does support scalable vectors.
 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
@@ -622,34 +516,55 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
     auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
     if (!vecType)
       return failure();
-    VectorType newVecType =
-        trimLeadingUnitDims(vecType, /*zeroDimsAllowed=*/true);
+    VectorType newVecType = trimLeadingOneDims(vecType);
     if (newVecType == vecType)
       return failure();
+    int64_t dropDim = vecType.getRank() - newVecType.getRank();
     SmallVector<Value, 4> newOperands;
     for (Value operand : op->getOperands()) {
-      if (auto opVecType = dyn_cast<VectorType>(operand.getType()))
-        newOperands.push_back(rewriter.createOrFold<vector::ShapeCastOp>(
-            op->getLoc(),
-            trimLeadingUnitDims(opVecType, /*zeroDimsAllowed=*/true), operand));
-      else
+      if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
+        newOperands.push_back(vector::ExtractOp::create(
+            rewriter, op->getLoc(), operand, splatZero(dropDim)));
+      } else {
         newOperands.push_back(operand);
+      }
     }
     Operation *newOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
                         newOperands, newVecType, op->getAttrs());
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
                                                      newOp->getResult(0));
     return success();
   }
 };
 } // namespace
 
+// Drops `dropDim` leading dimensions from `operand` using vector.extract when
+// those dims are all non-scalable units (the cheap, structural rewrite); falls
+// back to vector.shape_cast otherwise.
+static Value dropLeadingOneDimsFromOperand(OpBuilder &b, Location loc,
+                                           Value operand, int64_t nDropped) {
+  auto oldType = cast<VectorType>(operand.getType());
+  ArrayRef<int64_t> leadingShape = oldType.getShape().take_front(nDropped);
+  ArrayRef<bool> leadingScalable =
+      oldType.getScalableDims().take_front(nDropped);
+  bool extractable =
+      llvm::all_of(leadingShape, [](int64_t d) { return d == 1; }) &&
+      llvm::none_of(leadingScalable, [](bool s) { return s; });
+  if (extractable)
+    return vector::ExtractOp::create(b, loc, operand, splatZero(nDropped));
+  VectorType newType = VectorType::get(
+      oldType.getShape().drop_front(nDropped), oldType.getElementType(),
+      oldType.getScalableDims().drop_front(nDropped));
+  return vector::ShapeCastOp::create(b, loc, newType, operand);
+}
+
 namespace {
 
-// Drops leading unit dimensions from load-like memory operations by
-// shape_casting each vector operand and shape_casting the result back to the
-// original type.
+// Drops leading 1 dimensions from load-like memory operaitons. REmoves leading
+// unit dimensions from the result types and then broadcasts back in those 1s,
+// while also extracting (or shape_cast-ing) any leading unit dimensions on
+// the input operands.
 template <typename OpTy>
 struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -657,10 +572,7 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
     VectorType oldResultType = op.getVectorType();
-    constexpr bool zeroDimsAllowed =
-        llvm::is_one_of<OpTy, vector::LoadOp>::value;
-    VectorType newResultType =
-        trimLeadingUnitDims(oldResultType, zeroDimsAllowed);
+    VectorType newResultType = trimLeadingOneDims(oldResultType);
     if (newResultType == oldResultType)
       return failure();
     int64_t nDropped = oldResultType.getRank() - newResultType.getRank();
@@ -670,8 +582,8 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
     newOperands.reserve(op->getNumOperands());
     for (Value operand : op->getOperands()) {
       if (isa<VectorType>(operand.getType())) {
-        newOperands.push_back(dropLeadingUnitDims(rewriter, loc, operand,
-                                                  nDropped, zeroDimsAllowed));
+        newOperands.push_back(
+            dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
       } else {
         newOperands.push_back(operand);
       }
@@ -680,14 +592,15 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
     Operation *newOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                         TypeRange{newResultType}, op->getAttrs());
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, oldResultType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, oldResultType,
                                                      newOp->getResult(0));
     return success();
   }
 };
 
-// Drops leading unit dimensions from store-like memory operations by
-// shape_casting each vector operand and leaving any scalar operands alone.
+// Drops leading 1 dimensions from store-like memory ops. Extracts or
+// `shape_cast`s away those leading unit dimensions and leaves any scalar
+// operands alone.
 template <typename OpTy>
 struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -695,9 +608,7 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
     VectorType oldVecType = op.getVectorType();
-    constexpr bool zeroDimsAllowed =
-        llvm::is_one_of<OpTy, vector::StoreOp>::value;
-    VectorType newVecType = trimLeadingUnitDims(oldVecType, zeroDimsAllowed);
+    VectorType newVecType = trimLeadingOneDims(oldVecType);
     if (newVecType == oldVecType)
       return failure();
     int64_t nDropped = oldVecType.getRank() - newVecType.getRank();
@@ -707,8 +618,8 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
     newOperands.reserve(op->getNumOperands());
     for (Value operand : op->getOperands()) {
       if (isa<VectorType>(operand.getType())) {
-        newOperands.push_back(dropLeadingUnitDims(rewriter, loc, operand,
-                                                  nDropped, zeroDimsAllowed));
+        newOperands.push_back(
+            dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
       } else {
         newOperands.push_back(operand);
       }
@@ -722,8 +633,8 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
   }
 };
 
-// Drops leading 1 dimensions from vector.constant_mask and shape_casts back to
-// the original shape.
+// Drops leading 1 dimensions from vector.constant_mask and inserts a
+// vector.broadcast back to the original shape.
 struct CastAwayConstantMaskLeadingOneDim
     : public OpRewritePattern<vector::ConstantMaskOp> {
   using Base::Base;
@@ -731,8 +642,7 @@ struct CastAwayConstantMaskLeadingOneDim
   LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
                                 PatternRewriter &rewriter) const override {
     VectorType oldType = mask.getType();
-    VectorType newType = trimLeadingUnitDims(oldType,
-                                             /*zeroDimsAllowed=*/true);
+    VectorType newType = trimLeadingOneDims(oldType);
 
     if (newType == oldType)
       return failure();
@@ -740,22 +650,16 @@ struct CastAwayConstantMaskLeadingOneDim
     int64_t dropDim = oldType.getRank() - newType.getRank();
     ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
 
-    // If any of the folded unit dims has a size of `0`, the entire leading
-    // mask region is zero. Otherwise the folded unit dims have no effect on
-    // the mask.
-    SmallVector<int64_t> newDimSizes;
-    if (newType.getRank() == 0) {
-      newDimSizes.push_back(llvm::product_of(dimSizes));
-    } else {
-      int64_t flatLeadingSize =
-          llvm::product_of(dimSizes.take_front(dropDim + 1));
-      newDimSizes.push_back(flatLeadingSize);
-      newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
-    }
+    // If any of the dropped unit dims has a size of `0`, the entire mask is a
+    // zero mask, else the unit dim has no effect on the mask.
+    int64_t flatLeadingSize =
+        llvm::product_of(dimSizes.take_front(dropDim + 1));
+    SmallVector<int64_t> newDimSizes = {flatLeadingSize};
+    newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
 
     auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
                                                   newType, newDimSizes);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(mask, oldType, newMask);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2575c9e4a85b9..752610efc6992 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1931,12 +1931,12 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-// Helper function dropping unit non-scalable dimension from a VectorType.
-// Scalable unit dimensions are not dropped. Folding such dimensions would
-// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
-// vector<[1]x4xf32> -> vector<[4]xf32>).
-static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy,
-                                                 bool zeroDimsAllowed) {
+// Helper function dropping unit non-scalable dimension from a VectorType
+// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
+// dimensions are not dropped. Folding such dimensions would require "shifting"
+// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
+// vector<[4]xf32>). This could be implemented in the future.
+static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
   auto inVecShape = inVecTy.getShape();
   SmallVector<int64_t> newShape;
   SmallVector<bool> newScalableDims;
@@ -1948,8 +1948,8 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy,
     newShape.push_back(dim);
     newScalableDims.push_back(isScalable);
   }
-  // Some vector ops forbid 0-D vectors.
-  if (!zeroDimsAllowed && newShape.empty()) {
+  // All dims have been dropped, return vector<1xeType>.
+  if (newShape.empty()) {
     newShape.push_back(1);
     newScalableDims.push_back(false);
   }
@@ -2000,12 +2000,14 @@ struct DropUnitDimFromElementwiseOps final
     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
     if (!sourceVectorType)
       return failure();
+    if (sourceVectorType.getRank() < 2)
+      return failure();
+
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
       auto opVectorType = cast<VectorType>(operand.getType());
-      auto newVType = dropNonScalableUnitDimFromType(opVectorType,
-                                                     /*zeroDimsAllowed=*/true);
+      auto newVType = dropNonScalableUnitDimFromType(opVectorType);
       if (newVType == opVectorType)
         return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
 
@@ -2014,8 +2016,7 @@ struct DropUnitDimFromElementwiseOps final
     }
 
     VectorType newResultVectorType =
-        dropNonScalableUnitDimFromType(resultVectorType,
-                                       /*zeroDimsAllowed=*/true);
+        dropNonScalableUnitDimFromType(resultVectorType);
     // Create an updated elementwise Op without unit dim.
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
@@ -2056,8 +2057,7 @@ struct DropUnitDimsFromTransposeOp final
                                 PatternRewriter &rewriter) const override {
     VectorType sourceType = op.getSourceVectorType();
     VectorType sourceTypeWithoutUnitDims =
-        dropNonScalableUnitDimFromType(sourceType,
-                                       /*zeroDimsAllowed=*/true);
+        dropNonScalableUnitDimFromType(sourceType);
 
     if (sourceType == sourceTypeWithoutUnitDims)
       return failure();
@@ -2082,9 +2082,9 @@ struct DropUnitDimsFromTransposeOp final
     }
 
     // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
-    // type when the dimensions are unit dimensions and 0-D vectors are not
-    // allowed. In this case, the newPerm should be [0].
-    if (newPerm.empty() && sourceTypeWithoutUnitDims.getRank() > 0) {
+    // type when the dimensions are unit dimensions. In this case, the newPerm
+    // should be [0].
+    if (newPerm.empty()) {
       newPerm.push_back(0);
     }
 
@@ -2139,9 +2139,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
       if (!vectorType)
         continue;
 
-      VectorType newVectorType =
-          dropNonScalableUnitDimFromType(vectorType,
-                                         /*zeroDimsAllowed=*/true);
+      VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
       if (vectorType == newVectorType)
         continue;
 

diff  --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index 4e800ab169bf6..34a155fbf2fc1 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -150,26 +150,10 @@ func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> {
 
 // CHECK-LABEL: func.func @fold_all_unit_dims(
 // CHECK-SAME:    %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
-// CHECK:         %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<f32>
-// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<f32>
-// CHECK:         %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<f32>
-// CHECK:         %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<f32> to vector<1xf32>
-// CHECK:         return %[[VAL_4]] : vector<1xf32>
-
-// -----
-
-func.func @fold_rank1_unit_dim(%vec: vector<1xf32>) -> vector<1xf32> {
-  %res = arith.addf %vec, %vec : vector<1xf32>
-  return %res : vector<1xf32>
-}
-
-// CHECK-LABEL: func.func @fold_rank1_unit_dim(
-// CHECK-SAME:    %[[VAL_0:.*]]: vector<1xf32>) -> vector<1xf32>
-// CHECK:         %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
-// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
-// CHECK:         %[[VAL_3:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : vector<f32>
-// CHECK:         %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<f32> to vector<1xf32>
-// CHECK:         return %[[VAL_4]] : vector<1xf32>
+// CHECK:         %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK:         %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
+// CHECK:         return %[[VAL_3]] : vector<1xf32>
 
 ///----------------------------------------------------------------------------------------
 /// [Pattern: DropUnitDimsFromTransposeOp]
@@ -265,11 +249,11 @@ func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32>
 
 // CHECK-LABEL: func.func @scf_for_with_all_unit_dims
 //  CHECK-SAME:   %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
-//       CHECK:   %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<f32>
+//       CHECK:   %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
 //       CHECK:   %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
-//       CHECK:     %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<f32>
+//       CHECK:     %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
 //       CHECK:     scf.yield %[[SQRT]]
-//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<f32> to vector<1x1xf32>
+//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
 //       CHECK:   return %[[CASTBACK]]
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index cd1ecec455896..bf01c8a8589d9 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -5,13 +5,13 @@
 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: cast_away_contraction_leading_one_dims
-//  CHECK-NEXT:   %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32>
-//  CHECK-NEXT:   %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-//  CHECK-NEXT:   %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32>
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
 //  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
-//  CHECK-NEXT:   %[[R4:.+]] = vector.shape_cast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32>
+//  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32>
 //  CHECK-NEXT:  return %[[R4]] : vector<1x16x16xf32>
 
 #contraction_accesses0 = [
@@ -36,14 +36,14 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
 
 // CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_const_mask
 // CHECK:           %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
-// CHECK:           %[[R0:.*]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32>
-// CHECK:           %[[R1:.*]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-// CHECK:           %[[R2:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32>
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
 // CHECK:           %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
 // CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 // CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
 // CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
-// CHECK:           %[[RES:.*]] = vector.shape_cast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK:           %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
 // CHECK:           return %[[RES]] : vector<1x16x16xf32>
 
 #contraction_accesses0 = [
@@ -70,15 +70,15 @@ func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<
 // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_mask
-// CHECK:           %[[R0:.*]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32>
-// CHECK:           %[[R1:.*]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-// CHECK:           %[[R2:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32>
-// CHECK:           %[[M:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16x8xi1> to vector<16x16x8xi1>
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK:           %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
 // CHECK:           %[[CONTRACT:.*]] = vector.mask %[[M]] {
 // CHECK-SAME:      vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 // CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
 // CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.shape_cast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
 // CHECK-NEXT:      return %[[RES]] : vector<1x16x16xf32>
 
 #contraction_accesses0 = [
@@ -109,14 +109,15 @@ func.func @cast_away_contraction_leading_one_dim_under_mask(
 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
 
 // CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded
-//  CHECK-NEXT:   %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32>
-//  CHECK-NEXT:   %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
-//  CHECK-NEXT:   %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x1x16xf32> to vector<16xf32>
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0, 0] : vector<16xf32> from vector<1x1x16xf32>
 //  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "reduction"], kind = #vector.kind<mul>}
 //  CHECK-SAME:   %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32>
-//  CHECK-NEXT:   %[[R4:.+]] = vector.shape_cast %[[R3]] : vector<16xf32> to vector<1x1x16xf32>
-//  CHECK-NEXT:  return %[[R4]] : vector<1x1x16xf32>
+//  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32>
+//  CHECK-NEXT:   %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32>
+//  CHECK-NEXT:  return %[[R5]] : vector<1x1x16xf32>
 
 #contraction_accesses1 = [
   affine_map<(l, i, j, k) -> (i, l, k)>,
@@ -140,13 +141,15 @@ func.func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector<
 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2
-//  CHECK-NEXT:   %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<8x1x16xf32> to vector<8x16xf32>
-//  CHECK-NEXT:   %[[R3:.+]] = vector.shape_cast %{{.*}} : vector<2x8x1xf32> to vector<2x8xf32>
-//  CHECK-NEXT:   %[[R4:.+]] = vector.shape_cast %{{.*}} : vector<1x2x16xf32> to vector<2x16xf32>
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %[[R0]][0] : vector<8x16xf32> from vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<2x8xf32> from vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0] : vector<2x16xf32> from vector<1x2x16xf32>
 //  CHECK-NEXT:   %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
-//  CHECK-NEXT:   %[[R6:.+]] = vector.shape_cast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
+//  CHECK-NEXT:   %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
 //  CHECK-NEXT:  return %[[R6]] : vector<1x2x16xf32>
 
 #contraction_accesses2 = [
@@ -172,14 +175,19 @@ func.func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector
 
 
 // CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4
-//  CHECK-NEXT:   %[[R3:.+]] =  vector.shape_cast %{{.*}} : vector<1x8x1x16xf32> to vector<8x16xf32>
-//  CHECK-NEXT:   %[[R5:.+]] =  vector.shape_cast %{{.*}} : vector<1x2x8x1xf32> to vector<2x8xf32>
-//  CHECK-NEXT:   %[[R6:.+]] =  vector.shape_cast %{{.*}} : vector<1x1x2x16xf32> to vector<2x16xf32>
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<8x1x16xf32> from vector<1x8x1x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<2x8x1xf32> from vector<1x2x8x1xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<8x16xf32> from vector<1x8x16xf32>
+//  CHECK-NEXT:   %[[R4:.+]] =  vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R5:.+]] =  vector.extract %[[R4]][0] : vector<2x8xf32> from vector<1x2x8xf32>
+//  CHECK-NEXT:   %[[R6:.+]] =  vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
 //  CHECK-NEXT:   %[[R7:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
-//  CHECK-NEXT:   %[[R8:.+]] =  vector.shape_cast %[[R7]] : vector<2x16xf32> to vector<1x1x2x16xf32>
-//  CHECK-NEXT:  return %[[R8]] : vector<1x1x2x16xf32>
+//  CHECK-NEXT:   %[[R8:.+]] =  vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32>
+//  CHECK-NEXT:   %[[R9:.+]] =  vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
+//  CHECK-NEXT:  return %[[R9]] : vector<1x1x2x16xf32>
 
 #contraction_accesses2 = [
   affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
@@ -203,14 +211,17 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0:
 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose
-//  CHECK-NEXT:   %[[R2:.+]] =  vector.shape_cast %{{.*}} : vector<1x8x1x16xf32> to vector<8x16xf32>
-//  CHECK-NEXT:   %[[R3:.+]] =  vector.shape_cast %{{.*}} : vector<1x2x8x1xf32> to vector<2x8xf32>
-//  CHECK-NEXT:   %[[R4:.+]] =  vector.shape_cast %{{.*}} : vector<1x1x2x16xf32> to vector<2x16xf32>
+//  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32>
+//  CHECK-NEXT:   %[[R1:.+]] =  vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32>
+//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %[[R0]][0, 0] : vector<8x16xf32> from vector<1x1x8x16xf32>
+//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R1]][0, 0] : vector<2x8xf32> from vector<1x1x2x8xf32>
+//  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
 //  CHECK-NEXT:   %[[R5:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
-//  CHECK-NEXT:   %[[R6:.+]] =  vector.shape_cast %[[R5]] : vector<2x16xf32> to vector<1x1x2x16xf32>
-//  CHECK-NEXT:  return %[[R6]] : vector<1x1x2x16xf32>
+//  CHECK-NEXT:   %[[R6:.+]] =  vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
+//  CHECK-NEXT:   %[[R7:.+]] =  vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
+//  CHECK-NEXT:  return %[[R7]] : vector<1x1x2x16xf32>
 
 #contraction_accesses3 = [
   affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
@@ -245,7 +256,7 @@ func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vect
 // CHECK-DAG: #[[$map_dp1:.*]] = affine_map<(d0) -> ()>
 
 // CHECK-LABEL: cast_away_contraction_leading_one_dims_to_dot_product
-//  CHECK-NEXT:   %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x64xf32> to vector<64xf32>
+//  CHECK-NEXT:   %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<64xf32> from vector<1x64xf32>
 //  CHECK-NEXT:   %[[R1:.+]] = vector.extract %{{.*}}[0] : f32 from vector<1xf32>
 //  CHECK-NEXT:   %[[R2:.+]] = vector.contract {indexing_maps = [#[[$map_dp0]], #[[$map_dp0]], #[[$map_dp1]]],
 //  CHECK-SAME:   iterator_types = ["reduction"], kind = #vector.kind<add>}
@@ -259,96 +270,44 @@ func.func @cast_away_contraction_leading_one_dims_to_dot_product(%arg0: vector<6
 }
 
 // -----
-
-// CHECK-DAG: #[[$DOT_MAP:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-DAG: #[[$SCALAR_MAP:.*]] = affine_map<(d0) -> ()>
-
-// CHECK-LABEL: cast_away_masked_contraction_with_rank1_acc
-//  CHECK-NEXT:   %[[RHS:.+]] = vector.shape_cast %{{.*}} : vector<1x64xf32> to vector<64xf32>
-//  CHECK-NEXT:   %[[ACC:.+]] = vector.extract %{{.*}}[0] : f32 from vector<1xf32>
-//  CHECK-NEXT:   %[[MASK:.+]] = vector.shape_cast %{{.*}} : vector<64x1xi1> to vector<64xi1>
-//  CHECK-NEXT:   %[[DOT:.+]] = vector.mask %[[MASK]] {
-//  CHECK-SAME:     vector.contract {indexing_maps = [#[[$DOT_MAP]], #[[$DOT_MAP]], #[[$SCALAR_MAP]]], iterator_types = ["reduction"], kind = #vector.kind<add>}
-//  CHECK-SAME:     %{{.*}}, %[[RHS]], %[[ACC]] : vector<64xf32>, vector<64xf32> into f32
-//  CHECK-SAME:   } : vector<64xi1> -> f32
-//  CHECK-NEXT:   %[[RES:.+]] = vector.broadcast %[[DOT]] : f32 to vector<1xf32>
-//  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
-
-func.func @cast_away_masked_contraction_with_rank1_acc(%arg0: vector<64xf32>, %arg1: vector<1x64xf32>, %arg2: vector<1xf32>, %mask: vector<64x1xi1>) -> vector<1xf32> {
-  %0 = vector.mask %mask {
-    vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64xf32>, vector<1x64xf32> into vector<1xf32>
-  } : vector<64x1xi1> -> vector<1xf32>
-  return %0 : vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: negative_cast_away_contraction_with_scalable_rank1_acc
-//  CHECK-NOT: vector.shape_cast
-//  CHECK-NOT: vector.extract
-//  CHECK-NOT: vector.broadcast
-//  CHECK-NEXT: vector.contract
-//  CHECK-NEXT: return
-
-func.func @negative_cast_away_contraction_with_scalable_rank1_acc(%arg0: vector<64xf32>, %arg1: vector<[1]x64xf32>, %arg2: vector<[1]xf32>) -> vector<[1]xf32> {
-  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64xf32>, vector<[1]x64xf32> into vector<[1]xf32>
-  return %0 : vector<[1]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: negative_cast_away_contraction_with_scalable_operand_dim
-//  CHECK-NOT: vector.shape_cast
-//  CHECK-NOT: vector.extract
-//  CHECK-NOT: vector.broadcast
-//  CHECK-NEXT: vector.contract
-//  CHECK-NEXT: return
-
-func.func @negative_cast_away_contraction_with_scalable_operand_dim(%arg0: vector<64xf32>, %arg1: vector<[1]x64xf32>, %arg2: vector<1xf32>) -> vector<1xf32> {
-  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64xf32>, vector<[1]x64xf32> into vector<1xf32>
-  return %0 : vector<1xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
 func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
-  // CHECK:     %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
   // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
-  // CHECK:     %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
   // CHECK: return %[[RET]]
   return %0: vector<1x1x8xf16>
 }
 
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable
 func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> {
-  // CHECK:     %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x[8]xf16> to vector<8x[8]xf16>
+  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
   // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16>
   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16>
-  // CHECK:     %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
+  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
   // CHECK: return %[[RET]]
   return %0: vector<1x1x[8]xf16>
 }
 
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
 func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
-  // CHECK:    %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
-  // CHECK:    %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16>
+  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
   // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
-  // CHECK:    %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
   // CHECK: return %[[RET]]
   return %0: vector<1x8x8xf16>
 }
 
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable
 func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> {
-  // CHECK:    %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x[8]xf16> to vector<[8]xf16>
-  // CHECK:    %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x[8]xf16> to vector<8x[8]xf16>
+  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16>
+  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
   // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16>
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16>
-  // CHECK:    %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
+  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
   // CHECK: return %[[RET]]
   return %0: vector<1x8x[8]xf16>
 }
@@ -356,7 +315,8 @@ func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vecto
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
 //  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
 func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
-  // CHECK: %[[B:.+]] = vector.shape_cast %{{.*}} : vector<1x1xf16> to vector<1x1x1xf16>
+  // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1xf16> from vector<1x1xf16>
+  // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
   // CHECK: return %[[B]]
   return %0: vector<1x1x1xf16>
@@ -365,7 +325,8 @@ func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: ve
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable
 //  CHECK-SAME: %[[ARG0:.+]]: vector<1x[1]xf16>, %{{.+}}: vector<1x1x[1]xf16>
 func.func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable(%arg0: vector<1x[1]xf16>, %arg1: vector<1x1x[1]xf16>) -> vector<1x1x[1]xf16> {
-  // CHECK: %[[B:.+]] = vector.shape_cast %{{.*}} : vector<1x[1]xf16> to vector<1x1x[1]xf16>
+  // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<[1]xf16> from vector<1x[1]xf16>
+  // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<[1]xf16> to vector<1x1x[1]xf16>
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[1]xf16> into vector<1x1x[1]xf16>
   // CHECK: return %[[B]]
   return %0: vector<1x1x[1]xf16>
@@ -378,7 +339,7 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
   %f0 = arith.constant 0. : f16
   // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
   // CHECK: return %[[CAST]]
   return %0: vector<1x4xf16>
@@ -390,9 +351,9 @@ func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x1
   %c0 = arith.constant 0 : index
   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
   %f0 = arith.constant 0. : f16
-  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
+  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
   // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
   // CHECK: return %[[CAST]]
   return %0: vector<1x4xf16>
@@ -402,7 +363,7 @@ func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x1
 func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
   %c0 = arith.constant 0 : index
   %f0 = arith.constant 0. : f16
-  // CHECK: vector.shape_cast %{{.+}} : vector<f16> to vector<1x1xf16>
+  // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
   return %0: vector<1x1xf16>
 }
@@ -419,7 +380,7 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
   // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
   // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
   // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
   %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
                             permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
   // CHECK: return %[[CAST]]
@@ -430,7 +391,7 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
 
 // CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
 // CHECK:      %[[MASK:.+]] = vector.constant_mask
-// CHECK:      %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]]
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
 // CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
 // CHECK-SAME:   vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
 // CHECK:      return %[[RET]] : vector<1x4xf16>
@@ -450,7 +411,7 @@ func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>
 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
   // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
 
   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
@@ -461,8 +422,8 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
 func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
-  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
+  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
   // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
 
   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
@@ -472,7 +433,7 @@ func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
 func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
   %c0 = arith.constant 0 : index
-  // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<f16>
+  // CHECK: vector.extract %{{.+}}[0] : vector<1xf16> from vector<1x1xf16>
   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
   return
 }
@@ -481,7 +442,7 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
 
 // CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
 // CHECK:      %[[MASK:.+]] = vector.constant_mask
-// CHECK:      %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]]
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
 // CHECK:      vector.mask %[[CASTED_MASK]] {
 // CHECK-SAME:   vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
 // CHECK:      return
@@ -501,7 +462,7 @@ func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16
 func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x1x4xf16> to vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
   // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
   // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
   // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
@@ -518,25 +479,25 @@ func.func @cast_away_elementwise_leading_one_dims(
   %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
   %arg3: vector<1x4xf32>, %arg4: i1) ->
   (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
+  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
+  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
   // CHECK:  arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
   %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
   // CHECK:  arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
   %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
   // CHECK:  select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
   %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
   // CHECK:  select %arg4, %12, %{{.*}} : vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
   %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32>
   return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
 }
@@ -545,10 +506,10 @@ func.func @cast_away_elementwise_leading_one_dims(
 
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
 //  CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
-//       CHECK:   %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1x4xf32> to vector<4xf32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[S]], %[[DST_CAST]] [0] : f32 into vector<4xf32>
-//       CHECK:   %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
-//       CHECK:   return %[[RESULT_CAST]]
+//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[BCAST]]
 func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
   %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32>
   return %0: vector<1x1x4xf32>
@@ -556,27 +517,14 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3
 
 // -----
 
-// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar_0d_dest
-//  CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1xf32>)
-//       CHECK:   %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1xf32> to vector<f32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[S]], %[[DST_CAST]] [] : f32 into vector<f32>
-//       CHECK:   %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<f32> to vector<1x1xf32>
-//       CHECK:   return %[[RESULT_CAST]]
-func.func @cast_away_insert_leading_one_dims_scalar_0d_dest(%s: f32, %v: vector<1x1xf32>) -> vector<1x1xf32> {
-  %0 = vector.insert %s, %v [0, 0] : f32 into vector<1x1xf32>
-  return %0: vector<1x1xf32>
-}
-
-// -----
-
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_scalable(
 // CHECK-SAME:    %[[S:.*]]: f32,
 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
 func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
-// CHECK:           %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x1x[4]xf32> to vector<[4]xf32>
-// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[DST_CAST]] [0] : f32 into vector<[4]xf32>
-// CHECK:           %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
-// CHECK:           return %[[RESULT_CAST]] : vector<1x1x[4]xf32>
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<[4]xf32> from vector<1x1x[4]xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
   %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32>
   return %0: vector<1x1x[4]xf32>
 }
@@ -587,10 +535,10 @@ func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector
 // CHECK-SAME:    %[[S:.*]]: f32,
 // CHECK-SAME:    %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
 func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
-// CHECK:           %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x[1]x4xf32> to vector<[1]x4xf32>
-// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[DST_CAST]] [0, 0] : f32 into vector<[1]x4xf32>
-// CHECK:           %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
-// CHECK:           return %[[RESULT_CAST]] : vector<1x[1]x4xf32>
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<[1]x4xf32> from vector<1x[1]x4xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
+// CHECK:           return %[[BCAST]] : vector<1x[1]x4xf32>
   %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32>
   return %0: vector<1x[1]x4xf32>
 }
@@ -599,8 +547,8 @@ func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %
 
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
 //  CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
-//       CHECK:   %[[RESULT_CAST:.+]] = vector.shape_cast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
-//       CHECK:   return %[[RESULT_CAST]]
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[BCAST]]
 func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
   %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32>
   return %0: vector<1x1x4xf32>
@@ -611,8 +559,8 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank1_scalable(
 // CHECK-SAME:    %[[S:.*]]: vector<[4]xf32>,
 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
-// CHECK:           %[[RESULT_CAST:.*]] = vector.shape_cast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
-// CHECK:           return %[[RESULT_CAST]] : vector<1x1x[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
 func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
   %0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32>
   return %0: vector<1x1x[4]xf32>
@@ -622,8 +570,9 @@ func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>,
 
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
-//       CHECK:   %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<1x1x4xf32>
-//       CHECK:   return %[[SRC_CAST]]
+//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[BCAST]]
 func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
   %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32>
   return %0: vector<1x1x4xf32>
@@ -634,8 +583,9 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_scalable(
 // CHECK-SAME:    %[[S:.*]]: vector<1x[4]xf32>,
 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
-// CHECK:           %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<1x1x[4]xf32>
-// CHECK:           return %[[SRC_CAST]] : vector<1x1x[4]xf32>
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
 func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
   %0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32>
   return %0: vector<1x1x[4]xf32>
@@ -645,11 +595,11 @@ func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32
 
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
-//       CHECK:   %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<4xf32>
-//       CHECK:   %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x2x1x4xf32> to vector<2x1x4xf32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
-//       CHECK:   %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
-//       CHECK:   return %[[RESULT_CAST]]
+//       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
+//       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<2x1x4xf32> from vector<1x2x1x4xf32>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
+//       CHECK:   return %[[BCAST]]
 func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> {
   %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32>
   return %0: vector<1x2x1x4xf32>
@@ -660,11 +610,11 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>,
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
 // CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
 // CHECK-SAME:      %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
-// CHECK:           %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<[4]xf32>
-// CHECK:           %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x2x1x[4]xf32> to vector<2x1x[4]xf32>
-// CHECK:           %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
-// CHECK:           %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
-// CHECK:           return %[[RESULT_CAST]] : vector<1x2x1x[4]xf32>
+// CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
+// CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<2x1x[4]xf32> from vector<1x2x1x[4]xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x2x1x[4]xf32>
 func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
   %0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32>
   return %0: vector<1x2x1x[4]xf32>
@@ -674,8 +624,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<
 
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
-//       CHECK:   %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<4xf32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
+//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
 //       CHECK:   return %[[INSERT]]
 func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> {
   %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
@@ -687,8 +637,8 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
 // CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
 // CHECK-SAME:      %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
-// CHECK:           %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<[4]xf32>
-// CHECK:           %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
 // CHECK:           return %[[INSERT]] : vector<8x1x[4]xf32>
 func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
   %0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32>
@@ -699,11 +649,11 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x
 
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
-//       CHECK:   %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x8xi1> to vector<8xi1>
-//       CHECK:   %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1x8x1x8xi1> to vector<8x1x8xi1>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
-//       CHECK:   %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
-//       CHECK:   return %[[RESULT_CAST]]
+//       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1>
+//       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<8x1x8xi1> from vector<1x1x8x1x8xi1>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
+//       CHECK:   return %[[BCAST]]
 func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> {
   %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
   return %0: vector<1x1x8x1x8xi1>
@@ -714,11 +664,11 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v
 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
 // CHECK-SAME:      %[[S:.*]]: vector<1x[8]xi1>,
 // CHECK-SAME:      %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
-// CHECK:           %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[8]xi1> to vector<[8]xi1>
-// CHECK:           %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x1x8x1x[8]xi1> to vector<8x1x[8]xi1>
-// CHECK:           %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
-// CHECK:           %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
-// CHECK:           return %[[RESULT_CAST]] : vector<1x1x8x1x[8]xi1>
+// CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[8]xi1> from vector<1x[8]xi1>
+// CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<8x1x[8]xi1> from vector<1x1x8x1x[8]xi1>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
+// CHECK:           return %[[BCAST]] : vector<1x1x8x1x[8]xi1>
 func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
   %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
   return %0: vector<1x1x8x1x[8]xi1>
@@ -728,8 +678,8 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
 
 // CHECK-LABEL:   func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
 // CHECK:           %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
-// CHECK:           %[[MASK_CAST:.*]] = vector.shape_cast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
-// CHECK:           return %[[MASK_CAST]] : vector<1x1x8x2x1xi1>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
+// CHECK:           return %[[BCAST]] : vector<1x1x8x2x1xi1>
 func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
   %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
   return %0: vector<1x1x8x2x1xi1>
@@ -737,16 +687,6 @@ func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
 
 // -----
 
-// CHECK-LABEL:   func.func @cast_away_constant_mask_all_unit_dims() -> vector<1x1xi1> {
-// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<1x1xi1>
-// CHECK:           return %[[MASK]] : vector<1x1xi1>
-func.func @cast_away_constant_mask_all_unit_dims() -> vector<1x1xi1> {
-  %0 = vector.constant_mask [1, 1] : vector<1x1xi1>
-  return %0: vector<1x1xi1>
-}
-
-// -----
-
 // CHECK-LABEL:   func.func @drop_unit_dims_scalar_cond_select(
 // CHECK:           arith.select {{.*}} : vector<16xi1>
 func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> {
@@ -758,7 +698,7 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
 
 // CHECK-LABEL: func.func @cast_away_load_leading_one_dims
 // CHECK:         %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
-// CHECK:         %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK:         %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
 // CHECK:         return %[[B]] : vector<1x4xf32>
 func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %j: index) -> vector<1x4xf32> {
   %0 = vector.load %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
@@ -767,33 +707,11 @@ func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %
 
 // -----
 
-// CHECK-LABEL: func.func @cast_away_load_all_unit_dims
-// CHECK:         %[[L:.+]] = vector.load %{{.*}}[%{{.*}}] : memref<1xf32>, vector<f32>
-// CHECK:         %[[B:.+]] = vector.shape_cast %[[L]] : vector<f32> to vector<1xf32>
-// CHECK:         return %[[B]] : vector<1xf32>
-func.func @cast_away_load_all_unit_dims(%base: memref<1xf32>, %i: index) -> vector<1xf32> {
-  %0 = vector.load %base[%i] : memref<1xf32>, vector<1xf32>
-  return %0 : vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @cast_away_load_leading_one_dims_scalable
-// CHECK:         %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, vector<[4]xf32>
-// CHECK:         %[[B:.+]] = vector.shape_cast %[[L]] : vector<[4]xf32> to vector<1x[4]xf32>
-// CHECK:         return %[[B]] : vector<1x[4]xf32>
-func.func @cast_away_load_leading_one_dims_scalable(%base: memref<?x?xf32>, %i: index, %j: index) -> vector<1x[4]xf32> {
-  %0 = vector.load %base[%i, %j] : memref<?x?xf32>, vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func.func @cast_away_maskedload_leading_one_dims
-// CHECK:         %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK:         %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK:         %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
 // CHECK:         %[[L:.+]] = vector.maskedload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
-// CHECK:         %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK:         %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
 // CHECK:         return %[[B]] : vector<1x4xf32>
 func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
   %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
@@ -803,10 +721,10 @@ func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: inde
 // -----
 
 // CHECK-LABEL: func.func @cast_away_expandload_leading_one_dims
-// CHECK:         %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK:         %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK:         %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
 // CHECK:         %[[L:.+]] = vector.expandload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
-// CHECK:         %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK:         %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
 // CHECK:         return %[[B]] : vector<1x4xf32>
 func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
   %0 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
@@ -816,11 +734,11 @@ func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: inde
 // -----
 
 // CHECK-LABEL: func.func @cast_away_gather_leading_one_dims
-// CHECK:         %[[I:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi32> to vector<4xi32>
-// CHECK:         %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK:         %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK:         %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
+// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK:         %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
 // CHECK:         %[[G:.+]] = vector.gather %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[P]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
-// CHECK:         %[[B:.+]] = vector.shape_cast %[[G]] : vector<4xf32> to vector<1x4xf32>
+// CHECK:         %[[B:.+]] = vector.broadcast %[[G]] : vector<4xf32> to vector<1x4xf32>
 // CHECK:         return %[[B]] : vector<1x4xf32>
 func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
   %0 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
@@ -830,7 +748,7 @@ func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %
 // -----
 
 // CHECK-LABEL: func.func @cast_away_store_leading_one_dims
-// CHECK:         %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
 // CHECK:         vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
 func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref<8x16xf32>, %i: index, %j: index) {
   vector.store %val, %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
@@ -839,29 +757,9 @@ func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref
 
 // -----
 
-// CHECK-LABEL: func.func @cast_away_store_all_unit_dims
-// CHECK:         %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<f32>
-// CHECK:         vector.store %[[V]], %{{.*}}[%{{.*}}] : memref<1xf32>, vector<f32>
-func.func @cast_away_store_all_unit_dims(%val: vector<1xf32>, %base: memref<1xf32>, %i: index) {
-  vector.store %val, %base[%i] : memref<1xf32>, vector<1xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func.func @cast_away_store_leading_one_dims_scalable
-// CHECK:         %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]xf32>
-// CHECK:         vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, vector<[4]xf32>
-func.func @cast_away_store_leading_one_dims_scalable(%val: vector<1x[4]xf32>, %base: memref<?x?xf32>, %i: index, %j: index) {
-  vector.store %val, %base[%i, %j] : memref<?x?xf32>, vector<1x[4]xf32>
-  return
-}
-
-// -----
-
 // CHECK-LABEL: func.func @cast_away_maskedstore_leading_one_dims
-// CHECK:         %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK:         %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
 // CHECK:         vector.maskedstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
 func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
   vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
@@ -871,8 +769,8 @@ func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: ind
 // -----
 
 // CHECK-LABEL: func.func @cast_away_compressstore_leading_one_dims
-// CHECK:         %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK:         %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
 // CHECK:         vector.compressstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
 func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
   vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
@@ -882,41 +780,11 @@ func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: i
 // -----
 
 // CHECK-LABEL: func.func @cast_away_scatter_leading_one_dims
-// CHECK:         %[[I:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi32> to vector<4xi32>
-// CHECK:         %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1>
-// CHECK:         %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+// CHECK:         %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
+// CHECK:         %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK:         %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
 // CHECK:         vector.scatter %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[V]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
 func.func @cast_away_scatter_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
   vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32>
   return
 }
-
-// -----
-
-// CHECK-LABEL: func.func @negative_cast_memory_ops_to_0d
-//   CHECK-NOT:   vector.shape_cast
-//       CHECK:   vector.maskedload {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-//   CHECK-NOT:   vector.shape_cast
-//       CHECK:   vector.expandload {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-//   CHECK-NOT:   vector.shape_cast
-//       CHECK:   vector.gather {{.*}} : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-//   CHECK-NOT:   vector.shape_cast
-//       CHECK:   vector.maskedstore {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32>
-//   CHECK-NOT:   vector.shape_cast
-//       CHECK:   vector.compressstore {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32>
-//   CHECK-NOT:   vector.shape_cast
-//       CHECK:   vector.scatter {{.*}} : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32>
-//   CHECK-NOT:   vector.shape_cast
-//       CHECK:   return
-func.func @negative_cast_memory_ops_to_0d(
-    %base: memref<16xf32>, %i: index, %idx: vector<1xi32>,
-    %mask: vector<1xi1>, %pass: vector<1xf32>, %val: vector<1xf32>)
-    -> (vector<1xf32>, vector<1xf32>, vector<1xf32>) {
-  %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-  %1 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-  %2 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> into vector<1xf32>
-  vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1xi1>, vector<1xf32>
-  vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1xi1>, vector<1xf32>
-  vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32>
-  return %0, %1, %2 : vector<1xf32>, vector<1xf32>, vector<1xf32>
-}

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index d0d3a6c0bb976..de12a87253a67 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -36,7 +36,7 @@ func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) ->
 
 // CHECK-LABEL:   func.func @cast_away_leading_one_dim(
 // CHECK:           %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32>
-// CHECK:           vector.shape_cast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
+// CHECK:           vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
 func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> {
   %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
   return %1: vector<1x4x1xf32>
@@ -44,7 +44,7 @@ func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4
 
 // CHECK-LABEL:   func.func @cast_away_leading_one_dim_scalable(
 // CHECK:           %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32>
-// CHECK:           vector.shape_cast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
+// CHECK:           vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
 func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> {
   %1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32>
   return %1: vector<1x[4]x1xf32>
@@ -277,15 +277,13 @@ func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
 func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
   // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : f32 from vector<4xf32>
-  // CHECK:  %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [] : f32 into vector<f32>
-  // CHECK:  %[[SHAPE_CAST1:.+]] = vector.shape_cast %[[INSERT1]] : vector<f32> to vector<1xf32>
-  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[SHAPE_CAST1]] : vector<1xf32> to vector<2xf16>
+  // CHECK:  %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32>
+  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16>
   // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : f16 from vector<2xf16>
   %1 = vector.extract %0[3] : f16 from vector<8xf16>
   // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : f32 from vector<4xf32>
-  // CHECK:  %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [] : f32 into vector<f32>
-  // CHECK:  %[[SHAPE_CAST2:.+]] = vector.shape_cast %[[INSERT3]] : vector<f32> to vector<1xf32>
-  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[SHAPE_CAST2]] : vector<1xf32> to vector<2xf16>
+  // CHECK:  %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32>
+  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16>
   // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : f16 from vector<2xf16>
   %2 = vector.extract %0[4] : f16 from vector<8xf16>
   // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]


        


More information about the Mlir-commits mailing list