[Mlir-commits] [mlir] Expose Tail Kind Call to MLIR (PR #98080)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 8 14:28:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Steffi Stumpos (stumpOS)

<details>
<summary>Changes</summary>

I would like to mark a call op in LLVM dialect as Musttail. The calling convention attribute only exposes Tail, not Musttail. I noticed that the CallInst of LLVM has an additional field to specify the flavor of tail call kind. I bubbled this up to the LLVM dialect by adding another attribute that maps to LLVM::CallInst::TailCallKind.

---
Full diff: https://github.com/llvm/llvm-project/pull/98080.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+10) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h (+1) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td (+29) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+2-1) 
- (modified) mlir/include/mlir/IR/DialectInterface.h (+1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+18-6) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+1) 
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+38) 
- (added) mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir (+39) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index b05366d2a635df..db43fb023b107c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1077,4 +1077,14 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
 /// Folded into from LLVM::ZeroOp.
 def LLVM_ZeroAttr : LLVM_Attr<"Zero", "zero">;
 
+
+//===----------------------------------------------------------------------===//
+// TailCallKindAttr
+//===----------------------------------------------------------------------===//
+
+def TailCallKindAttr : LLVM_Attr<"TailCallKind", "tailcallkind"> {
+  let parameters = (ins "TailCallKind":$TailCallKind);
+  let assemblyFormat = "`<` $TailCallKind `>`";
+}
+
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index 3a93be21da3756..3ede8577332422 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -89,6 +89,7 @@ class TBAANodeAttr : public Attribute {
 // TODO: this shouldn't be needed after we unify the attribute generation, i.e.
 // --gen-attr-* and --gen-attrdef-*.
 using cconv::CConv;
+using tailcallkind::TailCallKind;
 using linkage::Linkage;
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f8e85004d5f93c..f41a97f9ecc818 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -279,6 +279,35 @@ def CConv : DialectAttr<
           "::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
 }
 
+//===----------------------------------------------------------------------===//
+// TailCallKind
+//===----------------------------------------------------------------------===//
+
+def TailCallKindNone : LLVM_EnumAttrCase<"None", "none", "TCK_None", 0>;
+def TailCallKindTail : LLVM_EnumAttrCase<"Tail", "tail", "TCK_Tail", 1>;
+def TailCallKindMustTail : LLVM_EnumAttrCase<"MustTail", "musttail", "TCK_MustTail", 2>;
+def TailCallKindNoTailCall : LLVM_EnumAttrCase<"NoTail", "notail", "TCK_NoTail", 3>;
+
+def TailCallKindEnum : LLVM_EnumAttr<
+    "TailCallKind",
+    "::llvm::CallInst::TailCallKind",
+    "Tail Call Kind",
+    [TailCallKindNone, TailCallKindNoTailCall,
+    TailCallKindMustTail, TailCallKindTail]> {
+  let cppNamespace = "::mlir::LLVM::tailcallkind";
+}
+
+def TailCallKind : DialectAttr<
+    LLVM_Dialect,
+    CPred<"::llvm::isa<::mlir::LLVM::TailCallKindAttr>($_self)">,
+    "LLVM Calling Convention specification"> {
+  let storageType = "::mlir::LLVM::TailCallKindAttr";
+  let returnType = "::mlir::LLVM::tailcallkind::TailCallKind";
+  let convertFromStorage = "$_self.getTailCallKind()";
+  let constBuilderCall =
+          "::mlir::LLVM::TailCallKindAttr::get($_builder.getContext(), $0)";
+}
+
 //===----------------------------------------------------------------------===//
 // DIEmissionKind
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3774bda05eb2be..807dc41e11d28d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -650,7 +650,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
-                  DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+                  DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+                  DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index 3a7ad87b161eea..36502cc304b695 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -12,6 +12,7 @@
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
+#include <vector>
 
 namespace mlir {
 class Dialect;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d1280aceeb7b6c..fdeb55ef0c17bb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -44,6 +44,7 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using mlir::LLVM::cconv::getMaxEnumValForCConv;
 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
+using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
@@ -197,6 +198,7 @@ struct EnumTraits {};
 REGISTER_ENUM_TYPE(Linkage);
 REGISTER_ENUM_TYPE(UnnamedAddr);
 REGISTER_ENUM_TYPE(CConv);
+REGISTER_ENUM_TYPE(TailCallKind);
 REGISTER_ENUM_TYPE(Visibility);
 } // namespace
 
@@ -974,7 +976,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
   build(builder, state, results,
         TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
         callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
-        /*CConv=*/nullptr,
+        /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -997,7 +999,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
-        /*access_groups=*/nullptr,
+        /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -1006,7 +1008,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
-        /*CConv=*/nullptr,
+        /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1017,7 +1019,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
-        /*CConv=*/nullptr,
+        /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1180,6 +1182,9 @@ void CallOp::print(OpAsmPrinter &p) {
   if (getCConv() != LLVM::CConv::C)
     p << stringifyCConv(getCConv()) << ' ';
 
+  if(getTailCallKind() != LLVM::TailCallKind::None)
+    p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
+
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
   if (isDirect)
@@ -1194,7 +1199,8 @@ void CallOp::print(OpAsmPrinter &p) {
     p << " vararg(" << calleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
-                          {getCConvAttrName(), "callee", "callee_type"});
+                          {getCConvAttrName(), "callee", "callee_type",
+                           getTailCallKindAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1262,7 +1268,7 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
+// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-arg-func-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
@@ -1277,6 +1283,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
       CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
                                               parser, result, LLVM::CConv::C)));
 
+  result.addAttribute(
+      getTailCallKindAttrName(result.name),
+      TailCallKindAttr::get(parser.getContext(),
+                            parseOptionalLLVMKeyword<TailCallKind>(
+                                parser, result, LLVM::TailCallKind::None)));
+
   // Parse a function pointer for indirect calls.
   if (parseOptionalCallFuncPtr(parser, operands))
     return failure();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index f144c7158d6796..3d6dd1247b4136 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -218,6 +218,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                                 operandsRef.drop_front());
     }
     call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
+    call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);
     moduleTranslation.setTBAAMetadata(callOp, call);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 2386dde19301e3..ca9748a2b8b7bc 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -673,3 +673,41 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) {
   %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
   llvm.return
 }
+
+// CHECK: llvm.func @tail_call_target() -> i32
+llvm.func @tail_call_target() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+  // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+  %0 = llvm.call none @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+  // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+  %0 = llvm.call @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+  // CHECK-NEXT: llvm.call musttail @tail_call_target() : () -> i32
+  %0 = llvm.call musttail @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+  // CHECK-NEXT: llvm.call tail @tail_call_target() : () -> i32
+  %0 = llvm.call tail @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+  // CHECK-NEXT: llvm.call notail @tail_call_target() : () -> i32
+  %0 = llvm.call notail @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
diff --git a/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
new file mode 100644
index 00000000000000..73a6aa2f91cbaa
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK: declare i32 @foo()
+llvm.func @foo() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+  // CHECK-NEXT: call i32 @foo()
+  %0 = llvm.call none @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+  // CHECK-NEXT: call i32 @foo()
+  %0 = llvm.call @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+  // CHECK-NEXT: musttail call i32 @foo()
+  %0 = llvm.call musttail @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+  // CHECK-NEXT: tail call i32 @foo()
+  %0 = llvm.call tail @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+  // CHECK-NEXT: notail call i32 @foo()
+  %0 = llvm.call notail @foo() : () -> i32
+  llvm.return %0 : i32
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/98080


More information about the Mlir-commits mailing list