[Mlir-commits] [mlir] a89fc12 - [mlir] Support return and call ops in bare-ptr calling convention
Diego Caballero
llvmlistbot at llvm.org
Tue Sep 29 12:09:45 PDT 2020
Author: Diego Caballero
Date: 2020-09-29T12:00:47-07:00
New Revision: a89fc12653c520a5a70249e07c0a394584f4abbe
URL: https://github.com/llvm/llvm-project/commit/a89fc12653c520a5a70249e07c0a394584f4abbe
DIFF: https://github.com/llvm/llvm-project/commit/a89fc12653c520a5a70249e07c0a394584f4abbe.diff
LOG: [mlir] Support return and call ops in bare-ptr calling convention
This patch adds support for the 'return' and 'call' ops to the bare-ptr
calling convention. These changes also align the bare-ptr calling
convention code with the latest changes in the default calling convention
and reduce the amount of customization code needed.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D87724
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index ab047a08f404..d98a0ff6efb3 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -27,6 +27,7 @@ class Type;
namespace mlir {
+class BaseMemRefType;
class ComplexType;
class LLVMTypeConverter;
class UnrankedMemRefType;
@@ -74,15 +75,28 @@ class LLVMTypeConverter : public TypeConverter {
SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
- /// supported LLVM IR type. In particular, if more than one values is
+ /// supported LLVM IR type. In particular, if more than one value is
/// returned, create an LLVM IR structure type with elements that correspond
/// to each of the MLIR types converted with `convertType`.
Type packFunctionResults(ArrayRef<Type> types);
+ /// Convert a type in the context of the default or bare pointer calling
+ /// convention. Calling convention sensitive types, such as MemRefType and
+ /// UnrankedMemRefType, are converted following the specific rules for the
+ /// calling convention. Calling convention independent types are converted
+ /// following the default LLVM type conversions.
+ Type convertCallingConventionType(Type type);
+
+ /// Promote the bare pointers in 'values' that resulted from memrefs to
+ /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
+ /// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type).
+ void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
+ Location loc, ArrayRef<Type> stdTypes,
+ SmallVectorImpl<Value> &values);
+
/// Returns the MLIR context.
MLIRContext &getContext();
-
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }
@@ -179,6 +193,9 @@ class LLVMTypeConverter : public TypeConverter {
// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
+ /// Convert a memref type to a bare pointer to the memref element type.
+ Type convertMemRefToBarePtr(BaseMemRefType type);
+
// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type);
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 186c8ec48fa5..c77c0b529caf 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -80,37 +80,12 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
return success();
}
-/// Convert a MemRef type to a bare pointer to the MemRef element type.
-static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter,
- MemRefType type) {
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(type, strides, offset)))
- return {};
-
- LLVM::LLVMType elementType =
- unwrap(converter.convertType(type.getElementType()));
- if (!elementType)
- return {};
- return elementType.getPointerTo(type.getMemorySpace());
-}
-
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result) {
- // TODO: Add support for unranked memref.
- if (auto memrefTy = type.dyn_cast<MemRefType>()) {
- auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy);
- if (!llvmTy)
- return failure();
-
- result.push_back(llvmTy);
- return success();
- }
-
- auto llvmTy = converter.convertType(type);
+ auto llvmTy = converter.convertCallingConventionType(type);
if (!llvmTy)
return failure();
@@ -272,14 +247,14 @@ SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
- FunctionType type, bool isVariadic,
+ FunctionType funcTy, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
// Select the argument converter depending on the calling convetion.
auto funcArgConverter = options.useBarePtrCallConv
? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
// Convert argument types one by one and check for errors.
- for (auto &en : llvm::enumerate(type.getInputs())) {
+ for (auto &en : llvm::enumerate(funcTy.getInputs())) {
Type type = en.value();
SmallVector<Type, 8> converted;
if (failed(funcArgConverter(*this, type, converted)))
@@ -296,9 +271,9 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
LLVM::LLVMType resultType =
- type.getNumResults() == 0
+ funcTy.getNumResults() == 0
? LLVM::LLVMType::getVoidTy(&getContext())
- : unwrap(packFunctionResults(type.getResults()));
+ : unwrap(packFunctionResults(funcTy.getResults()));
if (!resultType)
return {};
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
@@ -394,6 +369,36 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
}
+/// Convert a memref type to a bare pointer to the memref element type.
+Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
+ if (type.isa<UnrankedMemRefType>())
+ // Unranked memref is not supported in the bare pointer calling convention.
+ return {};
+
+ // Check that the memref has static shape, strides and offset. Otherwise, it
+ // cannot be lowered to a bare pointer.
+ auto memrefTy = type.cast<MemRefType>();
+ if (!memrefTy.hasStaticShape())
+ return {};
+
+ int64_t offset = 0;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(memrefTy, strides, offset)))
+ return {};
+
+ for (int64_t stride : strides)
+ if (ShapedType::isDynamicStrideOrOffset(stride))
+ return {};
+
+ if (ShapedType::isDynamicStrideOrOffset(offset))
+ return {};
+
+ LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+ if (!elementType)
+ return {};
+ return elementType.getPointerTo(type.getMemorySpace());
+}
+
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
// n > 1.
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
@@ -410,6 +415,37 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
return vectorType;
}
+/// Convert a type in the context of the default or bare pointer calling
+/// convention. Calling convention sensitive types, such as MemRefType and
+/// UnrankedMemRefType, are converted following the specific rules for the
+/// calling convention. Calling convention independent types are converted
+/// following the default LLVM type conversions.
+Type LLVMTypeConverter::convertCallingConventionType(Type type) {
+ if (options.useBarePtrCallConv)
+ if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
+ return convertMemRefToBarePtr(memrefTy);
+
+ return convertType(type);
+}
+
+/// Promote the bare pointers in 'values' that resulted from memrefs to
+/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
+/// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type).
+void LLVMTypeConverter::promoteBarePtrsToDescriptors(
+ ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
+ SmallVectorImpl<Value> &values) {
+ assert(stdTypes.size() == values.size() &&
+ "The number of types and values doesn't match");
+ for (unsigned i = 0, end = values.size(); i < end; ++i) {
+ Type stdTy = stdTypes[i];
+ if (auto memrefTy = stdTy.dyn_cast<MemRefType>())
+ values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
+ memrefTy, values[i]);
+ else
+ llvm_unreachable("Unranked memrefs are not supported");
+ }
+}
+
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
MLIRContext *context,
LLVMTypeConverter &typeConverter,
@@ -1088,18 +1124,6 @@ namespace {
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
protected:
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
- using UnsignedTypePair = std::pair<unsigned, Type>;
-
- // Gather the positions and types of memref-typed arguments in a given
- // FunctionType.
- void getMemRefArgIndicesAndTypes(
- FunctionType type, SmallVectorImpl<UnsignedTypePair> &argsInfo) const {
- argsInfo.reserve(type.getNumInputs());
- for (auto en : llvm::enumerate(type.getInputs())) {
- if (en.value().isa<MemRefType, UnrankedMemRefType>())
- argsInfo.push_back({en.index(), en.value()});
- }
- }
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
@@ -1192,11 +1216,10 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
- // Store the positions and type of memref-typed arguments so that we can
- // promote them to MemRef descriptor structs at the beginning of the
- // function.
- SmallVector<UnsignedTypePair, 4> promotedArgsInfo;
- getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo);
+ // Store the type of memref-typed arguments before the conversion so that we
+ // can promote them to MemRef descriptor at the beginning of the function.
+ SmallVector<Type, 8> oldArgTypes =
+ llvm::to_vector<8>(funcOp.getType().getInputs());
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
@@ -1206,27 +1229,42 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
return success();
}
- // Promote bare pointers from MemRef arguments to a MemRef descriptor struct
- // at the beginning of the function so that all the MemRefs in the function
- // have a uniform representation.
- Block *firstBlock = &newFuncOp.getBody().front();
- rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
- auto funcLoc = funcOp.getLoc();
- for (const auto &argInfo : promotedArgsInfo) {
- // TODO: Add support for unranked MemRefs.
- if (auto memrefType = argInfo.second.dyn_cast<MemRefType>()) {
- // Replace argument with a placeholder (undef), promote argument to a
- // MemRef descriptor and replace placeholder with the last instruction
- // of the MemRef descriptor. The placeholder is needed to avoid
- // replacing argument uses in the MemRef descriptor instructions.
- BlockArgument arg = firstBlock->getArgument(argInfo.first);
- Value placeHolder =
- rewriter.create<LLVM::UndefOp>(funcLoc, arg.getType());
- rewriter.replaceUsesOfBlockArgument(arg, placeHolder);
- auto desc = MemRefDescriptor::fromStaticShape(
- rewriter, funcLoc, typeConverter, memrefType, arg);
- rewriter.replaceOp(placeHolder.getDefiningOp(), {desc});
- }
+ // Promote bare pointers from memref arguments to memref descriptors at the
+ // beginning of the function so that all the memrefs in the function have a
+ // uniform representation.
+ Block *entryBlock = &newFuncOp.getBody().front();
+ auto blockArgs = entryBlock->getArguments();
+ assert(blockArgs.size() == oldArgTypes.size() &&
+ "The number of arguments and types doesn't match");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(entryBlock);
+ for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
+ BlockArgument arg = std::get<0>(it);
+ Type argTy = std::get<1>(it);
+
+ // Unranked memrefs are not supported in the bare pointer calling
+ // convention. We should have bailed out before in the presence of
+ // unranked memrefs.
+ assert(!argTy.isa<UnrankedMemRefType>() &&
+ "Unranked memref is not supported");
+ auto memrefTy = argTy.dyn_cast<MemRefType>();
+ if (!memrefTy)
+ continue;
+
+ // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
+ // or unranked memref descriptor and replace placeholder with the last
+ // instruction of the memref descriptor.
+ // TODO: The placeholder is needed to avoid replacing barePtr uses in the
+ // MemRef descriptor instructions. We may want to have a utility in the
+ // rewriter to properly handle this use case.
+ Location loc = op->getLoc();
+ auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
+ rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+
+ Value desc = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, typeConverter, memrefTy, arg);
+ rewriter.replaceOp(placeholder, {desc});
}
rewriter.eraseOp(op);
@@ -2138,12 +2176,22 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
rewriter.getI64ArrayAttr(i)));
}
}
- if (failed(copyUnrankedDescriptors(
- rewriter, op->getLoc(), this->typeConverter, op->getResultTypes(),
- results, /*toDynamic=*/false)))
+
+ if (this->typeConverter.getOptions().useBarePtrCallConv) {
+ // For the bare-ptr calling convention, promote memref results to
+ // descriptors.
+ assert(results.size() == resultTypes.size() &&
+ "The number of arguments and types doesn't match");
+ this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(),
+ resultTypes, results);
+ } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(),
+ this->typeConverter, resultTypes,
+ results,
+ /*toDynamic=*/false))) {
return failure();
- rewriter.replaceOp(op, results);
+ }
+ rewriter.replaceOp(op, results);
return success();
}
};
@@ -2706,11 +2754,32 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
unsigned numArguments = op->getNumOperands();
- auto updatedOperands = llvm::to_vector<4>(operands);
- copyUnrankedDescriptors(rewriter, op->getLoc(), typeConverter,
- op->getOperands().getTypes(), updatedOperands,
- /*toDynamic=*/true);
+ SmallVector<Value, 4> updatedOperands;
+
+ if (typeConverter.getOptions().useBarePtrCallConv) {
+ // For the bare-ptr calling convention, extract the aligned pointer to
+ // be returned from the memref descriptor.
+ for (auto it : llvm::zip(op->getOperands(), operands)) {
+ Type oldTy = std::get<0>(it).getType();
+ Value newOperand = std::get<1>(it);
+ if (oldTy.isa<MemRefType>()) {
+ MemRefDescriptor memrefDesc(newOperand);
+ newOperand = memrefDesc.alignedPtr(rewriter, loc);
+ } else if (oldTy.isa<UnrankedMemRefType>()) {
+ // Unranked memref is not supported in the bare pointer calling
+ // convention.
+ return failure();
+ }
+ updatedOperands.push_back(newOperand);
+ }
+ } else {
+ updatedOperands = llvm::to_vector<4>(operands);
+ copyUnrankedDescriptors(rewriter, loc, typeConverter,
+ op->getOperands().getTypes(), updatedOperands,
+ /*toDynamic=*/true);
+ }
// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
@@ -2729,10 +2798,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
auto packedType = typeConverter.packFunctionResults(
llvm::to_vector<4>(op->getOperandTypes()));
- Value packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
+ Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
for (unsigned i = 0; i < numArguments; ++i) {
packed = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), packedType, packed, updatedOperands[i],
+ loc, packedType, packed, updatedOperands[i],
rewriter.getI64ArrayAttr(i));
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
@@ -3380,17 +3449,21 @@ void mlir::populateStdToLLVMConversionPatterns(
populateStdToLLVMMemoryConversionPatterns(converter, patterns);
}
-// Create an LLVM IR structure type if there is more than one result.
+/// Convert a non-empty list of types to be returned from a function into a
+/// supported LLVM IR type. In particular, if more than one value is returned,
+/// create an LLVM IR structure type with elements that correspond to each of
+/// the MLIR types converted with `convertType`.
Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
assert(!types.empty() && "expected non-empty list of type");
if (types.size() == 1)
- return convertType(types.front());
+ return convertCallingConventionType(types.front());
SmallVector<LLVM::LLVMType, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
- auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
+ auto converted =
+ convertCallingConventionType(t).dyn_cast_or_null<LLVM::LLVMType>();
if (!converted)
return {};
resultTypes.push_back(converted);
@@ -3426,16 +3499,27 @@ SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
auto operand = std::get<0>(it);
auto llvmOperand = std::get<1>(it);
- if (operand.getType().isa<UnrankedMemRefType>()) {
- UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
- promotedOperands);
- continue;
- }
- if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
- MemRefDescriptor::unpack(builder, loc, llvmOperand,
- operand.getType().cast<MemRefType>(),
- promotedOperands);
- continue;
+ if (options.useBarePtrCallConv) {
+ // For the bare-ptr calling convention, we only have to extract the
+ // aligned pointer of a memref.
+ if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ MemRefDescriptor desc(llvmOperand);
+ llvmOperand = desc.alignedPtr(builder, loc);
+ } else if (operand.getType().isa<UnrankedMemRefType>()) {
+ llvm_unreachable("Unranked memrefs are not supported");
+ }
+ } else {
+ if (operand.getType().isa<UnrankedMemRefType>()) {
+ UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+ promotedOperands);
+ continue;
+ }
+ if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ MemRefDescriptor::unpack(builder, loc, llvmOperand,
+ operand.getType().cast<MemRefType>(),
+ promotedOperands);
+ continue;
+ }
}
promotedOperands.push_back(llvmOperand);
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index 5cccca3795b3..5dd36ba6d2ac 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -14,13 +14,13 @@ func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}, %other : memr
// CHECK-COUNT-5: !llvm.i64
// CHECK-SAME: -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-LABEL: func @check_static_return
-// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
+// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.ptr<float> {
func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
// CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
@@ -31,7 +31,8 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr<float>
return %static : memref<32x18xf32>
}
@@ -42,13 +43,13 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
// CHECK-COUNT-5: !llvm.i64
// CHECK-SAME: -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-LABEL: func @check_static_return_with_offset
-// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
+// BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr<float>) -> !llvm.ptr<float> {
func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> {
// CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
@@ -59,14 +60,15 @@ func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, stri
// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr<float>
return %static : memref<32x18xf32, offset:7, strides:[22,1]>
}
// -----
// CHECK-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64)> {
-// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64)> {
+// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.ptr<float> {
func @zero_d_alloc() -> memref<f32> {
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr<float>
@@ -174,7 +176,7 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// -----
// CHECK-LABEL: func @static_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
-// BAREPTR-LABEL: func @static_alloc() -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
+// BAREPTR-LABEL: func @static_alloc() -> !llvm.ptr<float> {
func @static_alloc() -> memref<32x18xf32> {
// CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
@@ -388,3 +390,29 @@ func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
%4 = dim %static, %c4 : memref<42x32x15x13x27xf32>
return
}
+
+// -----
+
+// BAREPTR: llvm.func @foo(!llvm.ptr<i8>) -> !llvm.ptr<i8>
+func @foo(memref<10xi8>) -> memref<20xi8>
+
+// BAREPTR-LABEL: func @check_memref_func_call
+// BAREPTR-SAME: %[[in:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
+ // BAREPTR: %[[inDesc:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0]
+ // BAREPTR-NEXT: %[[barePtr:.*]] = llvm.extractvalue %[[inDesc]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: %[[call:.*]] = llvm.call @foo(%[[barePtr]]) : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
+ // BAREPTR-NEXT: %[[desc0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: %[[desc1:.*]] = llvm.insertvalue %[[call]], %[[desc0]][0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: %[[desc2:.*]] = llvm.insertvalue %[[call]], %[[desc1]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // BAREPTR-NEXT: %[[desc4:.*]] = llvm.insertvalue %[[c0]], %[[desc2]][2] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: %[[c20:.*]] = llvm.mlir.constant(20 : index) : !llvm.i64
+ // BAREPTR-NEXT: %[[desc6:.*]] = llvm.insertvalue %[[c20]], %[[desc4]][3, 0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+ // BAREPTR-NEXT: %[[outDesc:.*]] = llvm.insertvalue %[[c1]], %[[desc6]][4, 0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ %res = call @foo(%in) : (memref<10xi8>) -> (memref<20xi8>)
+ // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr<i8>
+ return %res : memref<20xi8>
+}
More information about the Mlir-commits
mailing list