[Mlir-commits] [mlir] ae0fb61 - [MLIR] Check for static shape before bare pointer conversion
Lorenzo Chelini
llvmlistbot at llvm.org
Tue Apr 5 08:56:55 PDT 2022
Author: Lorenzo Chelini
Date: 2022-04-05T17:56:41+02:00
New Revision: ae0fb61303f84f64d4a84dedf72036672741f87f
URL: https://github.com/llvm/llvm-project/commit/ae0fb61303f84f64d4a84dedf72036672741f87f
DIFF: https://github.com/llvm/llvm-project/commit/ae0fb61303f84f64d4a84dedf72036672741f87f.diff
LOG: [MLIR] Check for static shape before bare pointer conversion
Originally in the returnOp conversion, the result type was changing to bare
pointer if the type was a memref. This is incorrect as conversion to bare
pointer can only be done if the memref has static shape, strides and offset.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D123121
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e5d4b5fd8e9d6..f9ad3586c87f2 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -125,6 +125,9 @@ class LLVMTypeConverter : public TypeConverter {
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
const DataLayout &layout);
+ /// Check if a memref type can be converted to a bare pointer.
+ bool canConvertToBarePtr(BaseMemRefType type);
+
protected:
/// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
@@ -191,8 +194,8 @@ class LLVMTypeConverter : public TypeConverter {
/// These types can be recomposed to a unranked memref descriptor struct.
SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
- // Convert an unranked memref type to an LLVM type that captures the
- // runtime rank and a pointer to the static ranked memref desc
+ /// Convert an unranked memref type to an LLVM type that captures the
+ /// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
/// Convert a memref type to a bare pointer to the memref element type.
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 2f3d75a69bf16..c75ee4c4ae300 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -599,7 +599,8 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
Type oldTy = std::get<0>(it).getType();
Value newOperand = std::get<1>(it);
- if (oldTy.isa<MemRefType>()) {
+ if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
+ oldTy.cast<BaseMemRefType>())) {
MemRefDescriptor memrefDesc(newOperand);
newOperand = memrefDesc.alignedPtr(rewriter, loc);
} else if (oldTy.isa<UnrankedMemRefType>()) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b5c26a8fefa60..2ac505c917c9c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -366,30 +366,37 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
getUnrankedMemRefDescriptorFields());
}
-/// Convert a memref type to a bare pointer to the memref element type.
-Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
+// Check if a memref type can be converted to a bare pointer.
+bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
if (type.isa<UnrankedMemRefType>())
// Unranked memref is not supported in the bare pointer calling convention.
- return {};
+ return false;
// Check that the memref has static shape, strides and offset. Otherwise, it
// cannot be lowered to a bare pointer.
auto memrefTy = type.cast<MemRefType>();
if (!memrefTy.hasStaticShape())
- return {};
+ return false;
int64_t offset = 0;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memrefTy, strides, offset)))
- return {};
+ return false;
for (int64_t stride : strides)
if (ShapedType::isDynamicStrideOrOffset(stride))
- return {};
+ return false;
if (ShapedType::isDynamicStrideOrOffset(offset))
- return {};
+ return false;
+
+ return true;
+}
+/// Convert a memref type to a bare pointer to the memref element type.
+Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
+ if (!canConvertToBarePtr(type))
+ return {};
Type elementType = convertType(type.getElementType());
if (!elementType)
return {};
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index 89a56df36f525..2e9bfefb3e466 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -86,3 +86,12 @@ func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
// BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr<i8>
return %res : memref<20xi8>
}
+
+// -----
+
+// BAREPTR-LABEL: func @check_return(
+// BAREPTR-SAME: %{{.*}}: memref<?xi8>) -> memref<?xi8>
+func @check_return(%in : memref<?xi8>) -> memref<?xi8> {
+ // BAREPTR: llvm.return {{.*}} : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ return %in : memref<?xi8>
+}
More information about the Mlir-commits
mailing list