[Mlir-commits] [mlir] [MLIR][LLVM] Tail call support for inline asm op (PR #140826)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 20 17:48:32 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Bruno Cardoso Lopes (bcardosolopes)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/140826.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11-3) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+2-1) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp (+1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+15) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp (+2-1) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+4) 
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+4) 
- (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+8) 
- (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+2-2) 
- (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+8-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 61ba8f7b991c8..ba13806a38a4c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -758,9 +758,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     the LLVM function type that uses an explicit void type to model functions
     that do not return a value.
 
-    If this operatin has the `no_inline` attribute, then this specific function call 
-    will never be inlined. The opposite behavior will occur if the call has `always_inline` 
-    attribute. The `inline_hint` attribute indicates that it is desirable to inline 
+    If this operatin has the `no_inline` attribute, then this specific function call
+    will never be inlined. The opposite behavior will occur if the call has `always_inline`
+    attribute. The `inline_hint` attribute indicates that it is desirable to inline
     this function call.
 
     Examples:
@@ -2298,6 +2298,9 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
     written, or referenced.
     Attempting to define or reference any symbol or any global behavior is
     considered undefined behavior at this time.
+    If `tail_call_kind` is used, the operation behaves like the specified
+    tail call kind. The `musttail` kind it's not available for this operation,
+    since it isn't supported by LLVM's inline asm.
   }];
   let arguments = (
     ins Variadic<LLVM_Type>:$operands,
@@ -2305,6 +2308,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
         StrAttr:$constraints,
         UnitAttr:$has_side_effects,
         UnitAttr:$is_align_stack,
+        OptionalAttr<
+        DefaultValuedAttr<TailCallKind, "TailCallKind::None">>:$tail_call_kind,
         OptionalAttr<
           DefaultValuedAttr<AsmATTOrIntel, "AsmDialect::AD_ATT">>:$asm_dialect,
         OptionalAttr<ArrayAttr>:$operand_attrs);
@@ -2314,6 +2319,7 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
   let assemblyFormat = [{
     (`has_side_effects` $has_side_effects^)?
     (`is_align_stack` $is_align_stack^)?
+    (`tail_call_kind` `=` $tail_call_kind^)?
     (`asm_dialect` `=` $asm_dialect^)?
     (`operand_attrs` `=` $operand_attrs^)?
     attr-dict
@@ -2326,6 +2332,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
       return "elementtype";
     }
   }];
+
+  let hasVerifier = 1;
 }
 
 //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0694cf27faff4..ae07666574260 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -439,7 +439,8 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
           op,
           /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
           /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
-          /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+          /*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
+          /*asm_dialect=*/asmDialectAttr,
           /*operand_attrs=*/ArrayAttr());
       return success();
     }
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index eb3558d2460e4..c5e5094a35010 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -572,6 +572,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
       /*constraints=*/constraintStr,
       /*has_side_effects=*/true,
       /*is_align_stack=*/false,
+      /*tail_call_kind=*/nullptr,
       /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index f2e71e7795c3e..9b83fd190b9eb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -140,6 +140,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
       /*constraints=*/registerConstraints.data(),
       /*has_side_effects=*/interfaceOp.hasSideEffect(),
       /*is_align_stack=*/false,
+      /*tail_call_kind=*/nullptr,
       /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d8abf6fd41301..c7528c970a4ba 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -4042,6 +4042,21 @@ LogicalResult LLVM::masked_scatter::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// InlineAsmOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult InlineAsmOp::verify() {
+  if (!getTailCallKindAttr())
+    return success();
+
+  if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
+    return emitOpError(
+        "tail call kind 'musttail' is not supported by this operation");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LLVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index 3fc05c8cb8707..5bce7c7625e8d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -41,7 +41,8 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
   auto asmOp = b.create<LLVM::InlineAsmOp>(
       v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
       /*constraints=*/asmCstr, /*has_side_effects=*/false,
-      /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+      /*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
+      /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
   return asmOp.getResult(0);
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index c954dffb376bb..37387130ee012 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/Operator.h"
@@ -507,6 +508,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::CallInst *inst = builder.CreateCall(
         inlineAsmInst,
         moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
+    if (inlineAsmOp.getTailCallKindAttr())
+      inst->setTailCallKind(convertTailCallKindToLLVM(
+          inlineAsmOp.getTailCallKindAttr().getTailCallKind()));
     if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
       llvm::AttributeList attrList;
       for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8945ae933dd65..8041480ac687f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2201,6 +2201,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                 builder.getStringAttr(asmI->getAsmString()),
                 builder.getStringAttr(asmI->getConstraintString()),
                 asmI->hasSideEffects(), asmI->isAlignStack(),
+                callInst->isTailCall()
+                    ? TailCallKindAttr::get(mlirModule.getContext(),
+                                            TailCallKind::Tail)
+                    : nullptr,
                 AsmDialectAttr::get(
                     mlirModule.getContext(),
                     convertAsmDialectFromLLVM(asmI->getDialect())),
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f5adf4b3bf33d..251ca716c7a7a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1882,3 +1882,11 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
   %0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64>
   llvm.return %0 : !llvm.array<2 x f64>
 }
+
+// ----
+
+llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
+  // expected-error at +1 {{op tail call kind 'musttail' is not supported}}
+  %8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 68ef47c3f42f1..d92abd4408249 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -554,8 +554,8 @@ define i32 @inlineasm(i32 %arg1) {
 define void @inlineasm2() {
   %p = alloca ptr, align 8
   ; CHECK: {{.*}} = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
-  ; CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
-  call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
+  ; CHECK-NEXT: llvm.inline_asm has_side_effects tail_call_kind = <tail> asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
+  tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
   ret void
 }
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 237612244d8de..7742259e7a478 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2081,8 +2081,14 @@ llvm.func @useInlineAsm(%arg0: i32, %arg1 : !llvm.ptr) {
   // CHECK-NEXT:  call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
   %5 = llvm.inline_asm "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
 
-  // CHECK-NEXT:  call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
-  %6 = llvm.inline_asm has_side_effects operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
+  // CHECK-NEXT:  tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
+  %6 = llvm.inline_asm has_side_effects tail_call_kind = <tail> operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
+
+  // CHECK-NEXT:  = call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
+  %7 = llvm.inline_asm tail_call_kind = <none> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
+
+  // CHECK-NEXT:  notail call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
+  %8 = llvm.inline_asm tail_call_kind = <notail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
 
   llvm.return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/140826


More information about the Mlir-commits mailing list