[llvm-branch-commits] [llvm] [mlir] [mlir][LLVM] add argument and result attributes to llvm.call (PR #123177)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 21 06:27:41 PST 2025
https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/123177
>From 137705661c184ea1530982c19163341933ab421e Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 15 Jan 2025 09:09:53 -0800
Subject: [PATCH 1/4] [mlir][LLVM] add argument and result attributes to
llvm.call
---
llvm/include/llvm/IR/InstrTypes.h | 11 +++
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 4 +-
.../include/mlir/Target/LLVMIR/ModuleImport.h | 8 ++-
.../mlir/Target/LLVMIR/ModuleTranslation.h | 9 ++-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 67 +++++++++++++------
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 21 ++++++
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 35 ++++++++++
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 52 ++++++++++----
mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 20 ++++++
.../LLVMIR/Import/call-argument-attributes.ll | 25 +++++++
.../LLVMIR/call-argument-attributes.mlir | 17 +++++
12 files changed, 230 insertions(+), 41 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
create mode 100644 mlir/test/Target/LLVMIR/call-argument-attributes.mlir
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index b8d9cc10292f4a..0e391325eebdce 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
Attrs = Attrs.addRetAttribute(getContext(), Attr);
}
+ /// Adds attributes to the return value.
+ void addRetAttrs(const AttrBuilder &B) {
+ Attrs = Attrs.addRetAttributes(getContext(), B);
+ }
+
/// Adds the attribute to the indicated argument
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
assert(ArgNo < arg_size() && "Out of bounds");
@@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
}
+ /// Adds attributes to the indicated argument
+ void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
+ assert(ArgNo < arg_size() && "Out of bounds");
+ Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
+ }
+
/// removes the attribute from the list of attributes.
void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b2281536aa40b6..85f5c6cc8cca07 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -755,7 +755,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
- OptionalAttr<ArrayAttr>:$op_bundle_tags);
+ OptionalAttr<ArrayAttr>:$op_bundle_tags,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 33c9af7c6335a4..86e1d6a04cd096 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -326,14 +326,18 @@ class ModuleImport {
SmallVectorImpl<Type> &types,
SmallVectorImpl<Value> &operands,
bool allowInlineAsm = false);
- /// Converts the parameter attributes attached to `func` and adds them to the
- /// `funcOp`.
+ /// Converts the parameter and result attributes attached to `func` and adds
+ /// them to the `funcOp`.
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
OpBuilder &builder);
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
/// DictionaryAttr for the LLVM dialect.
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
OpBuilder &builder);
+ /// Converts the parameter and result attributes attached to `call` and adds
+ /// them to the `callOp`.
+ void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
+ OpBuilder &builder);
/// Returns the builtin type equivalent to the given LLVM dialect type or
/// nullptr if there is no equivalent. The returned type can be used to create
/// an attribute for a GlobalOp or a ConstantOp.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 1b62437761ed9d..88fc17ca4fda24 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -228,6 +228,11 @@ class ModuleTranslation {
/*recordInsertions=*/false);
}
+ /// Translates parameter attributes of a call and adds them to the returned
+ /// AttrBuilder. Returns failure if any of the translations failed.
+ FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp, int argIdx,
+ DictionaryAttr paramAttrs);
+
/// Gets the named metadata in the LLVM IR module being constructed, creating
/// it if it does not exist.
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
@@ -346,8 +351,8 @@ class ModuleTranslation {
convertDialectAttributes(Operation *op,
ArrayRef<llvm::Instruction *> instructions);
- /// Translates parameter attributes and adds them to the returned AttrBuilder.
- /// Returns failure if any of the translations failed.
+ /// Translates parameter attributes of a function and adds them to the
+ /// returned AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder>
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ef1e0222e05f06..6c4988bac7813e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1033,6 +1033,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
/*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1060,6 +1061,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
/*convergent=*/nullptr,
/*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1073,6 +1075,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1087,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1331,42 +1335,52 @@ void CallOp::print(OpAsmPrinter &p) {
getVarCalleeTypeAttrName(), getCConvAttrName(),
getOperandSegmentSizesAttrName(),
getOpBundleSizesAttrName(),
- getOpBundleTagsAttrName()});
+ getOpBundleTagsAttrName(), getArgAttrsAttrName(),
+ getResAttrsAttrName()});
p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";
// Reconstruct the function MLIR function type from operand and result types.
- p.printFunctionalType(args.getTypes(), getResultTypes());
+ call_interface_impl::printFunctionSignature(
+ p, *this, args.getTypes(), /*isVariadic=*/false, getResultTypes());
}
/// Parses the type of a call operation and resolves the operands if the parsing
/// succeeds. Returns failure otherwise.
static ParseResult parseCallTypeAndResolveOperands(
OpAsmParser &parser, OperationState &result, bool isDirect,
- ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+ ArrayRef<OpAsmParser::UnresolvedOperand> operands,
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs) {
SMLoc trailingTypesLoc = parser.getCurrentLocation();
SmallVector<Type> types;
- if (parser.parseColonTypeList(types))
+ if (parser.parseColon())
return failure();
-
- if (isDirect && types.size() != 1)
- return parser.emitError(trailingTypesLoc,
- "expected direct call to have 1 trailing type");
- if (!isDirect && types.size() != 2)
- return parser.emitError(trailingTypesLoc,
- "expected indirect call to have 2 trailing types");
-
- auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
- if (!funcType)
+ if (!isDirect) {
+ types.emplace_back();
+ if (parser.parseType(types.back()))
+ return failure();
+ if (parser.parseOptionalComma())
+ return parser.emitError(
+ trailingTypesLoc, "expected indirect call to have 2 trailing types");
+ }
+ SmallVector<Type> argTypes;
+ SmallVector<Type> resTypes;
+ if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
+ resTypes, resultAttrs)) {
+ if (isDirect)
+ return parser.emitError(trailingTypesLoc,
+ "expected direct call to have 1 trailing types");
return parser.emitError(trailingTypesLoc,
"expected trailing function type");
- if (funcType.getNumResults() > 1)
+ }
+
+ if (resTypes.size() > 1)
return parser.emitError(trailingTypesLoc,
"expected function with 0 or 1 result");
- if (funcType.getNumResults() == 1 &&
- llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
+ if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
return parser.emitError(trailingTypesLoc,
"expected a non-void result type");
@@ -1374,12 +1388,12 @@ static ParseResult parseCallTypeAndResolveOperands(
// indirect calls, while the types list is emtpy for direct calls.
// Append the function input types to resolve the call operation
// operands.
- llvm::append_range(types, funcType.getInputs());
+ llvm::append_range(types, argTypes);
if (parser.resolveOperands(operands, types, parser.getNameLoc(),
result.operands))
return failure();
- if (funcType.getNumResults() != 0)
- result.addTypes(funcType.getResults());
+ if (resTypes.size() != 0)
+ result.addTypes(resTypes);
return success();
}
@@ -1493,8 +1507,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
// Parse the trailing type list and resolve the operands.
- if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+ SmallVector<DictionaryAttr> argAttrs;
+ SmallVector<DictionaryAttr> resultAttrs;
+ if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+ argAttrs, resultAttrs))
return failure();
+ call_interface_impl::addArgAndResultAttrs(
+ parser.getBuilder(), result, argAttrs, resultAttrs,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
getOpBundleSizesAttrName(result.name)))
@@ -1714,7 +1734,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
// Parse the trailing type list and resolve the function operands.
- if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+ SmallVector<DictionaryAttr> argAttrs;
+ SmallVector<DictionaryAttr> resultAttrs;
+ if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+ argAttrs, resultAttrs))
return failure();
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 2084e527773ca8..52f42df60f0015 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -265,6 +265,27 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getWillReturnAttr())
call->addFnAttr(llvm::Attribute::WillReturn);
+ if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr())
+ for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
+ if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addParamAttrs(argIdx, *attrBuilder);
+ }
+ }
+
+ ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
+ if (resAttrsArray && resAttrsArray.size() == 1)
+ if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addRetAttrs(*attrBuilder);
+ }
+
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
llvm::MemoryEffects memEffects =
llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index eba86f06d09056..f65bf6584d51f2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1641,6 +1641,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
callOp.setWillReturn(true);
+ // Handle parameter and result attributes.
+ convertParameterAttributes(callInst, callOp, builder);
+
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
@@ -2084,6 +2087,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}
+void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
+ CallOpInterface callOp,
+ OpBuilder &builder) {
+ auto llvmAttrs = call->getAttributes();
+ SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
+ bool anyArgAttrs = false;
+ for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+ llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
+ if (llvmArgAttrsSet.back().hasAttributes())
+ anyArgAttrs = true;
+ }
+ auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+ SmallVector<Attribute> attrs;
+ for (auto &dict : dictAttrs)
+ attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+ return builder.getArrayAttr(attrs);
+ };
+ if (anyArgAttrs) {
+ SmallVector<DictionaryAttr> argAttrs;
+ for (auto &llvmArgAttrs : llvmArgAttrsSet)
+ argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
+ callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
+ }
+
+ llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
+ if (!llvmResAttr.hasAttributes())
+ return;
+ SmallVector<DictionaryAttr, 1> resAttrs;
+ resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder));
+ callOp.setResAttrsAttr(getArrayAttr(resAttrs));
+}
+
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
clearRegionState();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 4367100e3aca68..b2d2c1cddca318 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1563,6 +1563,26 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func,
}
}
+static void convertParameterAttr(llvm::AttrBuilder &attrBuilder,
+ llvm::Attribute::AttrKind llvmKind,
+ NamedAttribute namedAttr,
+ ModuleTranslation &moduleTranslation) {
+ llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+ .Case<TypeAttr>([&](auto typeAttr) {
+ attrBuilder.addTypeAttr(
+ llvmKind, moduleTranslation.convertType(typeAttr.getValue()));
+ })
+ .Case<IntegerAttr>([&](auto intAttr) {
+ attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+ })
+ .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
+ .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
+ attrBuilder.addConstantRangeAttr(
+ llvmKind,
+ llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
+ });
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
DictionaryAttr paramAttrs) {
@@ -1573,20 +1593,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
auto it = attrNameToKindMapping.find(namedAttr.getName());
if (it != attrNameToKindMapping.end()) {
llvm::Attribute::AttrKind llvmKind = it->second;
-
- llvm::TypeSwitch<Attribute>(namedAttr.getValue())
- .Case<TypeAttr>([&](auto typeAttr) {
- attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
- })
- .Case<IntegerAttr>([&](auto intAttr) {
- attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
- })
- .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
- .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
- attrBuilder.addConstantRangeAttr(
- llvmKind, llvm::ConstantRange(rangeAttr.getLower(),
- rangeAttr.getUpper()));
- });
+ convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
} else if (namedAttr.getNameDialect()) {
if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
return failure();
@@ -1596,6 +1603,23 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
return attrBuilder;
}
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(CallOp, int argIdx,
+ DictionaryAttr paramAttrs) {
+ llvm::AttrBuilder attrBuilder(llvmModule->getContext());
+ auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+ for (auto namedAttr : paramAttrs) {
+ auto it = attrNameToKindMapping.find(namedAttr.getName());
+ if (it != attrNameToKindMapping.end()) {
+ llvm::Attribute::AttrKind llvmKind = it->second;
+ convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
+ }
+ }
+
+ return attrBuilder;
+}
+
LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles, or global initializers that reference functions.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 25806d9d0edd72..14cdcc06625c06 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
func.func private @standard_func_callee()
func.func @call_missing_ptr_type(%arg : i8) {
+ // expected-error at +2 {{expected '('}}
// expected-error at +1 {{expected direct call to have 1 trailing type}}
llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8)
llvm.return
@@ -251,6 +252,7 @@ func.func @call_non_pointer_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
// -----
func.func @call_non_function_type(%callee : !llvm.ptr, %arg : i8) {
+ // expected-error at +2 {{expected '('}}
// expected-error at +1 {{expected trailing function type}}
llvm.call %callee(%arg) : !llvm.ptr, !llvm.func<i8 (i8)>
llvm.return
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 88660ce598f3c2..e565772f06b03c 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -941,3 +941,23 @@ llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) {
llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
llvm.return
}
+
+llvm.func @somefunc(i32, !llvm.ptr)
+
+// CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
+// CHECK-SAME: %[[VAL_0:.*]]: i32,
+// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr)
+llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
+ // CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+ llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
+// CHECK-SAME: %[[VAL_0:.*]]: i16,
+// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr
+llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
+ // CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ llvm.return %0 : i16
+}
diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
new file mode 100644
index 00000000000000..8294579b48c63c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
@@ -0,0 +1,25 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK-LABEL: llvm.func @somefunc(i32, !llvm.ptr)
+declare void @somefunc(i32, ptr)
+
+; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
+; CHECK-SAME: %[[VAL_0:.*]]: i32,
+; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr)
+llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
+declare void @somefunc(i32, ptr)
+; CHECK-LABEL: @test_call_arg_attrs_direct
+define void @test_call_arg_attrs_direct(i32 %0, ptr %1) {
+ ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+ call void @somefunc(i32 %0, ptr byval(i64) %1)
+ ret void
+}
+
+; CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
+; CHECK-SAME: %[[VAL_0:.*]]: i16,
+; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr
+define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) {
+; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ %3 = tail call signext i16 %1(i16 noundef signext %0)
+ ret i16 %3
+}
diff --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
new file mode 100644
index 00000000000000..89b1f29a68623b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @somefunc(i32, !llvm.ptr)
+
+// CHECK-LABEL: define void @test_call_arg_attrs_direct
+llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
+ // CHECK: call void @somefunc(i32 %{{.*}}, ptr byval(i64) %{{.*}})
+ llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: define i16 @test_call_arg_attrs_indirec
+llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
+ // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}})
+ %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ llvm.return %0 : i16
+}
>From 879b03de74daffe4f83a5c72f76fbeed495f73bf Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 16 Jan 2025 04:42:18 -0800
Subject: [PATCH 2/4] remove bogus extra lines in new test
---
mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll | 3 ---
1 file changed, 3 deletions(-)
diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
index 8294579b48c63c..2c86ca6b03125e 100644
--- a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
+++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
@@ -6,9 +6,6 @@ declare void @somefunc(i32, ptr)
; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
; CHECK-SAME: %[[VAL_0:.*]]: i32,
; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr)
-llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
-declare void @somefunc(i32, ptr)
-; CHECK-LABEL: @test_call_arg_attrs_direct
define void @test_call_arg_attrs_direct(i32 %0, ptr %1) {
; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
call void @somefunc(i32 %0, ptr byval(i64) %1)
>From 854e43c16d73a3645cd224be2861836373079b1f Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 16 Jan 2025 08:59:35 -0800
Subject: [PATCH 3/4] change inheritance level of new interface
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 +
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++++++--
2 files changed, 7 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 85f5c6cc8cca07..5c7a697107c237 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -699,6 +699,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
+ DeclareOpInterfaceMethods<ArgumentAttributesOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
let summary = "Call to an LLVM function.";
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f65bf6584d51f2..f51b577c255660 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2090,6 +2090,10 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
OpBuilder &builder) {
+ auto argAttrsOpInterface =
+ dyn_cast<ArgumentAttributesOpInterface>(callOp.getOperation());
+ if (!argAttrsOpInterface)
+ return;
auto llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
@@ -2108,7 +2112,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
- callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
+ argAttrsOpInterface.setArgAttrsAttr(getArrayAttr(argAttrs));
}
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
@@ -2116,7 +2120,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
return;
SmallVector<DictionaryAttr, 1> resAttrs;
resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder));
- callOp.setResAttrsAttr(getArrayAttr(resAttrs));
+ argAttrsOpInterface.setResAttrsAttr(getArrayAttr(resAttrs));
}
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
>From b588ff3dcdcd85414ce0ed6b274cce2ee3db2bd5 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 21 Jan 2025 02:02:00 -0800
Subject: [PATCH 4/4] adapt to ArgumentAttributesOpInterface iface removal
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 -
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 ++-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++------
3 files changed, 4 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f6721819d04e6f..ee6e10efed4f16 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -701,7 +701,6 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
- DeclareOpInterfaceMethods<ArgumentAttributesOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
let summary = "Call to an LLVM function.";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a1a14f41e122b5..c10abdc24527e4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1344,7 +1344,8 @@ void CallOp::print(OpAsmPrinter &p) {
// Reconstruct the function MLIR function type from operand and result types.
call_interface_impl::printFunctionSignature(
- p, *this, args.getTypes(), /*isVariadic=*/false, getResultTypes());
+ p, args.getTypes(), getArgAttrsAttr(),
+ /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}
/// Parses the type of a call operation and resolves the operands if the parsing
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f51b577c255660..f65bf6584d51f2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2090,10 +2090,6 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
OpBuilder &builder) {
- auto argAttrsOpInterface =
- dyn_cast<ArgumentAttributesOpInterface>(callOp.getOperation());
- if (!argAttrsOpInterface)
- return;
auto llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
@@ -2112,7 +2108,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
- argAttrsOpInterface.setArgAttrsAttr(getArrayAttr(argAttrs));
+ callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
@@ -2120,7 +2116,7 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
return;
SmallVector<DictionaryAttr, 1> resAttrs;
resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder));
- argAttrsOpInterface.setResAttrsAttr(getArrayAttr(resAttrs));
+ callOp.setResAttrsAttr(getArrayAttr(resAttrs));
}
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
More information about the llvm-branch-commits
mailing list