[Mlir-commits] [mlir] b20e150 - [mlir] Use static shape knowledge when lowering memref.reshape

Benjamin Kramer llvmlistbot at llvm.org
Tue May 11 09:21:15 PDT 2021


Author: Benjamin Kramer
Date: 2021-05-11T18:21:09+02:00
New Revision: b20e150c9be16f69c73f4cd2986053d13d0f376a

URL: https://github.com/llvm/llvm-project/commit/b20e150c9be16f69c73f4cd2986053d13d0f376a
DIFF: https://github.com/llvm/llvm-project/commit/b20e150c9be16f69c73f4cd2986053d13d0f376a.diff

LOG: [mlir] Use static shape knowledge when lowering memref.reshape

This is actually necessary for correctness, as memref.reinterpret_cast
doesn't verify if the output shape doesn't match the static sizes.

Differential Revision: https://reviews.llvm.org/D102232

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Standard/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index fd1b36907ae1e..d4810ac40f31b 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -91,11 +91,19 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
     Location loc = op.getLoc();
     Value stride = rewriter.create<ConstantIndexOp>(loc, 1);
     for (int i = rank - 1; i >= 0; --i) {
-      Value index = rewriter.create<ConstantIndexOp>(loc, i);
-      Value size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
-      if (!size.getType().isa<IndexType>())
-        size = rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
-      sizes[i] = size;
+      Value size;
+      // Load dynamic sizes from the shape input, use constants for static dims.
+      if (op.getType().isDynamicDim(i)) {
+        Value index = rewriter.create<ConstantIndexOp>(loc, i);
+        size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
+        if (!size.getType().isa<IndexType>())
+          size =
+              rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
+        sizes[i] = size;
+      } else {
+        sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
+        size = rewriter.create<ConstantOp>(loc, sizes[i].get<Attribute>());
+      }
       strides[i] = stride;
       if (i > 0)
         stride = rewriter.create<MulIOp>(loc, stride, size);

diff  --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir
index f91f9b8ff373b..e133df36684fb 100644
--- a/mlir/test/Dialect/Standard/expand-ops.mlir
+++ b/mlir/test/Dialect/Standard/expand-ops.mlir
@@ -84,19 +84,17 @@ func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
 
 // CHECK-LABEL: func @memref_reshape(
 func @memref_reshape(%input: memref<*xf32>,
-                     %shape: memref<3xi32>) -> memref<?x?x?xf32> {
+                     %shape: memref<3xi32>) -> memref<?x?x8xf32> {
   %result = memref.reshape %input(%shape)
-               : (memref<*xf32>, memref<3xi32>) -> memref<?x?x?xf32>
-  return %result : memref<?x?x?xf32>
+               : (memref<*xf32>, memref<3xi32>) -> memref<?x?x8xf32>
+  return %result : memref<?x?x8xf32>
 }
 // CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
-// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x?xf32> {
+// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {
 
 // CHECK: [[C1:%.*]] = constant 1 : index
-// CHECK: [[C2:%.*]] = constant 2 : index
-// CHECK: [[DIM_2:%.*]] = memref.load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32>
-// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index
-// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[SIZE_2]] : index
+// CHECK: [[C8:%.*]] = constant 8 : index
+// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[C8]] : index
 
 // CHECK: [[C1_:%.*]] = constant 1 : index
 // CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
@@ -108,6 +106,6 @@ func @memref_reshape(%input: memref<*xf32>,
 // CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index
 
 // CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
-// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]],
+// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
 // CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
-// CHECK-SAME: : memref<*xf32> to memref<?x?x?xf32>
+// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>


        


More information about the Mlir-commits mailing list