[Mlir-commits] [mlir] 0990f1a - [MLIR][Standard] Lower `std.dim` with dynamic dimension operand to LLVM

Frederik Gossen llvmlistbot at llvm.org
Tue Jun 16 13:58:29 PDT 2020


Author: Frederik Gossen
Date: 2020-06-16T20:57:42Z
New Revision: 0990f1a3adefb6121573c63fec4879dc2b2010c0

URL: https://github.com/llvm/llvm-project/commit/0990f1a3adefb6121573c63fec4879dc2b2010c0
DIFF: https://github.com/llvm/llvm-project/commit/0990f1a3adefb6121573c63fec4879dc2b2010c0.diff

LOG: [MLIR][Standard] Lower `std.dim` with dynamic dimension operand to LLVM

Implement the missing lowering from `std.dim` to the LLVM dialect in case of a
dynamic dimension.

Differential Revision: https://reviews.llvm.org/D81834

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index a5c14139668b..a7e4ff2f52cf 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -278,6 +278,7 @@ class MemRefDescriptor : public StructBuilder {
 
   /// Builds IR extracting the pos-th size from the descriptor.
   Value size(OpBuilder &builder, Location loc, unsigned pos);
+  Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
 
   /// Builds IR inserting the pos-th size into the descriptor
   void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 4d1cad7e9985..a316f2e56041 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -568,6 +568,29 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
 }
 
+Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
+                             int64_t rank) {
+  auto indexTy = indexType.cast<LLVM::LLVMType>();
+  auto indexPtrTy = indexTy.getPointerTo();
+  auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
+  auto arrayPtrTy = arrayTy.getPointerTo();
+
+  // Copy size values to stack-allocated memory.
+  auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
+  auto one = createIndexAttrConstant(builder, loc, indexType, 1);
+  auto sizes = builder.create<LLVM::ExtractValueOp>(
+      loc, arrayTy, value,
+      builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
+  auto sizesPtr =
+      builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
+  builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
+
+  // Load an return size value of interest.
+  auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
+                                               ValueRange({zero, pos}));
+  return builder.create<LLVM::LoadOp>(loc, resultPtr);
+}
+
 /// Builds IR inserting the pos-th size into the descriptor
 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
                                Value size) {
@@ -576,7 +599,6 @@ void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
 }
 
-/// Builds IR inserting the pos-th size into the descriptor
 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
                                        unsigned pos, uint64_t size) {
   setSize(builder, loc, pos,
@@ -598,7 +620,6 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
       builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
 }
 
-/// Builds IR inserting the pos-th stride into the descriptor
 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
                                          unsigned pos, uint64_t stride) {
   setStride(builder, loc, pos,
@@ -2117,25 +2138,29 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto dimOp = cast<DimOp>(op);
+    auto loc = op->getLoc();
     DimOp::Adaptor transformed(operands);
-    MemRefType type = dimOp.memrefOrTensor().getType().cast<MemRefType>();
 
-    Optional<int64_t> index = dimOp.getConstantIndex();
-    if (!index.hasValue()) {
-      // TODO: Implement this lowering.
-      return failure();
+    // Take advantage if index is constant.
+    MemRefType memRefType = dimOp.memrefOrTensor().getType().cast<MemRefType>();
+    if (Optional<int64_t> index = dimOp.getConstantIndex()) {
+      int64_t i = index.getValue();
+      if (memRefType.isDynamicDim(i)) {
+        // Extract dynamic size from the memref descriptor.
+        MemRefDescriptor descriptor(transformed.memrefOrTensor());
+        rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)});
+      } else {
+        // Use constant for static size.
+        int64_t dimSize = memRefType.getDimSize(i);
+        rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize));
+      }
+      return success();
     }
 
-    int64_t i = index.getValue();
-    // Extract dynamic size from the memref descriptor.
-    if (type.isDynamicDim(i))
-      rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
-                                  .size(rewriter, op->getLoc(), i)});
-    else
-      // Use constant for static size.
-      rewriter.replaceOp(
-          op, createIndexConstant(rewriter, op->getLoc(), type.getDimSize(i)));
-
+    Value index = dimOp.index();
+    int64_t rank = memRefType.getRank();
+    MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
+    rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index, rank)});
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0ce03b3cf114..dd67f2de86f0 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1308,7 +1308,7 @@ static LogicalResult verify(DimOp op) {
 }
 
 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
-  auto index = operands[1].dyn_cast<IntegerAttr>();
+  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
 
   // All forms of folding require a known index.
   if (!index)

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index b13a502a2e60..a1a9634818b3 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -408,3 +408,26 @@ func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
   %4 = dim %mixed, %c4 : memref<42x?x?x13x?xf32>
   return
 }
+
+// CHECK-LABEL: @memref_dim_with_dyn_index
+// CHECK-SAME: %[[ALLOC_PTR:.*]]: !llvm<"float*">, %[[ALIGN_PTR:.*]]: !llvm<"float*">, %[[OFFSET:.*]]: !llvm.i64, %[[SIZE0:.*]]: !llvm.i64, %[[SIZE1:.*]]: !llvm.i64, %[[STRIDE0:.*]]: !llvm.i64, %[[STRIDE1:.*]]: !llvm.i64, %[[IDX:.*]]: !llvm.i64) -> !llvm.i64
+func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index {
+  // CHECK-NEXT: %[[DESCR0:.*]] = llvm.mlir.undef : [[DESCR_TY:!llvm<"{ float\*, float\*, i64, \[2 x i64\], \[2 x i64\] }">]]
+  // CHECK-NEXT: %[[DESCR1:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DESCR0]][0] : [[DESCR_TY]]
+  // CHECK-NEXT: %[[DESCR2:.*]] = llvm.insertvalue %[[ALIGN_PTR]], %[[DESCR1]][1] : [[DESCR_TY]]
+  // CHECK-NEXT: %[[DESCR3:.*]] = llvm.insertvalue %[[OFFSET]],    %[[DESCR2]][2] : [[DESCR_TY]]
+  // CHECK-NEXT: %[[DESCR4:.*]] = llvm.insertvalue %[[SIZE0]],     %[[DESCR3]][3, 0] : [[DESCR_TY]]
+  // CHECK-NEXT: %[[DESCR5:.*]] = llvm.insertvalue %[[STRIDE0]],   %[[DESCR4]][4, 0] : [[DESCR_TY]]
+  // CHECK-NEXT: %[[DESCR6:.*]] = llvm.insertvalue %[[SIZE1]],     %[[DESCR5]][3, 1] : [[DESCR_TY]]
+  // CHECK-NEXT: %[[DESCR7:.*]] = llvm.insertvalue %[[STRIDE1]],   %[[DESCR6]][4, 1] : [[DESCR_TY]]
+  // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+  // CHECK-DAG: %[[SIZES:.*]] = llvm.extractvalue %[[DESCR7]][3] : [[DESCR_TY]]
+  // CHECK-DAG: %[[SIZES_PTR:.*]] = llvm.alloca %[[C1]] x !llvm<"[2 x i64]"> : (!llvm.i64) -> !llvm<"[2 x i64]*">
+  // CHECK-DAG: llvm.store %[[SIZES]], %[[SIZES_PTR]] : !llvm<"[2 x i64]*">
+  // CHECK-DAG: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]], %[[IDX]]] : (!llvm<"[2 x i64]*">, !llvm.i64, !llvm.i64) -> !llvm<"i64*">
+  // CHECK-DAG: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm<"i64*">
+  // CHECK-DAG: llvm.return %[[RESULT]] : !llvm.i64
+  %result = dim %arg, %idx : memref<3x?xf32>
+  return %result : index
+}


        


More information about the Mlir-commits mailing list