[Mlir-commits] [mlir] [mlir][vector] Better handle rank-preserving shape_cast (PR #135855)
James Newling
llvmlistbot at llvm.org
Tue Apr 15 16:14:56 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135855
>From c1f4264a71d6d80350056d0d9ca86a0ac2c1e04f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 13:09:43 -0700
Subject: [PATCH 1/3] fix edge case where n=k (rank-preserving shape_cast)
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 17 ++---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 63 ++++++++++---------
mlir/test/Dialect/Vector/invalid.mlir | 15 +++--
mlir/test/Dialect/Vector/ops.mlir | 8 +++
4 files changed, 61 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..a9e25f23ef90f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,19 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
- The shape_cast operation casts between an n-D source vector shape and
- a k-D result vector shape (the element type remains the same).
+ The shape_cast operation casts from an n-D source vector to a k-D result
+ vector. The element type remains the same, as does the number of elements
+ (product of dimensions).
+
+ If reducing or preserving rank (n >= k), all result dimension sizes must be
+ products of contiguous source dimension sizes. If expanding rank (n < k),
+ source dimensions must all factor into contiguous sequences of destination
+ dimension sizes.
- If reducing rank (n > k), result dimension sizes must be a product
- of contiguous source dimension sizes.
- If expanding rank (n < k), source dimensions must factor into a
- contiguous sequence of destination dimension sizes.
Each source dim is expanded (or contiguous sequence of source dims combined)
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
sequence of result dims (or a single result dim), in result dimension list
- order (i.e. 0 <= j < k). The product of all source dimension sizes and all
- result dimension sizes must match.
+ order (i.e. 0 <= j < k).
It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..554dbba081898 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5534,10 +5534,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+static bool isValidExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
unsigned rankA = a.size();
unsigned rankB = b.size();
- assert(rankA < rankB);
+ assert(rankA <= rankB);
auto isOne = [](int64_t v) { return v == 1; };
@@ -5573,34 +5573,36 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
VectorType resultVectorType) {
// Check that element type is the same.
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
- return op->emitOpError("source/result vectors must have same element type");
- auto sourceShape = sourceVectorType.getShape();
- auto resultShape = resultVectorType.getShape();
+ return op->emitOpError("has different source and result element types");
+ ArrayRef<int64_t> lowRankShape = sourceVectorType.getShape();
+ ArrayRef<int64_t> highRankShape = resultVectorType.getShape();
+ if (lowRankShape.size() > highRankShape.size())
+ std::swap(lowRankShape, highRankShape);
// Check that product of source dim sizes matches product of result dim sizes.
- int64_t sourceDimProduct = std::accumulate(
- sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
- int64_t resultDimProduct = std::accumulate(
- resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
- if (sourceDimProduct != resultDimProduct)
- return op->emitOpError("source/result number of elements must match");
-
- // Check that expanding/contracting rank cases.
- unsigned sourceRank = sourceVectorType.getRank();
- unsigned resultRank = resultVectorType.getRank();
- if (sourceRank < resultRank) {
- if (!isValidShapeCast(sourceShape, resultShape))
- return op->emitOpError("invalid shape cast");
- } else if (sourceRank > resultRank) {
- if (!isValidShapeCast(resultShape, sourceShape))
- return op->emitOpError("invalid shape cast");
+ int64_t nLowRankElms =
+ std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+ int64_t nHighRankElms =
+ std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+
+ if (nLowRankElms != nHighRankElms) {
+ return op->emitOpError(
+ "has a different number of source and result elements");
+ }
+
+ if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) {
+ return op->emitOpError(
+ "is invalid (does not uniformly collapse or expand)");
}
// Check that (non-)scalability is preserved
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
- return op->emitOpError("different number of scalable dims at source (")
+ return op->emitOpError(
+ "has a different number of scalable dims at source (")
<< sourceNScalableDims << ") and result (" << resultNScalableDims
<< ")";
sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5636,18 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
+
if (resultType == srcType)
return otherOp.getSource();
- if (srcType.getRank() < resultType.getRank()) {
- if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
- return {};
- } else if (srcType.getRank() > resultType.getRank()) {
- if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
- return {};
- } else {
+
+ ArrayRef<int64_t> lowRankShape = srcType.getShape();
+ ArrayRef<int64_t> highRankShape = resultType.getShape();
+ if (lowRankShape.size() > highRankShape.size())
+ std::swap(lowRankShape, highRankShape);
+
+ if (!isValidExpandingShapeCast(lowRankShape, highRankShape))
return {};
- }
+
setOperand(otherOp.getSource());
return getResult();
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..9f94fb0574504 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{op source/result vectors must have same element type}}
+ // expected-error at +1 {{op has different source and result element types}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}
// -----
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{op source/result number of elements must match}}
+ // expected-error at +1 {{op has a different number of source and result elements}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}
// -----
+func.func @shape_cast_invalid_rank_preservating(%arg0 : vector<3x2xf32>) {
+ // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
+ %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
+}
+
+// -----
+
func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{invalid shape cast}}
+ // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
}
// -----
func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
- // expected-error at +1 {{invalid shape cast}}
+ // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..527bccf8383ca 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -576,6 +576,14 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
return %1 : vector<1x1x1x1xf32>
}
+// CHECK-LABEL: @shape_cast_rank_preserving
+func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+ %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+ return %0 : vector<4x1xf32>
+}
+
// CHECK-LABEL: @bitcast
func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xi32>,
>From b4914638b9426472c780cc94ac29224353d1b9d4 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 15:10:54 -0700
Subject: [PATCH 2/3] clang-format
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 554dbba081898..120dd57659e6b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5534,7 +5534,8 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
+ ArrayRef<int64_t> b) {
unsigned rankA = a.size();
unsigned rankB = b.size();
assert(rankA <= rankB);
>From 7bfc219d9d3e4881a70d6866ec59b077da22e71b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 16:14:40 -0700
Subject: [PATCH 3/3] update tests
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 41 ++++++-------
mlir/test/Dialect/Vector/canonicalize.mlir | 10 ++--
...-shape-cast-lowering-scalable-vectors.mlir | 58 +++++++++----------
...vector-shape-cast-lowering-transforms.mlir | 21 -------
4 files changed, 52 insertions(+), 78 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 120dd57659e6b..07b6baf961a3e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5534,11 +5534,12 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
- ArrayRef<int64_t> b) {
+static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
unsigned rankA = a.size();
unsigned rankB = b.size();
- assert(rankA <= rankB);
+ if (rankA > rankB) {
+ return false;
+ }
auto isOne = [](int64_t v) { return v == 1; };
@@ -5565,35 +5566,34 @@ static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
if (j < rankB && llvm::all_of(b.slice(j), isOne))
j = rankB;
}
-
return i == rankA && j == rankB;
}
+static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+ return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
+}
+
static LogicalResult verifyVectorShapeCast(Operation *op,
VectorType sourceVectorType,
VectorType resultVectorType) {
// Check that element type is the same.
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
return op->emitOpError("has different source and result element types");
- ArrayRef<int64_t> lowRankShape = sourceVectorType.getShape();
- ArrayRef<int64_t> highRankShape = resultVectorType.getShape();
- if (lowRankShape.size() > highRankShape.size())
- std::swap(lowRankShape, highRankShape);
+ ArrayRef<int64_t> inShape = sourceVectorType.getShape();
+ ArrayRef<int64_t> outShape = resultVectorType.getShape();
// Check that product of source dim sizes matches product of result dim sizes.
- int64_t nLowRankElms =
- std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL,
- std::multiplies<int64_t>{});
- int64_t nHighRankElms =
- std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL,
- std::multiplies<int64_t>{});
-
- if (nLowRankElms != nHighRankElms) {
+ int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+ int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+
+ if (nInElms != nOutElms) {
return op->emitOpError(
"has a different number of source and result elements");
}
- if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) {
+ if (!isValidShapeCast(inShape, outShape)) {
return op->emitOpError(
"is invalid (does not uniformly collapse or expand)");
}
@@ -5641,12 +5641,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
if (resultType == srcType)
return otherOp.getSource();
- ArrayRef<int64_t> lowRankShape = srcType.getShape();
- ArrayRef<int64_t> highRankShape = resultType.getShape();
- if (lowRankShape.size() > highRankShape.size())
- std::swap(lowRankShape, highRankShape);
-
- if (!isValidExpandingShapeCast(lowRankShape, highRankShape))
+ if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
setOperand(otherOp.getSource());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..8d24e1bf2ba94 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1290,12 +1290,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----
// CHECK-LABEL: consecutive_shape_cast
-// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
-// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
-func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
+// CHECK-NEXT: return %[[C]] : vector<2x2x4xf16>
+func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
%0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
- %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
- return %1 : vector<4x4xf16>
+ %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
+ return %1 : vector<2x2x4xf16>
}
// -----
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index f4becad3c79c1..2faa47c1b08a8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
// CHECK-LABEL: f32_permute_leading_non_scalable_dims
// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
-func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
- // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32>
+func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> {
+ // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
- %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
- // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
- return %res : vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32>
+ %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32>
+ // CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32>
+ return %res : vector<1x6x[4]xf32>
}
// -----
@@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
// CHECK-LABEL: f32_reduce_trailing_scalable_dim
// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
-func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
+func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32>
{
- // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32>
+ // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
- %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
- // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
- return %res: vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
+ %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<3x2x[2]xf32>
+ // CHECK-NEXT: return %[[res5]] : vector<3x2x[2]xf32>
+ return %res: vector<3x2x[2]xf32>
}
// -----
// CHECK-LABEL: f32_increase_trailing_scalable_dim
-// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
-func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
+// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf32>
+func.func @f32_increase_trailing_scalable_dim(%arg0: vector<2x2x[2]xf32>) -> vector<2x[4]xf32>
{
// CHECK-DAG: %[[ub0:.*]] = ub.poison : vector<2x[4]xf32>
// CHECK-DAG: %[[ub1:.*]] = ub.poison : vector<[4]xf32>
- // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32>
+ // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf32> from vector<2x2x[2]xf32>
// CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32>
- // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[2]xf32> from vector<4x[2]xf32>
+ // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf32> from vector<2x2x[2]xf32>
// CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32>
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[ub0]] [0] : vector<[4]xf32> into vector<2x[4]xf32>
- // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<[2]xf32> from vector<4x[2]xf32>
+ // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[2]xf32> from vector<2x2x[2]xf32>
// CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32>
- // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<[2]xf32> from vector<4x[2]xf32>
+ // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf32> from vector<2x2x[2]xf32>
// CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32>
// CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32>
- %res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32>
+ %res = vector.shape_cast %arg0: vector<2x2x[2]xf32> to vector<2x[4]xf32>
// CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32>
return %res: vector<2x[4]xf32>
}
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..fbfe3789b871b 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -57,27 +57,6 @@ func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>)
return %r0, %1 : vector<4xf32>, vector<2x2xf32>
}
-// CHECK-LABEL: func @shape_cast_2d2d
-// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
-// CHECK: return %[[T11]] : vector<2x3xf32>
-
-func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
- %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
- return %s : vector<2x3xf32>
-}
// CHECK-LABEL: func @shape_cast_3d1d
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
More information about the Mlir-commits
mailing list