[Mlir-commits] [mlir] a89fc12 - [mlir] Support return and call ops in bare-ptr calling convention

Diego Caballero llvmlistbot at llvm.org
Tue Sep 29 12:09:45 PDT 2020


Author: Diego Caballero
Date: 2020-09-29T12:00:47-07:00
New Revision: a89fc12653c520a5a70249e07c0a394584f4abbe

URL: https://github.com/llvm/llvm-project/commit/a89fc12653c520a5a70249e07c0a394584f4abbe
DIFF: https://github.com/llvm/llvm-project/commit/a89fc12653c520a5a70249e07c0a394584f4abbe.diff

LOG: [mlir] Support return and call ops in bare-ptr calling convention

This patch adds support for the 'return' and 'call' ops to the bare-ptr
calling convention. These changes also align the bare-ptr calling
convention code with the latest changes in the default calling convention
and reduce the amount of customization code needed.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87724

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index ab047a08f404..d98a0ff6efb3 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -27,6 +27,7 @@ class Type;
 
 namespace mlir {
 
+class BaseMemRefType;
 class ComplexType;
 class LLVMTypeConverter;
 class UnrankedMemRefType;
@@ -74,15 +75,28 @@ class LLVMTypeConverter : public TypeConverter {
                                           SignatureConversion &result);
 
   /// Convert a non-empty list of types to be returned from a function into a
-  /// supported LLVM IR type.  In particular, if more than one values is
+  /// supported LLVM IR type.  In particular, if more than one value is
   /// returned, create an LLVM IR structure type with elements that correspond
   /// to each of the MLIR types converted with `convertType`.
   Type packFunctionResults(ArrayRef<Type> types);
 
+  /// Convert a type in the context of the default or bare pointer calling
+  /// convention. Calling convention sensitive types, such as MemRefType and
+  /// UnrankedMemRefType, are converted following the specific rules for the
+  /// calling convention. Calling convention independent types are converted
+  /// following the default LLVM type conversions.
+  Type convertCallingConventionType(Type type);
+
+  /// Promote the bare pointers in 'values' that resulted from memrefs to
+  /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
+  /// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type).
+  void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
+                                    Location loc, ArrayRef<Type> stdTypes,
+                                    SmallVectorImpl<Value> &values);
+
   /// Returns the MLIR context.
   MLIRContext &getContext();
 
-
   /// Returns the LLVM dialect.
   LLVM::LLVMDialect *getDialect() { return llvmDialect; }
 
@@ -179,6 +193,9 @@ class LLVMTypeConverter : public TypeConverter {
   // runtime rank and a pointer to the static ranked memref desc
   Type convertUnrankedMemRefType(UnrankedMemRefType type);
 
+  /// Convert a memref type to a bare pointer to the memref element type.
+  Type convertMemRefToBarePtr(BaseMemRefType type);
+
   // Convert a 1D vector type into an LLVM vector type.
   Type convertVectorType(VectorType type);
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 186c8ec48fa5..c77c0b529caf 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -80,37 +80,12 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
   return success();
 }
 
-/// Convert a MemRef type to a bare pointer to the MemRef element type.
-static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter,
-                                       MemRefType type) {
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(type, strides, offset)))
-    return {};
-
-  LLVM::LLVMType elementType =
-      unwrap(converter.convertType(type.getElementType()));
-  if (!elementType)
-    return {};
-  return elementType.getPointerTo(type.getMemorySpace());
-}
-
 /// Callback to convert function argument types. It converts MemRef function
 /// arguments to bare pointers to the MemRef element type.
 LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
                                                 Type type,
                                                 SmallVectorImpl<Type> &result) {
-  // TODO: Add support for unranked memref.
-  if (auto memrefTy = type.dyn_cast<MemRefType>()) {
-    auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy);
-    if (!llvmTy)
-      return failure();
-
-    result.push_back(llvmTy);
-    return success();
-  }
-
-  auto llvmTy = converter.convertType(type);
+  auto llvmTy = converter.convertCallingConventionType(type);
   if (!llvmTy)
     return failure();
 
@@ -272,14 +247,14 @@ SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
 // Function has one VoidType result.  If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
 LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
-    FunctionType type, bool isVariadic,
+    FunctionType funcTy, bool isVariadic,
     LLVMTypeConverter::SignatureConversion &result) {
   // Select the argument converter depending on the calling convetion.
   auto funcArgConverter = options.useBarePtrCallConv
                               ? barePtrFuncArgTypeConverter
                               : structFuncArgTypeConverter;
   // Convert argument types one by one and check for errors.
-  for (auto &en : llvm::enumerate(type.getInputs())) {
+  for (auto &en : llvm::enumerate(funcTy.getInputs())) {
     Type type = en.value();
     SmallVector<Type, 8> converted;
     if (failed(funcArgConverter(*this, type, converted)))
@@ -296,9 +271,9 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
   // if it returns on element, convert it, otherwise pack the result types into
   // a struct.
   LLVM::LLVMType resultType =
-      type.getNumResults() == 0
+      funcTy.getNumResults() == 0
           ? LLVM::LLVMType::getVoidTy(&getContext())
-          : unwrap(packFunctionResults(type.getResults()));
+          : unwrap(packFunctionResults(funcTy.getResults()));
   if (!resultType)
     return {};
   return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
@@ -394,6 +369,36 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
   return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
 }
 
+/// Convert a memref type to a bare pointer to the memref element type.
+Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
+  if (type.isa<UnrankedMemRefType>())
+    // Unranked memref is not supported in the bare pointer calling convention.
+    return {};
+
+  // Check that the memref has static shape, strides and offset. Otherwise, it
+  // cannot be lowered to a bare pointer.
+  auto memrefTy = type.cast<MemRefType>();
+  if (!memrefTy.hasStaticShape())
+    return {};
+
+  int64_t offset = 0;
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
+    return {};
+
+  for (int64_t stride : strides)
+    if (ShapedType::isDynamicStrideOrOffset(stride))
+      return {};
+
+  if (ShapedType::isDynamicStrideOrOffset(offset))
+    return {};
+
+  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+  if (!elementType)
+    return {};
+  return elementType.getPointerTo(type.getMemorySpace());
+}
+
 // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
 // n > 1.
 // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
@@ -410,6 +415,37 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
   return vectorType;
 }
 
+/// Convert a type in the context of the default or bare pointer calling
+/// convention. Calling convention sensitive types, such as MemRefType and
+/// UnrankedMemRefType, are converted following the specific rules for the
+/// calling convention. Calling convention independent types are converted
+/// following the default LLVM type conversions.
+Type LLVMTypeConverter::convertCallingConventionType(Type type) {
+  if (options.useBarePtrCallConv)
+    if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
+      return convertMemRefToBarePtr(memrefTy);
+
+  return convertType(type);
+}
+
+/// Promote the bare pointers in 'values' that resulted from memrefs to
+/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
+/// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type).
+void LLVMTypeConverter::promoteBarePtrsToDescriptors(
+    ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
+    SmallVectorImpl<Value> &values) {
+  assert(stdTypes.size() == values.size() &&
+         "The number of types and values doesn't match");
+  for (unsigned i = 0, end = values.size(); i < end; ++i) {
+    Type stdTy = stdTypes[i];
+    if (auto memrefTy = stdTy.dyn_cast<MemRefType>())
+      values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
+                                                    memrefTy, values[i]);
+    else
+      llvm_unreachable("Unranked memrefs are not supported");
+  }
+}
+
 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
                                            MLIRContext *context,
                                            LLVMTypeConverter &typeConverter,
@@ -1088,18 +1124,6 @@ namespace {
 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
 protected:
   using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
-  using UnsignedTypePair = std::pair<unsigned, Type>;
-
-  // Gather the positions and types of memref-typed arguments in a given
-  // FunctionType.
-  void getMemRefArgIndicesAndTypes(
-      FunctionType type, SmallVectorImpl<UnsignedTypePair> &argsInfo) const {
-    argsInfo.reserve(type.getNumInputs());
-    for (auto en : llvm::enumerate(type.getInputs())) {
-      if (en.value().isa<MemRefType, UnrankedMemRefType>())
-        argsInfo.push_back({en.index(), en.value()});
-    }
-  }
 
   // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
   // to this legalization pattern.
@@ -1192,11 +1216,10 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
                   ConversionPatternRewriter &rewriter) const override {
     auto funcOp = cast<FuncOp>(op);
 
-    // Store the positions and type of memref-typed arguments so that we can
-    // promote them to MemRef descriptor structs at the beginning of the
-    // function.
-    SmallVector<UnsignedTypePair, 4> promotedArgsInfo;
-    getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo);
+    // Store the type of memref-typed arguments before the conversion so that we
+    // can promote them to MemRef descriptor at the beginning of the function.
+    SmallVector<Type, 8> oldArgTypes =
+        llvm::to_vector<8>(funcOp.getType().getInputs());
 
     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
     if (!newFuncOp)
@@ -1206,27 +1229,42 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
       return success();
     }
 
-    // Promote bare pointers from MemRef arguments to a MemRef descriptor struct
-    // at the beginning of the function so that all the MemRefs in the function
-    // have a uniform representation.
-    Block *firstBlock = &newFuncOp.getBody().front();
-    rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
-    auto funcLoc = funcOp.getLoc();
-    for (const auto &argInfo : promotedArgsInfo) {
-      // TODO: Add support for unranked MemRefs.
-      if (auto memrefType = argInfo.second.dyn_cast<MemRefType>()) {
-        // Replace argument with a placeholder (undef), promote argument to a
-        // MemRef descriptor and replace placeholder with the last instruction
-        // of the MemRef descriptor. The placeholder is needed to avoid
-        // replacing argument uses in the MemRef descriptor instructions.
-        BlockArgument arg = firstBlock->getArgument(argInfo.first);
-        Value placeHolder =
-            rewriter.create<LLVM::UndefOp>(funcLoc, arg.getType());
-        rewriter.replaceUsesOfBlockArgument(arg, placeHolder);
-        auto desc = MemRefDescriptor::fromStaticShape(
-            rewriter, funcLoc, typeConverter, memrefType, arg);
-        rewriter.replaceOp(placeHolder.getDefiningOp(), {desc});
-      }
+    // Promote bare pointers from memref arguments to memref descriptors at the
+    // beginning of the function so that all the memrefs in the function have a
+    // uniform representation.
+    Block *entryBlock = &newFuncOp.getBody().front();
+    auto blockArgs = entryBlock->getArguments();
+    assert(blockArgs.size() == oldArgTypes.size() &&
+           "The number of arguments and types doesn't match");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(entryBlock);
+    for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
+      BlockArgument arg = std::get<0>(it);
+      Type argTy = std::get<1>(it);
+
+      // Unranked memrefs are not supported in the bare pointer calling
+      // convention. We should have bailed out before in the presence of
+      // unranked memrefs.
+      assert(!argTy.isa<UnrankedMemRefType>() &&
+             "Unranked memref is not supported");
+      auto memrefTy = argTy.dyn_cast<MemRefType>();
+      if (!memrefTy)
+        continue;
+
+      // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
+      // or unranked memref descriptor and replace placeholder with the last
+      // instruction of the memref descriptor.
+      // TODO: The placeholder is needed to avoid replacing barePtr uses in the
+      // MemRef descriptor instructions. We may want to have a utility in the
+      // rewriter to properly handle this use case.
+      Location loc = op->getLoc();
+      auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
+      rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+
+      Value desc = MemRefDescriptor::fromStaticShape(
+          rewriter, loc, typeConverter, memrefTy, arg);
+      rewriter.replaceOp(placeholder, {desc});
     }
 
     rewriter.eraseOp(op);
@@ -2138,12 +2176,22 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
             rewriter.getI64ArrayAttr(i)));
       }
     }
-    if (failed(copyUnrankedDescriptors(
-            rewriter, op->getLoc(), this->typeConverter, op->getResultTypes(),
-            results, /*toDynamic=*/false)))
+
+    if (this->typeConverter.getOptions().useBarePtrCallConv) {
+      // For the bare-ptr calling convention, promote memref results to
+      // descriptors.
+      assert(results.size() == resultTypes.size() &&
+             "The number of arguments and types doesn't match");
+      this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(),
+                                                       resultTypes, results);
+    } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(),
+                                              this->typeConverter, resultTypes,
+                                              results,
+                                              /*toDynamic=*/false))) {
       return failure();
-    rewriter.replaceOp(op, results);
+    }
 
+    rewriter.replaceOp(op, results);
     return success();
   }
 };
@@ -2706,11 +2754,32 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
     unsigned numArguments = op->getNumOperands();
-    auto updatedOperands = llvm::to_vector<4>(operands);
-    copyUnrankedDescriptors(rewriter, op->getLoc(), typeConverter,
-                            op->getOperands().getTypes(), updatedOperands,
-                            /*toDynamic=*/true);
+    SmallVector<Value, 4> updatedOperands;
+
+    if (typeConverter.getOptions().useBarePtrCallConv) {
+      // For the bare-ptr calling convention, extract the aligned pointer to
+      // be returned from the memref descriptor.
+      for (auto it : llvm::zip(op->getOperands(), operands)) {
+        Type oldTy = std::get<0>(it).getType();
+        Value newOperand = std::get<1>(it);
+        if (oldTy.isa<MemRefType>()) {
+          MemRefDescriptor memrefDesc(newOperand);
+          newOperand = memrefDesc.alignedPtr(rewriter, loc);
+        } else if (oldTy.isa<UnrankedMemRefType>()) {
+          // Unranked memref is not supported in the bare pointer calling
+          // convention.
+          return failure();
+        }
+        updatedOperands.push_back(newOperand);
+      }
+    } else {
+      updatedOperands = llvm::to_vector<4>(operands);
+      copyUnrankedDescriptors(rewriter, loc, typeConverter,
+                              op->getOperands().getTypes(), updatedOperands,
+                              /*toDynamic=*/true);
+    }
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
     if (numArguments == 0) {
@@ -2729,10 +2798,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
     auto packedType = typeConverter.packFunctionResults(
         llvm::to_vector<4>(op->getOperandTypes()));
 
-    Value packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
+    Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
     for (unsigned i = 0; i < numArguments; ++i) {
       packed = rewriter.create<LLVM::InsertValueOp>(
-          op->getLoc(), packedType, packed, updatedOperands[i],
+          loc, packedType, packed, updatedOperands[i],
           rewriter.getI64ArrayAttr(i));
     }
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
@@ -3380,17 +3449,21 @@ void mlir::populateStdToLLVMConversionPatterns(
   populateStdToLLVMMemoryConversionPatterns(converter, patterns);
 }
 
-// Create an LLVM IR structure type if there is more than one result.
+/// Convert a non-empty list of types to be returned from a function into a
+/// supported LLVM IR type.  In particular, if more than one value is returned,
+/// create an LLVM IR structure type with elements that correspond to each of
+/// the MLIR types converted with `convertType`.
 Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
   assert(!types.empty() && "expected non-empty list of type");
 
   if (types.size() == 1)
-    return convertType(types.front());
+    return convertCallingConventionType(types.front());
 
   SmallVector<LLVM::LLVMType, 8> resultTypes;
   resultTypes.reserve(types.size());
   for (auto t : types) {
-    auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
+    auto converted =
+        convertCallingConventionType(t).dyn_cast_or_null<LLVM::LLVMType>();
     if (!converted)
       return {};
     resultTypes.push_back(converted);
@@ -3426,16 +3499,27 @@ SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
     auto operand = std::get<0>(it);
     auto llvmOperand = std::get<1>(it);
 
-    if (operand.getType().isa<UnrankedMemRefType>()) {
-      UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
-                                       promotedOperands);
-      continue;
-    }
-    if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
-      MemRefDescriptor::unpack(builder, loc, llvmOperand,
-                               operand.getType().cast<MemRefType>(),
-                               promotedOperands);
-      continue;
+    if (options.useBarePtrCallConv) {
+      // For the bare-ptr calling convention, we only have to extract the
+      // aligned pointer of a memref.
+      if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+        MemRefDescriptor desc(llvmOperand);
+        llvmOperand = desc.alignedPtr(builder, loc);
+      } else if (operand.getType().isa<UnrankedMemRefType>()) {
+        llvm_unreachable("Unranked memrefs are not supported");
+      }
+    } else {
+      if (operand.getType().isa<UnrankedMemRefType>()) {
+        UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+                                         promotedOperands);
+        continue;
+      }
+      if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+        MemRefDescriptor::unpack(builder, loc, llvmOperand,
+                                 operand.getType().cast<MemRefType>(),
+                                 promotedOperands);
+        continue;
+      }
     }
 
     promotedOperands.push_back(llvmOperand);

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index 5cccca3795b3..5dd36ba6d2ac 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -14,13 +14,13 @@ func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}, %other : memr
 // CHECK-COUNT-5: !llvm.i64
 // CHECK-SAME: -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-LABEL: func @check_static_return
-// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
+// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.ptr<float> {
 func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
 // CHECK:  llvm.return %{{.*}} : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 
 // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
@@ -31,7 +31,8 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
 // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr<float>
   return %static : memref<32x18xf32>
 }
 
@@ -42,13 +43,13 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
 // CHECK-COUNT-5: !llvm.i64
 // CHECK-SAME: -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-LABEL: func @check_static_return_with_offset
-// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
+// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.ptr<float> {
 func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> {
 // CHECK:  llvm.return %{{.*}} : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 
 // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
@@ -59,14 +60,15 @@ func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, stri
 // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr<float>
   return %static : memref<32x18xf32, offset:7, strides:[22,1]>
 }
 
 // -----
 
 // CHECK-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64)> {
-// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64)> {
+// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.ptr<float> {
 func @zero_d_alloc() -> memref<f32> {
 // CHECK-NEXT:  llvm.mlir.constant(1 : index) : !llvm.i64
 // CHECK-NEXT:  %[[null:.*]] = llvm.mlir.null : !llvm.ptr<float>
@@ -174,7 +176,7 @@ func @aligned_1d_alloc() -> memref<42xf32> {
 // -----
 
 // CHECK-LABEL: func @static_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
-// BAREPTR-LABEL: func @static_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
+// BAREPTR-LABEL: func @static_alloc() -> !llvm.ptr<float> {
 func @static_alloc() -> memref<32x18xf32> {
 //      CHECK:  %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
 // CHECK-NEXT:  %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
@@ -388,3 +390,29 @@ func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
   %4 = dim %static, %c4 : memref<42x32x15x13x27xf32>
   return
 }
+
+// -----
+
+// BAREPTR: llvm.func @foo(!llvm.ptr<i8>) -> !llvm.ptr<i8>
+func @foo(memref<10xi8>) -> memref<20xi8>
+
+// BAREPTR-LABEL: func @check_memref_func_call
+// BAREPTR-SAME:    %[[in:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
+  // BAREPTR:         %[[inDesc:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0]
+  // BAREPTR-NEXT:    %[[barePtr:.*]] = llvm.extractvalue %[[inDesc]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    %[[call:.*]] = llvm.call @foo(%[[barePtr]]) : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
+  // BAREPTR-NEXT:    %[[desc0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    %[[desc1:.*]] = llvm.insertvalue %[[call]], %[[desc0]][0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    %[[desc2:.*]] = llvm.insertvalue %[[call]], %[[desc1]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // BAREPTR-NEXT:    %[[desc4:.*]] = llvm.insertvalue %[[c0]], %[[desc2]][2] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    %[[c20:.*]] = llvm.mlir.constant(20 : index) : !llvm.i64
+  // BAREPTR-NEXT:    %[[desc6:.*]] = llvm.insertvalue %[[c20]], %[[desc4]][3, 0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+  // BAREPTR-NEXT:    %[[outDesc:.*]] = llvm.insertvalue %[[c1]], %[[desc6]][4, 0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  %res = call @foo(%in) : (memref<10xi8>) -> (memref<20xi8>)
+  // BAREPTR-NEXT:    %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    llvm.return %[[res]] : !llvm.ptr<i8>
+  return %res : memref<20xi8>
+}


        


More information about the Mlir-commits mailing list