[Mlir-commits] [mlir] b9745ad - [mlir][tensor/memref] Disallow Collapse/ExpandShapeOps that do not reduce/increase the rank

Matthias Springer llvmlistbot at llvm.org
Wed Nov 23 00:31:32 PST 2022


Author: Matthias Springer
Date: 2022-11-23T09:19:35+01:00
New Revision: b9745ad81273a0d866873fdaee2843ca87e15f18

URL: https://github.com/llvm/llvm-project/commit/b9745ad81273a0d866873fdaee2843ca87e15f18
DIFF: https://github.com/llvm/llvm-project/commit/b9745ad81273a0d866873fdaee2843ca87e15f18.diff

LOG: [mlir][tensor/memref] Disallow Collapse/ExpandShapeOps that do not reduce/increase the rank

CollapseShapeOp/ExpandShapeOp that do not change the rank (or increase/reduce it) are invalid.

Differential Revision: https://reviews.llvm.org/D138498

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/Tensor/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e53879b618cc7..2bbb57e6e0d28 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2219,6 +2219,10 @@ LogicalResult ExpandShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
+  if (srcType.getRank() >= resultType.getRank())
+    return emitOpError("expected rank expansion, but found source rank ")
+           << srcType.getRank() << " >= result rank " << resultType.getRank();
+
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
                                   resultType.getShape(),
@@ -2370,6 +2374,10 @@ LogicalResult CollapseShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
 
+  if (srcType.getRank() <= resultType.getRank())
+    return emitOpError("expected rank reduction, but found source rank ")
+           << srcType.getRank() << " <= result rank " << resultType.getRank();
+
   // Verify result shape.
   if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
                                   srcType.getShape(), getReassociationIndices(),

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 826c69e23f048..36e3aadbc5982 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1398,10 +1398,22 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
+  auto srcType = getSrcType();
+  auto resultType = getResultType();
+  if (srcType.getRank() >= resultType.getRank())
+    return emitOpError("expected rank expansion, but found source rank ")
+           << srcType.getRank() << " >= result rank " << resultType.getRank();
+
   return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
 }
 
 LogicalResult CollapseShapeOp::verify() {
+  auto srcType = getSrcType();
+  auto resultType = getResultType();
+  if (srcType.getRank() <= resultType.getRank())
+    return emitOpError("expected rank reduction, but found source rank ")
+           << srcType.getRank() << " <= result rank " << resultType.getRank();
+
   return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
 }
 

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 344f22cb7d2eb..ccbf929dbd202 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -392,9 +392,9 @@ func.func @copy_
diff erent_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
 
 // -----
 
-func.func @expand_shape(%arg0: memref<f32>) {
-  // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 0}}
-  %0 = memref.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
+func.func @expand_shape(%arg0: memref<?x?xf32>) {
+  // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}}
+  %0 = memref.expand_shape %arg0 [[0, 1]] : memref<?x?xf32> into memref<?x5x?xf32>
   return
 }
 
@@ -408,16 +408,30 @@ func.func @expand_shape(%arg0: memref<f32>) {
 
 // -----
 
-func.func @collapse_shape_to_higher_rank(%arg0: memref<f32>) {
-  // expected-error @+1 {{op reassociation index 0 is out of bounds}}
-  %0 = memref.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
+func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
+  // expected-error @+1 {{op reassociation index 2 is out of bounds}}
+  %0 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<?x?xf32> into memref<?xf32>
+}
+
+// -----
+
+func.func @expand_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
+  // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
+  %0 = memref.expand_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_ranks(%arg0: memref<?x?xf32>) {
+  // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
+  %0 = memref.collapse_shape %arg0 [[0], [1]] : memref<?x?xf32> into memref<?x?xf32>
 }
 
 // -----
 
-func.func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
-  // expected-error @+1 {{op reassociation index 0 is out of bounds}}
-  %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
+func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
+  // expected-error @+1 {{op reassociation index 2 is out of bounds}}
+  %0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index b085053296ca4..4a25cd494a7e2 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -294,6 +294,20 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
 
 // -----
 
+func.func @expand_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
+  // expected-error @+1 {{op expected rank expansion, but found source rank 2 >= result rank 2}}
+  %0 = tensor.expand_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
+}
+
+// -----
+
+func.func @collapse_shape_invalid_ranks(%arg0: tensor<?x?xf32>) {
+  // expected-error @+1 {{op expected rank reduction, but found source rank 2 <= result rank 2}}
+  %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<?x?xf32> into tensor<?x?xf32>
+}
+
+// -----
+
 func.func @rank(%0: f32) {
   // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
   "tensor.rank"(%0): (f32)->index


        


More information about the Mlir-commits mailing list