[Mlir-commits] [mlir] Expose Tail Kind Call to MLIR (PR #98080)

Steffi Stumpos llvmlistbot at llvm.org
Tue Jul 9 08:30:14 PDT 2024


https://github.com/stumpOS updated https://github.com/llvm/llvm-project/pull/98080

>From e3738c0eea822b1b4b06f59b9858105d31d257de Mon Sep 17 00:00:00 2001
From: Steffi Stumpos <stumposs12 at gmail.com>
Date: Tue, 2 Jul 2024 17:56:42 -0600
Subject: [PATCH] expose tail kind call to mlir

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  9 +++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  |  1 +
 mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 29 ++++++++++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  3 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 24 +++++++++---
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  1 +
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       | 38 ++++++++++++++++++
 mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir | 39 +++++++++++++++++++
 8 files changed, 137 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index b05366d2a635d..25a6ee27b01db 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1077,4 +1077,13 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
 /// Folded into from LLVM::ZeroOp.
 def LLVM_ZeroAttr : LLVM_Attr<"Zero", "zero">;
 
+//===----------------------------------------------------------------------===//
+// TailCallKindAttr
+//===----------------------------------------------------------------------===//
+
+def TailCallKindAttr : LLVM_Attr<"TailCallKind", "tailcallkind"> {
+  let parameters = (ins "TailCallKind":$TailCallKind);
+  let assemblyFormat = "`<` $TailCallKind `>`";
+}
+
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index 3a93be21da375..3ede857733242 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -89,6 +89,7 @@ class TBAANodeAttr : public Attribute {
 // TODO: this shouldn't be needed after we unify the attribute generation, i.e.
 // --gen-attr-* and --gen-attrdef-*.
 using cconv::CConv;
+using tailcallkind::TailCallKind;
 using linkage::Linkage;
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f8e85004d5f93..f41a97f9ecc81 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -279,6 +279,35 @@ def CConv : DialectAttr<
           "::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
 }
 
+//===----------------------------------------------------------------------===//
+// TailCallKind
+//===----------------------------------------------------------------------===//
+
+def TailCallKindNone : LLVM_EnumAttrCase<"None", "none", "TCK_None", 0>;
+def TailCallKindTail : LLVM_EnumAttrCase<"Tail", "tail", "TCK_Tail", 1>;
+def TailCallKindMustTail : LLVM_EnumAttrCase<"MustTail", "musttail", "TCK_MustTail", 2>;
+def TailCallKindNoTailCall : LLVM_EnumAttrCase<"NoTail", "notail", "TCK_NoTail", 3>;
+
+def TailCallKindEnum : LLVM_EnumAttr<
+    "TailCallKind",
+    "::llvm::CallInst::TailCallKind",
+    "Tail Call Kind",
+    [TailCallKindNone, TailCallKindNoTailCall,
+    TailCallKindMustTail, TailCallKindTail]> {
+  let cppNamespace = "::mlir::LLVM::tailcallkind";
+}
+
+def TailCallKind : DialectAttr<
+    LLVM_Dialect,
+    CPred<"::llvm::isa<::mlir::LLVM::TailCallKindAttr>($_self)">,
+    "LLVM Calling Convention specification"> {
+  let storageType = "::mlir::LLVM::TailCallKindAttr";
+  let returnType = "::mlir::LLVM::tailcallkind::TailCallKind";
+  let convertFromStorage = "$_self.getTailCallKind()";
+  let constBuilderCall =
+          "::mlir::LLVM::TailCallKindAttr::get($_builder.getContext(), $0)";
+}
+
 //===----------------------------------------------------------------------===//
 // DIEmissionKind
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3774bda05eb2b..807dc41e11d28 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -650,7 +650,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
-                  DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+                  DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+                  DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d1280aceeb7b6..fdeb55ef0c17b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -44,6 +44,7 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using mlir::LLVM::cconv::getMaxEnumValForCConv;
 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
+using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
@@ -197,6 +198,7 @@ struct EnumTraits {};
 REGISTER_ENUM_TYPE(Linkage);
 REGISTER_ENUM_TYPE(UnnamedAddr);
 REGISTER_ENUM_TYPE(CConv);
+REGISTER_ENUM_TYPE(TailCallKind);
 REGISTER_ENUM_TYPE(Visibility);
 } // namespace
 
@@ -974,7 +976,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
   build(builder, state, results,
         TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
         callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
-        /*CConv=*/nullptr,
+        /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -997,7 +999,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
-        /*access_groups=*/nullptr,
+        /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -1006,7 +1008,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
-        /*CConv=*/nullptr,
+        /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1017,7 +1019,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
-        /*CConv=*/nullptr,
+        /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1180,6 +1182,9 @@ void CallOp::print(OpAsmPrinter &p) {
   if (getCConv() != LLVM::CConv::C)
     p << stringifyCConv(getCConv()) << ' ';
 
+  if(getTailCallKind() != LLVM::TailCallKind::None)
+    p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
+
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
   if (isDirect)
@@ -1194,7 +1199,8 @@ void CallOp::print(OpAsmPrinter &p) {
     p << " vararg(" << calleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
-                          {getCConvAttrName(), "callee", "callee_type"});
+                          {getCConvAttrName(), "callee", "callee_type",
+                           getTailCallKindAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1262,7 +1268,7 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
+// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-arg-func-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
@@ -1277,6 +1283,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
       CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
                                               parser, result, LLVM::CConv::C)));
 
+  result.addAttribute(
+      getTailCallKindAttrName(result.name),
+      TailCallKindAttr::get(parser.getContext(),
+                            parseOptionalLLVMKeyword<TailCallKind>(
+                                parser, result, LLVM::TailCallKind::None)));
+
   // Parse a function pointer for indirect calls.
   if (parseOptionalCallFuncPtr(parser, operands))
     return failure();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index f144c7158d679..3d6dd1247b413 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -218,6 +218,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                                 operandsRef.drop_front());
     }
     call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
+    call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);
     moduleTranslation.setTBAAMetadata(callOp, call);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 2386dde19301e..ca9748a2b8b7b 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -673,3 +673,41 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) {
   %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
   llvm.return
 }
+
+// CHECK: llvm.func @tail_call_target() -> i32
+llvm.func @tail_call_target() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+  // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+  %0 = llvm.call none @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+  // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
+  %0 = llvm.call @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+  // CHECK-NEXT: llvm.call musttail @tail_call_target() : () -> i32
+  %0 = llvm.call musttail @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+  // CHECK-NEXT: llvm.call tail @tail_call_target() : () -> i32
+  %0 = llvm.call tail @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+  // CHECK-NEXT: llvm.call notail @tail_call_target() : () -> i32
+  %0 = llvm.call notail @tail_call_target() : () -> i32
+  llvm.return %0 : i32
+}
diff --git a/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
new file mode 100644
index 0000000000000..73a6aa2f91cba
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK: declare i32 @foo()
+llvm.func @foo() -> i32
+
+// CHECK-LABEL: @test_none
+llvm.func @test_none() -> i32 {
+  // CHECK-NEXT: call i32 @foo()
+  %0 = llvm.call none @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_default
+llvm.func @test_default() -> i32 {
+  // CHECK-NEXT: call i32 @foo()
+  %0 = llvm.call @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_musttail
+llvm.func @test_musttail() -> i32 {
+  // CHECK-NEXT: musttail call i32 @foo()
+  %0 = llvm.call musttail @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_tail
+llvm.func @test_tail() -> i32 {
+  // CHECK-NEXT: tail call i32 @foo()
+  %0 = llvm.call tail @foo() : () -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @test_notail
+llvm.func @test_notail() -> i32 {
+  // CHECK-NEXT: notail call i32 @foo()
+  %0 = llvm.call notail @foo() : () -> i32
+  llvm.return %0 : i32
+}



More information about the Mlir-commits mailing list