[Mlir-commits] [mlir] 4a01079 - Expose Tail Kind Call to MLIR (#98080)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 9 14:04:37 PDT 2024
Author: Steffi Stumpos
Date: 2024-07-09T14:04:33-07:00
New Revision: 4a010799317dfe19758477f693968fc594c1895d
URL: https://github.com/llvm/llvm-project/commit/4a010799317dfe19758477f693968fc594c1895d
DIFF: https://github.com/llvm/llvm-project/commit/4a010799317dfe19758477f693968fc594c1895d.diff
LOG: Expose Tail Kind Call to MLIR (#98080)
I would like to mark a call op in LLVM dialect as Musttail. The calling
convention attribute only exposes Tail, not Musttail. I noticed that the
CallInst of LLVM has an additional field to specify the flavor of tail
call kind. I bubbled this up to the LLVM dialect by adding another
attribute that maps to LLVM::CallInst::TailCallKind.
Added:
mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
mlir/test/Target/LLVMIR/Import/tail-kind.ll
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index b05366d2a635d..25a6ee27b01db 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1077,4 +1077,13 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
/// Folded into from LLVM::ZeroOp.
def LLVM_ZeroAttr : LLVM_Attr<"Zero", "zero">;
+//===----------------------------------------------------------------------===//
+// TailCallKindAttr
+//===----------------------------------------------------------------------===//
+
+def TailCallKindAttr : LLVM_Attr<"TailCallKind", "tailcallkind"> {
+ let parameters = (ins "TailCallKind":$TailCallKind);
+ let assemblyFormat = "`<` $TailCallKind `>`";
+}
+
#endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index 3a93be21da375..3ede857733242 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -89,6 +89,7 @@ class TBAANodeAttr : public Attribute {
// TODO: this shouldn't be needed after we unify the attribute generation, i.e.
// --gen-attr-* and --gen-attrdef-*.
using cconv::CConv;
+using tailcallkind::TailCallKind;
using linkage::Linkage;
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f8e85004d5f93..f41a97f9ecc81 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -279,6 +279,35 @@ def CConv : DialectAttr<
"::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
}
+//===----------------------------------------------------------------------===//
+// TailCallKind
+//===----------------------------------------------------------------------===//
+
+def TailCallKindNone : LLVM_EnumAttrCase<"None", "none", "TCK_None", 0>;
+def TailCallKindTail : LLVM_EnumAttrCase<"Tail", "tail", "TCK_Tail", 1>;
+def TailCallKindMustTail : LLVM_EnumAttrCase<"MustTail", "musttail", "TCK_MustTail", 2>;
+def TailCallKindNoTailCall : LLVM_EnumAttrCase<"NoTail", "notail", "TCK_NoTail", 3>;
+
+def TailCallKindEnum : LLVM_EnumAttr<
+ "TailCallKind",
+ "::llvm::CallInst::TailCallKind",
+ "Tail Call Kind",
+ [TailCallKindNone, TailCallKindNoTailCall,
+ TailCallKindMustTail, TailCallKindTail]> {
+ let cppNamespace = "::mlir::LLVM::tailcallkind";
+}
+
+def TailCallKind : DialectAttr<
+ LLVM_Dialect,
+ CPred<"::llvm::isa<::mlir::LLVM::TailCallKindAttr>($_self)">,
+ "LLVM Calling Convention specification"> {
+ let storageType = "::mlir::LLVM::TailCallKindAttr";
+ let returnType = "::mlir::LLVM::tailcallkind::TailCallKind";
+ let convertFromStorage = "$_self.getTailCallKind()";
+ let constBuilderCall =
+ "::mlir::LLVM::TailCallKindAttr::get($_builder.getContext(), $0)";
+}
+
//===----------------------------------------------------------------------===//
// DIEmissionKind
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 54f38c93e5080..65dfcf93d7029 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -650,7 +650,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
- DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+ DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+ DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a01c4ee4923eb..9372caf6e32a7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -44,6 +44,7 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::cconv::getMaxEnumValForCConv;
using mlir::LLVM::linkage::getMaxEnumValForLinkage;
+using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
@@ -197,6 +198,7 @@ struct EnumTraits {};
REGISTER_ENUM_TYPE(Linkage);
REGISTER_ENUM_TYPE(UnnamedAddr);
REGISTER_ENUM_TYPE(CConv);
+REGISTER_ENUM_TYPE(TailCallKind);
REGISTER_ENUM_TYPE(Visibility);
} // namespace
@@ -974,7 +976,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
build(builder, state, results,
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
- /*CConv=*/nullptr,
+ /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -997,7 +999,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
- /*access_groups=*/nullptr,
+ /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1006,7 +1008,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
- /*CConv=*/nullptr,
+ /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1017,7 +1019,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
- /*CConv=*/nullptr,
+ /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1180,6 +1182,9 @@ void CallOp::print(OpAsmPrinter &p) {
if (getCConv() != LLVM::CConv::C)
p << stringifyCConv(getCConv()) << ' ';
+ if(getTailCallKind() != LLVM::TailCallKind::None)
+ p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
+
// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
if (isDirect)
@@ -1194,7 +1199,8 @@ void CallOp::print(OpAsmPrinter &p) {
p << " vararg(" << calleeType << ")";
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
- {getCConvAttrName(), "callee", "callee_type"});
+ {getCConvAttrName(), "callee", "callee_type",
+ getTailCallKindAttrName()});
p << " : ";
if (!isDirect)
@@ -1262,7 +1268,7 @@ static ParseResult parseOptionalCallFuncPtr(
return success();
}
-// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
+// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
// ( `vararg(` var-arg-func-type `)` )?
// attribute-dict? `:` (type `,`)? function-type
@@ -1277,6 +1283,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
parser, result, LLVM::CConv::C)));
+ result.addAttribute(
+ getTailCallKindAttrName(result.name),
+ TailCallKindAttr::get(parser.getContext(),
+ parseOptionalLLVMKeyword<TailCallKind>(
+ parser, result, LLVM::TailCallKind::None)));
+
// Parse a function pointer for indirect calls.
if (parseOptionalCallFuncPtr(parser, operands))
return failure();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index f144c7158d679..3d6dd1247b413 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -218,6 +218,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front());
}
call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
+ call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
moduleTranslation.setTBAAMetadata(callOp, call);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 0c8b3296f44a7..9915576bbc458 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1468,6 +1468,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
callOp = builder.create<CallOp>(loc, funcTy, operands);
}
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
+ callOp.setTailCallKind(
+ convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);
if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult());
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 2386dde19301e..ca9748a2b8b7b 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -673,3 +673,41 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) {
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
llvm.return
}
+
+// CHECK: llvm.func @tail_call_target() -> i32
+llvm.func @tail_call_target() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+ // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+ %0 = llvm.call none @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+ // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+ %0 = llvm.call @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+ // CHECK-NEXT: llvm.call musttail @tail_call_target() : () -> i32
+ %0 = llvm.call musttail @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+ // CHECK-NEXT: llvm.call tail @tail_call_target() : () -> i32
+ %0 = llvm.call tail @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+ // CHECK-NEXT: llvm.call notail @tail_call_target() : () -> i32
+ %0 = llvm.call notail @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
diff --git a/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
new file mode 100644
index 0000000000000..73a6aa2f91cba
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK: declare i32 @foo()
+llvm.func @foo() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+ // CHECK-NEXT: call i32 @foo()
+ %0 = llvm.call none @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+ // CHECK-NEXT: call i32 @foo()
+ %0 = llvm.call @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+ // CHECK-NEXT: musttail call i32 @foo()
+ %0 = llvm.call musttail @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+ // CHECK-NEXT: tail call i32 @foo()
+ %0 = llvm.call tail @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+ // CHECK-NEXT: notail call i32 @foo()
+ %0 = llvm.call notail @foo() : () -> i32
+ llvm.return %0 : i32
+}
diff --git a/mlir/test/Target/LLVMIR/Import/tail-kind.ll b/mlir/test/Target/LLVMIR/Import/tail-kind.ll
new file mode 100644
index 0000000000000..608ae4043b671
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/tail-kind.ll
@@ -0,0 +1,35 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK: llvm.func @tailkind()
+declare void @tailkind()
+
+; CHECK-LABEL: @call_tailkind
+define void @call_tailkind() {
+ ; CHECK: llvm.call musttail @tailkind()
+ musttail call void @tailkind()
+ ret void
+}
+
+; // -----
+
+; CHECK: llvm.func @tailkind()
+declare void @tailkind()
+
+; CHECK-LABEL: @call_tailkind
+define void @call_tailkind() {
+ ; CHECK: llvm.call tail @tailkind()
+ tail call void @tailkind()
+ ret void
+}
+
+; // -----
+
+; CHECK: llvm.func @tailkind()
+declare void @tailkind()
+
+; CHECK-LABEL: @call_tailkind
+define void @call_tailkind() {
+ ; CHECK: llvm.call notail @tailkind()
+ notail call void @tailkind()
+ ret void
+}
More information about the Mlir-commits
mailing list