[Mlir-commits] [mlir] 5fee179 - [mlir] translate memref.reshape with static shapes but dynamic dims

Ashay Rane llvmlistbot at llvm.org
Thu Jun 2 10:01:03 PDT 2022


Author: Ashay Rane
Date: 2022-06-02T10:00:58-07:00
New Revision: 5fee1799f4d8da59c251e2d04172fc2f387cbe54

URL: https://github.com/llvm/llvm-project/commit/5fee1799f4d8da59c251e2d04172fc2f387cbe54
DIFF: https://github.com/llvm/llvm-project/commit/5fee1799f4d8da59c251e2d04172fc2f387cbe54.diff

LOG: [mlir] translate memref.reshape with static shapes but dynamic dims

Prior to this patch, the lowering of memref.reshape operations to the
LLVM dialect failed if the shape argument had a static shape with
dynamic dimensions.  This patch adds the necessary support so that when
the shape argument has dynamic values, the lowering probes the dimension
at runtime to set the size in the `MemRefDescriptor` type.  This patch
also computes the stride for dynamic dimensions by deriving it from the
sizes of the inner dimensions.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D126604

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 3dfd591b0fca8..97b631afc33cd 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1128,14 +1128,43 @@ struct MemRefReshapeOpLowering
       if (!isStaticStrideOrOffset(offset))
         return rewriter.notifyMatchFailure(reshapeOp,
                                            "dynamic offset is unsupported");
-      if (!llvm::all_of(strides, isStaticStrideOrOffset))
-        return rewriter.notifyMatchFailure(reshapeOp,
-                                           "dynamic strides are unsupported");
 
       desc.setConstantOffset(rewriter, loc, offset);
-      for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
-        desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i));
-        desc.setConstantStride(rewriter, loc, i, strides[i]);
+
+      assert(targetMemRefType.getLayout().isIdentity() &&
+             "Identity layout map is a precondition of a valid reshape op");
+
+      Value stride = nullptr;
+      int64_t targetRank = targetMemRefType.getRank();
+      for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
+        if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
+          // If the stride for this dimension is dynamic, then use the product
+          // of the sizes of the inner dimensions.
+          stride = createIndexConstant(rewriter, loc, strides[i]);
+        } else if (!stride) {
+          // `stride` is null only in the first iteration of the loop.  However,
+          // since the target memref has an identity layout, we can safely set
+          // the innermost stride to 1.
+          stride = createIndexConstant(rewriter, loc, 1);
+        }
+
+        Value dimSize;
+        int64_t size = targetMemRefType.getDimSize(i);
+        // If the size of this dimension is dynamic, then load it at runtime
+        // from the shape operand.
+        if (!ShapedType::isDynamic(size)) {
+          dimSize = createIndexConstant(rewriter, loc, size);
+        } else {
+          Value shapeOp = reshapeOp.shape();
+          Value index = createIndexConstant(rewriter, loc, i);
+          dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
+        }
+
+        desc.setSize(rewriter, loc, i, dimSize);
+        desc.setStride(rewriter, loc, i, stride);
+
+        // Prepare the stride value for the next dimension.
+        stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
       }
 
       *descriptor = desc;

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index b8f3717682a05..1296f8e4881cd 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -232,23 +232,77 @@ func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> {
   // CHECK: %[[elem1:.*]] = llvm.extractvalue %[[cast0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
   // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
   // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[elem1]], %[[insert0:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+
   // CHECK: %[[zero:.*]] = llvm.mlir.constant(0 : index) : i64
   // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero]], %[[insert1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-  // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64
-  // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[two]], %[[insert2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-  // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64
-  // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-  // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64
-  // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+
+  // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
   // CHECK: %[[twenty0:.*]] = llvm.mlir.constant(20 : index) : i64
-  // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty0]], %[[insert5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[twenty0]], %[[insert2]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one]], %[[insert3]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+
   // CHECK: %[[twenty1:.*]] = llvm.mlir.constant(20 : index) : i64
-  // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[twenty1]], %[[insert6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-  // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
-  // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[one]], %[[insert7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64
+  // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty1]], %[[insert5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+
+  // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64
+  // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64
+  // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[two]], %[[insert6]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert7]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+
   // CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[insert8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32>
   %1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32>
 
   // CHECK: return %[[cast1]] : memref<2x6x20xf32>
   return %1 : memref<2x6x20xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @memref.reshape.dynamic.dim
+// CHECK-SAME:    %[[arg:.*]]: memref<?x?x?xf32>, %[[shape:.*]]: memref<4xi64>) -> memref<?x?x12x32xf32>
+func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4xi64>) -> memref<?x?x12x32xf32> {
+  // CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?x?xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<4xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+  // CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+  // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+
+  // CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+
+  // CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[thirty_two0:.*]] = llvm.mlir.constant(32 : index) : i64
+  // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[thirty_two0]], %[[insert2]][3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+  // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0]], %[[insert3]][4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+
+  // CHECK: %[[thirty_two1:.*]] = llvm.mlir.constant(32 : index) : i64
+  // CHECK: %[[twelve:.*]] = llvm.mlir.constant(12 : index) : i64
+  // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[twelve]], %[[insert4]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+  // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[thirty_two1]], %[[insert5]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+
+  // CHECK: %[[three_hundred_and_eighty_four:.*]] = llvm.mlir.constant(384 : index) : i64
+  // CHECK: %[[one1:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
+  // CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0]] : !llvm.ptr<i64>
+  // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[shape_load0]], %[[insert6]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+  // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[three_hundred_and_eighty_four]], %[[insert7]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+
+  // CHECK: %[[mul:.*]] = llvm.mul %19, %23  : i64
+  // CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[shape_ptr1:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[shape_gep1:.*]] = llvm.getelementptr %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
+  // CHECK: %[[shape_load1:.*]] = llvm.load %[[shape_gep1]] : !llvm.ptr<i64>
+  // CHECK: %[[insert9:.*]] = llvm.insertvalue %[[shape_load1]], %[[insert8]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+  // CHECK: %[[insert10:.*]] = llvm.insertvalue %[[mul]], %[[insert9]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
+
+  // CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert10]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)> to memref<?x?x12x32xf32>
+  %0 = memref.reshape %arg(%shape) : (memref<?x?x?xf32>, memref<4xi64>) -> memref<?x?x12x32xf32>
+
+  return %0 : memref<?x?x12x32xf32>
+  // CHECK: return %[[result_cast]] : memref<?x?x12x32xf32>
+}


        


More information about the Mlir-commits mailing list