[Mlir-commits] [mlir] 6323065 - [mlir] support returning unranked memrefs

Alex Zinenko llvmlistbot at llvm.org
Fri Jun 26 06:37:46 PDT 2020


Author: Alex Zinenko
Date: 2020-06-26T15:37:37+02:00
New Revision: 6323065fd6026de926b15bb609f4601e366a300c

URL: https://github.com/llvm/llvm-project/commit/6323065fd6026de926b15bb609f4601e366a300c
DIFF: https://github.com/llvm/llvm-project/commit/6323065fd6026de926b15bb609f4601e366a300c.diff

LOG: [mlir] support returning unranked memrefs

Initially, unranked memref descriptors in the LLVM dialect were designed only
to be passed into functions. An assertion was guarding against returning
unranked memrefs from functions in the standard-to-LLVM conversion. This is
insufficient for functions that wish to return an unranked memref such that the
caller does not know the rank in advance, and hence cannot allocate the
descriptor and pass it in as an argument.

Introduce a calling convention for returning unranked memref descriptors as
follows. An unranked memref descriptor always points to a ranked memref
descriptor stored on stack of the current function. When an unranked memref
descriptor is returned from a function, the ranked memref descriptor it points
to is copied to dynamically allocated memory, the ownership of which is
transferred to the caller. The caller is responsible for deallocating the
dynamically allocated memory and for copying the pointed-to ranked memref
descriptor onto its stack.

Provide default lowerings for std.return, std.call and std.indirect_call that
maintain the conversion defined above.

This convention is additionally exercised by a runtime test to guard against
memory errors.

Differential Revision: https://reviews.llvm.org/D82647

Added: 
    

Modified: 
    mlir/docs/ConversionToLLVMDialect.md
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/llvmir-intrinsics.mlir
    mlir/test/mlir-cpu-runner/unranked_memref.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/ConversionToLLVMDialect.md b/mlir/docs/ConversionToLLVMDialect.md
index 15d09cb4b87a..e65df4444b80 100644
--- a/mlir/docs/ConversionToLLVMDialect.md
+++ b/mlir/docs/ConversionToLLVMDialect.md
@@ -246,7 +246,7 @@ func @bar() {
 }
 ```
 
-### Calling Convention for `memref`
+### Calling Convention for Ranked `memref`
 
 Function _arguments_ of `memref` type, ranked or unranked, are _expanded_ into a
 list of arguments of non-aggregate types that the memref descriptor defined
@@ -317,7 +317,9 @@ llvm.func @bar() {
 
 ```
 
-For **unranked** memrefs, the list of function arguments always contains two
+### Calling Convention for Unranked `memref`
+
+For unranked memrefs, the list of function arguments always contains two
 elements, same as the unranked memref descriptor: an integer rank, and a
 type-erased (`!llvm<"i8*">`) pointer to the ranked memref descriptor. Note that
 while the _calling convention_ does not require stack allocation, _casting_ to
@@ -369,6 +371,20 @@ llvm.func @bar() {
 }
 ```
 
+**Lifetime.** The second element of the unranked memref descriptor points to
+some memory in which the ranked memref descriptor is stored. By convention, this
+memory is allocated on stack and has the lifetime of the function. (*Note:* due
+to function-length lifetime, creation of multiple unranked memref descriptors,
+e.g., in a loop, may lead to stack overflows.) If an unranked descriptor has to
+be returned from a function, the ranked descriptor it points to is copied into
+dynamically allocated memory, and the pointer in the unranked descriptor is
+updated accodingly. The allocation happens immediately before returning. It is
+the responsibility of the caller to free the dynamically allocated memory. The
+default conversion of `std.call` and `std.call_indirect` copies the ranked
+descriptor to newly allocated memory on the caller's stack. Thus, the convention
+of the ranked memref descriptor pointed to by an unranked memref descriptor
+being stored on stack is respected.
+
 *This convention may or may not apply if the conversion of MemRef types is
 overridden by the user.*
 

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index a7e4ff2f52cf..c96341094af2 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -129,6 +129,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// Gets the bitwidth of the index type when converted to LLVM.
   unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; }
 
+  /// Gets the pointer bitwidth.
+  unsigned getPointerBitwidth(unsigned addressSpace = 0);
+
 protected:
   /// LLVM IR module used to parse/create types.
   llvm::Module *module;
@@ -386,6 +389,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
   /// Returns the number of non-aggregate values that would be produced by
   /// `unpack`.
   static unsigned getNumUnpackedValues() { return 2; }
+
+  /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
+  /// and appends the corresponding values into `sizes`.
+  static void computeSizes(OpBuilder &builder, Location loc,
+                           LLVMTypeConverter &typeConverter,
+                           ArrayRef<UnrankedMemRefDescriptor> values,
+                           SmallVectorImpl<Value> &sizes);
 };
 
 /// Base class for operation conversions targeting the LLVM IR dialect. Provides

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 30e34440c2dd..3d8e52cecf94 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -794,6 +794,13 @@ def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
 def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">;
 def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">;
 
+def LLVM_MemcpyOp : LLVM_ZeroResultIntrOp<"memcpy", [0, 1, 2]>,
+                    Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src,
+                               LLVM_Type:$len, LLVM_Type:$isVolatile)>;
+def LLVM_MemcpyInlineOp : LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1, 2]>,
+                          Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src,
+                                     LLVM_Type:$len, LLVM_Type:$isVolatile)>;
+
 //
 // Vector Reductions.
 //

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 19c451fa3fe9..9376d53dc994 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"
@@ -184,6 +185,10 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() {
   return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth());
 }
 
+unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
+  return module->getDataLayout().getPointerSizeInBits(addressSpace);
+}
+
 Type LLVMTypeConverter::convertIndexType(IndexType type) {
   return getIndexType();
 }
@@ -769,6 +774,51 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
   results.push_back(d.memRefDescPtr(builder, loc));
 }
 
+void UnrankedMemRefDescriptor::computeSizes(
+    OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+    ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
+  if (values.empty())
+    return;
+
+  // Cache the index type.
+  LLVM::LLVMType indexType = typeConverter.getIndexType();
+
+  // Initialize shared constants.
+  Value one = createIndexAttrConstant(builder, loc, indexType, 1);
+  Value two = createIndexAttrConstant(builder, loc, indexType, 2);
+  Value pointerSize = createIndexAttrConstant(
+      builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
+  Value indexSize =
+      createIndexAttrConstant(builder, loc, indexType,
+                              ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
+
+  sizes.reserve(sizes.size() + values.size());
+  for (UnrankedMemRefDescriptor desc : values) {
+    // Emit IR computing the memory necessary to store the descriptor. This
+    // assumes the descriptor to be
+    //   { type*, type*, index, index[rank], index[rank] }
+    // and densely packed, so the total size is
+    //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
+    // TODO: consider including the actual size (including eventual padding due
+    // to data layout) into the unranked descriptor.
+    Value doublePointerSize =
+        builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
+
+    // (1 + 2 * rank) * sizeof(index)
+    Value rank = desc.rank(builder, loc);
+    Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
+    Value doubleRankIncremented =
+        builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
+    Value rankIndexSize = builder.create<LLVM::MulOp>(
+        loc, indexType, doubleRankIncremented, indexSize);
+
+    // Total allocation size.
+    Value allocationSize = builder.create<LLVM::AddOp>(
+        loc, indexType, doublePointerSize, rankIndexSize);
+    sizes.push_back(allocationSize);
+  }
+}
+
 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
   return *typeConverter.getDialect();
 }
@@ -1863,6 +1913,104 @@ struct AllocOpLowering : public AllocLikeOpLowering<AllocOp> {
 
 using AllocaOpLowering = AllocLikeOpLowering<AllocaOp>;
 
+/// Copies the shaped descriptor part to (if `toDynamic` is set) or from
+/// (otherwise) the dynamically allocated memory for any operands that were
+/// unranked descriptors originally.
+static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
+                                             LLVMTypeConverter &typeConverter,
+                                             TypeRange origTypes,
+                                             SmallVectorImpl<Value> &operands,
+                                             bool toDynamic) {
+  assert(origTypes.size() == operands.size() &&
+         "expected as may original types as operands");
+
+  // Find operands of unranked memref type and store them.
+  SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
+  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+    if (!origTypes[i].isa<UnrankedMemRefType>())
+      continue;
+    unrankedMemrefs.emplace_back(operands[i]);
+  }
+
+  if (unrankedMemrefs.empty())
+    return success();
+
+  // Compute allocation sizes.
+  SmallVector<Value, 4> sizes;
+  UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter,
+                                         unrankedMemrefs, sizes);
+
+  // Get frequently used types.
+  auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect());
+  auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
+  auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect());
+  LLVM::LLVMType indexType = typeConverter.getIndexType();
+
+  // Find the malloc and free, or declare them if necessary.
+  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
+  auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
+  if (!mallocFunc && toDynamic) {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(module.getBody());
+    mallocFunc = builder.create<LLVM::LLVMFuncOp>(
+        builder.getUnknownLoc(), "malloc",
+        LLVM::LLVMType::getFunctionTy(
+            voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false));
+  }
+  auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free");
+  if (!freeFunc && !toDynamic) {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(module.getBody());
+    freeFunc = builder.create<LLVM::LLVMFuncOp>(
+        builder.getUnknownLoc(), "free",
+        LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType),
+                                      /*isVarArg=*/false));
+  }
+
+  // Initialize shared constants.
+  Value zero =
+      builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
+
+  unsigned unrankedMemrefPos = 0;
+  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+    Type type = origTypes[i];
+    if (!type.isa<UnrankedMemRefType>())
+      continue;
+    Value allocationSize = sizes[unrankedMemrefPos++];
+    UnrankedMemRefDescriptor desc(operands[i]);
+
+    // Allocate memory, copy, and free the source if necessary.
+    Value memory =
+        toDynamic
+            ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
+                  .getResult(0)
+            : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
+                                             /*alignment=*/0);
+
+    Value source = desc.memRefDescPtr(builder, loc);
+    builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
+    if (!toDynamic)
+      builder.create<LLVM::CallOp>(loc, freeFunc, source);
+
+    // Create a new descriptor. The same descriptor can be returned multiple
+    // times, attempting to modify its pointer can lead to memory leaks
+    // (allocated twice and overwritten) or double frees (the caller does not
+    // know if the descriptor points to the same memory).
+    Type descriptorType = typeConverter.convertType(type);
+    if (!descriptorType)
+      return failure();
+    auto updatedDesc =
+        UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
+    Value rank = desc.rank(builder, loc);
+    updatedDesc.setRank(builder, loc, rank);
+    updatedDesc.setMemRefDescPtr(builder, loc, memory);
+
+    operands[i] = updatedDesc;
+  }
+
+  return success();
+}
+
 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
 // passes the pointer to the MemRef across function boundaries.
 template <typename CallOpType>
@@ -1882,13 +2030,6 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
     unsigned numResults = callOp.getNumResults();
     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
 
-    for (Type resType : resultTypes) {
-      assert(!resType.isa<UnrankedMemRefType>() &&
-             "Returning unranked memref is not supported. Pass result as an"
-             "argument instead.");
-      (void)resType;
-    }
-
     if (numResults != 0) {
       if (!(packedResult =
                 this->typeConverter.packFunctionResults(resultTypes)))
@@ -1900,25 +2041,25 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
     auto newOp = rewriter.create<LLVM::CallOp>(op->getLoc(), packedResult,
                                                promoted, op->getAttrs());
 
-    // If < 2 results, packing did not do anything and we can just return.
-    if (numResults < 2) {
-      rewriter.replaceOp(op, newOp.getResults());
-      return success();
-    }
-
-    // Otherwise, it had been converted to an operation producing a structure.
-    // Extract individual results from the structure and return them as list.
-    // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around
-    // a particular interaction between MemRefType and CallOp lowering. Find a
-    // way to avoid special casing.
     SmallVector<Value, 4> results;
-    results.reserve(numResults);
-    for (unsigned i = 0; i < numResults; ++i) {
-      auto type = this->typeConverter.convertType(op->getResult(i).getType());
-      results.push_back(rewriter.create<LLVM::ExtractValueOp>(
-          op->getLoc(), type, newOp.getOperation()->getResult(0),
-          rewriter.getI64ArrayAttr(i)));
+    if (numResults < 2) {
+      // If < 2 results, packing did not do anything and we can just return.
+      results.append(newOp.result_begin(), newOp.result_end());
+    } else {
+      // Otherwise, it had been converted to an operation producing a structure.
+      // Extract individual results from the structure and return them as list.
+      results.reserve(numResults);
+      for (unsigned i = 0; i < numResults; ++i) {
+        auto type = this->typeConverter.convertType(op->getResult(i).getType());
+        results.push_back(rewriter.create<LLVM::ExtractValueOp>(
+            op->getLoc(), type, newOp.getOperation()->getResult(0),
+            rewriter.getI64ArrayAttr(i)));
+      }
     }
+    if (failed(copyUnrankedDescriptors(
+            rewriter, op->getLoc(), this->typeConverter, op->getResultTypes(),
+            results, /*toDynamic=*/false)))
+      return failure();
     rewriter.replaceOp(op, results);
 
     return success();
@@ -2397,6 +2538,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     unsigned numArguments = op->getNumOperands();
+    auto updatedOperands = llvm::to_vector<4>(operands);
+    copyUnrankedDescriptors(rewriter, op->getLoc(), typeConverter,
+                            op->getOperands().getTypes(), updatedOperands,
+                            /*toDynamic=*/true);
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
     if (numArguments == 0) {
@@ -2406,7 +2551,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
     }
     if (numArguments == 1) {
       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
-          op, ArrayRef<Type>(), operands.front(), op->getAttrs());
+          op, ArrayRef<Type>(), updatedOperands, op->getAttrs());
       return success();
     }
 
@@ -2418,7 +2563,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
     Value packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
     for (unsigned i = 0; i < numArguments; ++i) {
       packed = rewriter.create<LLVM::InsertValueOp>(
-          op->getLoc(), packedType, packed, operands[i],
+          op->getLoc(), packedType, packed, updatedOperands[i],
           rewriter.getI64ArrayAttr(i));
     }
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, ArrayRef<Type>(), packed,

diff  --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
index 87bdab2680f9..e17bf3e24225 100644
--- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
@@ -109,3 +109,134 @@ func @other_callee(%arg0: memref<?xf32>, %arg1: index) attributes { llvm.emit_c_
 
 // EMIT_C_ATTRIBUTE: @_mlir_ciface_other_callee
 // EMIT_C_ATTRIBUTE:   llvm.call @other_callee
+
+//===========================================================================//
+// Calling convention on returning unranked memrefs.
+//===========================================================================//
+
+// CHECK-LABEL: llvm.func @return_var_memref_caller
+func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
+  // CHECK: %[[CALL_RES:.*]] = llvm.call @return_var_memref
+  %0 = call @return_var_memref(%arg0) : (memref<4x3xf32>) -> memref<*xf32>
+
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
+  // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
+  // These sizes may depend on the data layout, not matching specific values.
+  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
+  // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
+
+  // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
+  // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm<"{ i64, i8* }">
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+  // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
+  // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
+  // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
+  // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false)
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOC_SIZE]] x !llvm.i8
+  // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[CALL_RES]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]])
+  // CHECK: llvm.call @free(%[[SOURCE]])
+  // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }">
+  // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm<"{ i64, i8* }">
+  // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[DESC]][0]
+  // CHECK: llvm.insertvalue %[[ALLOCA]], %[[DESC_1]][1]
+  return
+}
+
+// CHECK-LABEL: llvm.func @return_var_memref
+func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
+  // Match the construction of the unranked descriptor.
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca
+  // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
+  // CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }">
+  // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0]
+  // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1]
+  %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32>
+
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
+  // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
+  // These sizes may depend on the data layout, not matching specific values.
+  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
+  // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
+
+  // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
+  // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm<"{ i64, i8* }">
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+  // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
+  // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
+  // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
+  // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false)
+  // CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]])
+  // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[DESC_2]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]])
+  // CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }">
+  // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm<"{ i64, i8* }">
+  // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0]
+  // CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1]
+  // CHECL: llvm.return %[[NEW_DESC_2]]
+  return %0 : memref<*xf32>
+}
+
+// CHECK-LABEL: llvm.func @return_two_var_memref_caller
+func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
+  // Only check that we create two 
diff erent descriptors using 
diff erent
+  // memory, and deallocate both sources. The size computation is same as for
+  // the single result.
+  // CHECK: %[[CALL_RES:.*]] = llvm.call @return_two_var_memref
+  // CHECK: %[[RES_1:.*]] = llvm.extractvalue %[[CALL_RES]][0]
+  // CHECK: %[[RES_2:.*]] = llvm.extractvalue %[[CALL_RES]][1]
+  %0:2 = call @return_two_var_memref(%arg0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>)
+
+  // CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %{{.*}} x !llvm.i8
+  // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[RES_1:.*]][1] : ![[DESC_TYPE:.*]]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCA_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]])
+  // CHECK: llvm.call @free(%[[SOURCE_1]])
+  // CHECK: %[[DESC_1:.*]] = llvm.mlir.undef : ![[DESC_TYPE]]
+  // CHECK: %[[DESC_11:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_1]][0]
+  // CHECK: llvm.insertvalue %[[ALLOCA_1]], %[[DESC_11]][1]
+
+  // CHECK: %[[ALLOCA_2:.*]] = llvm.alloca %{{.*}} x !llvm.i8
+  // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[RES_2:.*]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCA_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]])
+  // CHECK: llvm.call @free(%[[SOURCE_2]])
+  // CHECK: %[[DESC_2:.*]] = llvm.mlir.undef : ![[DESC_TYPE]]
+  // CHECK: %[[DESC_21:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_2]][0]
+  // CHECK: llvm.insertvalue %[[ALLOCA_2]], %[[DESC_21]][1]
+  return
+}
+
+// CHECK-LABEL: llvm.func @return_two_var_memref
+func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) {
+  // Match the construction of the unranked descriptor.
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca
+  // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
+  // CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }">
+  // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0]
+  // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1]
+  %0 = memref_cast %arg0 : memref<4x3xf32> to memref<*xf32>
+
+  // Only check that we allocate the memory for each operand of the "return"
+  // separately, even if both operands are the same value. The calling
+  // convention requires the caller to free them and the caller cannot know
+  // whether they are the same value or not.
+  // CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}})
+  // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[DESC_2]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]])
+  // CHECK: %[[RES_1:.*]] = llvm.mlir.undef
+  // CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0]
+  // CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1]
+
+  // CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}})
+  // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[DESC_2]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]])
+  // CHECK: %[[RES_2:.*]] = llvm.mlir.undef
+  // CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0]
+  // CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1]
+
+  // CHECK: %[[RESULTS:.*]] = llvm.mlir.undef : !llvm<"{ { i64, i8* }, { i64, i8* } }">
+  // CHECK: %[[RESULTS_1:.*]] = llvm.insertvalue %[[RES_12]], %[[RESULTS]]
+  // CHECK: %[[RESULTS_2:.*]] = llvm.insertvalue %[[RES_22]], %[[RESULTS_1]]
+  // CHECK: llvm.return %[[RESULTS_2]]
+  return %0, %0 : memref<*xf32>, memref<*xf32>
+}
+

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index e782d5de1aaa..a6ce8d9e2195 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -1,7 +1,9 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
-// CHECK-LABEL: func @ops(%arg0: !llvm.i32, %arg1: !llvm.float)
-func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) {
+// CHECK-LABEL: func @ops
+func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
+          %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">,
+          %arg4: !llvm.i1) {
 // Integer arithmetic binary operations.
 //
 // CHECK-NEXT:  %0 = llvm.add %arg0, %arg0 : !llvm.i32
@@ -109,6 +111,17 @@ func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) {
 // CHECK: "llvm.intr.ctpop"(%{{.*}}) : (!llvm.i32) -> !llvm.i32
   %33 = "llvm.intr.ctpop"(%arg0) : (!llvm.i32) -> !llvm.i32
 
+// CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> ()
+  "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> ()
+
+// CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> ()
+  "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> ()
+
+// CHECK: %[[SZ:.*]] = llvm.mlir.constant
+  %sz = llvm.mlir.constant(10: i64) : !llvm.i64
+// CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> ()
+  "llvm.intr.memcpy.inline"(%arg2, %arg3, %sz, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> ()
+
 // CHECK:  llvm.return
   llvm.return
 }
@@ -315,4 +328,4 @@ func @useFenceInst() {
   // CHECK:  release
   llvm.fence release
   return
-}
\ No newline at end of file
+}

diff  --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index e04b40e916fd..45292124aedc 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -202,6 +202,17 @@ llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>">
   llvm.return
 }
 
+// CHECK-LABEL: @memcpy_test
+llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) {
+  // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})
+  "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg1) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> ()
+  %sz = llvm.mlir.constant(10: i64) : !llvm.i64
+  // CHECK: call void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 10, i1 %{{.*}})
+  "llvm.intr.memcpy.inline"(%arg2, %arg3, %sz, %arg1) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> ()
+  llvm.return
+}
+
+
 // Check that intrinsics are declared with appropriate types.
 // CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
 // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@@ -231,3 +242,5 @@ llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>">
 // CHECK-DAG: declare void @llvm.matrix.column.major.store.v48f32.p0f32(<48 x float>, float* nocapture writeonly, i64, i1 immarg, i32 immarg, i32 immarg)
 // CHECK-DAG: declare <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>*, i32 immarg, <7 x i1>, <7 x float>)
 // CHECK-DAG: declare void @llvm.masked.store.v7f32.p0v7f32(<7 x float>, <7 x float>*, i32 immarg, <7 x i1>)
+// CHECK-DAG: declare void @llvm.memcpy.p0i8.p0i8.i32(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i32, i1 immarg)
+// CHECK-DAG: declare void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64 immarg, i1 immarg)

diff  --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
index 0eb68ac03368..df760f593db2 100644
--- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir
+++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
@@ -18,6 +18,21 @@
 // CHECK: rank = 0
 // 122 is ASCII for 'z'.
 // CHECK: [z]
+//
+// CHECK: rank = 2
+// CHECK-SAME: sizes = [4, 3]
+// CHECK-SAME: strides = [3, 1]
+// CHECK-COUNT-4: [1, 1, 1]
+//
+// CHECK: rank = 2
+// CHECK-SAME: sizes = [4, 3]
+// CHECK-SAME: strides = [3, 1]
+// CHECK-COUNT-4: [1, 1, 1]
+//
+// CHECK: rank = 2
+// CHECK-SAME: sizes = [4, 3]
+// CHECK-SAME: strides = [3, 1]
+// CHECK-COUNT-4: [1, 1, 1]
 func @main() -> () {
     %A = alloc() : memref<10x3xf32, 0>
     %f2 = constant 2.00000e+00 : f32
@@ -48,8 +63,40 @@ func @main() -> () {
     call @print_memref_i8(%U4) : (memref<*xi8>) -> ()
 
     dealloc %A : memref<10x3xf32, 0>
+
+    call @return_var_memref_caller() : () -> ()
+    call @return_two_var_memref_caller() : () -> ()
     return
 }
 
 func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface }
 func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
+
+func @return_two_var_memref_caller() {
+  %0 = alloca() : memref<4x3xf32>
+  %c0f32 = constant 1.0 : f32
+  linalg.fill(%0, %c0f32) : memref<4x3xf32>, f32
+  %1:2 = call @return_two_var_memref(%0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>)
+  call @print_memref_f32(%1#0) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%1#1) : (memref<*xf32>) -> ()
+  return
+ }
+
+ func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) {
+  %0 = memref_cast %arg0 : memref<4x3xf32> to memref<*xf32>
+  return %0, %0 : memref<*xf32>, memref<*xf32>
+}
+
+func @return_var_memref_caller() {
+  %0 = alloca() : memref<4x3xf32>
+  %c0f32 = constant 1.0 : f32
+  linalg.fill(%0, %c0f32) : memref<4x3xf32>, f32
+  %1 = call @return_var_memref(%0) : (memref<4x3xf32>) -> memref<*xf32>
+  call @print_memref_f32(%1) : (memref<*xf32>) -> ()
+  return
+}
+
+func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
+  %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32>
+  return %0 : memref<*xf32>
+}


        


More information about the Mlir-commits mailing list