[Mlir-commits] [mlir] a97e357 - [MLIR] Support `global_memref` and `get_global_memref` in standard -> LLVM conversion.
Rahul Joshi
llvmlistbot at llvm.org
Mon Nov 9 10:54:51 PST 2020
Author: Rahul Joshi
Date: 2020-11-09T10:54:21-08:00
New Revision: a97e357e8ed46e578cb34ec795ba4f9fdefee189
URL: https://github.com/llvm/llvm-project/commit/a97e357e8ed46e578cb34ec795ba4f9fdefee189
DIFF: https://github.com/llvm/llvm-project/commit/a97e357e8ed46e578cb34ec795ba4f9fdefee189.diff
LOG: [MLIR] Support `global_memref` and `get_global_memref` in standard -> LLVM conversion.
- Convert `global_memref` to LLVM::GlobalOp.
- Convert `get_global_memref` to a memref descriptor with a pointer to the first element
of the global stashed in it.
- Extend unit test and a mlir-cpu-runner test to validate the generated LLVM IR.
Differential Revision: https://reviews.llvm.org/D90803
Added:
mlir/test/mlir-cpu-runner/global_memref.mlir
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3cd28bf919e8..4ca1bd62afc4 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1970,11 +1970,15 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
}
+ /// Returns if buffer allocation needs buffer size to be computed. This size
+ /// feeds into the `bufferSize` argument of `allocateBuffer`.
+ virtual bool needsBufferSize() const { return true; }
+
/// Allocates the underlying buffer. Returns the allocated pointer and the
/// aligned pointer.
virtual std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
- Value cumulativeSize, Operation *op) const = 0;
+ Value bufferSize, Operation *op) const = 0;
private:
static MemRefType getMemRefResultType(Operation *op) {
@@ -2027,14 +2031,16 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
SmallVector<Value, 4> sizes;
this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes);
- Value cumulativeSize = this->getCumulativeSizeInBytes(
- loc, memRefType.getElementType(), sizes, rewriter);
+ Value bufferSize;
+ if (needsBufferSize())
+ bufferSize = this->getCumulativeSizeInBytes(
+ loc, memRefType.getElementType(), sizes, rewriter);
// Allocate the underlying buffer.
Value allocatedPtr;
Value alignedPtr;
std::tie(allocatedPtr, alignedPtr) =
- this->allocateBuffer(rewriter, loc, cumulativeSize, op);
+ this->allocateBuffer(rewriter, loc, bufferSize, op);
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -2065,7 +2071,7 @@ struct AllocOpLowering : public AllocLikeOpLowering {
: AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value cumulativeSize,
+ Location loc, Value bufferSize,
Operation *op) const override {
// Heap allocations.
AllocOp allocOp = cast<AllocOp>(op);
@@ -2084,15 +2090,14 @@ struct AllocOpLowering : public AllocLikeOpLowering {
if (alignment) {
// Adjust the allocation size to consider alignment.
- cumulativeSize =
- rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignment);
+ bufferSize = rewriter.create<LLVM::AddOp>(loc, bufferSize, alignment);
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
Type elementPtrType = this->getElementPtrType(memRefType);
Value allocatedPtr =
- createAllocCall(loc, "malloc", elementPtrType, {cumulativeSize},
+ createAllocCall(loc, "malloc", elementPtrType, {bufferSize},
allocOp.getParentOfType<ModuleOp>(), rewriter);
Value alignedPtr = allocatedPtr;
@@ -2159,7 +2164,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value cumulativeSize,
+ Location loc, Value bufferSize,
Operation *op) const override {
// Heap allocations.
AllocOp allocOp = cast<AllocOp>(op);
@@ -2170,12 +2175,11 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
// aligned_alloc requires size to be a multiple of alignment; we will pad
// the size to the next multiple if necessary.
if (!isMemRefSizeMultipleOf(memRefType, alignment))
- cumulativeSize =
- createAligned(rewriter, loc, cumulativeSize, allocAlignment);
+ bufferSize = createAligned(rewriter, loc, bufferSize, allocAlignment);
Type elementPtrType = this->getElementPtrType(memRefType);
Value allocatedPtr = createAllocCall(
- loc, "aligned_alloc", elementPtrType, {allocAlignment, cumulativeSize},
+ loc, "aligned_alloc", elementPtrType, {allocAlignment, bufferSize},
allocOp.getParentOfType<ModuleOp>(), rewriter);
return std::make_tuple(allocatedPtr, allocatedPtr);
@@ -2196,7 +2200,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering {
/// is set to null for stack allocations. `accessAlignment` is set if
/// alignment is needed post allocation (for eg. in conjunction with malloc).
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value cumulativeSize,
+ Location loc, Value bufferSize,
Operation *op) const override {
// With alloca, one gets a pointer to the element type right away.
@@ -2205,7 +2209,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering {
auto elementPtrType = this->getElementPtrType(allocaOp.getType());
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
- loc, elementPtrType, cumulativeSize,
+ loc, elementPtrType, bufferSize,
allocaOp.alignment() ? *allocaOp.alignment() : 0);
return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
@@ -2420,6 +2424,109 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
}
};
+/// Returns the LLVM type of the global variable given the memref type `type`.
+static LLVM::LLVMType
+convertGlobalMemrefTypeToLLVM(MemRefType type,
+ LLVMTypeConverter &typeConverter) {
+ // LLVM type for a global memref will be a multi-dimension array. For
+ // declarations or uninitialized global memrefs, we can potentially flatten
+ // this to a 1D array. However, for global_memref's with an initial value,
+ // we do not intend to flatten the ElementsAttribute when going from std ->
+ // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
+ LLVM::LLVMType elementType =
+ unwrap(typeConverter.convertType(type.getElementType()));
+ LLVM::LLVMType arrayTy = elementType;
+ // Shape has the outermost dim at index 0, so need to walk it backwards
+ for (int64_t dim : llvm::reverse(type.getShape()))
+ arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim);
+ return arrayTy;
+}
+
+/// GlobalMemrefOp is lowered to a LLVM Global Variable.
+struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
+ using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto global = cast<GlobalMemrefOp>(op);
+ MemRefType type = global.type().cast<MemRefType>();
+ if (!isSupportedMemRefType(type))
+ return failure();
+
+ LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+
+ LLVM::Linkage linkage =
+ global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
+
+ Attribute initialValue = nullptr;
+ if (!global.isExternal() && !global.isUninitialized()) {
+ auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
+ initialValue = elementsAttr;
+
+ // For scalar memrefs, the global variable created is of the element type,
+ // so unpack the elements attribute to extract the value.
+ if (type.getRank() == 0)
+ initialValue = elementsAttr.getValue({});
+ }
+
+ rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
+ op, arrayTy, global.constant(), linkage, global.sym_name(),
+ initialValue, type.getMemorySpace());
+ return success();
+ }
+};
+
+/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
+/// the first element stashed into the descriptor. This reuses
+/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
+struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
+ GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
+ : AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {}
+
+ /// Allocation for GetGlobalMemrefOp just returns the GV pointer, so no need
+ /// to compute buffer size.
+ bool needsBufferSize() const override { return false; }
+
+ /// Buffer "allocation" for get_global_memref op is getting the address of
+ /// the global variable referenced.
+ std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
+ Location loc, Value bufferSize,
+ Operation *op) const override {
+ auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
+ MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
+ unsigned memSpace = type.getMemorySpace();
+
+ LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+ auto addressOf = rewriter.create<LLVM::AddressOfOp>(
+ loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
+
+ // Get the address of the first element in the array by creating a GEP with
+ // the address of the GV as the base, and (rank + 1) number of 0 indices.
+ LLVM::LLVMType elementType =
+ unwrap(typeConverter.convertType(type.getElementType()));
+ LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
+
+ SmallVector<Value, 4> operands = {addressOf};
+ operands.insert(operands.end(), type.getRank() + 1,
+ createIndexConstant(rewriter, loc, 0));
+ auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
+
+ // We do not expect the memref obtained using `get_global_memref` to be
+ // ever deallocated. Set the allocated pointer to be known bad value to
+ // help debug if that ever happens.
+ auto intPtrType = getIntPtrType(memSpace);
+ Value deadBeefConst =
+ createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
+ auto deadBeefPtr =
+ rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
+
+ // Both allocated and aligned pointers are same. We could potentially stash
+ // a nullptr for the allocated pointer since we do not expect any dealloc.
+ return {deadBeefPtr, gep};
+ }
+};
+
// A `rsqrt` is converted into `1 / sqrt`.
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
@@ -3941,6 +4048,8 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
AssumeAlignmentOpLowering,
DeallocOpLowering,
DimOpLowering,
+ GlobalMemrefOpLowering,
+ GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
MemRefReinterpretCastOpLowering,
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 71a35f6ccf0a..0c38d02e816a 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -131,3 +131,82 @@ func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
%0 = transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
return
}
+
+// -----
+
+// CHECK: llvm.mlir.global external @gv0() : !llvm.array<2 x float>
+global_memref @gv0 : memref<2xf32> = uninitialized
+
+// CHECK: llvm.mlir.global private @gv1() : !llvm.array<2 x float>
+global_memref "private" @gv1 : memref<2xf32>
+
+// CHECK: llvm.mlir.global external @gv2(dense<{{\[\[}}0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]]> : tensor<2x3xf32>) : !llvm.array<2 x array<3 x float>>
+global_memref @gv2 : memref<2x3xf32> = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]>
+
+// Test 1D memref.
+// CHECK-LABEL: func @get_gv0_memref
+func @get_gv0_memref() {
+ %0 = get_global_memref @gv0 : memref<2xf32>
+ // CHECK: %[[DIM:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+ // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv0 : !llvm.ptr<array<2 x float>>
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]], %[[ZERO]]] : (!llvm.ptr<array<2 x float>>, !llvm.i64, !llvm.i64) -> !llvm.ptr<float>
+ // CHECK: %[[DEADBEEF:.*]] = llvm.mlir.constant(3735928559 : index) : !llvm.i64
+ // CHECK: %[[DEADBEEFPTR:.*]] = llvm.inttoptr %[[DEADBEEF]] : !llvm.i64 to !llvm.ptr<float>
+ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: llvm.insertvalue %[[DEADBEEFPTR]], {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[STRIDE:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %[[DIM]], {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: llvm.insertvalue %[[STRIDE]], {{.*}}[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+ return
+}
+
+// Test 2D memref.
+// CHECK-LABEL: func @get_gv2_memref
+func @get_gv2_memref() {
+ // CHECK: %[[DIM0:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
+ // CHECK: %[[DIM1:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
+ // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv2 : !llvm.ptr<array<2 x array<3 x float>>>
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]], %[[ZERO]], %[[ZERO]]] : (!llvm.ptr<array<2 x array<3 x float>>>, !llvm.i64, !llvm.i64, !llvm.i64) -> !llvm.ptr<float>
+ // CHECK: %[[DEADBEEF:.*]] = llvm.mlir.constant(3735928559 : index) : !llvm.i64
+ // CHECK: %[[DEADBEEFPTR:.*]] = llvm.inttoptr %[[DEADBEEF]] : !llvm.i64 to !llvm.ptr<float>
+ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: llvm.insertvalue %[[DEADBEEFPTR]], {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[STRIDE1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %[[DIM0]], {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: llvm.insertvalue %[[STRIDE0]], {{.*}}[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: llvm.insertvalue %[[DIM1]], {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: llvm.insertvalue %[[STRIDE1]], {{.*}}[4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+
+ %0 = get_global_memref @gv2 : memref<2x3xf32>
+ return
+}
+
+// Test scalar memref.
+// CHECK: llvm.mlir.global external @gv3(1.000000e+00 : f32) : !llvm.float
+global_memref @gv3 : memref<f32> = dense<1.0>
+
+// CHECK-LABEL: func @get_gv3_memref
+func @get_gv3_memref() {
+ // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv3 : !llvm.ptr<float>
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
+ // CHECK: %[[DEADBEEF:.*]] = llvm.mlir.constant(3735928559 : index) : !llvm.i64
+ // CHECK: %[[DEADBEEFPTR:.*]] = llvm.inttoptr %[[DEADBEEF]] : !llvm.i64 to !llvm.ptr<float>
+ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+ // CHECK: llvm.insertvalue %[[DEADBEEFPTR]], {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+ // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+ // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
+ %0 = get_global_memref @gv3 : memref<f32>
+ return
+}
+
diff --git a/mlir/test/mlir-cpu-runner/global_memref.mlir b/mlir/test/mlir-cpu-runner/global_memref.mlir
new file mode 100644
index 000000000000..1c9cf4a33aeb
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/global_memref.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
+func @print_memref_i32(memref<*xi32>) attributes { llvm.emit_c_interface }
+func @printNewline() -> ()
+
+global_memref "private" @gv0 : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]>
+func @test1DMemref() {
+ %0 = get_global_memref @gv0 : memref<4xf32>
+ %U = memref_cast %0 : memref<4xf32> to memref<*xf32>
+ // CHECK: rank = 1
+ // CHECK: offset = 0
+ // CHECK: sizes = [4]
+ // CHECK: strides = [1]
+ // CHECK: [0, 1, 2, 3]
+ call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @printNewline() : () -> ()
+
+ // Overwrite some of the elements.
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %fp0 = constant 4.0 : f32
+ %fp1 = constant 5.0 : f32
+ store %fp0, %0[%c0] : memref<4xf32>
+ store %fp1, %0[%c2] : memref<4xf32>
+ // CHECK: rank = 1
+ // CHECK: offset = 0
+ // CHECK: sizes = [4]
+ // CHECK: strides = [1]
+ // CHECK: [4, 1, 5, 3]
+ call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @printNewline() : () -> ()
+ return
+}
+
+global_memref constant @gv1 : memref<3x2xi32> = dense<[[0, 1],[2, 3],[4, 5]]>
+func @testConstantMemref() {
+ %0 = get_global_memref @gv1 : memref<3x2xi32>
+ %U = memref_cast %0 : memref<3x2xi32> to memref<*xi32>
+ // CHECK: rank = 2
+ // CHECK: offset = 0
+ // CHECK: sizes = [3, 2]
+ // CHECK: strides = [2, 1]
+ // CHECK: [0, 1]
+ // CHECK: [2, 3]
+ // CHECK: [4, 5]
+ call @print_memref_i32(%U) : (memref<*xi32>) -> ()
+ call @printNewline() : () -> ()
+ return
+}
+
+global_memref "private" @gv2 : memref<4x2xf32> = dense<[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]>
+func @test2DMemref() {
+ %0 = get_global_memref @gv2 : memref<4x2xf32>
+ %U = memref_cast %0 : memref<4x2xf32> to memref<*xf32>
+ // CHECK: rank = 2
+ // CHECK: offset = 0
+ // CHECK: sizes = [4, 2]
+ // CHECK: strides = [2, 1]
+ // CHECK: [0, 1]
+ // CHECK: [2, 3]
+ // CHECK: [4, 5]
+ // CHECK: [6, 7]
+ call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @printNewline() : () -> ()
+
+ // Overwrite the 1.0 (at index [0, 1]) with 10.0
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %fp10 = constant 10.0 : f32
+ store %fp10, %0[%c0, %c1] : memref<4x2xf32>
+ // CHECK: rank = 2
+ // CHECK: offset = 0
+ // CHECK: sizes = [4, 2]
+ // CHECK: strides = [2, 1]
+ // CHECK: [0, 10]
+ // CHECK: [2, 3]
+ // CHECK: [4, 5]
+ // CHECK: [6, 7]
+ call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @printNewline() : () -> ()
+ return
+}
+
+global_memref @gv3 : memref<i32> = dense<11>
+func @testScalarMemref() {
+ %0 = get_global_memref @gv3 : memref<i32>
+ %U = memref_cast %0 : memref<i32> to memref<*xi32>
+ // CHECK: rank = 0
+ // CHECK: offset = 0
+ // CHECK: sizes = []
+ // CHECK: strides = []
+ // CHECK: [11]
+ call @print_memref_i32(%U) : (memref<*xi32>) -> ()
+ call @printNewline() : () -> ()
+ return
+}
+
+func @main() -> () {
+ call @test1DMemref() : () -> ()
+ call @testConstantMemref() : () -> ()
+ call @test2DMemref() : () -> ()
+ call @testScalarMemref() : () -> ()
+ return
+}
+
+
More information about the Mlir-commits
mailing list