[Mlir-commits] [mlir] a3d427d - [mlir] Lower RankOp to LLVM for unranked memrefs.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Aug 5 03:14:02 PDT 2020
Author: Alexander Belyaev
Date: 2020-08-05T12:13:43+02:00
New Revision: a3d427d30cd32f218f53e32b58e232ea8312aa50
URL: https://github.com/llvm/llvm-project/commit/a3d427d30cd32f218f53e32b58e232ea8312aa50
DIFF: https://github.com/llvm/llvm-project/commit/a3d427d30cd32f218f53e32b58e232ea8312aa50.diff
LOG: [mlir] Lower RankOp to LLVM for unranked memrefs.
Differential Revision: https://reviews.llvm.org/D85273
Added:
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index d0b49bb18195..533ac629ba5a 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2402,6 +2402,28 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
}
};
+struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
+ using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ Type operandType = cast<RankOp>(op).memrefOrTensor().getType();
+ if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
+ UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
+ rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
+ return success();
+ }
+ if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
+ rewriter.replaceOp(
+ op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
+ return success();
+ }
+ return failure();
+ }
+};
+
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
@@ -3272,6 +3294,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
DimOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
+ RankOpLowering,
StoreOpLowering,
SubViewOpLowering,
ViewOpLowering,
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index d0e883a10bee..6123f68b7e85 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1291,3 +1291,26 @@ func @bfloat(%arg0: bf16) -> bf16 {
func @memref_index(%arg0: memref<32xindex>) -> memref<32xindex> {
return %arg0 : memref<32xindex>
}
+
+// -----
+
+// CHECK-LABEL: func @rank_of_unranked
+// CHECK32-LABEL: func @rank_of_unranked
+func @rank_of_unranked(%unranked: memref<*xi32>) {
+ %rank = rank %unranked : memref<*xi32>
+ return
+}
+// CHECK-NEXT: llvm.mlir.undef
+// CHECK-NEXT: llvm.insertvalue
+// CHECK-NEXT: llvm.insertvalue
+// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }">
+// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i8* }">
+
+// CHECK-LABEL: func @rank_of_ranked
+// CHECK32-LABEL: func @rank_of_ranked
+func @rank_of_ranked(%ranked: memref<?xi32>) {
+ %rank = rank %ranked : memref<?xi32>
+ return
+}
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK32: llvm.mlir.constant(1 : index) : !llvm.i32
More information about the Mlir-commits
mailing list