[Mlir-commits] [mlir] Relax checks on vector.shape_cast (PR #136587)
James Newling
llvmlistbot at llvm.org
Tue Apr 22 16:30:06 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/136587
>From 84dff32df484a5d001db6334f034ff343900cac2 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 10:48:22 -0700
Subject: [PATCH 1/4] remove checks for collapse / expand
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 21 +----
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 83 +++----------------
mlir/test/Dialect/Vector/canonicalize.mlir | 7 +-
mlir/test/Dialect/Vector/invalid.mlir | 13 ---
mlir/test/Dialect/Vector/ops.mlir | 16 ++++
5 files changed, 32 insertions(+), 108 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d7518943229ea..64b5c58cfec24 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
- The shape_cast operation casts between an n-D source vector shape and
- a k-D result vector shape (the element type remains the same).
-
- If reducing rank (n > k), result dimension sizes must be a product
- of contiguous source dimension sizes.
- If expanding rank (n < k), source dimensions must factor into a
- contiguous sequence of destination dimension sizes.
- Each source dim is expanded (or contiguous sequence of source dims combined)
- in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
- sequence of result dims (or a single result dim), in result dimension list
- order (i.e. 0 <= j < k). The product of all source dimension sizes and all
- result dimension sizes must match.
+ The shape_cast operation casts from a source vector to a target vector,
+ retaining the element type and number of elements.
It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
@@ -2268,12 +2258,7 @@ def Vector_ShapeCastOp :
Example:
```mlir
- // Example casting to a lower vector rank.
- %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
-
- // Example casting to a higher vector rank.
- %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
-
+ %1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
```
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 368259b38b153..2fcfab4770d4d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5505,48 +5505,18 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
- unsigned rankA = a.size();
- unsigned rankB = b.size();
- assert(rankA < rankB);
-
- auto isOne = [](int64_t v) { return v == 1; };
-
- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
- // casted to a 0-d vector.
- if (rankA == 0 && llvm::all_of(b, isOne))
- return true;
-
- unsigned i = 0;
- unsigned j = 0;
- while (i < rankA && j < rankB) {
- int64_t dimA = a[i];
- int64_t dimB = 1;
- while (dimB < dimA && j < rankB)
- dimB *= b[j++];
- if (dimA != dimB)
- break;
- ++i;
-
- // Handle the case when trailing dimensions are of size 1.
- // Include them into the contiguous sequence.
- if (i < rankA && llvm::all_of(a.slice(i), isOne))
- i = rankA;
- if (j < rankB && llvm::all_of(b.slice(j), isOne))
- j = rankB;
- }
+LogicalResult ShapeCastOp::verify() {
+ auto sourceVectorType =
+ llvm::dyn_cast_or_null<VectorType>(getSource().getType());
+ auto resultVectorType =
+ llvm::dyn_cast_or_null<VectorType>(getResult().getType());
- return i == rankA && j == rankB;
-}
+ if (!sourceVectorType) return failure();
+ if (!resultVectorType) return failure();
-static LogicalResult verifyVectorShapeCast(Operation *op,
- VectorType sourceVectorType,
- VectorType resultVectorType) {
// Check that element type is the same.
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
- return op->emitOpError("source/result vectors must have same element type");
+ return emitOpError("source/result vectors must have same element type");
auto sourceShape = sourceVectorType.getShape();
auto resultShape = resultVectorType.getShape();
@@ -5556,24 +5526,13 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
int64_t resultDimProduct = std::accumulate(
resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
if (sourceDimProduct != resultDimProduct)
- return op->emitOpError("source/result number of elements must match");
-
- // Check that expanding/contracting rank cases.
- unsigned sourceRank = sourceVectorType.getRank();
- unsigned resultRank = resultVectorType.getRank();
- if (sourceRank < resultRank) {
- if (!isValidShapeCast(sourceShape, resultShape))
- return op->emitOpError("invalid shape cast");
- } else if (sourceRank > resultRank) {
- if (!isValidShapeCast(resultShape, sourceShape))
- return op->emitOpError("invalid shape cast");
- }
+ return emitOpError("source/result number of elements must match");
// Check that (non-)scalability is preserved
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
- return op->emitOpError("different number of scalable dims at source (")
+ return emitOpError("different number of scalable dims at source (")
<< sourceNScalableDims << ") and result (" << resultNScalableDims
<< ")";
sourceVectorType.getNumDynamicDims();
@@ -5581,19 +5540,6 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
return success();
}
-LogicalResult ShapeCastOp::verify() {
- auto sourceVectorType =
- llvm::dyn_cast_or_null<VectorType>(getSource().getType());
- auto resultVectorType =
- llvm::dyn_cast_or_null<VectorType>(getResult().getType());
-
- // Check if source/result are of vector type.
- if (sourceVectorType && resultVectorType)
- return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
-
- return success();
-}
-
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// No-op shape cast.
@@ -5609,15 +5555,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType srcType = otherOp.getSource().getType();
if (resultType == srcType)
return otherOp.getSource();
- if (srcType.getRank() < resultType.getRank()) {
- if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
- return {};
- } else if (srcType.getRank() > resultType.getRank()) {
- if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
- return {};
- } else {
- return {};
- }
setOperand(otherOp.getSource());
return getResult();
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2d365ac2b4287..5bf8b9338c498 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,10 +950,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
// -----
-// CHECK-LABEL: dont_fold_expand_collapse
-// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-// CHECK: return %[[B]] : vector<8x8xf32>
+// CHECK-LABEL: fold_expand_collapse
+// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32>
+// CHECK: return %[[A]] : vector<8x8xf32>
func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
%0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
%1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3a8320971bac4..14399d5a19394 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1145,19 +1145,6 @@ func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
// -----
-func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{invalid shape cast}}
- %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
-}
-
-// -----
-
-func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
- // expected-error at +1 {{invalid shape cast}}
- %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
-}
-
-// -----
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
// expected-error at +1 {{different number of scalable dims at source (1) and result (0)}}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..bbd8f51445549 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -543,6 +543,22 @@ func.func @vector_print_on_scalar(%arg0: i64) {
return
}
+// CHECK-LABEL: @shape_cast_valid_rank_reduction
+func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
+ // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<2x15xf32>
+ %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
+ return
+}
+
+
+// CHECK-LABEL: @shape_cast_valid_rank_expansion
+func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
+ // CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
+ %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
+ return
+}
+
+
// CHECK-LABEL: @shape_cast
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xf32>,
>From b04d8692afd8e2a67f9438159997ffbd421ab2d9 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 10:57:01 -0700
Subject: [PATCH 2/4] clang-format
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2fcfab4770d4d..fec71935c51e3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5511,8 +5511,10 @@ LogicalResult ShapeCastOp::verify() {
auto resultVectorType =
llvm::dyn_cast_or_null<VectorType>(getResult().getType());
- if (!sourceVectorType) return failure();
- if (!resultVectorType) return failure();
+ if (!sourceVectorType)
+ return failure();
+ if (!resultVectorType)
+ return failure();
// Check that element type is the same.
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
>From 5753fbc67fe33ab71c9258c703fac35cbddfdb78 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 11:25:12 -0700
Subject: [PATCH 3/4] fold shape_casts now
---
.../Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
index ae2b5393ca449..60ad54bf5c370 100644
--- a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
@@ -26,8 +26,7 @@ func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
// CHECK-NEXT: vector.insert {{.*}}[1]
// CHECK-NEXT: vector.insert {{.*}}[2]
// CHECK-NEXT: vector.insert {{.*}}[3]
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
+ // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<8x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
return %0 : vector<8x4xf32>
}
@@ -54,8 +53,7 @@ func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32>
// CHECK-NEXT: vector.insert {{.*}}[1]
// CHECK-NEXT: vector.insert {{.*}}[2]
// CHECK-NEXT: vector.insert {{.*}}[3]
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32>
+ // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<1x8x4xf32>
%0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32>
return %0 : vector<1x8x4xf32>
}
>From 570b023f2029410258d93b497914b60026488874 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Apr 2025 16:29:55 -0700
Subject: [PATCH 4/4] improve draft version, add canonicalizer
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 7 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 67 ++++++++++---------
mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++
mlir/test/Dialect/Vector/invalid.mlir | 6 +-
mlir/test/Dialect/Vector/ops.mlir | 2 -
5 files changed, 53 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 64b5c58cfec24..ddfd78f6c2dc0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,8 +2244,8 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
- The shape_cast operation casts from a source vector to a target vector,
- retaining the element type and number of elements.
+ Casts to a vector with the same number of elements, element type, and
+ number of scalable dimensions.
It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
@@ -2255,10 +2255,11 @@ def Vector_ShapeCastOp :
2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
is supported in that particular case, for now.
- Example:
+ Examples:
```mlir
%1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
+ %2 = vector.shape_cast %0 : vector<[2]x3x[4]xi8> to vector<3x[1]x[8]xi8>
```
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fec71935c51e3..732a5d21a4b87 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5506,54 +5506,46 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
}
LogicalResult ShapeCastOp::verify() {
- auto sourceVectorType =
- llvm::dyn_cast_or_null<VectorType>(getSource().getType());
- auto resultVectorType =
- llvm::dyn_cast_or_null<VectorType>(getResult().getType());
- if (!sourceVectorType)
- return failure();
- if (!resultVectorType)
- return failure();
+ VectorType sourceType = getSourceVectorType();
+ VectorType resultType = getResultVectorType();
- // Check that element type is the same.
- if (sourceVectorType.getElementType() != resultVectorType.getElementType())
- return emitOpError("source/result vectors must have same element type");
- auto sourceShape = sourceVectorType.getShape();
- auto resultShape = resultVectorType.getShape();
+ // Check that element type is preserved
+ if (sourceType.getElementType() != resultType.getElementType())
+ return emitOpError("has different source and result element types");
- // Check that product of source dim sizes matches product of result dim sizes.
- int64_t sourceDimProduct = std::accumulate(
- sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
- int64_t resultDimProduct = std::accumulate(
- resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
- if (sourceDimProduct != resultDimProduct)
- return emitOpError("source/result number of elements must match");
+ // Check that number of elements is preserved
+ int64_t sourceNElms = sourceType.getNumElements();
+ int64_t resultNElms = resultType.getNumElements();
+ if (sourceNElms != resultNElms) {
+ return emitOpError() << "has different number of elements at source ("
+ << sourceNElms << ") and result (" << resultNElms
+ << ")";
+ }
// Check that (non-)scalability is preserved
- int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
- int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
+ int64_t sourceNScalableDims = sourceType.getNumScalableDims();
+ int64_t resultNScalableDims = resultType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
- return emitOpError("different number of scalable dims at source (")
- << sourceNScalableDims << ") and result (" << resultNScalableDims
- << ")";
- sourceVectorType.getNumDynamicDims();
+ return emitOpError() << "has different number of scalable dims at source ("
+ << sourceNScalableDims << ") and result ("
+ << resultNScalableDims << ")";
return success();
}
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+ VectorType resultType = getType();
+
// No-op shape cast.
- if (getSource().getType() == getType())
+ if (getSource().getType() == resultType)
return getSource();
- VectorType resultType = getType();
-
- // Canceling shape casts.
+ // Y = shape_cast(shape_cast(X)))
+ // -> X, if X and Y have same type
+ // -> shape_cast(X) otherwise.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
- // Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
if (resultType == srcType)
return otherOp.getSource();
@@ -5561,10 +5553,19 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return getResult();
}
- // Cancelling broadcast and shape cast ops.
+ // Y = shape_cast(broadcast(X))
+ // -> X, if X and Y have same type, else
+ // -> shape_cast(X) if X is a vector and the broadcast preserves
+ // number of elements.
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
if (bcastOp.getSourceType() == resultType)
return bcastOp.getSource();
+ if (auto bcastSrcType = dyn_cast<VectorType>(bcastOp.getSourceType())) {
+ if (bcastSrcType.getNumElements() == resultType.getNumElements()) {
+ setOperand(bcastOp.getSource());
+ return getResult();
+ }
+ }
}
// shape_cast(constant) -> constant
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5bf8b9338c498..04d8e613d4156 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -972,6 +972,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @fold_count_preserving_broadcast_shapecast
+// CHECK-SAME: (%[[V:.+]]: vector<4xf32>)
+// CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[V]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: return %[[SHAPECAST]] : vector<2x2xf32>
+func.func @fold_count_preserving_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<2x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
// CHECK: vector.broadcast
// CHECK-NOT: vector.shape_cast
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 14399d5a19394..fa4837126accb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1131,21 +1131,21 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----
+
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{op source/result vectors must have same element type}}
+ // expected-error at +1 {{'vector.shape_cast' op has different source and result element types}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}
// -----
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{op source/result number of elements must match}}
+ // expected-error at +1 {{'vector.shape_cast' op has different number of elements at source (30) and result (20)}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}
// -----
-
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
// expected-error at +1 {{different number of scalable dims at source (1) and result (0)}}
%0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index bbd8f51445549..36f7db8c39d4d 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -550,7 +550,6 @@ func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
return
}
-
// CHECK-LABEL: @shape_cast_valid_rank_expansion
func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
// CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
@@ -558,7 +557,6 @@ func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
return
}
-
// CHECK-LABEL: @shape_cast
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xf32>,
More information about the Mlir-commits
mailing list