[Mlir-commits] [mlir] 0990f1a - [MLIR][Standard] Lower `std.dim` with dynamic dimension operand to LLVM
Frederik Gossen
llvmlistbot at llvm.org
Tue Jun 16 13:58:29 PDT 2020
Author: Frederik Gossen
Date: 2020-06-16T20:57:42Z
New Revision: 0990f1a3adefb6121573c63fec4879dc2b2010c0
URL: https://github.com/llvm/llvm-project/commit/0990f1a3adefb6121573c63fec4879dc2b2010c0
DIFF: https://github.com/llvm/llvm-project/commit/0990f1a3adefb6121573c63fec4879dc2b2010c0.diff
LOG: [MLIR][Standard] Lower `std.dim` with dynamic dimension operand to LLVM
Implement the missing lowering from `std.dim` to the LLVM dialect in case of a
dynamic dimension.
Differential Revision: https://reviews.llvm.org/D81834
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index a5c14139668b..a7e4ff2f52cf 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -278,6 +278,7 @@ class MemRefDescriptor : public StructBuilder {
/// Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos);
+ Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 4d1cad7e9985..a316f2e56041 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -568,6 +568,29 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}
+Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
+ int64_t rank) {
+ auto indexTy = indexType.cast<LLVM::LLVMType>();
+ auto indexPtrTy = indexTy.getPointerTo();
+ auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
+ auto arrayPtrTy = arrayTy.getPointerTo();
+
+ // Copy size values to stack-allocated memory.
+ auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
+ auto one = createIndexAttrConstant(builder, loc, indexType, 1);
+ auto sizes = builder.create<LLVM::ExtractValueOp>(
+ loc, arrayTy, value,
+ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
+ auto sizesPtr =
+ builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
+ builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
+
+ // Load an return size value of interest.
+ auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
+ ValueRange({zero, pos}));
+ return builder.create<LLVM::LoadOp>(loc, resultPtr);
+}
+
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value size) {
@@ -576,7 +599,6 @@ void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}
-/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
unsigned pos, uint64_t size) {
setSize(builder, loc, pos,
@@ -598,7 +620,6 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}
-/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
unsigned pos, uint64_t stride) {
setStride(builder, loc, pos,
@@ -2117,25 +2138,29 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
+ auto loc = op->getLoc();
DimOp::Adaptor transformed(operands);
- MemRefType type = dimOp.memrefOrTensor().getType().cast<MemRefType>();
- Optional<int64_t> index = dimOp.getConstantIndex();
- if (!index.hasValue()) {
- // TODO: Implement this lowering.
- return failure();
+ // Take advantage if index is constant.
+ MemRefType memRefType = dimOp.memrefOrTensor().getType().cast<MemRefType>();
+ if (Optional<int64_t> index = dimOp.getConstantIndex()) {
+ int64_t i = index.getValue();
+ if (memRefType.isDynamicDim(i)) {
+ // Extract dynamic size from the memref descriptor.
+ MemRefDescriptor descriptor(transformed.memrefOrTensor());
+ rewriter.replaceOp(op, {descriptor.size(rewriter, loc, i)});
+ } else {
+ // Use constant for static size.
+ int64_t dimSize = memRefType.getDimSize(i);
+ rewriter.replaceOp(op, createIndexConstant(rewriter, loc, dimSize));
+ }
+ return success();
}
- int64_t i = index.getValue();
- // Extract dynamic size from the memref descriptor.
- if (type.isDynamicDim(i))
- rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
- .size(rewriter, op->getLoc(), i)});
- else
- // Use constant for static size.
- rewriter.replaceOp(
- op, createIndexConstant(rewriter, op->getLoc(), type.getDimSize(i)));
-
+ Value index = dimOp.index();
+ int64_t rank = memRefType.getRank();
+ MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
+ rewriter.replaceOp(op, {memrefDescriptor.size(rewriter, loc, index, rank)});
return success();
}
};
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0ce03b3cf114..dd67f2de86f0 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1308,7 +1308,7 @@ static LogicalResult verify(DimOp op) {
}
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
- auto index = operands[1].dyn_cast<IntegerAttr>();
+ auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
// All forms of folding require a known index.
if (!index)
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index b13a502a2e60..a1a9634818b3 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -408,3 +408,26 @@ func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
%4 = dim %mixed, %c4 : memref<42x?x?x13x?xf32>
return
}
+
+// CHECK-LABEL: @memref_dim_with_dyn_index
+// CHECK-SAME: %[[ALLOC_PTR:.*]]: !llvm<"float*">, %[[ALIGN_PTR:.*]]: !llvm<"float*">, %[[OFFSET:.*]]: !llvm.i64, %[[SIZE0:.*]]: !llvm.i64, %[[SIZE1:.*]]: !llvm.i64, %[[STRIDE0:.*]]: !llvm.i64, %[[STRIDE1:.*]]: !llvm.i64, %[[IDX:.*]]: !llvm.i64) -> !llvm.i64
+func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index {
+ // CHECK-NEXT: %[[DESCR0:.*]] = llvm.mlir.undef : [[DESCR_TY:!llvm<"{ float\*, float\*, i64, \[2 x i64\], \[2 x i64\] }">]]
+ // CHECK-NEXT: %[[DESCR1:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DESCR0]][0] : [[DESCR_TY]]
+ // CHECK-NEXT: %[[DESCR2:.*]] = llvm.insertvalue %[[ALIGN_PTR]], %[[DESCR1]][1] : [[DESCR_TY]]
+ // CHECK-NEXT: %[[DESCR3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESCR2]][2] : [[DESCR_TY]]
+ // CHECK-NEXT: %[[DESCR4:.*]] = llvm.insertvalue %[[SIZE0]], %[[DESCR3]][3, 0] : [[DESCR_TY]]
+ // CHECK-NEXT: %[[DESCR5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESCR4]][4, 0] : [[DESCR_TY]]
+ // CHECK-NEXT: %[[DESCR6:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESCR5]][3, 1] : [[DESCR_TY]]
+ // CHECK-NEXT: %[[DESCR7:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESCR6]][4, 1] : [[DESCR_TY]]
+ // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK-DAG: %[[SIZES:.*]] = llvm.extractvalue %[[DESCR7]][3] : [[DESCR_TY]]
+ // CHECK-DAG: %[[SIZES_PTR:.*]] = llvm.alloca %[[C1]] x !llvm<"[2 x i64]"> : (!llvm.i64) -> !llvm<"[2 x i64]*">
+ // CHECK-DAG: llvm.store %[[SIZES]], %[[SIZES_PTR]] : !llvm<"[2 x i64]*">
+ // CHECK-DAG: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]], %[[IDX]]] : (!llvm<"[2 x i64]*">, !llvm.i64, !llvm.i64) -> !llvm<"i64*">
+ // CHECK-DAG: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm<"i64*">
+ // CHECK-DAG: llvm.return %[[RESULT]] : !llvm.i64
+ %result = dim %arg, %idx : memref<3x?xf32>
+ return %result : index
+}
More information about the Mlir-commits
mailing list