[Mlir-commits] [mlir] b9745ad - [mlir][tensor/memref] Disallow Collapse/ExpandShapeOps that do not reduce/increase the rank
Matthias Springer
llvmlistbot at llvm.org
Wed Nov 23 00:31:32 PST 2022
Author: Matthias Springer
Date: 2022-11-23T09:19:35+01:00
New Revision: b9745ad81273a0d866873fdaee2843ca87e15f18
URL: https://github.com/llvm/llvm-project/commit/b9745ad81273a0d866873fdaee2843ca87e15f18
DIFF: https://github.com/llvm/llvm-project/commit/b9745ad81273a0d866873fdaee2843ca87e15f18.diff
LOG: [mlir][tensor/memref] Disallow Collapse/ExpandShapeOps that do not reduce/increase the rank
CollapseShapeOp/ExpandShapeOp that do not change the rank (or increase/reduce it) are invalid.
Differential Revision: https://reviews.llvm.org/D138498
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/Tensor/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e53879b618cc7..2bbb57e6e0d28 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2219,6 +2219,10 @@ LogicalResult ExpandShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
+ if (srcType.getRank() >= resultType.getRank())
+ return emitOpError("expected rank expansion, but found source rank ")
+ << srcType.getRank() << " >= result rank " << resultType.getRank();
+
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
resultType.getShape(),
@@ -2370,6 +2374,10 @@ LogicalResult CollapseShapeOp::verify() {
MemRefType srcType = getSrcType();
MemRefType resultType = getResultType();
+ if (srcType.getRank() <= resultType.getRank())
+ return emitOpError("expected rank reduction, but found source rank ")
+ << srcType.getRank() << " <= result rank " << resultType.getRank();
+
// Verify result shape.
if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
srcType.getShape(), getReassociationIndices(),
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 826c69e23f048..36e3aadbc5982 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1398,10 +1398,22 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
}
LogicalResult ExpandShapeOp::verify() {
+ auto srcType = getSrcType();
+ auto resultType = getResultType();
+ if (srcType.getRank() >= resultType.getRank())
+ return emitOpError("expected rank expansion, but found source rank ")
+ << srcType.getRank() << " >= result rank " << resultType.getRank();
+
return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
}
LogicalResult CollapseShapeOp::verify() {
+ auto srcType = getSrcType();
+ auto resultType = getResultType();
+ if (srcType.getRank() <= resultType.getRank())
+ return emitOpError("expected rank reduction, but found source rank ")
+ << srcType.getRank() << " <= result rank " << resultType.getRank();
+
return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 344f22cb7d2eb..ccbf929dbd202 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -392,9 +392,9 @@ func.func @copy_
diff erent_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
// -----
-func.func @expand_shape(%arg0: memref<f32>) {
- // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 0}}
- %0 = memref.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
+func.func @expand_shape(%arg0: memref<?x?xf32>) {
+ // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}}
+ %0 = memref.expand_shape %arg0 [[0, 1]] : memref<?x?xf32> into memref<?x5x?xf32>
return
}
@@ -408,16 +408,30 @@ func.func @expand_shape(%arg0: memref<f32>) {
// -----
-func.func @collapse_shape_to_higher_rank(%arg0: memref<f32>) {
- // expected-error @+1 {{op reassociation index 0 is out of bounds}}
- %0 = memref.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
+func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
+ // expected-error @+1 {{op reassociation index 2 is out of bounds}}
+ %0 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<?x?xf32> into memref<?xf32>
+}
+
+// -----
+
+func.func @expand_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
+ // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
+ %0 = memref.expand_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
+ // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
+ %0 = memref.collapse_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
}
// -----
-func.func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
- // expected-error @+1 {{op reassociation index 0 is out of bounds}}
- %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
+func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{op reassociation index 2 is out of bounds}}
+ %0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index b085053296ca4..4a25cd494a7e2 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -294,6 +294,20 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
// -----
+func.func @expand_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
+ // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
+ %0 = tensor.expand_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
+ // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
+ %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
+}
+
+// -----
+
func.func @rank(%0: f32) {
// expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
"tensor.rank"(%0): (f32)->index
More information about the Mlir-commits
mailing list