[Mlir-commits] [mlir] ae0fb61 - [MLIR] Check for static shape before bare pointer conversion

Lorenzo Chelini llvmlistbot at llvm.org
Tue Apr 5 08:56:55 PDT 2022


Author: Lorenzo Chelini
Date: 2022-04-05T17:56:41+02:00
New Revision: ae0fb61303f84f64d4a84dedf72036672741f87f

URL: https://github.com/llvm/llvm-project/commit/ae0fb61303f84f64d4a84dedf72036672741f87f
DIFF: https://github.com/llvm/llvm-project/commit/ae0fb61303f84f64d4a84dedf72036672741f87f.diff

LOG: [MLIR] Check for static shape before bare pointer conversion

Originally in the returnOp conversion, the result type was changing to bare
pointer if the type was a memref. This is incorrect as conversion to bare
pointer can only be done if the memref has static shape, strides and offset.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D123121

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e5d4b5fd8e9d6..f9ad3586c87f2 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -125,6 +125,9 @@ class LLVMTypeConverter : public TypeConverter {
   unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
                                            const DataLayout &layout);
 
+  /// Check if a memref type can be converted to a bare pointer.
+  bool canConvertToBarePtr(BaseMemRefType type);
+
 protected:
   /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
@@ -191,8 +194,8 @@ class LLVMTypeConverter : public TypeConverter {
   /// These types can be recomposed to a unranked memref descriptor struct.
   SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
 
-  // Convert an unranked memref type to an LLVM type that captures the
-  // runtime rank and a pointer to the static ranked memref desc
+  /// Convert an unranked memref type to an LLVM type that captures the
+  /// runtime rank and a pointer to the static ranked memref desc
   Type convertUnrankedMemRefType(UnrankedMemRefType type);
 
   /// Convert a memref type to a bare pointer to the memref element type.

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 2f3d75a69bf16..c75ee4c4ae300 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -599,7 +599,8 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
         Type oldTy = std::get<0>(it).getType();
         Value newOperand = std::get<1>(it);
-        if (oldTy.isa<MemRefType>()) {
+        if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
+                                           oldTy.cast<BaseMemRefType>())) {
           MemRefDescriptor memrefDesc(newOperand);
           newOperand = memrefDesc.alignedPtr(rewriter, loc);
         } else if (oldTy.isa<UnrankedMemRefType>()) {

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b5c26a8fefa60..2ac505c917c9c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -366,30 +366,37 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
                                           getUnrankedMemRefDescriptorFields());
 }
 
-/// Convert a memref type to a bare pointer to the memref element type.
-Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
+// Check if a memref type can be converted to a bare pointer.
+bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
   if (type.isa<UnrankedMemRefType>())
     // Unranked memref is not supported in the bare pointer calling convention.
-    return {};
+    return false;
 
   // Check that the memref has static shape, strides and offset. Otherwise, it
   // cannot be lowered to a bare pointer.
   auto memrefTy = type.cast<MemRefType>();
   if (!memrefTy.hasStaticShape())
-    return {};
+    return false;
 
   int64_t offset = 0;
   SmallVector<int64_t, 4> strides;
   if (failed(getStridesAndOffset(memrefTy, strides, offset)))
-    return {};
+    return false;
 
   for (int64_t stride : strides)
     if (ShapedType::isDynamicStrideOrOffset(stride))
-      return {};
+      return false;
 
   if (ShapedType::isDynamicStrideOrOffset(offset))
-    return {};
+    return false;
+
+  return true;
+}
 
+/// Convert a memref type to a bare pointer to the memref element type.
+Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
+  if (!canConvertToBarePtr(type))
+    return {};
   Type elementType = convertType(type.getElementType());
   if (!elementType)
     return {};

diff  --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index 89a56df36f525..2e9bfefb3e466 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -86,3 +86,12 @@ func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
   // BAREPTR-NEXT:    llvm.return %[[res]] : !llvm.ptr<i8>
   return %res : memref<20xi8>
 }
+
+// -----
+
+// BAREPTR-LABEL: func @check_return(
+// BAREPTR-SAME: %{{.*}}: memref<?xi8>) -> memref<?xi8> 
+func @check_return(%in : memref<?xi8>) -> memref<?xi8> {
+  // BAREPTR: llvm.return {{.*}} : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  return %in : memref<?xi8>
+}


        


More information about the Mlir-commits mailing list