[Mlir-commits] [mlir] d6ab047 - [mlir] Convert MemRefReinterpretCastOp to LLVM.

Alexander Belyaev llvmlistbot at llvm.org
Mon Oct 26 12:13:49 PDT 2020


Author: Alexander Belyaev
Date: 2020-10-26T20:13:17+01:00
New Revision: d6ab0474c6efc5a614a28ed21070f11d587467f8

URL: https://github.com/llvm/llvm-project/commit/d6ab0474c6efc5a614a28ed21070f11d587467f8
DIFF: https://github.com/llvm/llvm-project/commit/d6ab0474c6efc5a614a28ed21070f11d587467f8.diff

LOG: [mlir] Convert MemRefReinterpretCastOp to LLVM.

https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15

Differential Revision: https://reviews.llvm.org/D90033

Added: 
    mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3fe60f5e88d4..5a6c50f2a549 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2416,6 +2416,114 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
   }
 };
 
+struct MemRefReinterpretCastOpLowering
+    : public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
+  using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto castOp = cast<MemRefReinterpretCastOp>(op);
+    MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary());
+    Type srcType = castOp.source().getType();
+
+    Value descriptor;
+    if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
+                                               adaptor, &descriptor)))
+      return failure();
+    rewriter.replaceOp(op, {descriptor});
+    return success();
+  }
+
+private:
+  LogicalResult
+  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
+                                  Type srcType, MemRefReinterpretCastOp castOp,
+                                  MemRefReinterpretCastOp::Adaptor adaptor,
+                                  Value *descriptor) const {
+    MemRefType targetMemRefType =
+        castOp.getResult().getType().cast<MemRefType>();
+    auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+                                      .dyn_cast_or_null<LLVM::LLVMType>();
+    if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+      return failure();
+
+    // Create descriptor.
+    Location loc = castOp.getLoc();
+    auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+
+    // Set allocated and aligned pointers.
+    Value allocatedPtr, alignedPtr;
+    extractPointers(loc, rewriter, castOp.source(), adaptor.source(),
+                    &allocatedPtr, &alignedPtr);
+    desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
+    desc.setAlignedPtr(rewriter, loc, alignedPtr);
+
+    // Set offset.
+    if (castOp.isDynamicOffset(0))
+      desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
+    else
+      desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
+
+    // Set sizes and strides.
+    unsigned dynSizeId = 0;
+    unsigned dynStrideId = 0;
+    for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
+      if (castOp.isDynamicSize(i))
+        desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
+      else
+        desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
+
+      if (castOp.isDynamicStride(i))
+        desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
+      else
+        desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
+    }
+    *descriptor = desc;
+    return success();
+  }
+
+  void extractPointers(Location loc, ConversionPatternRewriter &rewriter,
+                       Value originalOperand, Value convertedOperand,
+                       Value *allocatedPtr, Value *alignedPtr) const {
+    Type operandType = originalOperand.getType();
+    if (operandType.isa<MemRefType>()) {
+      MemRefDescriptor desc(convertedOperand);
+      *allocatedPtr = desc.allocatedPtr(rewriter, loc);
+      *alignedPtr = desc.alignedPtr(rewriter, loc);
+      return;
+    }
+
+    unsigned memorySpace =
+        operandType.cast<UnrankedMemRefType>().getMemorySpace();
+    Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
+    LLVM::LLVMType llvmElementType =
+        typeConverter.convertType(elementType).cast<LLVM::LLVMType>();
+    LLVM::LLVMType elementPtrPtrType =
+        llvmElementType.getPointerTo(memorySpace).getPointerTo();
+
+    // Extract pointer to the underlying ranked memref descriptor and cast it to
+    // ElemType**.
+    UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
+    Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
+    Value elementPtrPtr = rewriter.create<LLVM::BitcastOp>(
+        loc, elementPtrPtrType, underlyingDescPtr);
+
+    LLVM::LLVMType int32Type =
+        typeConverter.convertType(rewriter.getI32Type()).cast<LLVM::LLVMType>();
+
+    // Extract and set allocated pointer.
+    *allocatedPtr = rewriter.create<LLVM::LoadOp>(loc, elementPtrPtr);
+
+    // Extract and set aligned pointer.
+    Value one = rewriter.create<LLVM::ConstantOp>(
+        loc, int32Type, rewriter.getI32IntegerAttr(1));
+    Value alignedGep = rewriter.create<LLVM::GEPOp>(
+        loc, elementPtrPtrType, elementPtrPtr, ValueRange({one}));
+    *alignedPtr = rewriter.create<LLVM::LoadOp>(loc, alignedGep);
+  }
+};
+
 struct DialectCastOpLowering
     : public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
   using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
@@ -3532,6 +3640,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
       DimOpLowering,
       LoadOpLowering,
       MemRefCastOpLowering,
+      MemRefReinterpretCastOpLowering,
       RankOpLowering,
       StoreOpLowering,
       SubViewOpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 8e7b22574432..8447474484e2 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -432,3 +432,60 @@ func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index {
   %result = dim %arg, %idx : memref<3x?xf32>
   return %result : index
 }
+
+// CHECK-LABEL: @memref_reinterpret_cast_ranked_to_static_shape
+func @memref_reinterpret_cast_ranked_to_static_shape(%input : memref<2x3xf32>) {
+  %output = memref_reinterpret_cast %input to
+           offset: [0], sizes: [6, 1], strides: [1, 1]
+           : memref<2x3xf32> to memref<6x1xf32>
+  return
+}
+// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[TY:!.*]]
+// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY]]
+// CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[TY]]
+// CHECK: [[ALIGNED_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[TY]]
+// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
+// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
+// CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
+// CHECK: [[SIZE_0:%.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
+// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
+// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_4]][4, 0] : [[TY]]
+// CHECK: [[STRIDE_0:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_5]][3, 1] : [[TY]]
+// CHECK: [[STRIDE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
+
+// CHECK-LABEL: @memref_reinterpret_cast_unranked_to_dynamic_shape
+func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
+                                                        %size_0 : index,
+                                                        %size_1 : index,
+                                                        %stride_0 : index,
+                                                        %stride_1 : index,
+                                                        %input : memref<*xf32>) {
+  %output = memref_reinterpret_cast %input to
+           offset: [%offset], sizes: [%size_0, %size_1],
+           strides: [%stride_0, %stride_1]
+           : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+  return
+}
+// CHECK-SAME: ([[OFFSET:%[a-z,0-9]+]]: !llvm.i64,
+// CHECK-SAME: [[SIZE_0:%[a-z,0-9]+]]: !llvm.i64, [[SIZE_1:%[a-z,0-9]+]]: !llvm.i64,
+// CHECK-SAME: [[STRIDE_0:%[a-z,0-9]+]]: !llvm.i64, [[STRIDE_1:%[a-z,0-9]+]]: !llvm.i64,
+// CHECK: [[INPUT:%.*]] = llvm.insertvalue {{.*}}[1] : !llvm.struct<(i64, ptr<i8>)>
+// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]]
+// CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr<i8>)>
+// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]]
+// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i32) -> !llvm.ptr<ptr<float>>
+// CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
+// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
+// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
+// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
+// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]]
+// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]]
+// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]

diff  --git a/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir
new file mode 100644
index 000000000000..4f933a7784cb
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
+
+func @main() -> () {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+
+  // Initialize input.
+  %input = alloc() : memref<2x3xf32>
+  %dim_x = dim %input, %c0 : memref<2x3xf32>
+  %dim_y = dim %input, %c1 : memref<2x3xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
+    %prod = muli %i,  %dim_y : index
+    %val = addi %prod, %j : index
+    %val_i64 = index_cast %val : index to i64
+    %val_f32 = sitofp %val_i64 : i64 to f32
+    store %val_f32, %input[%i, %j] : memref<2x3xf32>
+  }
+  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+  call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
+  // CHECK-NEXT: [0,   1,   2]
+  // CHECK-NEXT: [3,   4,   5]
+
+  // Test cases.
+  call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
+  call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
+  call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
+  call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
+  return
+}
+
+func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) {
+  %output = memref_reinterpret_cast %input to
+           offset: [0], sizes: [6, 1], strides: [1, 1]
+           : memref<2x3xf32> to memref<6x1xf32>
+
+  %unranked_output = memref_cast %output
+      : memref<6x1xf32> to memref<*xf32>
+  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
+  // CHECK-NEXT: [0],
+  // CHECK-NEXT: [1],
+  // CHECK-NEXT: [2],
+  // CHECK-NEXT: [3],
+  // CHECK-NEXT: [4],
+  // CHECK-NEXT: [5]
+  return
+}
+
+func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c6 = constant 6 : index
+  %output = memref_reinterpret_cast %input to
+           offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
+           : memref<2x3xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+
+  %unranked_output = memref_cast %output
+      : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
+  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
+  // CHECK-NEXT: [0,   1,   2,   3,   4,   5]
+  return
+}
+
+func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) {
+  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+  %output = memref_reinterpret_cast %unranked_input to
+           offset: [0], sizes: [6, 1], strides: [1, 1]
+           : memref<*xf32> to memref<6x1xf32>
+
+  %unranked_output = memref_cast %output
+      : memref<6x1xf32> to memref<*xf32>
+  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
+  // CHECK-NEXT: [0],
+  // CHECK-NEXT: [1],
+  // CHECK-NEXT: [2],
+  // CHECK-NEXT: [3],
+  // CHECK-NEXT: [4],
+  // CHECK-NEXT: [5]
+  return
+}
+
+func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
+  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c6 = constant 6 : index
+  %output = memref_reinterpret_cast %unranked_input to
+           offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
+           : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+
+  %unranked_output = memref_cast %output
+      : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
+  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
+  // CHECK-NEXT: [0,   1,   2,   3,   4,   5]
+  return
+}


        


More information about the Mlir-commits mailing list