[Mlir-commits] [mlir] 5380e30 - [mlir] translate memref.reshape ops that have static shapes

Ashay Rane llvmlistbot at llvm.org
Thu May 12 11:57:41 PDT 2022


Author: Ashay Rane
Date: 2022-05-12T11:57:20-07:00
New Revision: 5380e30e047bbac9b2cceb69162eb8db1e1a7abf

URL: https://github.com/llvm/llvm-project/commit/5380e30e047bbac9b2cceb69162eb8db1e1a7abf
DIFF: https://github.com/llvm/llvm-project/commit/5380e30e047bbac9b2cceb69162eb8db1e1a7abf.diff

LOG: [mlir] translate memref.reshape ops that have static shapes

This patch references code for translating memref.reinterpret_cast ops
to add translation rules for memref.reshape ops that have a static shape
argument.  Since reshape ops don't have offsets, sizes, or strides, this
patch simply sets the allocated and aligned pointers of the MemRef
descriptor.

Reviewed By: ftynse, cathyzhyi

Differential Revision: https://reviews.llvm.org/D125039

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 110b40adf777f..579f2da92fa41 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -26,6 +26,10 @@ using namespace mlir;
 
 namespace {
 
+bool isStaticStrideOrOffset(int64_t strideOrOffset) {
+  return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
+}
+
 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
   AllocOpLowering(LLVMTypeConverter &converter)
       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
@@ -1091,11 +1095,52 @@ struct MemRefReshapeOpLowering
                                   Type srcType, memref::ReshapeOp reshapeOp,
                                   memref::ReshapeOp::Adaptor adaptor,
                                   Value *descriptor) const {
-    // Conversion for statically-known shape args is performed via
-    // `memref_reinterpret_cast`.
     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
-    if (shapeMemRefType.hasStaticShape())
-      return failure();
+    if (shapeMemRefType.hasStaticShape()) {
+      MemRefType targetMemRefType =
+          reshapeOp.getResult().getType().cast<MemRefType>();
+      auto llvmTargetDescriptorTy =
+          typeConverter->convertType(targetMemRefType)
+              .dyn_cast_or_null<LLVM::LLVMStructType>();
+      if (!llvmTargetDescriptorTy)
+        return failure();
+
+      // Create descriptor.
+      Location loc = reshapeOp.getLoc();
+      auto desc =
+          MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+
+      // Set allocated and aligned pointers.
+      Value allocatedPtr, alignedPtr;
+      extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
+                               reshapeOp.source(), adaptor.source(),
+                               &allocatedPtr, &alignedPtr);
+      desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
+      desc.setAlignedPtr(rewriter, loc, alignedPtr);
+
+      // Extract the offset and strides from the type.
+      int64_t offset;
+      SmallVector<int64_t> strides;
+      if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
+        return rewriter.notifyMatchFailure(
+            reshapeOp, "failed to get stride and offset exprs");
+
+      if (!isStaticStrideOrOffset(offset))
+        return rewriter.notifyMatchFailure(reshapeOp,
+                                           "dynamic offset is unsupported");
+      if (!llvm::all_of(strides, isStaticStrideOrOffset))
+        return rewriter.notifyMatchFailure(reshapeOp,
+                                           "dynamic strides are unsupported");
+
+      desc.setConstantOffset(rewriter, loc, offset);
+      for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
+        desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i));
+        desc.setConstantStride(rewriter, loc, i, strides[i]);
+      }
+
+      *descriptor = desc;
+      return success();
+    }
 
     // The shape is a rank-1 tensor with unknown length.
     Location loc = reshapeOp.getLoc();
@@ -1499,10 +1544,7 @@ class ReassociatingReshapeOpConversion
     for (auto &en : llvm::enumerate(dstShape))
       dstDesc.setSize(rewriter, loc, en.index(), en.value());
 
-    auto isStaticStride = [](int64_t stride) {
-      return !ShapedType::isDynamicStrideOrOffset(stride);
-    };
-    if (llvm::all_of(strides, isStaticStride)) {
+    if (llvm::all_of(strides, isStaticStrideOrOffset)) {
       for (auto &en : llvm::enumerate(strides))
         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
     } else if (srcType.getLayout().isIdentity() &&

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 93f1449002902..b8f3717682a05 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -217,3 +217,38 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
   }
 }
 
+// -----
+
+memref.global "private" constant @__constant_3xi64 : memref<3xi64> = dense<[2, 6, 20]>
+
+// CHECK-LABEL: func @memref.reshape
+// CHECK-SAME:    %[[arg0:.*]]: memref<4x5x6xf32>) -> memref<2x6x20xf32>
+func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> {
+  // CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<4x5x6xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  %0 = memref.get_global @__constant_3xi64 : memref<3xi64>
+
+  // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[elem0:.*]] = llvm.extractvalue %[[cast0]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[elem1:.*]] = llvm.extractvalue %[[cast0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[elem1]], %[[insert0:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[zero:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero]], %[[insert1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64
+  // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[two]], %[[insert2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64
+  // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64
+  // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[twenty0:.*]] = llvm.mlir.constant(20 : index) : i64
+  // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty0]], %[[insert5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[twenty1:.*]] = llvm.mlir.constant(20 : index) : i64
+  // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[twenty1]], %[[insert6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[one]], %[[insert7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[insert8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32>
+  %1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32>
+
+  // CHECK: return %[[cast1]] : memref<2x6x20xf32>
+  return %1 : memref<2x6x20xf32>
+}


        


More information about the Mlir-commits mailing list