[Mlir-commits] [mlir] a97e357 - [MLIR] Support `global_memref` and `get_global_memref` in standard -> LLVM conversion.

Rahul Joshi llvmlistbot at llvm.org
Mon Nov 9 10:54:51 PST 2020


Author: Rahul Joshi
Date: 2020-11-09T10:54:21-08:00
New Revision: a97e357e8ed46e578cb34ec795ba4f9fdefee189

URL: https://github.com/llvm/llvm-project/commit/a97e357e8ed46e578cb34ec795ba4f9fdefee189
DIFF: https://github.com/llvm/llvm-project/commit/a97e357e8ed46e578cb34ec795ba4f9fdefee189.diff

LOG: [MLIR] Support `global_memref` and `get_global_memref` in standard -> LLVM conversion.

- Convert `global_memref` to LLVM::GlobalOp.
- Convert `get_global_memref` to a memref descriptor with a pointer to the first element
  of the global stashed in it.
- Extend unit test and a mlir-cpu-runner test to validate the generated LLVM IR.

Differential Revision: https://reviews.llvm.org/D90803

Added: 
    mlir/test/mlir-cpu-runner/global_memref.mlir

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3cd28bf919e8..4ca1bd62afc4 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1970,11 +1970,15 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
     return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
   }
 
+  /// Returns if buffer allocation needs buffer size to be computed. This size
+  /// feeds into the `bufferSize` argument of `allocateBuffer`.
+  virtual bool needsBufferSize() const { return true; }
+
   /// Allocates the underlying buffer. Returns the allocated pointer and the
   /// aligned pointer.
   virtual std::tuple<Value, Value>
   allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
-                 Value cumulativeSize, Operation *op) const = 0;
+                 Value bufferSize, Operation *op) const = 0;
 
 private:
   static MemRefType getMemRefResultType(Operation *op) {
@@ -2027,14 +2031,16 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
     SmallVector<Value, 4> sizes;
     this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes);
 
-    Value cumulativeSize = this->getCumulativeSizeInBytes(
-        loc, memRefType.getElementType(), sizes, rewriter);
+    Value bufferSize;
+    if (needsBufferSize())
+      bufferSize = this->getCumulativeSizeInBytes(
+          loc, memRefType.getElementType(), sizes, rewriter);
 
     // Allocate the underlying buffer.
     Value allocatedPtr;
     Value alignedPtr;
     std::tie(allocatedPtr, alignedPtr) =
-        this->allocateBuffer(rewriter, loc, cumulativeSize, op);
+        this->allocateBuffer(rewriter, loc, bufferSize, op);
 
     int64_t offset;
     SmallVector<int64_t, 4> strides;
@@ -2065,7 +2071,7 @@ struct AllocOpLowering : public AllocLikeOpLowering {
       : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
 
   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value cumulativeSize,
+                                          Location loc, Value bufferSize,
                                           Operation *op) const override {
     // Heap allocations.
     AllocOp allocOp = cast<AllocOp>(op);
@@ -2084,15 +2090,14 @@ struct AllocOpLowering : public AllocLikeOpLowering {
 
     if (alignment) {
       // Adjust the allocation size to consider alignment.
-      cumulativeSize =
-          rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignment);
+      bufferSize = rewriter.create<LLVM::AddOp>(loc, bufferSize, alignment);
     }
 
     // Allocate the underlying buffer and store a pointer to it in the MemRef
     // descriptor.
     Type elementPtrType = this->getElementPtrType(memRefType);
     Value allocatedPtr =
-        createAllocCall(loc, "malloc", elementPtrType, {cumulativeSize},
+        createAllocCall(loc, "malloc", elementPtrType, {bufferSize},
                         allocOp.getParentOfType<ModuleOp>(), rewriter);
 
     Value alignedPtr = allocatedPtr;
@@ -2159,7 +2164,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
   }
 
   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value cumulativeSize,
+                                          Location loc, Value bufferSize,
                                           Operation *op) const override {
     // Heap allocations.
     AllocOp allocOp = cast<AllocOp>(op);
@@ -2170,12 +2175,11 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
     // aligned_alloc requires size to be a multiple of alignment; we will pad
     // the size to the next multiple if necessary.
     if (!isMemRefSizeMultipleOf(memRefType, alignment))
-      cumulativeSize =
-          createAligned(rewriter, loc, cumulativeSize, allocAlignment);
+      bufferSize = createAligned(rewriter, loc, bufferSize, allocAlignment);
 
     Type elementPtrType = this->getElementPtrType(memRefType);
     Value allocatedPtr = createAllocCall(
-        loc, "aligned_alloc", elementPtrType, {allocAlignment, cumulativeSize},
+        loc, "aligned_alloc", elementPtrType, {allocAlignment, bufferSize},
         allocOp.getParentOfType<ModuleOp>(), rewriter);
 
     return std::make_tuple(allocatedPtr, allocatedPtr);
@@ -2196,7 +2200,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering {
   /// is set to null for stack allocations. `accessAlignment` is set if
   /// alignment is needed post allocation (for eg. in conjunction with malloc).
   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value cumulativeSize,
+                                          Location loc, Value bufferSize,
                                           Operation *op) const override {
 
     // With alloca, one gets a pointer to the element type right away.
@@ -2205,7 +2209,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering {
     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
 
     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
-        loc, elementPtrType, cumulativeSize,
+        loc, elementPtrType, bufferSize,
         allocaOp.alignment() ? *allocaOp.alignment() : 0);
 
     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
@@ -2420,6 +2424,109 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
   }
 };
 
+/// Returns the LLVM type of the global variable given the memref type `type`.
+static LLVM::LLVMType
+convertGlobalMemrefTypeToLLVM(MemRefType type,
+                              LLVMTypeConverter &typeConverter) {
+  // LLVM type for a global memref will be a multi-dimension array. For
+  // declarations or uninitialized global memrefs, we can potentially flatten
+  // this to a 1D array. However, for global_memref's with an initial value,
+  // we do not intend to flatten the ElementsAttribute when going from std ->
+  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
+  LLVM::LLVMType elementType =
+      unwrap(typeConverter.convertType(type.getElementType()));
+  LLVM::LLVMType arrayTy = elementType;
+  // Shape has the outermost dim at index 0, so need to walk it backwards
+  for (int64_t dim : llvm::reverse(type.getShape()))
+    arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim);
+  return arrayTy;
+}
+
+/// GlobalMemrefOp is lowered to a LLVM Global Variable.
+struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
+  using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto global = cast<GlobalMemrefOp>(op);
+    MemRefType type = global.type().cast<MemRefType>();
+    if (!isSupportedMemRefType(type))
+      return failure();
+
+    LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+
+    LLVM::Linkage linkage =
+        global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
+
+    Attribute initialValue = nullptr;
+    if (!global.isExternal() && !global.isUninitialized()) {
+      auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
+      initialValue = elementsAttr;
+
+      // For scalar memrefs, the global variable created is of the element type,
+      // so unpack the elements attribute to extract the value.
+      if (type.getRank() == 0)
+        initialValue = elementsAttr.getValue({});
+    }
+
+    rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
+        op, arrayTy, global.constant(), linkage, global.sym_name(),
+        initialValue, type.getMemorySpace());
+    return success();
+  }
+};
+
+/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
+/// the first element stashed into the descriptor. This reuses
+/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
+struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
+  GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
+      : AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {}
+
+  /// Allocation for GetGlobalMemrefOp just returns the GV pointer, so no need
+  /// to compute buffer size.
+  bool needsBufferSize() const override { return false; }
+
+  /// Buffer "allocation" for get_global_memref op is getting the address of
+  /// the global variable referenced.
+  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
+                                          Location loc, Value bufferSize,
+                                          Operation *op) const override {
+    auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
+    MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
+    unsigned memSpace = type.getMemorySpace();
+
+    LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+    auto addressOf = rewriter.create<LLVM::AddressOfOp>(
+        loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
+
+    // Get the address of the first element in the array by creating a GEP with
+    // the address of the GV as the base, and (rank + 1) number of 0 indices.
+    LLVM::LLVMType elementType =
+        unwrap(typeConverter.convertType(type.getElementType()));
+    LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
+
+    SmallVector<Value, 4> operands = {addressOf};
+    operands.insert(operands.end(), type.getRank() + 1,
+                    createIndexConstant(rewriter, loc, 0));
+    auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
+
+    // We do not expect the memref obtained using `get_global_memref` to be
+    // ever deallocated. Set the allocated pointer to be known bad value to
+    // help debug if that ever happens.
+    auto intPtrType = getIntPtrType(memSpace);
+    Value deadBeefConst =
+        createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
+    auto deadBeefPtr =
+        rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
+
+    // Both allocated and aligned pointers are same. We could potentially stash
+    // a nullptr for the allocated pointer since we do not expect any dealloc.
+    return {deadBeefPtr, gep};
+  }
+};
+
 // A `rsqrt` is converted into `1 / sqrt`.
 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
@@ -3941,6 +4048,8 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
       AssumeAlignmentOpLowering,
       DeallocOpLowering,
       DimOpLowering,
+      GlobalMemrefOpLowering,
+      GetGlobalMemrefOpLowering,
       LoadOpLowering,
       MemRefCastOpLowering,
       MemRefReinterpretCastOpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 71a35f6ccf0a..0c38d02e816a 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -131,3 +131,82 @@ func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
   %0 = transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
   return
 }
+
+// -----
+
+// CHECK: llvm.mlir.global external @gv0() : !llvm.array<2 x float>
+global_memref @gv0 : memref<2xf32> = uninitialized
+
+// CHECK: llvm.mlir.global private @gv1() : !llvm.array<2 x float>
+global_memref "private" @gv1 : memref<2xf32>
+
+// CHECK: llvm.mlir.global external @gv2(dense<{{\[\[}}0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]]> : tensor<2x3xf32>) : !llvm.array<2 x array<3 x float>>
+global_memref @gv2 : memref<2x3xf32> = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]>
+
+// Test 1D memref.
+// CHECK-LABEL: func @get_gv0_memref
+func @get_gv0_memref() {
+  %0 = get_global_memref @gv0 : memref<2xf32>
+  // CHECK: %[[DIM:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+  // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv0 : !llvm.ptr<array<2 x float>>
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]], %[[ZERO]]] : (!llvm.ptr<array<2 x float>>, !llvm.i64, !llvm.i64) -> !llvm.ptr<float>
+  // CHECK: %[[DEADBEEF:.*]] = llvm.mlir.constant(3735928559 : index) : !llvm.i64
+  // CHECK: %[[DEADBEEFPTR:.*]] = llvm.inttoptr %[[DEADBEEF]] : !llvm.i64 to !llvm.ptr<float>
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: llvm.insertvalue %[[DEADBEEFPTR]], {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[STRIDE:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %[[DIM]], {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: llvm.insertvalue %[[STRIDE]], {{.*}}[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+  return
+}
+
+// Test 2D memref.
+// CHECK-LABEL: func @get_gv2_memref
+func @get_gv2_memref() {
+  // CHECK: %[[DIM0:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+  // CHECK: %[[DIM1:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
+  // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv2 : !llvm.ptr<array<2 x array<3 x float>>>
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]], %[[ZERO]], %[[ZERO]]] : (!llvm.ptr<array<2 x array<3 x float>>>, !llvm.i64, !llvm.i64, !llvm.i64) -> !llvm.ptr<float>
+  // CHECK: %[[DEADBEEF:.*]] = llvm.mlir.constant(3735928559 : index) : !llvm.i64
+  // CHECK: %[[DEADBEEFPTR:.*]] = llvm.inttoptr %[[DEADBEEF]] : !llvm.i64 to !llvm.ptr<float>
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.insertvalue %[[DEADBEEFPTR]], {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[STRIDE1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+  // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %[[DIM0]], {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.insertvalue %[[STRIDE0]], {{.*}}[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.insertvalue %[[DIM1]], {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.insertvalue %[[STRIDE1]], {{.*}}[4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+
+  %0 = get_global_memref @gv2 : memref<2x3xf32>
+  return
+}
+
+// Test scalar memref.
+// CHECK: llvm.mlir.global external @gv3(1.000000e+00 : f32) : !llvm.float
+global_memref @gv3 : memref<f32> = dense<1.0>
+
+// CHECK-LABEL: func @get_gv3_memref
+func @get_gv3_memref() {
+  // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv3 : !llvm.ptr<float>
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
+  // CHECK: %[[DEADBEEF:.*]] = llvm.mlir.constant(3735928559 : index) : !llvm.i64
+  // CHECK: %[[DEADBEEFPTR:.*]] = llvm.inttoptr %[[DEADBEEF]] : !llvm.i64 to !llvm.ptr<float>
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+  // CHECK: llvm.insertvalue %[[DEADBEEFPTR]], {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+  // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+  // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+  %0 = get_global_memref @gv3 : memref<f32>
+  return
+}
+

diff  --git a/mlir/test/mlir-cpu-runner/global_memref.mlir b/mlir/test/mlir-cpu-runner/global_memref.mlir
new file mode 100644
index 000000000000..1c9cf4a33aeb
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/global_memref.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
+func @print_memref_i32(memref<*xi32>) attributes { llvm.emit_c_interface }
+func @printNewline() -> ()
+
+global_memref "private" @gv0 : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]>
+func @test1DMemref() {
+  %0 = get_global_memref @gv0 : memref<4xf32>
+  %U = memref_cast %0 : memref<4xf32> to memref<*xf32>
+  // CHECK: rank = 1
+  // CHECK: offset = 0
+  // CHECK: sizes = [4]
+  // CHECK: strides = [1]
+  // CHECK: [0,  1,  2,  3]
+  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  call @printNewline() : () -> ()
+
+  // Overwrite some of the elements.
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %fp0 = constant 4.0 : f32
+  %fp1 = constant 5.0 : f32
+  store %fp0, %0[%c0] : memref<4xf32>
+  store %fp1, %0[%c2] : memref<4xf32>
+  // CHECK: rank = 1
+  // CHECK: offset = 0
+  // CHECK: sizes = [4]
+  // CHECK: strides = [1]
+  // CHECK: [4,  1,  5,  3]
+  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  call @printNewline() : () -> ()
+  return
+}
+
+global_memref constant @gv1 : memref<3x2xi32> = dense<[[0, 1],[2, 3],[4, 5]]>
+func @testConstantMemref() {
+  %0 = get_global_memref @gv1 : memref<3x2xi32>
+  %U = memref_cast %0 : memref<3x2xi32> to memref<*xi32>
+  // CHECK: rank = 2
+  // CHECK: offset = 0
+  // CHECK: sizes = [3, 2]
+  // CHECK: strides = [2, 1]
+  // CHECK: [0,   1]
+  // CHECK: [2,   3]
+  // CHECK: [4,   5]
+  call @print_memref_i32(%U) : (memref<*xi32>) -> ()
+  call @printNewline() : () -> ()
+  return
+}
+
+global_memref "private" @gv2 : memref<4x2xf32> = dense<[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]>
+func @test2DMemref() {
+  %0 = get_global_memref @gv2 : memref<4x2xf32>
+  %U = memref_cast %0 : memref<4x2xf32> to memref<*xf32>
+  // CHECK: rank = 2
+  // CHECK: offset = 0
+  // CHECK: sizes = [4, 2]
+  // CHECK: strides = [2, 1]
+  // CHECK: [0,   1]
+  // CHECK: [2,   3]
+  // CHECK: [4,   5]
+  // CHECK: [6,   7]
+  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  call @printNewline() : () -> ()
+
+  // Overwrite the 1.0 (at index [0, 1]) with 10.0
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %fp10 = constant 10.0 : f32
+  store %fp10, %0[%c0, %c1] : memref<4x2xf32>
+  // CHECK: rank = 2
+  // CHECK: offset = 0
+  // CHECK: sizes = [4, 2]
+  // CHECK: strides = [2, 1]
+  // CHECK: [0,   10]
+  // CHECK: [2,   3]
+  // CHECK: [4,   5]
+  // CHECK: [6,   7]
+  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  call @printNewline() : () -> ()
+  return
+}
+
+global_memref @gv3 : memref<i32> = dense<11>
+func @testScalarMemref() {
+  %0 = get_global_memref @gv3 : memref<i32>
+  %U = memref_cast %0 : memref<i32> to memref<*xi32>
+  // CHECK: rank = 0
+  // CHECK: offset = 0
+  // CHECK: sizes = []
+  // CHECK: strides = []
+  // CHECK: [11]
+  call @print_memref_i32(%U) : (memref<*xi32>) -> ()
+  call @printNewline() : () -> ()
+  return
+}
+
+func @main() -> () {
+  call @test1DMemref() : () -> ()
+  call @testConstantMemref() : () -> ()
+  call @test2DMemref() : () -> ()
+  call @testScalarMemref() : () -> ()
+  return
+}
+
+


        


More information about the Mlir-commits mailing list