[Mlir-commits] [mlir] Expose Tail Kind Call to MLIR (PR #98080)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 8 14:28:16 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Steffi Stumpos (stumpOS)
<details>
<summary>Changes</summary>
I would like to mark a call op in LLVM dialect as Musttail. The calling convention attribute only exposes Tail, not Musttail. I noticed that the CallInst of LLVM has an additional field to specify the flavor of tail call kind. I bubbled this up to the LLVM dialect by adding another attribute that maps to LLVM::CallInst::TailCallKind.
---
Full diff: https://github.com/llvm/llvm-project/pull/98080.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+10)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h (+1)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td (+29)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+2-1)
- (modified) mlir/include/mlir/IR/DialectInterface.h (+1)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+18-6)
- (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+1)
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+38)
- (added) mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir (+39)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index b05366d2a635df..db43fb023b107c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1077,4 +1077,14 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
/// Folded into from LLVM::ZeroOp.
def LLVM_ZeroAttr : LLVM_Attr<"Zero", "zero">;
+
+//===----------------------------------------------------------------------===//
+// TailCallKindAttr
+//===----------------------------------------------------------------------===//
+
+def TailCallKindAttr : LLVM_Attr<"TailCallKind", "tailcallkind"> {
+ let parameters = (ins "TailCallKind":$TailCallKind);
+ let assemblyFormat = "`<` $TailCallKind `>`";
+}
+
#endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index 3a93be21da3756..3ede8577332422 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -89,6 +89,7 @@ class TBAANodeAttr : public Attribute {
// TODO: this shouldn't be needed after we unify the attribute generation, i.e.
// --gen-attr-* and --gen-attrdef-*.
using cconv::CConv;
+using tailcallkind::TailCallKind;
using linkage::Linkage;
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f8e85004d5f93c..f41a97f9ecc818 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -279,6 +279,35 @@ def CConv : DialectAttr<
"::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
}
+//===----------------------------------------------------------------------===//
+// TailCallKind
+//===----------------------------------------------------------------------===//
+
+def TailCallKindNone : LLVM_EnumAttrCase<"None", "none", "TCK_None", 0>;
+def TailCallKindTail : LLVM_EnumAttrCase<"Tail", "tail", "TCK_Tail", 1>;
+def TailCallKindMustTail : LLVM_EnumAttrCase<"MustTail", "musttail", "TCK_MustTail", 2>;
+def TailCallKindNoTailCall : LLVM_EnumAttrCase<"NoTail", "notail", "TCK_NoTail", 3>;
+
+def TailCallKindEnum : LLVM_EnumAttr<
+ "TailCallKind",
+ "::llvm::CallInst::TailCallKind",
+ "Tail Call Kind",
+ [TailCallKindNone, TailCallKindNoTailCall,
+ TailCallKindMustTail, TailCallKindTail]> {
+ let cppNamespace = "::mlir::LLVM::tailcallkind";
+}
+
+def TailCallKind : DialectAttr<
+ LLVM_Dialect,
+ CPred<"::llvm::isa<::mlir::LLVM::TailCallKindAttr>($_self)">,
+ "LLVM Calling Convention specification"> {
+ let storageType = "::mlir::LLVM::TailCallKindAttr";
+ let returnType = "::mlir::LLVM::tailcallkind::TailCallKind";
+ let convertFromStorage = "$_self.getTailCallKind()";
+ let constBuilderCall =
+ "::mlir::LLVM::TailCallKindAttr::get($_builder.getContext(), $0)";
+}
+
//===----------------------------------------------------------------------===//
// DIEmissionKind
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3774bda05eb2be..807dc41e11d28d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -650,7 +650,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
- DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+ DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+ DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index 3a7ad87b161eea..36502cc304b695 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -12,6 +12,7 @@
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
+#include <vector>
namespace mlir {
class Dialect;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d1280aceeb7b6c..fdeb55ef0c17bb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -44,6 +44,7 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::cconv::getMaxEnumValForCConv;
using mlir::LLVM::linkage::getMaxEnumValForLinkage;
+using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
@@ -197,6 +198,7 @@ struct EnumTraits {};
REGISTER_ENUM_TYPE(Linkage);
REGISTER_ENUM_TYPE(UnnamedAddr);
REGISTER_ENUM_TYPE(CConv);
+REGISTER_ENUM_TYPE(TailCallKind);
REGISTER_ENUM_TYPE(Visibility);
} // namespace
@@ -974,7 +976,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
build(builder, state, results,
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
- /*CConv=*/nullptr,
+ /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -997,7 +999,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
- /*access_groups=*/nullptr,
+ /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1006,7 +1008,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
- /*CConv=*/nullptr,
+ /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1017,7 +1019,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
- /*CConv=*/nullptr,
+ /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1180,6 +1182,9 @@ void CallOp::print(OpAsmPrinter &p) {
if (getCConv() != LLVM::CConv::C)
p << stringifyCConv(getCConv()) << ' ';
+ if(getTailCallKind() != LLVM::TailCallKind::None)
+ p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
+
// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
if (isDirect)
@@ -1194,7 +1199,8 @@ void CallOp::print(OpAsmPrinter &p) {
p << " vararg(" << calleeType << ")";
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
- {getCConvAttrName(), "callee", "callee_type"});
+ {getCConvAttrName(), "callee", "callee_type",
+ getTailCallKindAttrName()});
p << " : ";
if (!isDirect)
@@ -1262,7 +1268,7 @@ static ParseResult parseOptionalCallFuncPtr(
return success();
}
-// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
+// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
// ( `vararg(` var-arg-func-type `)` )?
// attribute-dict? `:` (type `,`)? function-type
@@ -1277,6 +1283,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
parser, result, LLVM::CConv::C)));
+ result.addAttribute(
+ getTailCallKindAttrName(result.name),
+ TailCallKindAttr::get(parser.getContext(),
+ parseOptionalLLVMKeyword<TailCallKind>(
+ parser, result, LLVM::TailCallKind::None)));
+
// Parse a function pointer for indirect calls.
if (parseOptionalCallFuncPtr(parser, operands))
return failure();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index f144c7158d6796..3d6dd1247b4136 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -218,6 +218,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front());
}
call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
+ call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
moduleTranslation.setTBAAMetadata(callOp, call);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 2386dde19301e3..ca9748a2b8b7bc 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -673,3 +673,41 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) {
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
llvm.return
}
+
+// CHECK: llvm.func @tail_call_target() -> i32
+llvm.func @tail_call_target() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+ // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+ %0 = llvm.call none @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+ // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+ %0 = llvm.call @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+ // CHECK-NEXT: llvm.call musttail @tail_call_target() : () -> i32
+ %0 = llvm.call musttail @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+ // CHECK-NEXT: llvm.call tail @tail_call_target() : () -> i32
+ %0 = llvm.call tail @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+ // CHECK-NEXT: llvm.call notail @tail_call_target() : () -> i32
+ %0 = llvm.call notail @tail_call_target() : () -> i32
+ llvm.return %0 : i32
+}
diff --git a/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
new file mode 100644
index 00000000000000..73a6aa2f91cbaa
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK: declare i32 @foo()
+llvm.func @foo() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+ // CHECK-NEXT: call i32 @foo()
+ %0 = llvm.call none @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+ // CHECK-NEXT: call i32 @foo()
+ %0 = llvm.call @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+ // CHECK-NEXT: musttail call i32 @foo()
+ %0 = llvm.call musttail @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+ // CHECK-NEXT: tail call i32 @foo()
+ %0 = llvm.call tail @foo() : () -> i32
+ llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+ // CHECK-NEXT: notail call i32 @foo()
+ %0 = llvm.call notail @foo() : () -> i32
+ llvm.return %0 : i32
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/98080
More information about the Mlir-commits
mailing list