[Mlir-commits] [mlir] 3effc35 - [mlir] Lower DimOp to LLVM for unranked memrefs.

Alexander Belyaev llvmlistbot at llvm.org
Thu Aug 6 02:46:27 PDT 2020


Author: Alexander Belyaev
Date: 2020-08-06T11:46:11+02:00
New Revision: 3effc35015ce0163579943ba1c86c35d0a3ba752

URL: https://github.com/llvm/llvm-project/commit/3effc35015ce0163579943ba1c86c35d0a3ba752
DIFF: https://github.com/llvm/llvm-project/commit/3effc35015ce0163579943ba1c86c35d0a3ba752.diff

LOG: [mlir] Lower DimOp to LLVM for unranked memrefs.

Differential Revision: https://reviews.llvm.org/D85361

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/mlir-cpu-runner/unranked_memref.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index a5ecbe4381de..b57a1607bc97 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2370,30 +2370,85 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto dimOp = cast<DimOp>(op);
-    auto loc = op->getLoc();
+    Type operandType = dimOp.memrefOrTensor().getType();
+    if (operandType.isa<UnrankedMemRefType>()) {
+      rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp,
+                                                          operands, rewriter)});
+
+      return success();
+    }
+    if (operandType.isa<MemRefType>()) {
+      rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp,
+                                                        operands, rewriter)});
+      return success();
+    }
+    return failure();
+  }
+
+private:
+  Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp,
+                                    ArrayRef<Value> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+    Location loc = dimOp.getLoc();
     DimOp::Adaptor transformed(operands);
 
+    auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
+    auto scalarMemRefType =
+        MemRefType::get({}, unrankedMemRefType.getElementType());
+    unsigned addressSpace = unrankedMemRefType.getMemorySpace();
+
+    // Extract pointer to the underlying ranked descriptor and bitcast it to a
+    // memref<element_type> descriptor pointer to minimize the number of GEP
+    // operations.
+    UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor());
+    Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
+    Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
+        loc,
+        typeConverter.convertType(scalarMemRefType)
+            .cast<LLVM::LLVMType>()
+            .getPointerTo(addressSpace),
+        underlyingRankedDesc);
+
+    // Get pointer to offset field of memref<element_type> descriptor.
+    Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace);
+    Value two = rewriter.create<LLVM::ConstantOp>(
+        loc, typeConverter.convertType(rewriter.getI32Type()),
+        rewriter.getI32IntegerAttr(2));
+    Value offsetPtr = rewriter.create<LLVM::GEPOp>(
+        loc, indexPtrTy, scalarMemRefDescPtr,
+        ValueRange({createIndexConstant(rewriter, loc, 0), two}));
+
+    // The size value that we have to extract can be obtained using GEPop with
+    // `dimOp.index() + 1` index argument.
+    Value idxPlusOne = rewriter.create<LLVM::AddOp>(
+        loc, createIndexConstant(rewriter, loc, 1), transformed.index());
+    Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
+                                                 ValueRange({idxPlusOne}));
+    return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
+  }
+
+  Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp,
+                                  ArrayRef<Value> operands,
+                                  ConversionPatternRewriter &rewriter) const {
+    Location loc = dimOp.getLoc();
+    DimOp::Adaptor transformed(operands);
     // Take advantage if index is constant.
-    MemRefType memRefType = dimOp.memrefOrTensor().getType().cast<MemRefType>();
+    MemRefType memRefType = operandType.cast<MemRefType>();
     if (Optional<int64_t> index = dimOp.getConstantIndex()) {
       int64_t i = index.getValue();
       if (memRefType.isDynamicDim(i)) {
-        // Extract dynamic size from the memref descriptor.
+        // extract dynamic size from the memref descriptor.
         MemRefDescriptor descriptor(transformed.memrefOrTensor());
-        rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)});
-      } else {
-        // Use constant for static size.
-        int64_t dimSize = memRefType.getDimSize(i);
-        rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize));
+        return descriptor.size(rewriter, loc, i);
       }
-      return success();
+      // Use constant for static size.
+      int64_t dimSize = memRefType.getDimSize(i);
+      return createIndexConstant(rewriter, loc, dimSize);
     }
-
     Value index = dimOp.index();
     int64_t rank = memRefType.getRank();
     MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
-    rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index, rank)});
-    return success();
+    return memrefDescriptor.size(rewriter, loc, index, rank);
   }
 };
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a78e2427b2fe..d084620f3a03 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1326,7 +1326,7 @@ static LogicalResult verify(DimOp op) {
   } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
     if (index.getValue() >= memrefType.getRank())
       return op.emitOpError("index is out of range");
-  } else if (type.isa<UnrankedTensorType>()) {
+  } else if (type.isa<UnrankedTensorType>() || type.isa<UnrankedMemRefType>()) {
     // Assume index to be in range.
   } else {
     llvm_unreachable("expected operand with tensor or memref type");
@@ -1342,9 +1342,13 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
   if (!index)
     return {};
 
-  // Fold if the shape extent along the given index is known.
   auto argTy = memrefOrTensor().getType();
+  // Fold if the shape extent along the given index is known.
   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
+    // Folding for unranked types (UnrankedMemRefType, UnrankedTensorType) is
+    // not supported.
+    if (!shapedTy.hasRank())
+      return {};
     if (!shapedTy.isDynamicDim(index.getInt())) {
       Builder builder(getContext());
       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
@@ -1357,7 +1361,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   // The size at the given index is now known to be a dynamic size of a memref.
-  auto memref = memrefOrTensor().getDefiningOp();
+  auto *memref = memrefOrTensor().getDefiningOp();
   unsigned unsignedIndex = index.getValue().getZExtValue();
   if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
     return *(alloc.getDynamicSizes().begin() +

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index d5d966fe1115..e7935bc165f9 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1336,3 +1336,42 @@ func @rank_of_ranked(%ranked: memref<?xi32>) {
 }
 // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
 // CHECK32: llvm.mlir.constant(1 : index) : !llvm.i32
+
+// -----
+
+// CHECK-LABEL: func @dim_of_unranked
+// CHECK32-LABEL: func @dim_of_unranked
+func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
+  %c0 = constant 0 : index
+  %dim = dim %unranked, %c0 : memref<*xi32>
+  return %dim : index
+}
+// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
+// CHECK-NEXT: llvm.insertvalue 
+// CHECK-NEXT: %[[UNRANKED_DESC:.*]] = llvm.insertvalue
+// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+
+// CHECK-NEXT: %[[RANKED_DESC:.*]] = llvm.extractvalue %[[UNRANKED_DESC]][1]
+// CHECK-SAME:   : !llvm.struct<(i64, ptr<i8>)>
+
+// CHECK-NEXT: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]]
+// CHECK-SAME:   : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, ptr<i32>, i64)>>
+
+// CHECK-NEXT: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : !llvm.i32
+// CHECK-NEXT: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+
+// CHECK-NEXT: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}}
+// CHECK-SAME:   %[[C0_]], %[[C2_i32]]] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
+// CHECK-SAME:   i64)>>, !llvm.i64, !llvm.i32) -> !llvm.ptr<i64>
+
+// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %[[C0]] : !llvm.i64
+
+// CHECK-NEXT: %[[SIZE_PTR:.*]] = llvm.getelementptr %[[OFFSET_PTR]]{{\[}}
+// CHECK-SAME:   %[[INDEX_INC]]] : (!llvm.ptr<i64>, !llvm.i64) -> !llvm.ptr<i64>
+
+// CHECK-NEXT: %[[SIZE:.*]] = llvm.load %[[SIZE_PTR]] : !llvm.ptr<i64>
+// CHECK-NEXT: llvm.return %[[SIZE]] : !llvm.i64
+
+// CHECK32: %[[SIZE:.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
+// CHECK32-NEXT: llvm.return %[[SIZE]] : !llvm.i32

diff  --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
index df760f593db2..4e0c58f71340 100644
--- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir
+++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext | FileCheck %s
 
 // CHECK: rank = 2
 // CHECK-SAME: sizes = [10, 3]
@@ -66,6 +66,7 @@ func @main() -> () {
 
     call @return_var_memref_caller() : () -> ()
     call @return_two_var_memref_caller() : () -> ()
+    call @dim_op_of_unranked() : () -> ()
     return
 }
 
@@ -100,3 +101,25 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
   %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32>
   return %0 : memref<*xf32>
 }
+
+func @print_i64(index) -> ()
+func @print_newline() -> ()
+
+func @dim_op_of_unranked() {
+  %ranked = alloc() : memref<4x3xf32>
+  %unranked = memref_cast %ranked: memref<4x3xf32> to memref<*xf32>
+
+  %c0 = constant 0 : index
+  %dim_0 = dim %unranked, %c0 : memref<*xf32>
+  call @print_i64(%dim_0) : (index) -> ()
+  call @print_newline() : () -> ()
+  // CHECK: 4
+
+  %c1 = constant 1 : index
+  %dim_1 = dim %unranked, %c1 : memref<*xf32>
+  call @print_i64(%dim_1) : (index) -> ()
+  call @print_newline() : () -> ()
+  // CHECK: 3
+
+  return
+}


        


More information about the Mlir-commits mailing list