[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)

James Newling llvmlistbot at llvm.org
Tue Jun 3 14:31:51 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/140800

>From 49ebfa731b9168cad0946b6e43229ae0ce064bd7 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 20 May 2025 13:07:23 -0700
Subject: [PATCH 1/5] extract as large a chunk as possible in shape_cast
 lowering

---
 .../Transforms/LowerVectorShapeCast.cpp       | 80 +++++++++++--------
 ...vector-shape-cast-lowering-transforms.mlir | 51 ++++++++++++
 2 files changed, 99 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
 using namespace mlir::vector;
 
 /// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
                    int step = 1) {
   for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
-    assert(indices[dim] < vecType.getDimSize(dim) &&
-           "Indices are out of bound");
+    int64_t dimSize = shape[dim];
+    assert(indices[dim] < dimSize && "Indices are out of bound");
+
     indices[dim] += step;
-    if (indices[dim] < vecType.getDimSize(dim))
+
+    int64_t spill = indices[dim] / dimSize;
+    if (spill == 0)
       break;
 
-    indices[dim] = 0;
-    step = 1;
+    indices[dim] %= dimSize;
+    step = spill;
   }
 }
 
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
     // and destination slice insertion and generate such instructions.
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/1);
-        incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
       }
 
       Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
     Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
-        incIdx(resIdx, resultVectorType, /*step=*/1);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
       }
 
       Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType resultType = op.getResultVectorType();
 
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+    if (sourceType.isScalable() || resultType.isScalable())
       return failure();
 
-    // Special case for n-D / 1-D lowerings with better implementations.
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+    // Special case for n-D / 1-D lowerings with implementations that use
+    // extract_strided_slice / insert_strided_slice.
+    int64_t sourceRank = sourceType.getRank();
+    int64_t resultRank = resultType.getRank();
+    if ((sourceRank > 1 && resultRank == 1) ||
+        (sourceRank == 1 && resultRank > 1))
       return failure();
 
-    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
-    // extract/insert chains.
-    int64_t numElts = 1;
-    for (int64_t r = 0; r < srcRank; r++)
-      numElts *= sourceVectorType.getDimSize(r);
+    int64_t numExtracts = sourceType.getNumElements();
+    int64_t nbCommonInnerDims = 0;
+    while (true) {
+      int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+      int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+      if (sourceDim < 0 || resultDim < 0)
+        break;
+      int64_t dimSize = sourceType.getDimSize(sourceDim);
+      if (dimSize != resultType.getDimSize(resultDim))
+        break;
+      numExtracts /= dimSize;
+      ++nbCommonInnerDims;
+    }
+
     // Replace with data movement operations:
     //    x[0,0,0] = y[0,0]
     //    x[0,0,1] = y[0,1]
     //    x[0,1,0] = y[0,2]
     // etc., incrementing the two index vectors "row-major"
     // within the source and result shape.
-    SmallVector<int64_t> srcIdx(srcRank, 0);
-    SmallVector<int64_t> resIdx(resRank, 0);
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-    for (int64_t i = 0; i < numElts; i++) {
+    SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+    SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+    for (int64_t i = 0; i < numExtracts; i++) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType);
-        incIdx(resIdx, resultVectorType);
+        incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+        incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
       }
 
       Value extract =
-          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+      result =
+          rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
 
       // 4. Increment the insert/extract indices, stepping by minExtractionSize
       // for the trailing dimensions.
-      incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
-      incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+      incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+      incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
     }
 
     rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..044a73d71e665 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,57 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
   return %s : vector<f32>
 }
 
+
+// CHECK-LABEL:  func.func @shape_cast_squeeze_leading_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK:          %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME:               vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK:          return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @shape_cast_squeeze_middle_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK:         %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK:         %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK:         %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK:         %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         return %[[I1]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @shape_cast_unsqueeze_leading_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK:          %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK:          %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME:               : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK:        return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+  return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @shape_cast_unsqueeze_middle_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK:           %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK:           %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK:           %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK:           %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK:           %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK:           return %[[I1]] : vector<2x1x3xf32>
+func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+  return %s : vector<2x1x3xf32>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

>From 1a4b759e753c5649aceb0a12cfd65f06a3d35a82 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 22 May 2025 11:54:33 -0700
Subject: [PATCH 2/5] test name improvements

---
 .../vector-shape-cast-lowering-transforms.mlir | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 044a73d71e665..2875f159a2df9 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -141,17 +141,19 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
 }
 
 
-// CHECK-LABEL:  func.func @shape_cast_squeeze_leading_one(
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL:  func.func @squeeze_out_prefix_unit_dim(
 // CHECK-SAME:               %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
 // CHECK:          %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
 // CHECK-SAME:               vector<2x3xf32> from vector<1x2x3xf32>
 // CHECK:          return %[[EXTRACTED]] : vector<2x3xf32>
-func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
   %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
   return %s : vector<2x3xf32>
 }
 
-// CHECK-LABEL:  func.func @shape_cast_squeeze_middle_one(
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL:  func.func @squeeze_out_middle_unit_dim(
 // CHECK-SAME:               %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
 // CHECK:         %[[UB:.*]] = ub.poison : vector<2x3xf32>
 // CHECK:         %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
@@ -161,23 +163,23 @@ func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2
 // CHECK:         %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
 // CHECK-SAME:                 into vector<2x3xf32>
 // CHECK:         return %[[I1]] : vector<2x3xf32>
-func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
   %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
   return %s : vector<2x3xf32>
 }
 
-// CHECK-LABEL:  func.func @shape_cast_unsqueeze_leading_one(
+// CHECK-LABEL:  func.func @prepend_unit_dim(
 // CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
 // CHECK:          %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
 // CHECK:          %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
 // CHECK-SAME:               : vector<2x3xf32> into vector<1x2x3xf32>
 // CHECK:        return %[[INSERTED]] : vector<1x2x3xf32>
-func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
   %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
   return %s : vector<1x2x3xf32>
 }
 
-// CHECK-LABEL:  func.func @shape_cast_unsqueeze_middle_one(
+// CHECK-LABEL:  func.func @insert_middle_unit_dim(
 // CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
 // CHECK:           %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
 // CHECK:           %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
@@ -185,7 +187,7 @@ func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1
 // CHECK:           %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
 // CHECK:           %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
 // CHECK:           return %[[I1]] : vector<2x1x3xf32>
-func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
   %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
   return %s : vector<2x1x3xf32>
 }

>From 31ec6c6bf384da3235dc2c93f8c8bc151e8fb261 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 29 May 2025 09:23:43 -0700
Subject: [PATCH 3/5] unify the 3 shape cast lowering patterns, major refactor
 of testing

---
 .../Transforms/LowerVectorShapeCast.cpp       | 388 ++++++++-------
 ...vector-shape-cast-lowering-transforms.mlir | 456 +++++++++++++-----
 2 files changed, 550 insertions(+), 294 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index d0085bffca23c..e84886e285ba9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -21,138 +21,103 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include <numeric>
 
 #define DEBUG_TYPE "vector-shape-cast-lowering"
 
 using namespace mlir;
-using namespace mlir::vector;
 
-/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
-                   int step = 1) {
-  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
-    int64_t dimSize = shape[dim];
-    assert(indices[dim] < dimSize && "Indices are out of bound");
-
-    indices[dim] += step;
-
-    int64_t spill = indices[dim] / dimSize;
-    if (spill == 0)
+/// Perform the inplace update
+///    rhs <- lhs + rhs
+///
+/// where `rhs` is a number expressed in mixed base `base` with most signficant
+/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
+/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
+///
+/// Some examples where `base` is {5,3,2}:
+/// rhs = {0,0,0}, lhs = 1  --> rhs = {0,0,1}
+/// rhs = {0,0,1}, lhs = 1  --> rhs = {0,1,0}
+/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
+///
+/// Invalid:
+/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
+///
+/// Overflows not handled correctly:
+/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
+static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
+                       MutableArrayRef<int64_t> rhs) {
+
+  // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
+  for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
+    int64_t dimBase = base[dim];
+    assert(rhs[dim] < dimBase && "rhs not in base");
+
+    int64_t incremented = rhs[dim] + lhs;
+
+    // If the incremented value excedes the dimension base, we must spill to the
+    // next most significant dimension and repeat (we might need to spill to
+    // more significant dimensions multiple times).
+    lhs = incremented / dimBase;
+    rhs[dim] = incremented % dimBase;
+    if (lhs == 0)
       break;
-
-    indices[dim] %= dimSize;
-    step = spill;
   }
 }
 
 namespace {
-/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
-/// vectors progressively. This iterates over the n-1 major dimensions of the
-/// n-D vector and performs rewrites into:
-///   vector.extract from n-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOpNDDownCastRewritePattern
-    : public OpRewritePattern<vector::ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
-      return failure();
-
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if (srcRank < 2 || resRank != 1)
-      return failure();
-
-    // Compute the number of 1-D vector elements involved in the reshape.
-    int64_t numElts = 1;
-    for (int64_t dim = 0; dim < srcRank - 1; ++dim)
-      numElts *= sourceVectorType.getDimSize(dim);
-
-    auto loc = op.getLoc();
-    SmallVector<int64_t> srcIdx(srcRank - 1, 0);
-    SmallVector<int64_t> resIdx(resRank, 0);
-    int64_t extractSize = sourceVectorType.getShape().back();
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-
-    // Compute the indices of each 1-D vector element of the source extraction
-    // and destination slice insertion and generate such instructions.
-    for (int64_t i = 0; i < numElts; ++i) {
-      if (i != 0) {
-        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
-        incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
-      }
-
-      Value extract =
-          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, extract, result,
-          /*offsets=*/resIdx, /*strides=*/1);
-    }
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
-/// vectors progressively. This iterates over the n-1 major dimension of the n-D
-/// vector and performs rewrites into:
-///   vector.extract_strided_slice from 1-D + vector.insert into n-D
-/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOpNDUpCastRewritePattern
-    : public OpRewritePattern<vector::ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
-      return failure();
-
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if (srcRank != 1 || resRank < 2)
-      return failure();
-
-    // Compute the number of 1-D vector elements involved in the reshape.
-    int64_t numElts = 1;
-    for (int64_t dim = 0; dim < resRank - 1; ++dim)
-      numElts *= resultVectorType.getDimSize(dim);
-
-    // Compute the indices of each 1-D vector element of the source slice
-    // extraction and destination insertion and generate such instructions.
-    auto loc = op.getLoc();
-    SmallVector<int64_t> srcIdx(srcRank, 0);
-    SmallVector<int64_t> resIdx(resRank - 1, 0);
-    int64_t extractSize = resultVectorType.getShape().back();
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-    for (int64_t i = 0; i < numElts; ++i) {
-      if (i != 0) {
-        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
-        incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
-      }
-
-      Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
-          /*strides=*/1);
-      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-// We typically should not lower general shape cast operations into data
-// movement instructions, since the assumption is that these casts are
-// optimized away during progressive lowering. For completeness, however,
-// we fall back to a reference implementation that moves all elements
-// into the right place if we get here.
+/// shape_cast is converted to a sequence of extract, extract_strided_slice,
+/// insert_strided_slice, and insert operations. The running example will be:
+///
+/// %0 = vector.shape_cast %arg0 :
+///         vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
+///
+/// In this example the source and result shapes share a common suffix of 7x11.
+/// This means we can always decompose the shape_cast into extract, insert, and
+/// their strided equivalents, on vectors with shape suffix 7x11.
+///
+/// The greatest common divisor (gcd) of the first dimension preceding the
+/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
+/// on vectors with shapes that are `multiples` of (what we define as) the
+/// 'atomic size', 2x7x11. The atomic size is `gcd` x `common-suffix`.
+///
+///         vector<2x2x3x4x7x11xi8> to
+///             vector<8x6x7x11xi8>
+///                        ^^^^         ---> common suffix of 7x11
+///                      ^              --->    gcd(4,6) is 2 | |
+///                                                         | | |
+///                                                         v v v
+///                                  atomic size   <-----   2x7x11
+///
+///
+///
+/// The decomposition implemented in this patterns consists of a sequence of
+/// repeated steps:
+///
+///  (1) Extract vectors from the suffix of the source.
+///      In our example this is 2x2x3x4x7x11 -> 4x7x11.
+///
+///  (2) Do extract_strided_slice down to the atomic size.
+///      In our example this is 4x7x11 -> 2x7x11.
+///
+///  (3) Do insert_strided_slice to the suffix of the result.
+///      In our example this is 2x7x11 -> 6x7x11.
+///
+///  (4) insert these vectors into the result vector.
+///      In our example this is 6x7x11 -> 8x6x7x11.
+///
+/// These steps occur with different periods. In this example
+///  (1) occurs 12 times,
+///  (2) and (3) occur 24 times, and
+///  (4) occurs 8 times.
+///
+/// Two special cases are handled seperately:
+///  (1) A shape_cast that just does leading 1 insertion/removal
+///  (2) A shape_cast where the gcd is 1.
+///
+/// These 2 cases can have more compact IR generated by not using the generic
+/// algorithm described above.
+///
 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -164,50 +129,149 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     VectorType resultType = op.getResultVectorType();
 
     if (sourceType.isScalable() || resultType.isScalable())
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "shape_cast lowering not handled by this pattern");
+
+    const ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    const ArrayRef<int64_t> resultShape = resultType.getShape();
+    const int64_t sourceRank = sourceType.getRank();
+    const int64_t resultRank = resultType.getRank();
+    const int64_t numElms = sourceType.getNumElements();
+    const Value source = op.getSource();
+
+    // Set the first dimension (starting at the end) in the source and result
+    // respectively where the dimension sizes differ. Using the running example:
+    //
+    //  dimensions:  [0 1 2 3 4 5 ]    [0 1 2 3 ]
+    //  shapes:      (2,2,3,4,7,11) -> (8,6,7,11)
+    //                      ^             ^
+    //                      |             |
+    //        sourceSuffixStartDim is 3   |
+    //                                    |
+    //                               resultSuffixStartDim is 1
+    int64_t sourceSuffixStartDim = sourceRank - 1;
+    int64_t resultSuffixStartDim = resultRank - 1;
+    while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
+           (sourceType.getDimSize(sourceSuffixStartDim) ==
+            resultType.getDimSize(resultSuffixStartDim))) {
+      --sourceSuffixStartDim;
+      --resultSuffixStartDim;
+    }
 
-    // Special case for n-D / 1-D lowerings with implementations that use
-    // extract_strided_slice / insert_strided_slice.
-    int64_t sourceRank = sourceType.getRank();
-    int64_t resultRank = resultType.getRank();
-    if ((sourceRank > 1 && resultRank == 1) ||
-        (sourceRank == 1 && resultRank > 1))
-      return failure();
+    // This is the case where there are just some leading ones to contend with
+    // in the source or result. It can be handled with a single extract/insert
+    // pair.
+    if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) {
+      const int64_t delta = sourceRank - resultRank;
+      const int64_t sourceLeading = delta > 0 ? delta : 0;
+      const int64_t resultLeading = delta > 0 ? 0 : -delta;
+      const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
+      const Value extracted = rewriter.create<vector::ExtractOp>(
+          loc, source, SmallVector<int64_t>(sourceLeading, 0));
+      const Value result = rewriter.create<vector::InsertOp>(
+          loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
+      rewriter.replaceOp(op, result);
+      return success();
+    }
 
-    int64_t numExtracts = sourceType.getNumElements();
-    int64_t nbCommonInnerDims = 0;
-    while (true) {
-      int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
-      int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
-      if (sourceDim < 0 || resultDim < 0)
-        break;
-      int64_t dimSize = sourceType.getDimSize(sourceDim);
-      if (dimSize != resultType.getDimSize(resultDim))
-        break;
-      numExtracts /= dimSize;
-      ++nbCommonInnerDims;
+    const int64_t sourceSuffixStartDimSize =
+        sourceType.getDimSize(sourceSuffixStartDim);
+    const int64_t resultSuffixStartDimSize =
+        resultType.getDimSize(resultSuffixStartDim);
+    const int64_t greatestCommonDivisor =
+        std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
+    const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
+    const size_t extractPeriod =
+        sourceSuffixStartDimSize / greatestCommonDivisor;
+    const size_t insertPeriod =
+        resultSuffixStartDimSize / greatestCommonDivisor;
+
+    SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
+                                     sourceShape.end());
+    atomicShape[0] = greatestCommonDivisor;
+
+    const int64_t numAtomicElms = std::accumulate(
+        atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
+    const size_t nAtomicSlices = numElms / numAtomicElms;
+
+    // This is the case where the strided dimension size is 1. More compact IR
+    // is generated in this case if we just extract and insert the elements
+    // directly. In other words, we don't use extract_strided_slice and
+    // insert_strided_slice.
+    if (greatestCommonDivisor == 1) {
+      sourceSuffixStartDim += 1;
+      resultSuffixStartDim += 1;
+      SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
+      SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
+      Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+      for (size_t i = 0; i < nAtomicSlices; ++i) {
+        Value extracted =
+            rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
+
+        result = rewriter.create<vector::InsertOp>(loc, extracted, result,
+                                                   insertIndex);
+
+        inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
+                   extractIndex);
+        inplaceAdd(1, resultShape.take_front(resultSuffixStartDim),
+                   insertIndex);
+      }
+      rewriter.replaceOp(op, result);
+      return success();
     }
 
-    // Replace with data movement operations:
-    //    x[0,0,0] = y[0,0]
-    //    x[0,0,1] = y[0,1]
-    //    x[0,1,0] = y[0,2]
-    // etc., incrementing the two index vectors "row-major"
-    // within the source and result shape.
-    SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
-    SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+    // The insert_strided_slice result's type
+    const ArrayRef<int64_t> insertStridedShape =
+        resultShape.drop_front(resultSuffixStartDim);
+    const VectorType insertStridedType =
+        VectorType::get(insertStridedShape, resultType.getElementType());
+
+    SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
+    SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
+    SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
+    SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
+    const SmallVector<int64_t> sizes(stridedSliceRank, 1);
+
+    Value extracted = {};
+    Value extractedStrided = {};
+    Value insertedSlice = {};
     Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+    const Value partResult =
+        rewriter.create<ub::PoisonOp>(loc, insertStridedType);
+
+    for (size_t i = 0; i < nAtomicSlices; ++i) {
 
-    for (int64_t i = 0; i < numExtracts; i++) {
-      if (i != 0) {
-        incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
-        incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
+      const size_t extractStridedPhase = i % extractPeriod;
+      const size_t insertStridedPhase = i % insertPeriod;
+
+      // vector.extract
+      if (extractStridedPhase == 0) {
+        extracted =
+            rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
+        inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
+                   extractIndex);
       }
 
-      Value extract =
-          rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
-      result =
-          rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
+      // vector.extract_strided_slice
+      extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
+      extractedStrided = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, extracted, extractOffsets, atomicShape, sizes);
+
+      // vector.insert_strided_slice
+      if (insertStridedPhase == 0) {
+        insertedSlice = partResult;
+      }
+      insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
+      insertedSlice = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, extractedStrided, insertedSlice, insertOffsets, sizes);
+
+      // vector.insert
+      if (insertStridedPhase + 1 == insertPeriod) {
+        result = rewriter.create<vector::InsertOp>(loc, insertedSlice, result,
+                                                   insertIndex);
+        inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
+                   insertIndex);
+      }
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -345,8 +409,8 @@ class ScalableShapeCastOpRewritePattern
 
       // 4. Increment the insert/extract indices, stepping by minExtractionSize
       // for the trailing dimensions.
-      incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
-      incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
+      inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
+      inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
     }
 
     rewriter.replaceOp(op, result);
@@ -363,8 +427,6 @@ class ScalableShapeCastOpRewritePattern
 
 void mlir::vector::populateVectorShapeCastLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ShapeCastOpNDDownCastRewritePattern,
-               ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
-               ScalableShapeCastOpRewritePattern>(patterns.getContext(),
-                                                  benefit);
+  patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
+      patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 2875f159a2df9..7b843750169b8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,197 +1,391 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
 // CHECK-LABEL: func @nop_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK:      return %[[A]] : vector<16xf32>
+//  CHECK-SAME: %[[A:.*]]: vector<16xf32>
+//       CHECK: return %[[A]] : vector<16xf32>
 func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
   %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
   return %0 : vector<16xf32>
 }
 
 // CHECK-LABEL: func @cancel_shape_cast
-// CHECK-SAME: %[[A:.*]]: vector<16xf32>
-// CHECK:      return %[[A]] : vector<16xf32>
-
+//  CHECK-SAME: %[[A:.*]]: vector<16xf32>
+//       CHECK: return %[[A]] : vector<16xf32>
 func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
   %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
   %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
   return %1 : vector<16xf32>
 }
 
-// Shape up and downcasts for 2-D vectors, for supporting conversion to
-// llvm.matrix operations
-// CHECK-LABEL: func @shape_casts
-func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
-  // CHECK-DAG: %[[ub22:.*]] = ub.poison : vector<2x2xf32>
-  // CHECK-DAG: %[[ub:.*]] = ub.poison : vector<4xf32>
-  // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
-  //
-  // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[ub]]
-  // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-  //
-  // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
-  //
-  // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
-  // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
-  //
+// Collapse 2-D to 1-D.
+// CHECK-LABEL: func @shape_cast_2d1d
+//  CHECK-SAME: %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<4xf32>
+//
+//       CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+//       CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UB]]
+//  CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//
+//       CHECK: %[[EX1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
+//       CHECK: %[[IN2:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
+//  CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+//       CHECK: return %[[IN2]] : vector<4xf32>
+func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) {
   %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
-  // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
-  %r0 = arith.addf %0, %0: vector<4xf32>
-  //
-  // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
-  // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
-  // CHECK-SAME: vector<4xf32> to vector<2xf32>
-  //
-  // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[ub22]] [0] :
-  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
-  //
-  // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
-  // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
-  // CHECK-SAME: vector<4xf32> to vector<2xf32>
-  //
-  // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
-  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
-  //
-  %1 = vector.shape_cast %r0  : vector<4xf32> to vector<2x2xf32>
-  // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
-  return %r0, %1 : vector<4xf32>, vector<2x2xf32>
-}
-
-// CHECK-LABEL: func @shape_cast_2d2d
-// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
-// CHECK: return %[[T11]] : vector<2x3xf32>
-
-func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
-  %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
-  return %s : vector<2x3xf32>
+  return %0 : vector<4xf32>
 }
 
+// Collapse 3-D to 1-D.
 // CHECK-LABEL: func @shape_cast_3d1d
-// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
-// CHECK-SAME:           {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
-// CHECK-SAME:           {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
-// CHECK-SAME:           {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
-// CHECK: return %[[T5]] : vector<6xf32>
-
+//  CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
+//       CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
+//
+//       CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
+//       CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
+//  CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+//
+//       CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
+//       CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
+//  CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+//
+//       CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
+//       CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
+//  CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+//       CHECK: return %[[T5]] : vector<6xf32>
 func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
   %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
   return %s : vector<6xf32>
 }
 
-// CHECK-LABEL: func @shape_cast_1d3d
-// CHECK-SAME: %[[A:.*]]: vector<6xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK-SAME:           {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
-// CHECK:                {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
-// CHECK: return %[[T3]] : vector<2x1x3xf32>
+// Expand 1-D to 2-D.
+// CHECK-LABEL: func.func @shape_cast_1d2d(
+//  CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<2x2xf32>
+//
+//       CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[A]]
+//  CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
+//  CHECK-SAME: vector<4xf32> to vector<2xf32>
+//       CHECK: %[[res0:.*]] = vector.insert %[[SS0]], %[[UB]] [0] :
+//  CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+//
+//       CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[A]]
+//  CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
+//  CHECK-SAME: vector<4xf32> to vector<2xf32>
+//       CHECK: %[[res1:.*]] = vector.insert %[[SS2]], %[[res0]] [1] :
+//  CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+//       CHECK: return  %[[res1]] :  vector<2x2xf32>
+func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) {
+  %1 = vector.shape_cast %a: vector<4xf32> to vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}
 
+// Expand 1-D to 3-D.
+// CHECK-LABEL: func @shape_cast_1d3d
+//  CHECK-SAME: %[[A:.*]]: vector<6xf32>
+//       CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+//
+//       CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
+//  CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} :
+//  CHECK-SAME: vector<6xf32> to vector<3xf32>
+//       CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
+//  CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+//
+//       CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
+//  CHECK-SAME: {offsets = [3], sizes = [3], strides = [1]} :
+//  CHECK-SAME: vector<6xf32> to vector<3xf32>
+//       CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] :
+//  CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+//       CHECK: return %[[T3]] : vector<2x1x3xf32>
 func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
   %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
   return %s : vector<2x1x3xf32>
 }
 
-// CHECK-LABEL:   func.func @shape_cast_0d1d(
-// CHECK-SAME:                               %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
-// CHECK:           %[[UB:.*]] = ub.poison : vector<1xf32>
-// CHECK:           %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
-// CHECK:           %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
-// CHECK:           return %[[RES]] : vector<1xf32>
-// CHECK:         }
+// 2-D to 2-D where the inner-most dimensions have no common factors. This
+// case requires scalar element by element extraction and insertion.
+// CHECK-LABEL: func @shape_cast_2d2d
+//  CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
+//       CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+//
+//       CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
+//       CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
+//
+//       CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
+//       CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] :
+//
+//       CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
+//       CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] :
+//
+//       CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
+//       CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] :
+//
+//       CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
+//       CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] :
+//
+//       CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
+//       CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] :
+//
+//       CHECK: return %[[T11]] : vector<2x3xf32>
+func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
 
+// CHECK-LABEL: func.func @shape_cast_0d1d(
+//  CHECK-SAME: %[[A:.*]]: vector<f32>) -> vector<1xf32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
+//
+//       CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][] : f32 from vector<f32>
+//       CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] :
+//       CHECK: return %[[RES]] : vector<1xf32>
 func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
   %s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
   return %s : vector<1xf32>
 }
 
-// CHECK-LABEL:   func.func @shape_cast_1d0d(
-// CHECK-SAME:                               %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
-// CHECK:           %[[UB:.*]] = ub.poison : vector<f32>
-// CHECK:           %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-// CHECK:           %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
-// CHECK:           return %[[RES]] : vector<f32>
-// CHECK:         }
-
+// CHECK-LABEL: func.func @shape_cast_1d0d(
+//  CHECK-SAME: %[[A:.*]]: vector<1xf32>) -> vector<f32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<f32>
+//
+//       CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
+//       CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] :
+//       CHECK: return %[[RES]] : vector<f32>
 func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
   %s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
   return %s : vector<f32>
 }
 
-
 // The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
-// CHECK-LABEL:  func.func @squeeze_out_prefix_unit_dim(
-// CHECK-SAME:               %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
-// CHECK:          %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
-// CHECK-SAME:               vector<2x3xf32> from vector<1x2x3xf32>
-// CHECK:          return %[[EXTRACTED]] : vector<2x3xf32>
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
+//  CHECK-SAME: %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+//
+//       CHECK: %[[EXTRACTED:.*]] = vector.extract %[[A]][0] :
+//  CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+//       CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
 func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
   %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
   return %s : vector<2x3xf32>
 }
 
 // The shapes have 1 inner dimension size in common, so the extract results are rank-1.
-// CHECK-LABEL:  func.func @squeeze_out_middle_unit_dim(
-// CHECK-SAME:               %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
-// CHECK:         %[[UB:.*]] = ub.poison : vector<2x3xf32>
-// CHECK:         %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
-// CHECK:         %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
-// CHECK-SAME:                 into vector<2x3xf32>
-// CHECK:         %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
-// CHECK:         %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
-// CHECK-SAME:                 into vector<2x3xf32>
-// CHECK:         return %[[I1]] : vector<2x3xf32>
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
+//  CHECK-SAME: %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+//
+//       CHECK: %[[E0:.*]] = vector.extract %[[A]][0, 0] : vector<3xf32>
+//       CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] :
+//
+//       CHECK: %[[E1:.*]] = vector.extract %[[A]][1, 0] : vector<3xf32>
+//       CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] :
+//       CHECK: return %[[I1]] : vector<2x3xf32>
 func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
   %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
   return %s : vector<2x3xf32>
 }
 
-// CHECK-LABEL:  func.func @prepend_unit_dim(
-// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
-// CHECK:          %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
-// CHECK:          %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
-// CHECK-SAME:               : vector<2x3xf32> into vector<1x2x3xf32>
-// CHECK:        return %[[INSERTED]] : vector<1x2x3xf32>
+// CHECK-LABEL: func.func @prepend_unit_dim(
+//  CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+//
+//       CHECK: %[[I0:.*]] = vector.insert %[[A]], %[[UB]] [0]
+//       CHECK: return %[[I0]] : vector<1x2x3xf32>
 func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
   %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
   return %s : vector<1x2x3xf32>
 }
 
-// CHECK-LABEL:  func.func @insert_middle_unit_dim(
-// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
-// CHECK:           %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
-// CHECK:           %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
-// CHECK:           %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
-// CHECK:           %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
-// CHECK:           %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
-// CHECK:           return %[[I1]] : vector<2x1x3xf32>
+// CHECK-LABEL: func.func @insert_middle_unit_dim(
+//  CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+//
+//       CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<3xf32>
+//       CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+//
+//       CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<3xf32>
+//       CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+//       CHECK: return %[[I1]] : vector<2x1x3xf32>
 func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
   %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
   return %s : vector<2x1x3xf32>
 }
 
+//   CHECK-LABEL: func.func @postpend_unit_dims(
+//    CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> {
+//         CHECK: %[[UB:.*]] = ub.poison : vector<4x1x1xf32>
+//         CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : f32 from vector<4xf32>
+//         CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0, 0] : f32 into vector<4x1x1xf32>
+//         CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+//         CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0, 0] : f32 into vector<4x1x1xf32>
+//         CHECK: vector.extract %[[A]][2]
+//         CHECK: vector.insert {{.*}} [2, 0, 0]
+//         CHECK: vector.extract %[[A]][3]
+//         CHECK: vector.insert {{.*}} [3, 0, 0]
+//         CHECK: return
+func.func @postpend_unit_dims(%arg0 : vector<4xf32>) -> vector<4x1x1xf32> {
+  %s = vector.shape_cast %arg0 : vector<4xf32> to vector<4x1x1xf32>
+  return %s : vector<4x1x1xf32>
+}
+
+// CHECK-LABEL: func.func @expand_inner_dims(
+//  CHECK-SAME: %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> {
+//       CHECK: %[[UB:.*]] = ub.poison : vector<2x2x5xf32>
+//
+//       CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<10xf32>
+//       CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[E0]]
+//  CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+//       CHECK: %[[I0:.*]] = vector.insert %[[S0]], %[[UB]] [0, 0]
+//
+//       CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[E0]]
+//  CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+//       CHECK: %[[I1:.*]] = vector.insert %[[S1]], %[[I0]] [0, 1]
+//
+//       CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<10xf32>
+//       CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[E1]]
+//  CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+//       CHECK: %[[I2:.*]] = vector.insert %[[S2]], %[[I1]] [1, 0]
+//
+//       CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[E1]]
+//  CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+//       CHECK: %[[I3:.*]] = vector.insert %[[S3]], %[[I2]] [1, 1]
+//       CHECK: return %[[I3]] : vector<2x2x5xf32>
+func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x10xf32> to vector<2x2x5xf32>
+  return %s : vector<2x2x5xf32>
+}
+
+
+// Some pseudocode describing how this function is lowered:
+//
+// func collapse_inner_dims(A : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+//   v0          = empty of shape (10)
+//   v1          = empty of shape (1,2,1,10)
+//   v0[0:5]     = A[0,0,:]
+//   v0[5:10]    = A[0,1,:]
+//   v1[0,0,0,:] = v0
+//   v0[0:5]     = A[1,0,:]
+//   v0[5:10]    = A[1,1,:]
+//   v1[0,1,0,:] = v0
+//   return v1;
+// }
+// CHECK-LABEL: func.func @collapse_inner_dims(
+//  CHECK-SAME: %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+//   CHECK-DAG: %[[UBSMALL:.*]] = ub.poison : vector<10xi8>
+//   CHECK-DAG: %[[UBLARGE:.*]] = ub.poison : vector<1x2x1x10xi8>
+//
+//       CHECK: %[[EX0:.*]] = vector.extract %[[A]][0, 0]
+//       CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UBSMALL]]
+//  CHECK-SAME: {offsets = [0], {{.*}}
+//       CHECK: %[[EX1:.*]] = vector.extract %[[A]][0, 1]
+//       CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
+//  CHECK-SAME: {offsets = [5], {{.*}}
+//       CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UBLARGE]] [0, 0, 0]
+//
+//       CHECK: %[[EX2:.*]] = vector.extract %[[A]][1, 0]
+//       CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[EX2]], %[[UBSMALL]]
+//  CHECK-SAME: {offsets = [0], {{.*}}
+//       CHECK: %[[EX3:.*]] = vector.extract %[[A]][1, 1]
+//       CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[EX3]], %[[IN3]]
+//  CHECK-SAME: {offsets = [5], {{.*}}
+//       CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [0, 1, 0]
+//       CHECK: return %[[IN5]] : vector<1x2x1x10xi8>
+func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+  %s = vector.shape_cast %arg0 : vector<2x2x5xi8> to vector<1x2x1x10xi8>
+  return %s : vector<1x2x1x10xi8>
+}
+
+// Some alternative pseudocode describing how this function is lowered:
+//
+// func non_dividing_gcd_decreasing(A : vector<2x15xi8>) -> vector<3x10xi8> {
+//   v0          = empty of shape (10)
+//   v1          = empty of shape (3,10)
+//   e0          = A[0,:]
+//   v0[0:5]     = e0[0:5]
+//   v0[5:10]    = e0[5:10]
+//   v1[0,:]     = v0
+//   v0[0,0:5]   = e0[10:15]
+//   e1          = A[1,:]
+//   v0[0,5:10]  = e1[0:5]
+//   v1[1,:]     = v0
+//   v0[0,0:5]   = e1[5:10]
+//   v0[0,5:10]  = e1[10:15]
+//   v1[2,:]     = v0
+//   return v1;
+// }
+// CHECK-LABEL: func.func @non_dividing_gcd_decreasing(
+//  CHECK-SAME: %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> {
+//   CHECK-DAG: %[[UB0:.*]] = ub.poison : vector<10xi8>
+//   CHECK-DAG: %[[UB1:.*]] = ub.poison : vector<3x10xi8>
+//
+// First 10 elements:
+//       CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<15xi8> from vector<2x15xi8>
+//       CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[EX0]]
+//  CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+//       CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[SS0]], %[[UB0]]
+//  CHECK-SAME: {offsets = [0], {{.*}}
+//       CHECK: %[[SS1:.*]] = vector.extract_strided_slice %[[EX0]]
+//  CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+//       CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[SS1]], %[[IN0]]
+//  CHECK-SAME: {offsets = [5], {{.*}}
+//       CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UB1]] [0] : vector<10xi8> into vector<3x10xi8>
+//
+// Next 10 elements:
+//       CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[EX0]]
+//  CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+//       CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[SS2]], %[[UB0]]
+//  CHECK-SAME: {offsets = [0], {{.*}}
+//       CHECK: %[[EX1:.*]] = vector.extract %[[A]][1] : vector<15xi8> from vector<2x15xi8>
+//       CHECK: %[[SS3:.*]] = vector.extract_strided_slice %[[EX1]]
+//  CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+//       CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[SS3]], %[[IN3]]
+//  CHECK-SAME: {offsets = [5], {{.*}}
+//       CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [1] : vector<10xi8> into vector<3x10xi8>
+//
+// Final 10 elements:
+//       CHECK: %[[SS4:.*]] = vector.extract_strided_slice %[[EX1]]
+//  CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+//       CHECK: %[[IN6:.*]] = vector.insert_strided_slice %[[SS4]], %[[UB0]]
+//  CHECK-SAME: {offsets = [0], {{.*}}
+//       CHECK: %[[SS5:.*]] = vector.extract_strided_slice %[[EX1]]
+//  CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+//       CHECK: %[[IN7:.*]] = vector.insert_strided_slice %[[SS5]], %[[IN6]]
+//  CHECK-SAME: {offsets = [5], {{.*}}
+//       CHECK: %[[IN8:.*]] = vector.insert %[[IN7]], %[[IN5]] [2] : vector<10xi8> into vector<3x10xi8>
+//       CHECK: return %[[IN8]] : vector<3x10xi8>
+func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi8> {
+  %0 = vector.shape_cast %arg0 : vector<2x15xi8> to vector<3x10xi8>
+  return %0 : vector<3x10xi8>
+}
+
+// CHECK-LABEL: func.func @non_dividing_gcd_increasing(
+//  CHECK-SAME: %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> {
+//
+//   CHECK-DAG: ub.poison : vector<15xi8>
+//   CHECK-DAG: ub.poison : vector<2x15xi8>
+//
+// Collect the first 15 elements, and insert into the first row of the result.
+//       CHECK: vector.extract %[[A]][0]
+//       CHECK: extract_strided_slice
+//       CHECK: insert_strided_slice
+//       CHECK: extract_strided_slice
+//       CHECK: insert_strided_slice
+//       CHECK: vector.extract %[[A]][1]
+//       CHECK: extract_strided_slice
+//       CHECK: insert_strided_slice
+//       CHECK: vector.insert {{.*}} [0] : vector<15xi8> into vector<2x15xi8>
+//
+// Collect the next 15 elements, and insert into the second row of the result.
+//       CHECK: extract_strided_slice
+//       CHECK: insert_strided_slice
+//       CHECK: vector.extract %[[A]][2]
+//       CHECK: extract_strided_slice
+//       CHECK: insert_strided_slice
+//       CHECK: extract_strided_slice
+//       CHECK: insert_strided_slice
+//       CHECK: vector.insert {{.*}} [1] : vector<15xi8> into vector<2x15xi8>
+func.func @non_dividing_gcd_increasing(%arg0 : vector<3x10xi8>) -> vector<2x15xi8> {
+  %0 = vector.shape_cast %arg0 : vector<3x10xi8> to vector<2x15xi8>
+  return %0 : vector<2x15xi8>
+}
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {

>From 5ad484e1a37ea21dd0c16e16e33686693c7f2991 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 29 May 2025 09:42:47 -0700
Subject: [PATCH 4/5] cosmetic fixes and better failure notifications

---
 .../Vector/Transforms/LowerVectorShapeCast.cpp     | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index e84886e285ba9..79048c5aab67a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -79,7 +79,7 @@ namespace {
 /// The greatest common divisor (gcd) of the first dimension preceding the
 /// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
 /// on vectors with shapes that are `multiples` of (what we define as) the
-/// 'atomic size', 2x7x11. The atomic size is `gcd` x `common-suffix`.
+/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
 ///
 ///         vector<2x2x3x4x7x11xi8> to
 ///             vector<8x6x7x11xi8>
@@ -87,17 +87,17 @@ namespace {
 ///                      ^              --->    gcd(4,6) is 2 | |
 ///                                                         | | |
 ///                                                         v v v
-///                                  atomic size   <-----   2x7x11
+///                                 atomic shape   <-----   2x7x11
 ///
 ///
 ///
-/// The decomposition implemented in this patterns consists of a sequence of
+/// The decomposition implemented in this pattern consists of a sequence of
 /// repeated steps:
 ///
 ///  (1) Extract vectors from the suffix of the source.
 ///      In our example this is 2x2x3x4x7x11 -> 4x7x11.
 ///
-///  (2) Do extract_strided_slice down to the atomic size.
+///  (2) Do extract_strided_slice down to the atomic shape.
 ///      In our example this is 4x7x11 -> 2x7x11.
 ///
 ///  (3) Do insert_strided_slice to the suffix of the result.
@@ -130,7 +130,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 
     if (sourceType.isScalable() || resultType.isScalable())
       return rewriter.notifyMatchFailure(
-          op, "shape_cast lowering not handled by this pattern");
+          op,
+          "shape_cast where vectors are scalable not handled by this pattern");
 
     const ArrayRef<int64_t> sourceShape = sourceType.getShape();
     const ArrayRef<int64_t> resultShape = resultType.getShape();
@@ -332,7 +333,8 @@ class ScalableShapeCastOpRewritePattern
     // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
     if (!isTrailingDimScalable(sourceVectorType) ||
         !isTrailingDimScalable(resultVectorType)) {
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "trailing dims are not scalable, not handled by this pattern");
     }
 
     // The sizes of the trailing dimension of the source and result vectors, the

>From 125573b1f889aa791e77935ca1ed0ffa634307be Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 14:32:24 -0700
Subject: [PATCH 5/5] cosmetic improvements and comment clarification

---
 .../Transforms/LowerVectorShapeCast.cpp       |  21 ++--
 ...vector-shape-cast-lowering-transforms.mlir | 112 +++++++++---------
 2 files changed, 67 insertions(+), 66 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 79048c5aab67a..b10cfa9932464 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -83,8 +83,9 @@ namespace {
 ///
 ///         vector<2x2x3x4x7x11xi8> to
 ///             vector<8x6x7x11xi8>
-///                        ^^^^         ---> common suffix of 7x11
-///                      ^              --->    gcd(4,6) is 2 | |
+///                      | ||||
+///                      | ++++------------> common suffix of 7x11
+///                      +----------------->    gcd(4,6) is 2 | |
 ///                                                         | | |
 ///                                                         v v v
 ///                                 atomic shape   <-----   2x7x11
@@ -111,9 +112,9 @@ namespace {
 ///  (2) and (3) occur 24 times, and
 ///  (4) occurs 8 times.
 ///
-/// Two special cases are handled seperately:
-///  (1) A shape_cast that just does leading 1 insertion/removal
-///  (2) A shape_cast where the gcd is 1.
+/// Two special cases are handled independently in this pattern
+///  (i) A shape_cast that just does leading 1 insertion/removal
+/// (ii) A shape_cast where the gcd is 1.
 ///
 /// These 2 cases can have more compact IR generated by not using the generic
 /// algorithm described above.
@@ -159,9 +160,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
       --resultSuffixStartDim;
     }
 
-    // This is the case where there are just some leading ones to contend with
-    // in the source or result. It can be handled with a single extract/insert
-    // pair.
+    // This is the case (i) where there are just some leading ones to contend
+    // with in the source or result. It can be handled with a single
+    // extract/insert pair.
     if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) {
       const int64_t delta = sourceRank - resultRank;
       const int64_t sourceLeading = delta > 0 ? delta : 0;
@@ -195,8 +196,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
         atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
     const size_t nAtomicSlices = numElms / numAtomicElms;
 
-    // This is the case where the strided dimension size is 1. More compact IR
-    // is generated in this case if we just extract and insert the elements
+    // This is the case (ii) where the strided dimension size is 1. More compact
+    // IR is generated in this case if we just extract and insert the elements
     // directly. In other words, we don't use extract_strided_slice and
     // insert_strided_slice.
     if (greatestCommonDivisor == 1) {
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 7b843750169b8..5011d8b2b2ef6 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
 // CHECK-LABEL: func @nop_shape_cast
-//  CHECK-SAME: %[[A:.*]]: vector<16xf32>
+//  CHECK-SAME:    %[[A:.*]]: vector<16xf32>
 //       CHECK: return %[[A]] : vector<16xf32>
 func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
   %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
@@ -9,7 +9,7 @@ func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
 }
 
 // CHECK-LABEL: func @cancel_shape_cast
-//  CHECK-SAME: %[[A:.*]]: vector<16xf32>
+//  CHECK-SAME:    %[[A:.*]]: vector<16xf32>
 //       CHECK: return %[[A]] : vector<16xf32>
 func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
   %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
@@ -19,16 +19,16 @@ func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 // Collapse 2-D to 1-D.
 // CHECK-LABEL: func @shape_cast_2d1d
-//  CHECK-SAME: %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<4xf32>
 //
 //       CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
 //       CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UB]]
-//  CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//  CHECK-SAME:    {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
 //
 //       CHECK: %[[EX1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
 //       CHECK: %[[IN2:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
-//  CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+//  CHECK-SAME:    {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 //       CHECK: return %[[IN2]] : vector<4xf32>
 func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) {
   %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
@@ -37,20 +37,20 @@ func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) {
 
 // Collapse 3-D to 1-D.
 // CHECK-LABEL: func @shape_cast_3d1d
-//  CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
+//  CHECK-SAME:    %[[A:.*]]: vector<1x3x2xf32>
 //       CHECK: %[[UB:.*]] = ub.poison : vector<6xf32>
 //
 //       CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
 //       CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]]
-//  CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+//  CHECK-SAME:    {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
 //
 //       CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
 //       CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
-//  CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+//  CHECK-SAME:    {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
 //
 //       CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
 //       CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
-//  CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+//  CHECK-SAME:    {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
 //       CHECK: return %[[T5]] : vector<6xf32>
 func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
   %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
@@ -59,20 +59,20 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
 
 // Expand 1-D to 2-D.
 // CHECK-LABEL: func.func @shape_cast_1d2d(
-//  CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<2x2xf32>
 //
 //       CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[A]]
-//  CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
-//  CHECK-SAME: vector<4xf32> to vector<2xf32>
+//  CHECK-SAME:    {offsets = [0], sizes = [2], strides = [1]} :
+//  CHECK-SAME:    vector<4xf32> to vector<2xf32>
 //       CHECK: %[[res0:.*]] = vector.insert %[[SS0]], %[[UB]] [0] :
-//  CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+//  CHECK-SAME:    vector<2xf32> into vector<2x2xf32>
 //
 //       CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[A]]
-//  CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
-//  CHECK-SAME: vector<4xf32> to vector<2xf32>
+//  CHECK-SAME:    {offsets = [2], sizes = [2], strides = [1]} :
+//  CHECK-SAME:    vector<4xf32> to vector<2xf32>
 //       CHECK: %[[res1:.*]] = vector.insert %[[SS2]], %[[res0]] [1] :
-//  CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+//  CHECK-SAME:    vector<2xf32> into vector<2x2xf32>
 //       CHECK: return  %[[res1]] :  vector<2x2xf32>
 func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) {
   %1 = vector.shape_cast %a: vector<4xf32> to vector<2x2xf32>
@@ -81,20 +81,20 @@ func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) {
 
 // Expand 1-D to 3-D.
 // CHECK-LABEL: func @shape_cast_1d3d
-//  CHECK-SAME: %[[A:.*]]: vector<6xf32>
+//  CHECK-SAME:    %[[A:.*]]: vector<6xf32>
 //       CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
 //
 //       CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
-//  CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} :
-//  CHECK-SAME: vector<6xf32> to vector<3xf32>
+//  CHECK-SAME:    {offsets = [0], sizes = [3], strides = [1]} :
+//  CHECK-SAME:    vector<6xf32> to vector<3xf32>
 //       CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] :
-//  CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+//  CHECK-SAME:    vector<3xf32> into vector<2x1x3xf32>
 //
 //       CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
-//  CHECK-SAME: {offsets = [3], sizes = [3], strides = [1]} :
-//  CHECK-SAME: vector<6xf32> to vector<3xf32>
+//  CHECK-SAME:    {offsets = [3], sizes = [3], strides = [1]} :
+//  CHECK-SAME:    vector<6xf32> to vector<3xf32>
 //       CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] :
-//  CHECK-SAME: vector<3xf32> into vector<2x1x3xf32>
+//  CHECK-SAME:    vector<3xf32> into vector<2x1x3xf32>
 //       CHECK: return %[[T3]] : vector<2x1x3xf32>
 func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
   %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
@@ -104,7 +104,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
 // 2-D to 2-D where the inner-most dimensions have no common factors. This
 // case requires scalar element by element extraction and insertion.
 // CHECK-LABEL: func @shape_cast_2d2d
-//  CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
+//  CHECK-SAME:    %[[A:.*]]: vector<3x2xf32>
 //       CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
 //
 //       CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
@@ -132,7 +132,7 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
 }
 
 // CHECK-LABEL: func.func @shape_cast_0d1d(
-//  CHECK-SAME: %[[A:.*]]: vector<f32>) -> vector<1xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<f32>) -> vector<1xf32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
 //
 //       CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][] : f32 from vector<f32>
@@ -144,7 +144,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
 }
 
 // CHECK-LABEL: func.func @shape_cast_1d0d(
-//  CHECK-SAME: %[[A:.*]]: vector<1xf32>) -> vector<f32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<1xf32>) -> vector<f32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<f32>
 //
 //       CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
@@ -157,10 +157,10 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
 
 // The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
 // CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
-//  CHECK-SAME: %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
 //
 //       CHECK: %[[EXTRACTED:.*]] = vector.extract %[[A]][0] :
-//  CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+//  CHECK-SAME:    vector<2x3xf32> from vector<1x2x3xf32>
 //       CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
 func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
   %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
@@ -169,7 +169,7 @@ func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3x
 
 // The shapes have 1 inner dimension size in common, so the extract results are rank-1.
 // CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
-//  CHECK-SAME: %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
 //
 //       CHECK: %[[E0:.*]] = vector.extract %[[A]][0, 0] : vector<3xf32>
@@ -184,7 +184,7 @@ func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3x
 }
 
 // CHECK-LABEL: func.func @prepend_unit_dim(
-//  CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
 //
 //       CHECK: %[[I0:.*]] = vector.insert %[[A]], %[[UB]] [0]
@@ -195,7 +195,7 @@ func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
 }
 
 // CHECK-LABEL: func.func @insert_middle_unit_dim(
-//  CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
 //
 //       CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<3xf32>
@@ -210,7 +210,7 @@ func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32>
 }
 
 //   CHECK-LABEL: func.func @postpend_unit_dims(
-//    CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> {
+//    CHECK-SAME:    %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> {
 //         CHECK: %[[UB:.*]] = ub.poison : vector<4x1x1xf32>
 //         CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : f32 from vector<4xf32>
 //         CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0, 0] : f32 into vector<4x1x1xf32>
@@ -227,25 +227,25 @@ func.func @postpend_unit_dims(%arg0 : vector<4xf32>) -> vector<4x1x1xf32> {
 }
 
 // CHECK-LABEL: func.func @expand_inner_dims(
-//  CHECK-SAME: %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> {
+//  CHECK-SAME:    %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> {
 //       CHECK: %[[UB:.*]] = ub.poison : vector<2x2x5xf32>
 //
 //       CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<10xf32>
 //       CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[E0]]
-//  CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+//  CHECK-SAME:    {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
 //       CHECK: %[[I0:.*]] = vector.insert %[[S0]], %[[UB]] [0, 0]
 //
 //       CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[E0]]
-//  CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+//  CHECK-SAME:    {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
 //       CHECK: %[[I1:.*]] = vector.insert %[[S1]], %[[I0]] [0, 1]
 //
 //       CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<10xf32>
 //       CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[E1]]
-//  CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
+//  CHECK-SAME:    {offsets = [0], sizes = [5], {{.*}} to vector<5xf32>
 //       CHECK: %[[I2:.*]] = vector.insert %[[S2]], %[[I1]] [1, 0]
 //
 //       CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[E1]]
-//  CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
+//  CHECK-SAME:    {offsets = [5], sizes = [5], {{.*}} to vector<5xf32>
 //       CHECK: %[[I3:.*]] = vector.insert %[[S3]], %[[I2]] [1, 1]
 //       CHECK: return %[[I3]] : vector<2x2x5xf32>
 func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> {
@@ -268,24 +268,24 @@ func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> {
 //   return v1;
 // }
 // CHECK-LABEL: func.func @collapse_inner_dims(
-//  CHECK-SAME: %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
+//  CHECK-SAME:    %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
 //   CHECK-DAG: %[[UBSMALL:.*]] = ub.poison : vector<10xi8>
 //   CHECK-DAG: %[[UBLARGE:.*]] = ub.poison : vector<1x2x1x10xi8>
 //
 //       CHECK: %[[EX0:.*]] = vector.extract %[[A]][0, 0]
 //       CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UBSMALL]]
-//  CHECK-SAME: {offsets = [0], {{.*}}
+//  CHECK-SAME:    {offsets = [0], {{.*}}
 //       CHECK: %[[EX1:.*]] = vector.extract %[[A]][0, 1]
 //       CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]]
-//  CHECK-SAME: {offsets = [5], {{.*}}
+//  CHECK-SAME:    {offsets = [5], {{.*}}
 //       CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UBLARGE]] [0, 0, 0]
 //
 //       CHECK: %[[EX2:.*]] = vector.extract %[[A]][1, 0]
 //       CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[EX2]], %[[UBSMALL]]
-//  CHECK-SAME: {offsets = [0], {{.*}}
+//  CHECK-SAME:    {offsets = [0], {{.*}}
 //       CHECK: %[[EX3:.*]] = vector.extract %[[A]][1, 1]
 //       CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[EX3]], %[[IN3]]
-//  CHECK-SAME: {offsets = [5], {{.*}}
+//  CHECK-SAME:    {offsets = [5], {{.*}}
 //       CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [0, 1, 0]
 //       CHECK: return %[[IN5]] : vector<1x2x1x10xi8>
 func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> {
@@ -312,43 +312,43 @@ func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8>
 //   return v1;
 // }
 // CHECK-LABEL: func.func @non_dividing_gcd_decreasing(
-//  CHECK-SAME: %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> {
+//  CHECK-SAME:    %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> {
 //   CHECK-DAG: %[[UB0:.*]] = ub.poison : vector<10xi8>
 //   CHECK-DAG: %[[UB1:.*]] = ub.poison : vector<3x10xi8>
 //
 // First 10 elements:
 //       CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<15xi8> from vector<2x15xi8>
 //       CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[EX0]]
-//  CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+//  CHECK-SAME:    {offsets = [0], {{.*}} to vector<5xi8>
 //       CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[SS0]], %[[UB0]]
-//  CHECK-SAME: {offsets = [0], {{.*}}
+//  CHECK-SAME:    {offsets = [0], {{.*}}
 //       CHECK: %[[SS1:.*]] = vector.extract_strided_slice %[[EX0]]
-//  CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+//  CHECK-SAME:    {offsets = [5], {{.*}} to vector<5xi8>
 //       CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[SS1]], %[[IN0]]
-//  CHECK-SAME: {offsets = [5], {{.*}}
+//  CHECK-SAME:    {offsets = [5], {{.*}}
 //       CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UB1]] [0] : vector<10xi8> into vector<3x10xi8>
 //
 // Next 10 elements:
 //       CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[EX0]]
-//  CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+//  CHECK-SAME:    {offsets = [10], {{.*}} to vector<5xi8>
 //       CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[SS2]], %[[UB0]]
-//  CHECK-SAME: {offsets = [0], {{.*}}
+//  CHECK-SAME:    {offsets = [0], {{.*}}
 //       CHECK: %[[EX1:.*]] = vector.extract %[[A]][1] : vector<15xi8> from vector<2x15xi8>
 //       CHECK: %[[SS3:.*]] = vector.extract_strided_slice %[[EX1]]
-//  CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8>
+//  CHECK-SAME:    {offsets = [0], {{.*}} to vector<5xi8>
 //       CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[SS3]], %[[IN3]]
-//  CHECK-SAME: {offsets = [5], {{.*}}
+//  CHECK-SAME:    {offsets = [5], {{.*}}
 //       CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [1] : vector<10xi8> into vector<3x10xi8>
 //
 // Final 10 elements:
 //       CHECK: %[[SS4:.*]] = vector.extract_strided_slice %[[EX1]]
-//  CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8>
+//  CHECK-SAME:    {offsets = [5], {{.*}} to vector<5xi8>
 //       CHECK: %[[IN6:.*]] = vector.insert_strided_slice %[[SS4]], %[[UB0]]
-//  CHECK-SAME: {offsets = [0], {{.*}}
+//  CHECK-SAME:    {offsets = [0], {{.*}}
 //       CHECK: %[[SS5:.*]] = vector.extract_strided_slice %[[EX1]]
-//  CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8>
+//  CHECK-SAME:    {offsets = [10], {{.*}} to vector<5xi8>
 //       CHECK: %[[IN7:.*]] = vector.insert_strided_slice %[[SS5]], %[[IN6]]
-//  CHECK-SAME: {offsets = [5], {{.*}}
+//  CHECK-SAME:    {offsets = [5], {{.*}}
 //       CHECK: %[[IN8:.*]] = vector.insert %[[IN7]], %[[IN5]] [2] : vector<10xi8> into vector<3x10xi8>
 //       CHECK: return %[[IN8]] : vector<3x10xi8>
 func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi8> {
@@ -357,7 +357,7 @@ func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi
 }
 
 // CHECK-LABEL: func.func @non_dividing_gcd_increasing(
-//  CHECK-SAME: %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> {
+//  CHECK-SAME:    %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> {
 //
 //   CHECK-DAG: ub.poison : vector<15xi8>
 //   CHECK-DAG: ub.poison : vector<2x15xi8>



More information about the Mlir-commits mailing list