[Mlir-commits] [mlir] [mlir][vector] Better handle rank-preserving shape_cast (PR #135855)

James Newling llvmlistbot at llvm.org
Tue Apr 15 18:56:45 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135855

>From c1f4264a71d6d80350056d0d9ca86a0ac2c1e04f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 13:09:43 -0700
Subject: [PATCH 1/4] fix edge case where n=k (rank-preserving shape_cast)

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 17 ++---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 63 ++++++++++---------
 mlir/test/Dialect/Vector/invalid.mlir         | 15 +++--
 mlir/test/Dialect/Vector/ops.mlir             |  8 +++
 4 files changed, 61 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..a9e25f23ef90f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,19 @@ def Vector_ShapeCastOp :
     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
-    The shape_cast operation casts between an n-D source vector shape and
-    a k-D result vector shape (the element type remains the same).
+    The shape_cast operation casts from an n-D source vector to a k-D result
+    vector. The element type remains the same, as does the number of elements
+    (product of dimensions).
+
+    If reducing or preserving rank (n >= k), all result dimension sizes must be
+    products of contiguous source dimension sizes. If expanding rank (n < k),
+    source dimensions must all factor into contiguous sequences of destination
+    dimension sizes.
 
-    If reducing rank (n > k), result dimension sizes must be a product
-    of contiguous source dimension sizes.
-    If expanding rank (n < k), source dimensions must factor into a
-    contiguous sequence of destination dimension sizes.
     Each source dim is expanded (or contiguous sequence of source dims combined)
     in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
     sequence of result dims (or a single result dim), in result dimension list
-    order (i.e. 0 <= j < k). The product of all source dimension sizes and all
-    result dimension sizes must match.
+    order (i.e. 0 <= j < k).
 
     It is currently assumed that this operation does not require moving data,
     and that it will be folded away before lowering vector operations.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..554dbba081898 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5534,10 +5534,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 /// Returns true if each element of 'a' is equal to the product of a contiguous
 /// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+static bool isValidExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
   unsigned rankA = a.size();
   unsigned rankB = b.size();
-  assert(rankA < rankB);
+  assert(rankA <= rankB);
 
   auto isOne = [](int64_t v) { return v == 1; };
 
@@ -5573,34 +5573,36 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
                                            VectorType resultVectorType) {
   // Check that element type is the same.
   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
-    return op->emitOpError("source/result vectors must have same element type");
-  auto sourceShape = sourceVectorType.getShape();
-  auto resultShape = resultVectorType.getShape();
+    return op->emitOpError("has different source and result element types");
+  ArrayRef<int64_t> lowRankShape = sourceVectorType.getShape();
+  ArrayRef<int64_t> highRankShape = resultVectorType.getShape();
+  if (lowRankShape.size() > highRankShape.size())
+    std::swap(lowRankShape, highRankShape);
 
   // Check that product of source dim sizes matches product of result dim sizes.
-  int64_t sourceDimProduct = std::accumulate(
-      sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
-  int64_t resultDimProduct = std::accumulate(
-      resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
-  if (sourceDimProduct != resultDimProduct)
-    return op->emitOpError("source/result number of elements must match");
-
-  // Check that expanding/contracting rank cases.
-  unsigned sourceRank = sourceVectorType.getRank();
-  unsigned resultRank = resultVectorType.getRank();
-  if (sourceRank < resultRank) {
-    if (!isValidShapeCast(sourceShape, resultShape))
-      return op->emitOpError("invalid shape cast");
-  } else if (sourceRank > resultRank) {
-    if (!isValidShapeCast(resultShape, sourceShape))
-      return op->emitOpError("invalid shape cast");
+  int64_t nLowRankElms =
+      std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL,
+                      std::multiplies<int64_t>{});
+  int64_t nHighRankElms =
+      std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL,
+                      std::multiplies<int64_t>{});
+
+  if (nLowRankElms != nHighRankElms) {
+    return op->emitOpError(
+        "has a different number of source and result elements");
+  }
+
+  if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) {
+    return op->emitOpError(
+        "is invalid (does not uniformly collapse or expand)");
   }
 
   // Check that (non-)scalability is preserved
   int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
   int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
   if (sourceNScalableDims != resultNScalableDims)
-    return op->emitOpError("different number of scalable dims at source (")
+    return op->emitOpError(
+               "has a different number of scalable dims at source (")
            << sourceNScalableDims << ") and result (" << resultNScalableDims
            << ")";
   sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5636,18 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
     // Only allows valid transitive folding (expand/collapse dimensions).
     VectorType srcType = otherOp.getSource().getType();
+
     if (resultType == srcType)
       return otherOp.getSource();
-    if (srcType.getRank() < resultType.getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
-        return {};
-    } else if (srcType.getRank() > resultType.getRank()) {
-      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
-        return {};
-    } else {
+
+    ArrayRef<int64_t> lowRankShape = srcType.getShape();
+    ArrayRef<int64_t> highRankShape = resultType.getShape();
+    if (lowRankShape.size() > highRankShape.size())
+      std::swap(lowRankShape, highRankShape);
+
+    if (!isValidExpandingShapeCast(lowRankShape, highRankShape))
       return {};
-    }
+
     setOperand(otherOp.getSource());
     return getResult();
   }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..9f94fb0574504 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 // -----
 
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{op source/result vectors must have same element type}}
+  // expected-error at +1 {{op has different source and result element types}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
 // -----
 
 func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{op source/result number of elements must match}}
+  // expected-error at +1 {{op has a different number of source and result elements}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
 }
 
 // -----
 
+func.func @shape_cast_invalid_rank_preservating(%arg0 : vector<3x2xf32>) {
+  // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
+  %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
+}
+
+// -----
+
 func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
+  // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
 }
 
 // -----
 
 func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
+  // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..527bccf8383ca 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -576,6 +576,14 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
   return %1 : vector<1x1x1x1xf32>
 }
 
+// CHECK-LABEL: @shape_cast_rank_preserving
+func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+  return %0 : vector<4x1xf32>
+}
+
 // CHECK-LABEL: @bitcast
 func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xi32>,

>From b4914638b9426472c780cc94ac29224353d1b9d4 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 15:10:54 -0700
Subject: [PATCH 2/4] clang-format

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 554dbba081898..120dd57659e6b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5534,7 +5534,8 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 /// Returns true if each element of 'a' is equal to the product of a contiguous
 /// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
+                                      ArrayRef<int64_t> b) {
   unsigned rankA = a.size();
   unsigned rankB = b.size();
   assert(rankA <= rankB);

>From 7bfc219d9d3e4881a70d6866ec59b077da22e71b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 16:14:40 -0700
Subject: [PATCH 3/4] update tests

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 41 ++++++-------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 10 ++--
 ...-shape-cast-lowering-scalable-vectors.mlir | 58 +++++++++----------
 ...vector-shape-cast-lowering-transforms.mlir | 21 -------
 4 files changed, 52 insertions(+), 78 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 120dd57659e6b..07b6baf961a3e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5534,11 +5534,12 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 /// Returns true if each element of 'a' is equal to the product of a contiguous
 /// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
-                                      ArrayRef<int64_t> b) {
+static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
   unsigned rankA = a.size();
   unsigned rankB = b.size();
-  assert(rankA <= rankB);
+  if (rankA > rankB) {
+    return false;
+  }
 
   auto isOne = [](int64_t v) { return v == 1; };
 
@@ -5565,35 +5566,34 @@ static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
     if (j < rankB && llvm::all_of(b.slice(j), isOne))
       j = rankB;
   }
-
   return i == rankA && j == rankB;
 }
 
+static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+  return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
+}
+
 static LogicalResult verifyVectorShapeCast(Operation *op,
                                            VectorType sourceVectorType,
                                            VectorType resultVectorType) {
   // Check that element type is the same.
   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
     return op->emitOpError("has different source and result element types");
-  ArrayRef<int64_t> lowRankShape = sourceVectorType.getShape();
-  ArrayRef<int64_t> highRankShape = resultVectorType.getShape();
-  if (lowRankShape.size() > highRankShape.size())
-    std::swap(lowRankShape, highRankShape);
+  ArrayRef<int64_t> inShape = sourceVectorType.getShape();
+  ArrayRef<int64_t> outShape = resultVectorType.getShape();
 
   // Check that product of source dim sizes matches product of result dim sizes.
-  int64_t nLowRankElms =
-      std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL,
-                      std::multiplies<int64_t>{});
-  int64_t nHighRankElms =
-      std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL,
-                      std::multiplies<int64_t>{});
-
-  if (nLowRankElms != nHighRankElms) {
+  int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
+                                    std::multiplies<int64_t>{});
+  int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
+                                     std::multiplies<int64_t>{});
+
+  if (nInElms != nOutElms) {
     return op->emitOpError(
         "has a different number of source and result elements");
   }
 
-  if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) {
+  if (!isValidShapeCast(inShape, outShape)) {
     return op->emitOpError(
         "is invalid (does not uniformly collapse or expand)");
   }
@@ -5641,12 +5641,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     if (resultType == srcType)
       return otherOp.getSource();
 
-    ArrayRef<int64_t> lowRankShape = srcType.getShape();
-    ArrayRef<int64_t> highRankShape = resultType.getShape();
-    if (lowRankShape.size() > highRankShape.size())
-      std::swap(lowRankShape, highRankShape);
-
-    if (!isValidExpandingShapeCast(lowRankShape, highRankShape))
+    if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
       return {};
 
     setOperand(otherOp.getSource());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..8d24e1bf2ba94 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1290,12 +1290,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 // -----
 
 // CHECK-LABEL: consecutive_shape_cast
-//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
-//  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>
-func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
+//  CHECK-NEXT:   return %[[C]] : vector<2x2x4xf16>
+func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
   %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
-  %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
-  return %1 : vector<4x4xf16>
+  %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
+  return %1 : vector<2x2x4xf16>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index f4becad3c79c1..2faa47c1b08a8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
 
 // CHECK-LABEL: f32_permute_leading_non_scalable_dims
 // CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
-func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
-  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32>
+func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> {
+  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
-  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
-  // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
-  return %res : vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32>
+  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32>
+  // CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32>
+  return %res : vector<1x6x[4]xf32>
 }
 
 // -----
@@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
 
 // CHECK-LABEL: f32_reduce_trailing_scalable_dim
 // CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
-func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
+func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32>
 {
-  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32>
+  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
-  %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
-  // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
-  return %res: vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
+  %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<3x2x[2]xf32>
+  // CHECK-NEXT: return %[[res5]] : vector<3x2x[2]xf32>
+  return %res: vector<3x2x[2]xf32>
 }
 
 // -----
 
 // CHECK-LABEL: f32_increase_trailing_scalable_dim
-// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
-func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
+// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf32>
+func.func @f32_increase_trailing_scalable_dim(%arg0: vector<2x2x[2]xf32>) -> vector<2x[4]xf32>
 {
   //  CHECK-DAG: %[[ub0:.*]] = ub.poison : vector<2x[4]xf32>
   //  CHECK-DAG: %[[ub1:.*]] = ub.poison : vector<[4]xf32>
-  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf32> from vector<2x2x[2]xf32>
   // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32>
-  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[2]xf32> from vector<4x[2]xf32>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf32> from vector<2x2x[2]xf32>
   // CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[ub0]] [0] : vector<[4]xf32> into vector<2x[4]xf32>
-  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<[2]xf32> from vector<4x[2]xf32>
+  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[2]xf32> from vector<2x2x[2]xf32>
   // CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32>
-  // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<[2]xf32> from vector<4x[2]xf32>
+  // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf32> from vector<2x2x[2]xf32>
   // CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32>
   // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32>
-  %res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32>
+  %res = vector.shape_cast %arg0: vector<2x2x[2]xf32> to vector<2x[4]xf32>
   // CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32>
   return %res: vector<2x[4]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..fbfe3789b871b 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -57,27 +57,6 @@ func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>)
   return %r0, %1 : vector<4xf32>, vector<2x2xf32>
 }
 
-// CHECK-LABEL: func @shape_cast_2d2d
-// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
-// CHECK: return %[[T11]] : vector<2x3xf32>
-
-func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
-  %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
-  return %s : vector<2x3xf32>
-}
 
 // CHECK-LABEL: func @shape_cast_3d1d
 // CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>

>From db1b71738607d73eb15c9fd14f94961dd6d0dc20 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 18:56:31 -0700
Subject: [PATCH 4/4] tighten

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  9 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 95 +++++++++----------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 16 ++--
 mlir/test/Dialect/Vector/ops.mlir             | 20 ++++
 4 files changed, 81 insertions(+), 59 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a9e25f23ef90f..7d5b5048131d8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2248,10 +2248,11 @@ def Vector_ShapeCastOp :
     vector. The element type remains the same, as does the number of elements
     (product of dimensions).
 
-    If reducing or preserving rank (n >= k), all result dimension sizes must be
-    products of contiguous source dimension sizes. If expanding rank (n < k),
-    source dimensions must all factor into contiguous sequences of destination
-    dimension sizes.
+    A shape_cast must be either collapsing or expanding. Collapsing means all
+    result dimension sizes are products of contiguous source dimension sizes.
+    Expanding means source dimensions all factor into contiguous sequences of
+    destination dimension sizes. Size 1 dimensions in source and destination
+    are ignored.
 
     Each source dim is expanded (or contiguous sequence of source dims combined)
     in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 07b6baf961a3e..d0797db00d528 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5532,41 +5532,34 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
+/// Returns true if each element of 'a' is either 1 or equal to the product of a
+/// contiguous sequence of the elements of 'b'. Returns false otherwise.
+///
+/// This function assumes that the product of elements in a and b are the same.
 static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
-  unsigned rankA = a.size();
-  unsigned rankB = b.size();
-  if (rankA > rankB) {
-    return false;
-  }
-
-  auto isOne = [](int64_t v) { return v == 1; };
-
-  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
-  // casted to a 0-d vector.
-  if (rankA == 0 && llvm::all_of(b, isOne))
-    return true;
 
+  unsigned rankA = a.size();
   unsigned i = 0;
   unsigned j = 0;
-  while (i < rankA && j < rankB) {
+  while (i < rankA) {
+    if (a[i] == 1) {
+      ++i;
+      continue;
+    }
+
     int64_t dimA = a[i];
     int64_t dimB = 1;
-    while (dimB < dimA && j < rankB)
+
+    while (dimB < dimA) {
       dimB *= b[j++];
-    if (dimA != dimB)
-      break;
-    ++i;
+    }
 
-    // Handle the case when trailing dimensions are of size 1.
-    // Include them into the contiguous sequence.
-    if (i < rankA && llvm::all_of(a.slice(i), isOne))
-      i = rankA;
-    if (j < rankB && llvm::all_of(b.slice(j), isOne))
-      j = rankB;
+    if (dimA != dimB) {
+      return false;
+    }
+    ++i;
   }
-  return i == rankA && j == rankB;
+  return true;
 }
 
 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
@@ -5582,7 +5575,8 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
   ArrayRef<int64_t> inShape = sourceVectorType.getShape();
   ArrayRef<int64_t> outShape = resultVectorType.getShape();
 
-  // Check that product of source dim sizes matches product of result dim sizes.
+  // Check that product of source dim sizes matches product of result dim
+  // sizes.
   int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
                                     std::multiplies<int64_t>{});
   int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
@@ -5702,8 +5696,8 @@ static VectorType trimTrailingOneDims(VectorType oldType) {
 ///
 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
 /// dimension. If the input vector comes from `vector.create_mask` for which
-/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
-/// to fold shape_cast into create_mask.
+/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is
+/// safe to fold shape_cast into create_mask.
 ///
 /// BEFORE:
 ///    %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
@@ -5970,8 +5964,8 @@ LogicalResult TypeCastOp::verify() {
   auto resultType = getResultMemRefType();
   if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
       getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
-    return emitOpError(
-               "expects result and operand with same underlying scalar type: ")
+    return emitOpError("expects result and operand with same underlying "
+                       "scalar type: ")
            << resultType;
   if (extractShape(sourceType) != extractShape(resultType))
     return emitOpError(
@@ -6009,7 +6003,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
       return attr.reshape(getResultVectorType());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
-  // input vector remain in their original order after the transpose operation.
+  // input vector remain in their original order after the transpose
+  // operation.
   ArrayRef<int64_t> perm = getPermutation();
 
   // Check if the permutation of the dimensions contains sequential values:
@@ -6068,7 +6063,8 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
       return result;
     };
 
-    // Return if the input of 'transposeOp' is not defined by another transpose.
+    // Return if the input of 'transposeOp' is not defined by another
+    // transpose.
     vector::TransposeOp parentTransposeOp =
         transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
     if (!parentTransposeOp)
@@ -6212,8 +6208,9 @@ LogicalResult ConstantMaskOp::verify() {
       return emitOpError(
           "only supports 'none set' or 'all set' scalable dimensions");
   }
-  // Verify that if one mask dim size is zero, they all should be zero (because
-  // the mask region is a conjunction of each mask dimension interval).
+  // Verify that if one mask dim size is zero, they all should be zero
+  // (because the mask region is a conjunction of each mask dimension
+  // interval).
   bool anyZeros = llvm::is_contained(maskDimSizes, 0);
   bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
   if (anyZeros && !allZeros)
@@ -6251,7 +6248,8 @@ void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
 
 LogicalResult CreateMaskOp::verify() {
   auto vectorType = llvm::cast<VectorType>(getResult().getType());
-  // Verify that an operand was specified for each result vector each dimension.
+  // Verify that an operand was specified for each result vector each
+  // dimension.
   if (vectorType.getRank() == 0) {
     if (getNumOperands() != 1)
       return emitOpError(
@@ -6458,8 +6456,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
   OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
       MaskOp>::ensureTerminator(region, builder, loc);
-  // Keep the default yield terminator if the number of masked operations is not
-  // the expected. This case will trigger a verification failure.
+  // Keep the default yield terminator if the number of masked operations is
+  // not the expected. This case will trigger a verification failure.
   Block &block = region.front();
   if (block.getOperations().size() != 2)
     return;
@@ -6563,9 +6561,9 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
-// Elides empty vector.mask operations with or without return values. Propagates
-// the yielded values by the vector.yield terminator, if any, or erases the op,
-// otherwise.
+// Elides empty vector.mask operations with or without return values.
+// Propagates the yielded values by the vector.yield terminator, if any, or
+// erases the op, otherwise.
 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -6668,7 +6666,8 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
     return {};
 
-  // SplatElementsAttr::get treats single value for second arg as being a splat.
+  // SplatElementsAttr::get treats single value for second arg as being a
+  // splat.
   return SplatElementsAttr::get(getType(), {constOperand});
 }
 
@@ -6790,12 +6789,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
 }
 
 /// Creates a vector select operation that picks values from `newValue` or
-/// `passthru` for each result vector lane based on `mask`. This utility is used
-/// to propagate the pass-thru value of vector.mask or for cases where only the
-/// pass-thru value propagation is needed. VP intrinsics do not support
-/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
-/// usually able to match op + select patterns and fold them into a native
-/// target instructions.
+/// `passthru` for each result vector lane based on `mask`. This utility is
+/// used to propagate the pass-thru value of vector.mask or for cases where
+/// only the pass-thru value propagation is needed. VP intrinsics do not
+/// support pass-thru values and every mask-out lane is set to poison. LLVM
+/// backends are usually able to match op + select patterns and fold them into
+/// a native target instructions.
 Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
                                    Value newValue, Value passthru) {
   if (!mask)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8d24e1bf2ba94..986e11d948052 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
 
 // -----
 
+// The definition of shape_cast stipulates that it must be either expanding or collapsing,
+// it cannot be a mixture of both.
 // CHECK-LABEL: dont_fold_expand_collapse
-//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-//       CHECK:   return %[[B]] : vector<8x8xf32>
-func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
-    %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
-    return %1 : vector<8x8xf32>
+//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+//       CHECK:   return %[[B]] : vector<4x3x3xf32>
+func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> {
+    %0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+    %1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+    return %1 : vector<4x3x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 527bccf8383ca..504c6c300e9f0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -581,9 +581,29 @@ func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32
 
   // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
   %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+
   return %0 : vector<4x1xf32>
 }
 
+
+// CHECK-LABEL: @collapse_but_increase_rank
+func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+
+  return %0 : vector<1x6x1x35x1xf32>
+}
+
+// CHECK-LABEL: @expand_but_decrease_rank
+func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8>
+  %0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8>
+
+  return %0 : vector<2x3xi8>
+}
+
 // CHECK-LABEL: @bitcast
 func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xi32>,



More information about the Mlir-commits mailing list