[Mlir-commits] [mlir] [vector][mlir] Canonicalize to shape_cast where possible (PR #140583)

James Newling llvmlistbot at llvm.org
Thu Jun 26 10:59:50 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/140583

>From cce1483dc8fbc696904581baba3b8d1318847ab9 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 5 Jun 2025 11:07:43 -0700
Subject: [PATCH 1/6] use shape_cast as canonical type for extract broadcast
 and transpose

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 236 ++++++++++--------
 .../Transforms/LowerVectorTranspose.cpp       |  60 +----
 mlir/test/Dialect/Vector/canonicalize.mlir    |  52 ++--
 .../canonicalize/vector-from-elements.mlir    |   4 +-
 .../canonicalize/vector-shape-cast.mlir       | 164 ++++++++++++
 .../Vector/canonicalize/vector-transpose.mlir |   2 -
 .../drop-unit-dims-with-shape-cast.mlir       |  12 -
 ...vector-shape-cast-lowering-transforms.mlir |  60 +++++
 .../vector-transfer-to-vector-load-store.mlir |  12 +-
 .../Vector/vector-transpose-lowering.mlir     |  85 -------
 .../Vector/vector-warp-distribute.mlir        |   8 +-
 11 files changed, 379 insertions(+), 316 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..244860530dca3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, cannot convert to
+    // shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -2867,13 +2901,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5816,30 +5877,6 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
-/// Return true if `transpose` does not permute a pair of non-unit dims.
-/// By `order preserving` we mean that the flattened versions of the input and
-/// output vectors are (numerically) identical. In other words `transpose` is
-/// effectively a shape cast.
-static bool isOrderPreserving(TransposeOp transpose) {
-  ArrayRef<int64_t> permutation = transpose.getPermutation();
-  VectorType sourceType = transpose.getSourceVectorType();
-  ArrayRef<int64_t> inShape = sourceType.getShape();
-  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
-  auto isNonScalableUnitDim = [&](int64_t dim) {
-    return inShape[dim] == 1 && !inDimIsScalable[dim];
-  };
-  int64_t current = 0;
-  for (auto p : permutation) {
-    if (!isNonScalableUnitDim(p)) {
-      if (p < current) {
-        return false;
-      }
-      current = p;
-    }
-  }
-  return true;
-}
-
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
@@ -5854,22 +5891,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
-  // shape_cast(transpose(x)) -> shape_cast(x)
-  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
-    if (isOrderPreserving(transpose)) {
-      setOperand(transpose.getVector());
-      return getResult();
-    }
-    return {};
-  }
-
-  // Y = shape_cast(broadcast(X))
-  //      -> X, if X and Y have same type
-  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
-    if (bcastOp.getSourceType() == resultType)
-      return bcastOp.getSource();
-  }
-
   // shape_cast(constant) -> constant
   if (auto splatAttr =
           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
@@ -5991,10 +6012,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6027,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6225,21 +6227,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
     return ub::PoisonAttr::get(getContext());
 
-  // Eliminate identity transposes, and more generally any transposes that
-  // preserves the shape without permuting elements.
-  //
-  // Examples of what to fold:
-  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
-  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
-  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
-  //
-  // Example of what NOT to fold:
-  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
-  //
-  if (getSourceVectorType() == getResultVectorType() &&
-      isOrderPreserving(*this))
-    return getVector();
-
   return {};
 }
 
@@ -6359,32 +6346,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
-                                PatternRewriter &rewriter) const override {
-    auto shapeCastOp =
-        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
-    if (!shapeCastOp)
-      return failure();
-    if (!isOrderPreserving(transposeOp))
-      return failure();
-
-    VectorType resultType = transposeOp.getType();
-
-    // We don't need to check isValidShapeCast at this point, because it is
-    // guaranteed that merging the transpose into the the shape_cast is a valid
-    // shape_cast, because the transpose just inserts/removes ones.
-
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
-                                                     shapeCastOp.getSource());
-    return success();
-  }
-};
-
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6441,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+static bool isOrderPreserving(TransposeOp transpose) {
+  ArrayRef<int64_t> permutation = transpose.getPermutation();
+  VectorType sourceType = transpose.getSourceVectorType();
+  ArrayRef<int64_t> inShape = sourceType.getShape();
+  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
+  auto isNonScalableUnitDim = [&](int64_t dim) {
+    return inShape[dim] == 1 && !inDimIsScalable[dim];
+  };
+  int64_t current = 0;
+  for (auto p : permutation) {
+    if (!isNonScalableUnitDim(p)) {
+      if (p < current) {
+        return false;
+      }
+      current = p;
+    }
+  }
+  return true;
+}
+
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    // This folder does
+    //    shape_cast(transpose) -> shape_cast
+    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
+    //    shape_cast -> shape_cast(transpose)
+    // i.e. the complete opposite. When paired, these 2 patterns can cause
+    // infinite cycles in pattern rewriting.
+    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
+    // vectors, so by disabling this folder for scalable vectors the
+    // cycle is avoided.
+    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
+    // still needed. If it's not, then we can fold here.
+    if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+  results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
+              FoldTransposeBroadcast, TransposeToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..4d6175fe6428b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -382,63 +382,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-///                                 vector<1x4xi32>
-/// with:
-///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-///  BEFORE:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  AFTER:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
-    : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
-                                    PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getVector();
-    VectorType resType = op.getResultVectorType();
-
-    // Set up convenience transposition table.
-    ArrayRef<int64_t> transp = op.getPermutation();
-
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    return failure();
-  }
-};
 
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
@@ -511,8 +454,7 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
+  BroadcastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..306a34e41bd82 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -753,12 +753,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
+
 // CHECK-LABEL: negative_fold_extract_broadcast
-//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
-  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -797,8 +798,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // rank(extract_output) < rank(broadcast_input)
 func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -1033,30 +1034,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
 
 // -----
 
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
-  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
-  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
-  return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
-    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
-    return %1 : vector<1x1xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1897,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
   %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
   %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2254,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -2764,9 +2741,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 // CHECK-LABEL: func.func @extract_from_broadcast
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
-  //  CHECK-NEXT:   return %0 : vector<1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
+//       CHECK:       %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[SC]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
new file mode 100644
index 0000000000000..357df0f129a5e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize |  FileCheck %s
+
+
+// +----------------------------------------
+//  Tests of BroadcastToShapeCast
+// +----------------------------------------
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+  return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// broadcast can only be transformed to a shape_cast if the number of elements is
+// unchanged by the broadcast
+// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+  return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
+// cannot be transformed to a shape_cast.
+// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
+  %0 = vector.broadcast %arg0 : i8 to vector<1xi8>
+  return %0 : vector<1xi8>
+}
+
+// -----
+
+// +----------------------------------------
+//  Tests of TransposeToShapeCast
+// +----------------------------------------
+
+// In this test, the permutation maps the non-unit dimensions (0 and 2) as follows:
+// 0 -> 0
+// 2 -> 1
+// Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @shape_cast_of_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]]
+func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]  : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+  return %0 : vector<4x1x1x1x4xi8>
+}
+
+// -----
+
+// Scalable dimensions should be treated as non-unit dimensions.
+// CHECK-LABEL: @transpose_scalable_unit
+//   CHECK-NOT: shape_cast
+func.func @transpose_scalable_unit(%arg : vector<[1]x4xi8>) -> vector<4x[1]xi8> {
+  %0 = vector.transpose %arg, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
+  return %0 : vector<4x[1]xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<1x4x4x1xi8> {
+  %0 = vector.transpose %arg, [0, 2, 1, 3]
+     : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+  return %0 : vector<1x4x4x1xi8>
+}
+
+// -----
+
+// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
+// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
+// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
+//       CHECK:  vector.transpose
+//       CHECK:  vector.shape_cast
+func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
+  %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
+  %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
+  return %1 : vector<[4]xi8>
+}
+
+// -----
+
+// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
+// vectors.
+// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable
+//       CHECK: vector.shape_cast
+//       CHECK: vector.transpose
+func.func @negative_transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
+  %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
+  return %1 : vector<[4]x1xi8>
+}
+
+// -----
+
+// A test where a transpose cannot be transformed to a shape_cast because it is not order
+// preserving
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+//  CHECK-NEXT:  return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// +----------------------------------------
+//  Tests of ExtractToShapeCast
+// +----------------------------------------
+
+// CHECK-LABEL: @extract_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+  %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+  %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index f1e1c5e896c66..8778669149e2c 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -141,8 +141,6 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
   return %1 : vector<3x3x3xi8>
 }
 
-// -----
-
 /// +--------------------------------------------------------------------------
 ///  Tests of ShapeCastOp::fold:  shape_cast(transpose) -> shape_cast
 /// +--------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index 34a155fbf2fc1..44abe2ac46fce 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -188,18 +188,6 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
 
 // -----
 
-func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
-  %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
-  return %res : vector<1x1x1xf32>
-}
-// The `vec` is returned because there are other flattening patterns that fold
-// vector.shape_cast ops away.
-// CHECK-LABEL: func.func @transpose_with_all_unit_dims
-// CHECK-SAME:      %[[VEC:.[a-zA-Z0-9]+]]
-// CHECK-NEXT:    return %[[VEC]]
-
-// -----
-
 func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
   %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
   return %res : vector<4x3x2xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 5011d8b2b2ef6..97a8a9a9c2597 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -387,6 +387,66 @@ func.func @non_dividing_gcd_increasing(%arg0 : vector<3x10xi8>) -> vector<2x15xi
   return %0 : vector<2x15xi8>
 }
 
+// **--------------------------------------------------------** //
+//   Tests where the shape_cast is equivalent to a transpose
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: func @transpose102_1x8x8xf32
+//       CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32>
+func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
+  %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32>
+  return %0 : vector<8x1x8xf32>
+}
+
+// CHECK-LABEL: func @transpose102_8x1x8xf32
+//       CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32>
+func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
+  %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32>
+  return %0 : vector<1x8x8xf32>
+}
+
+// CHECK-LABEL:   func @transpose1023_2x1x8x4xf32(
+//       CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32>
+//  CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+//  CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
+func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
+  // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed!
+  %0 = vector.shape_cast %arg0 : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
+  return %0 : vector<1x2x8x4xf32>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 511ab70f35086..7886fba6c80c4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK:   %[[S:.*]] = vector.transfer_read %[[SRC]][]
-//  CHECK:   %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
+//  CHECK:   %[[V:.*]] = vector.shape_cast %[[S]] : vector<f32> to vector<1xf32>
     %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
       tensor<f32>, vector<1xf32>
 
@@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
   %c0 = arith.constant 0 : index
 
   %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
-  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
+  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
 
   return %res : tensor<?x?x?x?xf32>
 }
@@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref(
   %c0 = arith.constant 0 : index
 
   vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
-  // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
 
   return
 }
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index a730f217f027d..dc9350f706fcd 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -21,61 +21,6 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
   return %0 : vector<3x2xf32>
 }
 
-// CHECK-LABEL: func @transpose102_1x8x8xf32
-func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32>
-  %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
-  return %0 : vector<8x1x8xf32>
-}
-
-// CHECK-LABEL: func @transpose102_8x1x8xf32
-func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32>
-  %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
-  return %0 : vector<1x8x8xf32>
-}
-
-// CHECK-LABEL:   func @transpose1023_2x1x8x4xf32(
-func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
-  // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed!
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
-  %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
-  return %0 : vector<1x2x8x4xf32>
-}
-
 /// Scalable dim should not be unrolled.
 
 // CHECK-LABEL: func @transpose23_scalable
@@ -348,36 +293,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
-
-// CHECK-LABEL: func @transpose10_4x1xf32
-func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
-  return %0 : vector<1x4xf32>
-}
-
-// CHECK-LABEL: func @transpose10_nx4x1xf32
-func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// CHECK-LABEL: func @transpose10_1x4xf32
-func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
-  return %0 : vector<4x1xf32>
-}
-
-// CHECK-LABEL: func @transpose10_1xnx4xf32
-func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
-  return %0 : vector<[4]x1xf32>
-}
-
 /// Scalable unit dim should not be lowered to shape_cast.
 
 // CHECK-LABEL: func @transpose10_4x1xf32_scalable
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..ba47799729f1d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1311,8 +1311,8 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
 //   CHECK-PROP-DAG:   %[[THREADID:.*]] = gpu.thread_id  x
 //       CHECK-PROP:   %[[W:.*]] = gpu.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]]
 //       CHECK-PROP:     %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}]
-//       CHECK-PROP:     %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
-//       CHECK-PROP:     %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
+//       CHECK-PROP:     %[[SHAPE_CAST:.*]] = vector.shape_cast %[[GATHER]] :  vector<1x64xi32> to vector<64xi32>
+//       CHECK-PROP:     %[[CAST:.*]] = arith.index_cast %[[SHAPE_CAST]] : vector<64xi32> to vector<64xindex>
 //       CHECK-PROP:     %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex>
 //       CHECK-PROP:     gpu.yield %[[EXTRACTELT]] : index
 //       CHECK-PROP:   %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]]
@@ -1348,8 +1348,8 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 :  memref<1
 // CHECK-PROP-LABEL: func @dont_fold_vector_broadcast(
 //       CHECK-PROP:   %[[r:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>)
 //       CHECK-PROP:     %[[some_def:.*]] = "some_def"
-//       CHECK-PROP:     %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
-//       CHECK-PROP:     gpu.yield %[[broadcast]] : vector<1x64xf32>
+//       CHECK-PROP:     %[[shape_cast:.*]] = vector.shape_cast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
+//       CHECK-PROP:     gpu.yield %[[shape_cast]] : vector<1x64xf32>
 //       CHECK-PROP:   vector.print %[[r]] : vector<1x2xf32>
 func.func @dont_fold_vector_broadcast(%laneid: index) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) {

>From f2e541725b70527abcd3f3c03c0379082d329d3a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 25 Jun 2025 16:04:08 -0700
Subject: [PATCH 2/6] simplifying tweaks

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 13 +-----------
 .../canonicalize/vector-shape-cast.mlir       | 20 ++++++++-----------
 2 files changed, 9 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 244860530dca3..ab92c164876e6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6481,18 +6481,7 @@ struct TransposeToShapeCast final
   LogicalResult matchAndRewrite(vector::TransposeOp transpose,
                                 PatternRewriter &rewriter) const override {
 
-    // This folder does
-    //    shape_cast(transpose) -> shape_cast
-    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
-    //    shape_cast -> shape_cast(transpose)
-    // i.e. the complete opposite. When paired, these 2 patterns can cause
-    // infinite cycles in pattern rewriting.
-    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
-    // vectors, so by disabling this folder for scalable vectors the
-    // cycle is avoided.
-    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
-    // still needed. If it's not, then we can fold here.
-    if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
+    if (!isOrderPreserving(transpose)) {
       return rewriter.notifyMatchFailure(
           transpose, "not order preserving, so not semantically a 'copy'");
     }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
index 357df0f129a5e..342aebce4f522 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
@@ -100,12 +100,10 @@ func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector
 
 // -----
 
-// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
-// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
-// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
-//       CHECK:  vector.transpose
-//       CHECK:  vector.shape_cast
-func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
+// CHECK-LABEL: @shape_cast_of_transpose_scalable
+//  CHECK-NEXT:  vector.shape_cast
+//  CHECK-NEXT:  return
+func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
   %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
   %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
   return %1 : vector<[4]xi8>
@@ -113,12 +111,10 @@ func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) ->
 
 // -----
 
-// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
-// vectors.
-// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable
-//       CHECK: vector.shape_cast
-//       CHECK: vector.transpose
-func.func @negative_transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
+// CHECK-LABEL: @transpose_of_shape_cast_scalable
+//  CHECK-NEXT: vector.shape_cast
+//  CHECK-NEXT: return
+func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
   %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
   %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
   return %1 : vector<[4]x1xi8>

>From 24f7531a58e35970d2033ce68f146b35a7bbd123 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 25 Jun 2025 16:33:20 -0700
Subject: [PATCH 3/6] ArmSME fix

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 28 ++++++-------------
 .../Transforms/LowerVectorTranspose.cpp       |  3 --
 .../Dialect/ArmSME/vector-legalization.mlir   |  8 +++---
 .../Vector/canonicalize/vector-transpose.mlir |  2 ++
 4 files changed, 14 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ab92c164876e6..83f302aff9d60 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,14 +2351,10 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
-/// For example,
-/// ```
+/// BEFORE:
 /// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
-/// ```
-/// becomes
-/// ```
+/// AFTER:
 /// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
-/// ```
 struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
@@ -2368,8 +2364,8 @@ struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
     if (!outType)
       return failure();
 
-    // Negative values in `position` indicates poison, cannot convert to
-    // shape_cast
+    // Negative values in `position` indicates poison, which cannot be
+    // represented with a shape_cast
     if (llvm::any_of(extractOp.getMixedPosition(),
                      [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
       return failure();
@@ -2902,14 +2898,10 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
   }
 };
 
-/// For example,
-/// ```
+/// BEFORE:
 /// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
-/// ```
-/// becomes
-/// ```
+/// AFTER:
 /// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
-/// ```
 struct BroadcastToShapeCast final
     : public OpRewritePattern<vector::BroadcastOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -6465,16 +6457,12 @@ static bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
-/// For example,
-/// ```
+/// BEFORE:
 /// %0 = vector.transpose %arg0, [0, 2, 1] :
 ///                   vector<2x1x2xf32> to vector<2x2x1xf32>
-/// ```
-/// becomes
-/// ```
+/// AFTER:
 /// %0 = vector.shape_cast %arg0 :
 ///                   vector<2x1x2xf32> to vector<2x2x1xf32>
-/// ```
 struct TransposeToShapeCast final
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 4d6175fe6428b..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,7 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
-
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -454,7 +452,6 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
-  BroadcastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
 
 // -----
 
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
 {
-  // CHECK: vector.transpose
+  // CHECK:      vector.shape_cast
+  // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 8778669149e2c..f1e1c5e896c66 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -141,6 +141,8 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
   return %1 : vector<3x3x3xi8>
 }
 
+// -----
+
 /// +--------------------------------------------------------------------------
 ///  Tests of ShapeCastOp::fold:  shape_cast(transpose) -> shape_cast
 /// +--------------------------------------------------------------------------

>From e6735226358f110cf1de0760ac924507fc0f7e7c Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 26 Jun 2025 08:32:18 -0700
Subject: [PATCH 4/6] reinstate folders

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 79 +++++++++++++------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 15 +++-
 .../canonicalize/vector-shape-cast.mlir       | 36 ++++-----
 .../drop-unit-dims-with-shape-cast.mlir       | 12 +++
 mlir/test/Dialect/Vector/single-fold.mlir     |  3 +-
 5 files changed, 99 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 83f302aff9d60..b51ca3445820a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5869,6 +5869,30 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+static bool isOrderPreserving(TransposeOp transpose) {
+  ArrayRef<int64_t> permutation = transpose.getPermutation();
+  VectorType sourceType = transpose.getSourceVectorType();
+  ArrayRef<int64_t> inShape = sourceType.getShape();
+  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
+  auto isNonScalableUnitDim = [&](int64_t dim) {
+    return inShape[dim] == 1 && !inDimIsScalable[dim];
+  };
+  int64_t current = 0;
+  for (auto p : permutation) {
+    if (!isNonScalableUnitDim(p)) {
+      if (p < current) {
+        return false;
+      }
+      current = p;
+    }
+  }
+  return true;
+}
+
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
@@ -5883,6 +5907,22 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
+  // shape_cast(transpose(x)) -> shape_cast(x)
+  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+    if (isOrderPreserving(transpose)) {
+      setOperand(transpose.getVector());
+      return getResult();
+    }
+    return {};
+  }
+
+  // Y = shape_cast(broadcast(X))
+  //      -> X, if X and Y have same type
+  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
+    if (bcastOp.getSourceType() == resultType)
+      return bcastOp.getSource();
+  }
+
   // shape_cast(constant) -> constant
   if (auto splatAttr =
           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
@@ -6219,6 +6259,21 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
     return ub::PoisonAttr::get(getContext());
 
+  // Eliminate identity transposes, and more generally any transposes that
+  // preserves the shape without permuting elements.
+  //
+  // Examples of what to fold:
+  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
+  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
+  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
+  //
+  // Example of what NOT to fold:
+  //
+  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+  if (getSourceVectorType() == getResultVectorType() &&
+      isOrderPreserving(*this))
+    return getVector();
+
   return {};
 }
 
@@ -6433,30 +6488,6 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
-/// Return true if `transpose` does not permute a pair of non-unit dims.
-/// By `order preserving` we mean that the flattened versions of the input and
-/// output vectors are (numerically) identical. In other words `transpose` is
-/// effectively a shape cast.
-static bool isOrderPreserving(TransposeOp transpose) {
-  ArrayRef<int64_t> permutation = transpose.getPermutation();
-  VectorType sourceType = transpose.getSourceVectorType();
-  ArrayRef<int64_t> inShape = sourceType.getShape();
-  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
-  auto isNonScalableUnitDim = [&](int64_t dim) {
-    return inShape[dim] == 1 && !inDimIsScalable[dim];
-  };
-  int64_t current = 0;
-  for (auto p : permutation) {
-    if (!isNonScalableUnitDim(p)) {
-      if (p < current) {
-        return false;
-      }
-      current = p;
-    }
-  }
-  return true;
-}
-
 /// BEFORE:
 /// %0 = vector.transpose %arg0, [0, 2, 1] :
 ///                   vector<2x1x2xf32> to vector<2x2x1xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 306a34e41bd82..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 // -----
 
 // CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-NEXT: return [[ARG]]
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
-  // CHECK-NOT: transpose
   %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
-  // CHECK-NEXT: return [[ARG]]
   return %0 : vector<4x3x2xf32>
 }
 
 // -----
 
+// CHECK-LABEL: transpose_0D_identity
+//  CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+//  CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+  %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+  return %0 : vector<i8>
+}
+
+// -----
+
 // CHECK-LABEL: transpose_2D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
 func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
index 342aebce4f522..8aaf1783dd7e6 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
@@ -6,9 +6,9 @@
 // +----------------------------------------
 
 // CHECK-LABEL: @broadcast_to_shape_cast
-//  CHECK-SAME:  %[[ARG0:.*]]: vector<4xi8>
-//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
-//  CHECK-NEXT:  return %[[SCAST]] : vector<1x1x4xi8>
+//  CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
 func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
   %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
   return %0 : vector<1x1x4xi8>
@@ -49,9 +49,9 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
 // 2 -> 1
 // Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
 // CHECK-LABEL: @transpose_to_shape_cast
-//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
-//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
-//  CHECK-NEXT:  return %[[SCAST]] : vector<2x2x1xf32>
+//  CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
 func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
   %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
   return %0 : vector<2x2x1xf32>
@@ -64,10 +64,10 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3
 // 2 -> 4
 // Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
 // CHECK-LABEL: @shape_cast_of_transpose
-//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
-//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
-//  CHECK-SAME:   vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
-//       CHECK:   return %[[SHAPE_CAST]]
+//  CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
+//       CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+//       CHECK: return %[[SHAPE_CAST]]
 func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> {
   %0 = vector.transpose %arg, [1, 0, 3, 4, 2]  : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
   return %0 : vector<4x1x1x1x4xi8>
@@ -101,8 +101,8 @@ func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector
 // -----
 
 // CHECK-LABEL: @shape_cast_of_transpose_scalable
-//  CHECK-NEXT:  vector.shape_cast
-//  CHECK-NEXT:  return
+//  CHECK-NEXT: vector.shape_cast
+//  CHECK-NEXT: return
 func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
   %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
   %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
@@ -125,9 +125,9 @@ func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]
 // A test where a transpose cannot be transformed to a shape_cast because it is not order
 // preserving
 // CHECK-LABEL: @negative_transpose_to_shape_cast
-//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
-//  CHECK-NEXT:  %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
-//  CHECK-NEXT:  return %[[TRANSPOSE]] : vector<2x2x1xf32>
+//  CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+//  CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
 func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
   %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
   return %0 : vector<2x2x1xf32>
@@ -140,9 +140,9 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector
 // +----------------------------------------
 
 // CHECK-LABEL: @extract_to_shape_cast
-//  CHECK-SAME:  %[[ARG0:.*]]: vector<1x4xf32>
-//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
-//  CHECK-NEXT:  return %[[SCAST]] : vector<4xf32>
+//  CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
 func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
   %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
   return %0 : vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index 44abe2ac46fce..34a155fbf2fc1 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -188,6 +188,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
 
 // -----
 
+func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
+  %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
+  return %res : vector<1x1x1xf32>
+}
+// The `vec` is returned because there are other flattening patterns that fold
+// vector.shape_cast ops away.
+// CHECK-LABEL: func.func @transpose_with_all_unit_dims
+// CHECK-SAME:      %[[VEC:.[a-zA-Z0-9]+]]
+// CHECK-NEXT:    return %[[VEC]]
+
+// -----
+
 func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
   %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
   return %res : vector<4x3x2xf32>
diff --git a/mlir/test/Dialect/Vector/single-fold.mlir b/mlir/test/Dialect/Vector/single-fold.mlir
index baccdc3f51c05..866b1563699eb 100644
--- a/mlir/test/Dialect/Vector/single-fold.mlir
+++ b/mlir/test/Dialect/Vector/single-fold.mlir
@@ -35,4 +35,5 @@ func.func @fold_insert_in_single_pass() -> vector<2xf16> {
   // CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
   %0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
   return %0 : vector<2xf16>
-} 
\ No newline at end of file
+}
+

>From aa992927e8498f3f12d193f13274387ec132bf56 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 26 Jun 2025 08:49:34 -0700
Subject: [PATCH 5/6] minor tweaks

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp                 | 9 +++++----
 .../Dialect/Vector/canonicalize/vector-shape-cast.mlir   | 2 ++
 mlir/test/Dialect/Vector/single-fold.mlir                | 3 +--
 3 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b51ca3445820a..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6267,9 +6267,9 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
   // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
   //
-  // Example of what NOT to fold:
-  //
+  // Example of what not to fold:
   // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+  //
   if (getSourceVectorType() == getResultVectorType() &&
       isOrderPreserving(*this))
     return getVector();
@@ -6514,8 +6514,9 @@ struct TransposeToShapeCast final
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
-              FoldTransposeBroadcast, TransposeToShapeCast>(context);
+  results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+              FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
index 8aaf1783dd7e6..e249a6afcc993 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -canonicalize |  FileCheck %s
 
+// This file contains tests where there a vector.shape_cast gets canonicalized, or where a
+// vector.shape_cast is the result of a canonicalization. Not all such tests must live in this file.
 
 // +----------------------------------------
 //  Tests of BroadcastToShapeCast
diff --git a/mlir/test/Dialect/Vector/single-fold.mlir b/mlir/test/Dialect/Vector/single-fold.mlir
index 866b1563699eb..baccdc3f51c05 100644
--- a/mlir/test/Dialect/Vector/single-fold.mlir
+++ b/mlir/test/Dialect/Vector/single-fold.mlir
@@ -35,5 +35,4 @@ func.func @fold_insert_in_single_pass() -> vector<2xf16> {
   // CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
   %0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
   return %0 : vector<2xf16>
-}
-
+} 
\ No newline at end of file

>From 7ad1802e8fcba780aa0b2d578e9b08012cf4cd29 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 26 Jun 2025 10:59:34 -0700
Subject: [PATCH 6/6] reintroduce removed tests

---
 .../Transforms/LowerVectorTranspose.cpp       |  2 +
 mlir/test/Dialect/Vector/canonicalize.mlir    | 24 +++++++++
 .../Vector/vector-transpose-lowering.mlir     | 51 +++++++++++++++++++
 3 files changed, 77 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 71410eda28297..3d12527b86283 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -452,6 +452,8 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
+  TransposeOp::getCanonicalizationPatterns(patterns, patterns.getContext());
+  ShapeCastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 374c71c814e89..4f8c4525f2df3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1043,6 +1043,30 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
 
 // -----
 
+// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
+  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
+  return %1 : vector<1x2x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
+    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+    return %1 : vector<1x1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index dc9350f706fcd..e5dfb0645eeb9 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -21,6 +21,27 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: func @transpose102_1x8x8xf32
+func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
+  //     CHECK: vector.shape_cast
+  %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
+  return %0 : vector<8x1x8xf32>
+}
+
+// CHECK-LABEL: func @transpose102_8x1x8xf32
+func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
+  //     CHECK: vector.shape_cast
+  %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
+  return %0 : vector<1x8x8xf32>
+}
+
+// CHECK-LABEL: func @transpose1023_2x1x8x4xf32(
+func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
+  //     CHECK: vector.shape_cast
+  %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
+  return %0 : vector<1x2x8x4xf32>
+}
+
 /// Scalable dim should not be unrolled.
 
 // CHECK-LABEL: func @transpose23_scalable
@@ -293,6 +314,36 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4x1xf32
+func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4x1xf32
+func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1x4xf32
+func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+  return %0 : vector<4x1xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1xnx4xf32
+func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
+  return %0 : vector<[4]x1xf32>
+}
+
 /// Scalable unit dim should not be lowered to shape_cast.
 
 // CHECK-LABEL: func @transpose10_4x1xf32_scalable



More information about the Mlir-commits mailing list