[Mlir-commits] [mlir] d6ab047 - [mlir] Convert MemRefReinterpretCastOp to LLVM.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Oct 26 12:13:49 PDT 2020
Author: Alexander Belyaev
Date: 2020-10-26T20:13:17+01:00
New Revision: d6ab0474c6efc5a614a28ed21070f11d587467f8
URL: https://github.com/llvm/llvm-project/commit/d6ab0474c6efc5a614a28ed21070f11d587467f8
DIFF: https://github.com/llvm/llvm-project/commit/d6ab0474c6efc5a614a28ed21070f11d587467f8.diff
LOG: [mlir] Convert MemRefReinterpretCastOp to LLVM.
https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15
Differential Revision: https://reviews.llvm.org/D90033
Added:
mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3fe60f5e88d4..5a6c50f2a549 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2416,6 +2416,114 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
}
};
+struct MemRefReinterpretCastOpLowering
+ : public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
+ using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto castOp = cast<MemRefReinterpretCastOp>(op);
+ MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary());
+ Type srcType = castOp.source().getType();
+
+ Value descriptor;
+ if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
+ adaptor, &descriptor)))
+ return failure();
+ rewriter.replaceOp(op, {descriptor});
+ return success();
+ }
+
+private:
+ LogicalResult
+ convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
+ Type srcType, MemRefReinterpretCastOp castOp,
+ MemRefReinterpretCastOp::Adaptor adaptor,
+ Value *descriptor) const {
+ MemRefType targetMemRefType =
+ castOp.getResult().getType().cast<MemRefType>();
+ auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+ .dyn_cast_or_null<LLVM::LLVMType>();
+ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ return failure();
+
+ // Create descriptor.
+ Location loc = castOp.getLoc();
+ auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+
+ // Set allocated and aligned pointers.
+ Value allocatedPtr, alignedPtr;
+ extractPointers(loc, rewriter, castOp.source(), adaptor.source(),
+ &allocatedPtr, &alignedPtr);
+ desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
+ desc.setAlignedPtr(rewriter, loc, alignedPtr);
+
+ // Set offset.
+ if (castOp.isDynamicOffset(0))
+ desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
+ else
+ desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
+
+ // Set sizes and strides.
+ unsigned dynSizeId = 0;
+ unsigned dynStrideId = 0;
+ for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
+ if (castOp.isDynamicSize(i))
+ desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
+ else
+ desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
+
+ if (castOp.isDynamicStride(i))
+ desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
+ else
+ desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
+ }
+ *descriptor = desc;
+ return success();
+ }
+
+ void extractPointers(Location loc, ConversionPatternRewriter &rewriter,
+ Value originalOperand, Value convertedOperand,
+ Value *allocatedPtr, Value *alignedPtr) const {
+ Type operandType = originalOperand.getType();
+ if (operandType.isa<MemRefType>()) {
+ MemRefDescriptor desc(convertedOperand);
+ *allocatedPtr = desc.allocatedPtr(rewriter, loc);
+ *alignedPtr = desc.alignedPtr(rewriter, loc);
+ return;
+ }
+
+ unsigned memorySpace =
+ operandType.cast<UnrankedMemRefType>().getMemorySpace();
+ Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
+ LLVM::LLVMType llvmElementType =
+ typeConverter.convertType(elementType).cast<LLVM::LLVMType>();
+ LLVM::LLVMType elementPtrPtrType =
+ llvmElementType.getPointerTo(memorySpace).getPointerTo();
+
+ // Extract pointer to the underlying ranked memref descriptor and cast it to
+ // ElemType**.
+ UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
+ Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
+ Value elementPtrPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, elementPtrPtrType, underlyingDescPtr);
+
+ LLVM::LLVMType int32Type =
+ typeConverter.convertType(rewriter.getI32Type()).cast<LLVM::LLVMType>();
+
+ // Extract and set allocated pointer.
+ *allocatedPtr = rewriter.create<LLVM::LoadOp>(loc, elementPtrPtr);
+
+ // Extract and set aligned pointer.
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(1));
+ Value alignedGep = rewriter.create<LLVM::GEPOp>(
+ loc, elementPtrPtrType, elementPtrPtr, ValueRange({one}));
+ *alignedPtr = rewriter.create<LLVM::LoadOp>(loc, alignedGep);
+ }
+};
+
struct DialectCastOpLowering
: public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
@@ -3532,6 +3640,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
DimOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
+ MemRefReinterpretCastOpLowering,
RankOpLowering,
StoreOpLowering,
SubViewOpLowering,
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 8e7b22574432..8447474484e2 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -432,3 +432,60 @@ func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index {
%result = dim %arg, %idx : memref<3x?xf32>
return %result : index
}
+
+// CHECK-LABEL: @memref_reinterpret_cast_ranked_to_static_shape
+func @memref_reinterpret_cast_ranked_to_static_shape(%input : memref<2x3xf32>) {
+ %output = memref_reinterpret_cast %input to
+ offset: [0], sizes: [6, 1], strides: [1, 1]
+ : memref<2x3xf32> to memref<6x1xf32>
+ return
+}
+// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[TY:!.*]]
+// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY]]
+// CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[TY]]
+// CHECK: [[ALIGNED_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[TY]]
+// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
+// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
+// CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
+// CHECK: [[SIZE_0:%.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
+// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
+// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_4]][4, 0] : [[TY]]
+// CHECK: [[STRIDE_0:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_5]][3, 1] : [[TY]]
+// CHECK: [[STRIDE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
+
+// CHECK-LABEL: @memref_reinterpret_cast_unranked_to_dynamic_shape
+func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
+ %size_0 : index,
+ %size_1 : index,
+ %stride_0 : index,
+ %stride_1 : index,
+ %input : memref<*xf32>) {
+ %output = memref_reinterpret_cast %input to
+ offset: [%offset], sizes: [%size_0, %size_1],
+ strides: [%stride_0, %stride_1]
+ : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ return
+}
+// CHECK-SAME: ([[OFFSET:%[a-z,0-9]+]]: !llvm.i64,
+// CHECK-SAME: [[SIZE_0:%[a-z,0-9]+]]: !llvm.i64, [[SIZE_1:%[a-z,0-9]+]]: !llvm.i64,
+// CHECK-SAME: [[STRIDE_0:%[a-z,0-9]+]]: !llvm.i64, [[STRIDE_1:%[a-z,0-9]+]]: !llvm.i64,
+// CHECK: [[INPUT:%.*]] = llvm.insertvalue {{.*}}[1] : !llvm.struct<(i64, ptr<i8>)>
+// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]]
+// CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr<i8>)>
+// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
+// CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]]
+// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i32) -> !llvm.ptr<ptr<float>>
+// CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
+// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
+// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
+// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
+// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
+// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]]
+// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]]
+// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
diff --git a/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir
new file mode 100644
index 000000000000..4f933a7784cb
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
+
+func @main() -> () {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+
+ // Initialize input.
+ %input = alloc() : memref<2x3xf32>
+ %dim_x = dim %input, %c0 : memref<2x3xf32>
+ %dim_y = dim %input, %c1 : memref<2x3xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
+ %prod = muli %i, %dim_y : index
+ %val = addi %prod, %j : index
+ %val_i64 = index_cast %val : index to i64
+ %val_f32 = sitofp %val_i64 : i64 to f32
+ store %val_f32, %input[%i, %j] : memref<2x3xf32>
+ }
+ %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+ call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
+ // CHECK-NEXT: [0, 1, 2]
+ // CHECK-NEXT: [3, 4, 5]
+
+ // Test cases.
+ call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
+ call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
+ call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
+ call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
+ return
+}
+
+func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) {
+ %output = memref_reinterpret_cast %input to
+ offset: [0], sizes: [6, 1], strides: [1, 1]
+ : memref<2x3xf32> to memref<6x1xf32>
+
+ %unranked_output = memref_cast %output
+ : memref<6x1xf32> to memref<*xf32>
+ call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
+ // CHECK-NEXT: [0],
+ // CHECK-NEXT: [1],
+ // CHECK-NEXT: [2],
+ // CHECK-NEXT: [3],
+ // CHECK-NEXT: [4],
+ // CHECK-NEXT: [5]
+ return
+}
+
+func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c6 = constant 6 : index
+ %output = memref_reinterpret_cast %input to
+ offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
+ : memref<2x3xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+
+ %unranked_output = memref_cast %output
+ : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
+ call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
+ // CHECK-NEXT: [0, 1, 2, 3, 4, 5]
+ return
+}
+
+func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) {
+ %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+ %output = memref_reinterpret_cast %unranked_input to
+ offset: [0], sizes: [6, 1], strides: [1, 1]
+ : memref<*xf32> to memref<6x1xf32>
+
+ %unranked_output = memref_cast %output
+ : memref<6x1xf32> to memref<*xf32>
+ call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
+ // CHECK-NEXT: [0],
+ // CHECK-NEXT: [1],
+ // CHECK-NEXT: [2],
+ // CHECK-NEXT: [3],
+ // CHECK-NEXT: [4],
+ // CHECK-NEXT: [5]
+ return
+}
+
+func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
+ %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c6 = constant 6 : index
+ %output = memref_reinterpret_cast %unranked_input to
+ offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
+ : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+
+ %unranked_output = memref_cast %output
+ : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
+ call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
+ // CHECK-NEXT: [0, 1, 2, 3, 4, 5]
+ return
+}
More information about the Mlir-commits
mailing list