[Mlir-commits] [mlir] 95f34e3 - [mlir][memref] Fix bug in verification of memref.collapse_shape

Stephan Herhut llvmlistbot at llvm.org
Mon Nov 29 06:47:54 PST 2021


Author: Stephan Herhut
Date: 2021-11-29T15:47:12+01:00
New Revision: 95f34e318c469806879a0cd1a6c5290901ed12df

URL: https://github.com/llvm/llvm-project/commit/95f34e318c469806879a0cd1a6c5290901ed12df
DIFF: https://github.com/llvm/llvm-project/commit/95f34e318c469806879a0cd1a6c5290901ed12df.diff

LOG: [mlir][memref] Fix bug in verification of memref.collapse_shape

The verifier computed an illegal type with negative dimension size when collapsing partially static memrefs.

Differential Revision: https://reviews.llvm.org/D114702

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f806cf51a9d0a..ac36d1dc11bbb 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1209,6 +1209,12 @@ static void print(OpAsmPrinter &p, CollapseShapeOp op) {
 static bool isReshapableDimBand(unsigned dim, unsigned extent,
                                 ArrayRef<int64_t> sizes,
                                 ArrayRef<AffineExpr> strides) {
+  // Bands of extent one can be reshaped, as they are not reshaped at all.
+  if (extent == 1)
+    return true;
+  // Otherwise, the size of the first dimension needs to be known.
+  if (ShapedType::isDynamic(sizes[dim]))
+    return false;
   assert(sizes.size() == strides.size() && "mismatched ranks");
   // off by 1 indexing to avoid out of bounds
   //                       V
@@ -1217,7 +1223,7 @@ static bool isReshapableDimBand(unsigned dim, unsigned extent,
     // there is no relation between dynamic sizes and dynamic strides: we do not
     // have enough information to know whether a "-1" size corresponds to the
     // proper symbol in the AffineExpr of a stride.
-    if (ShapedType::isDynamic(sizes[dim + 1]))
+    if (ShapedType::isDynamic(sizes[idx + 1]))
       return false;
     // TODO: Refine this by passing the proper nDims and nSymbols so we can
     // simplify on the fly and catch more reshapable cases.

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 9a5f3fee74ae0..620d93dd2ccce 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -5,6 +5,7 @@
 // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 // CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
 // CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)>
+// CHECK-DAG: #[[$strided2D42:.*]] = affine_map<(d0, d1) -> (d0 * 42 + d1)>
 
 // CHECK-LABEL: func @memref_reinterpret_cast
 func @memref_reinterpret_cast(%in: memref<?xf32>)
@@ -143,7 +144,8 @@ func @expand_collapse_shape_static(%arg0: memref<3x4x5xf32>,
 
 func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
-         %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
+         %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>,
+         %arg3: memref<?x42xf32, offset : 0, strides : [42, 1]>) {
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
     memref<?x?x?xf32> into memref<?x?xf32>
   %r0 = memref.expand_shape %0 [[0, 1], [2]] :
@@ -160,6 +162,12 @@ func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
   %r2 = memref.expand_shape %2 [[0, 1], [2]] :
     memref<?x?xf32, offset : ?, strides : [?, 1]> into
     memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
+  %3 = memref.collapse_shape %arg3 [[0, 1]] :
+    memref<?x42xf32, offset : 0, strides : [42, 1]> into
+    memref<?xf32, offset : 0, strides : [1]>
+  %r3 = memref.expand_shape %3 [[0, 1]] :
+    memref<?xf32, offset : 0, strides : [1]> into
+    memref<?x42xf32, offset : 0, strides : [42, 1]>
   return
 }
 // CHECK-LABEL: func @expand_collapse_shape_dynamic
@@ -175,6 +183,10 @@ func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
 //  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1]]
+//  CHECK-SAME:     memref<?x42xf32, #[[$strided2D42]]> into memref<?xf32>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1]]
+//  CHECK-SAME:     memref<?xf32> into memref<?x42xf32, #[[$strided2D42]]>
 
 func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
     -> (memref<f32>, memref<1x1xf32>) {


        


More information about the Mlir-commits mailing list