[Mlir-commits] [mlir] 3effc35 - [mlir] Lower DimOp to LLVM for unranked memrefs.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Aug 6 02:46:27 PDT 2020
Author: Alexander Belyaev
Date: 2020-08-06T11:46:11+02:00
New Revision: 3effc35015ce0163579943ba1c86c35d0a3ba752
URL: https://github.com/llvm/llvm-project/commit/3effc35015ce0163579943ba1c86c35d0a3ba752
DIFF: https://github.com/llvm/llvm-project/commit/3effc35015ce0163579943ba1c86c35d0a3ba752.diff
LOG: [mlir] Lower DimOp to LLVM for unranked memrefs.
Differential Revision: https://reviews.llvm.org/D85361
Added:
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/mlir-cpu-runner/unranked_memref.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index a5ecbe4381de..b57a1607bc97 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2370,30 +2370,85 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
- auto loc = op->getLoc();
+ Type operandType = dimOp.memrefOrTensor().getType();
+ if (operandType.isa<UnrankedMemRefType>()) {
+ rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp,
+ operands, rewriter)});
+
+ return success();
+ }
+ if (operandType.isa<MemRefType>()) {
+ rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp,
+ operands, rewriter)});
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = dimOp.getLoc();
DimOp::Adaptor transformed(operands);
+ auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
+ auto scalarMemRefType =
+ MemRefType::get({}, unrankedMemRefType.getElementType());
+ unsigned addressSpace = unrankedMemRefType.getMemorySpace();
+
+ // Extract pointer to the underlying ranked descriptor and bitcast it to a
+ // memref<element_type> descriptor pointer to minimize the number of GEP
+ // operations.
+ UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor());
+ Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
+ Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
+ loc,
+ typeConverter.convertType(scalarMemRefType)
+ .cast<LLVM::LLVMType>()
+ .getPointerTo(addressSpace),
+ underlyingRankedDesc);
+
+ // Get pointer to offset field of memref<element_type> descriptor.
+ Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace);
+ Value two = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter.convertType(rewriter.getI32Type()),
+ rewriter.getI32IntegerAttr(2));
+ Value offsetPtr = rewriter.create<LLVM::GEPOp>(
+ loc, indexPtrTy, scalarMemRefDescPtr,
+ ValueRange({createIndexConstant(rewriter, loc, 0), two}));
+
+ // The size value that we have to extract can be obtained using GEPop with
+ // `dimOp.index() + 1` index argument.
+ Value idxPlusOne = rewriter.create<LLVM::AddOp>(
+ loc, createIndexConstant(rewriter, loc, 1), transformed.index());
+ Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
+ ValueRange({idxPlusOne}));
+ return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
+ }
+
+ Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = dimOp.getLoc();
+ DimOp::Adaptor transformed(operands);
// Take advantage if index is constant.
- MemRefType memRefType = dimOp.memrefOrTensor().getType().cast<MemRefType>();
+ MemRefType memRefType = operandType.cast<MemRefType>();
if (Optional<int64_t> index = dimOp.getConstantIndex()) {
int64_t i = index.getValue();
if (memRefType.isDynamicDim(i)) {
- // Extract dynamic size from the memref descriptor.
+ // extract dynamic size from the memref descriptor.
MemRefDescriptor descriptor(transformed.memrefOrTensor());
- rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)});
- } else {
- // Use constant for static size.
- int64_t dimSize = memRefType.getDimSize(i);
- rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize));
+ return descriptor.size(rewriter, loc, i);
}
- return success();
+ // Use constant for static size.
+ int64_t dimSize = memRefType.getDimSize(i);
+ return createIndexConstant(rewriter, loc, dimSize);
}
-
Value index = dimOp.index();
int64_t rank = memRefType.getRank();
MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
- rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index, rank)});
- return success();
+ return memrefDescriptor.size(rewriter, loc, index, rank);
}
};
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a78e2427b2fe..d084620f3a03 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1326,7 +1326,7 @@ static LogicalResult verify(DimOp op) {
} else if (auto memrefType = type.dyn_cast<MemRefType>()) {
if (index.getValue() >= memrefType.getRank())
return op.emitOpError("index is out of range");
- } else if (type.isa<UnrankedTensorType>()) {
+ } else if (type.isa<UnrankedTensorType>() || type.isa<UnrankedMemRefType>()) {
// Assume index to be in range.
} else {
llvm_unreachable("expected operand with tensor or memref type");
@@ -1342,9 +1342,13 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
if (!index)
return {};
- // Fold if the shape extent along the given index is known.
auto argTy = memrefOrTensor().getType();
+ // Fold if the shape extent along the given index is known.
if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
+ // Folding for unranked types (UnrankedMemRefType, UnrankedTensorType) is
+ // not supported.
+ if (!shapedTy.hasRank())
+ return {};
if (!shapedTy.isDynamicDim(index.getInt())) {
Builder builder(getContext());
return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
@@ -1357,7 +1361,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
// The size at the given index is now known to be a dynamic size of a memref.
- auto memref = memrefOrTensor().getDefiningOp();
+ auto *memref = memrefOrTensor().getDefiningOp();
unsigned unsignedIndex = index.getValue().getZExtValue();
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
return *(alloc.getDynamicSizes().begin() +
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index d5d966fe1115..e7935bc165f9 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1336,3 +1336,42 @@ func @rank_of_ranked(%ranked: memref<?xi32>) {
}
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK32: llvm.mlir.constant(1 : index) : !llvm.i32
+
+// -----
+
+// CHECK-LABEL: func @dim_of_unranked
+// CHECK32-LABEL: func @dim_of_unranked
+func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
+ %c0 = constant 0 : index
+ %dim = dim %unranked, %c0 : memref<*xi32>
+ return %dim : index
+}
+// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
+// CHECK-NEXT: llvm.insertvalue
+// CHECK-NEXT: %[[UNRANKED_DESC:.*]] = llvm.insertvalue
+// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+
+// CHECK-NEXT: %[[RANKED_DESC:.*]] = llvm.extractvalue %[[UNRANKED_DESC]][1]
+// CHECK-SAME: : !llvm.struct<(i64, ptr<i8>)>
+
+// CHECK-NEXT: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]]
+// CHECK-SAME: : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, ptr<i32>, i64)>>
+
+// CHECK-NEXT: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : !llvm.i32
+// CHECK-NEXT: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+
+// CHECK-NEXT: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}}
+// CHECK-SAME: %[[C0_]], %[[C2_i32]]] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
+// CHECK-SAME: i64)>>, !llvm.i64, !llvm.i32) -> !llvm.ptr<i64>
+
+// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %[[C0]] : !llvm.i64
+
+// CHECK-NEXT: %[[SIZE_PTR:.*]] = llvm.getelementptr %[[OFFSET_PTR]]{{\[}}
+// CHECK-SAME: %[[INDEX_INC]]] : (!llvm.ptr<i64>, !llvm.i64) -> !llvm.ptr<i64>
+
+// CHECK-NEXT: %[[SIZE:.*]] = llvm.load %[[SIZE_PTR]] : !llvm.ptr<i64>
+// CHECK-NEXT: llvm.return %[[SIZE]] : !llvm.i64
+
+// CHECK32: %[[SIZE:.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
+// CHECK32-NEXT: llvm.return %[[SIZE]] : !llvm.i32
diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
index df760f593db2..4e0c58f71340 100644
--- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir
+++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext | FileCheck %s
// CHECK: rank = 2
// CHECK-SAME: sizes = [10, 3]
@@ -66,6 +66,7 @@ func @main() -> () {
call @return_var_memref_caller() : () -> ()
call @return_two_var_memref_caller() : () -> ()
+ call @dim_op_of_unranked() : () -> ()
return
}
@@ -100,3 +101,25 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
%0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32>
return %0 : memref<*xf32>
}
+
+func @print_i64(index) -> ()
+func @print_newline() -> ()
+
+func @dim_op_of_unranked() {
+ %ranked = alloc() : memref<4x3xf32>
+ %unranked = memref_cast %ranked: memref<4x3xf32> to memref<*xf32>
+
+ %c0 = constant 0 : index
+ %dim_0 = dim %unranked, %c0 : memref<*xf32>
+ call @print_i64(%dim_0) : (index) -> ()
+ call @print_newline() : () -> ()
+ // CHECK: 4
+
+ %c1 = constant 1 : index
+ %dim_1 = dim %unranked, %c1 : memref<*xf32>
+ call @print_i64(%dim_1) : (index) -> ()
+ call @print_newline() : () -> ()
+ // CHECK: 3
+
+ return
+}
More information about the Mlir-commits
mailing list