[Mlir-commits] [llvm] [mlir] [mlir][LLVM] add argument and result attributes to llvm.call (PR #123177)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 6 02:30:41 PST 2025


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/123177

>From f062d7811e81a4be97811d6addd53187cf11e96d Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 3 Feb 2025 08:31:01 -0800
Subject: [PATCH 1/3] [mlir][LLVM] add argument and result attributes to
 llvm.call

---
 llvm/include/llvm/IR/InstrTypes.h             | 11 ++++
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  8 ++-
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  9 ++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 64 ++++++++++++-------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        | 21 ++++++
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 34 ++++++++++
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 52 +++++++++++----
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  2 +
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       | 20 ++++++
 .../LLVMIR/Import/call-argument-attributes.ll | 22 +++++++
 .../LLVMIR/call-argument-attributes.mlir      | 17 +++++
 11 files changed, 220 insertions(+), 40 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
 create mode 100644 mlir/test/Target/LLVMIR/call-argument-attributes.mlir

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 26be02d4b193de6..90fe864d4ae71bf 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
     Attrs = Attrs.addRetAttribute(getContext(), Attr);
   }
 
+  /// Adds attributes to the return value.
+  void addRetAttrs(const AttrBuilder &B) {
+    Attrs = Attrs.addRetAttributes(getContext(), B);
+  }
+
   /// Adds the attribute to the indicated argument
   void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
     assert(ArgNo < arg_size() && "Out of bounds");
@@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
     Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
   }
 
+  /// Adds attributes to the indicated argument
+  void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
+    assert(ArgNo < arg_size() && "Out of bounds");
+    Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
+  }
+
   /// removes the attribute from the list of attributes.
   void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
     Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 80ae4d679624c25..d09c73c2f467dec 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -335,14 +335,18 @@ class ModuleImport {
   FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
   /// Returns the callee name, or an empty symbol if the call is not direct.
   FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
-  /// Converts the parameter attributes attached to `func` and adds them to
-  /// the `funcOp`.
+  /// Converts the parameter and result attributes attached to `func` and adds
+  /// them to the `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
                                   OpBuilder &builder);
   /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
   /// DictionaryAttr for the LLVM dialect.
   DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
                                            OpBuilder &builder);
+  /// Converts the parameter and result attributes attached to `call` and adds
+  /// them to the `callOp`.
+  void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
+                                  OpBuilder &builder);
   /// Converts the attributes attached to `inst` and adds them to the `op`.
   LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
   /// Converts the attributes attached to `inst` and adds them to the `op`.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 1b62437761ed9d2..88fc17ca4fda248 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -228,6 +228,11 @@ class ModuleTranslation {
                             /*recordInsertions=*/false);
   }
 
+  /// Translates parameter attributes of a call and adds them to the returned
+  /// AttrBuilder. Returns failure if any of the translations failed.
+  FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp, int argIdx,
+                                                     DictionaryAttr paramAttrs);
+
   /// Gets the named metadata in the LLVM IR module being constructed, creating
   /// it if it does not exist.
   llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
@@ -346,8 +351,8 @@ class ModuleTranslation {
   convertDialectAttributes(Operation *op,
                            ArrayRef<llvm::Instruction *> instructions);
 
-  /// Translates parameter attributes and adds them to the returned AttrBuilder.
-  /// Returns failure if any of the translations failed.
+  /// Translates parameter attributes of a function and adds them to the
+  /// returned AttrBuilder. Returns failure if any of the translations failed.
   FailureOr<llvm::AttrBuilder>
   convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a6e996f3fb810db..25d45f70b09ac54 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1335,42 +1335,53 @@ void CallOp::print(OpAsmPrinter &p) {
                            getVarCalleeTypeAttrName(), getCConvAttrName(),
                            getOperandSegmentSizesAttrName(),
                            getOpBundleSizesAttrName(),
-                           getOpBundleTagsAttrName()});
+                           getOpBundleTagsAttrName(), getArgAttrsAttrName(),
+                           getResAttrsAttrName()});
 
   p << " : ";
   if (!isDirect)
     p << getOperand(0).getType() << ", ";
 
   // Reconstruct the function MLIR function type from operand and result types.
-  p.printFunctionalType(args.getTypes(), getResultTypes());
+  call_interface_impl::printFunctionSignature(
+      p, args.getTypes(), getArgAttrsAttr(),
+      /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
 }
 
 /// Parses the type of a call operation and resolves the operands if the parsing
 /// succeeds. Returns failure otherwise.
 static ParseResult parseCallTypeAndResolveOperands(
     OpAsmParser &parser, OperationState &result, bool isDirect,
-    ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+    ArrayRef<OpAsmParser::UnresolvedOperand> operands,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
   SMLoc trailingTypesLoc = parser.getCurrentLocation();
   SmallVector<Type> types;
-  if (parser.parseColonTypeList(types))
+  if (parser.parseColon())
     return failure();
-
-  if (isDirect && types.size() != 1)
-    return parser.emitError(trailingTypesLoc,
-                            "expected direct call to have 1 trailing type");
-  if (!isDirect && types.size() != 2)
-    return parser.emitError(trailingTypesLoc,
-                            "expected indirect call to have 2 trailing types");
-
-  auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
-  if (!funcType)
+  if (!isDirect) {
+    types.emplace_back();
+    if (parser.parseType(types.back()))
+      return failure();
+    if (parser.parseOptionalComma())
+      return parser.emitError(
+          trailingTypesLoc, "expected indirect call to have 2 trailing types");
+  }
+  SmallVector<Type> argTypes;
+  SmallVector<Type> resTypes;
+  if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
+                                                  resTypes, resultAttrs)) {
+    if (isDirect)
+      return parser.emitError(trailingTypesLoc,
+                              "expected direct call to have 1 trailing types");
     return parser.emitError(trailingTypesLoc,
                             "expected trailing function type");
-  if (funcType.getNumResults() > 1)
+  }
+
+  if (resTypes.size() > 1)
     return parser.emitError(trailingTypesLoc,
                             "expected function with 0 or 1 result");
-  if (funcType.getNumResults() == 1 &&
-      llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
+  if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
     return parser.emitError(trailingTypesLoc,
                             "expected a non-void result type");
 
@@ -1378,12 +1389,12 @@ static ParseResult parseCallTypeAndResolveOperands(
   // indirect calls, while the types list is emtpy for direct calls.
   // Append the function input types to resolve the call operation
   // operands.
-  llvm::append_range(types, funcType.getInputs());
+  llvm::append_range(types, argTypes);
   if (parser.resolveOperands(operands, types, parser.getNameLoc(),
                              result.operands))
     return failure();
-  if (funcType.getNumResults() != 0)
-    result.addTypes(funcType.getResults());
+  if (resTypes.size() != 0)
+    result.addTypes(resTypes);
 
   return success();
 }
@@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  SmallVector<DictionaryAttr> argAttrs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      argAttrs, resultAttrs))
     return failure();
+  call_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, argAttrs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
                               opBundleOperandTypes,
                               getOpBundleSizesAttrName(result.name)))
@@ -1721,7 +1738,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  SmallVector<DictionaryAttr> argAttrs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      argAttrs, resultAttrs))
     return failure();
   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
                               opBundleOperandTypes,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 2084e527773ca82..52f42df60f0015f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -265,6 +265,27 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     if (callOp.getWillReturnAttr())
       call->addFnAttr(llvm::Attribute::WillReturn);
 
+    if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr())
+      for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
+        if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
+          FailureOr<llvm::AttrBuilder> attrBuilder =
+              moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs);
+          if (failed(attrBuilder))
+            return failure();
+          call->addParamAttrs(argIdx, *attrBuilder);
+        }
+      }
+
+    ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
+    if (resAttrsArray && resAttrsArray.size() == 1)
+      if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
+        FailureOr<llvm::AttrBuilder> attrBuilder =
+            moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs);
+        if (failed(attrBuilder))
+          return failure();
+        call->addRetAttrs(*attrBuilder);
+      }
+
     if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
       llvm::MemoryEffects memEffects =
           llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 5ebde22cccbdf3e..8d779c5083eb601 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1706,6 +1706,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
       if (failed(convertCallAttributes(callInst, callOp)))
         return failure();
+      // Handle parameter and result attributes.
+      convertParameterAttributes(callInst, callOp, builder);
       return callOp.getOperation();
     }();
 
@@ -2149,6 +2151,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
+void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
+                                              CallOpInterface callOp,
+                                              OpBuilder &builder) {
+  auto llvmAttrs = call->getAttributes();
+  SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
+  bool anyArgAttrs = false;
+  for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+    llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
+    if (llvmArgAttrsSet.back().hasAttributes())
+      anyArgAttrs = true;
+  }
+  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+    SmallVector<Attribute> attrs;
+    for (auto &dict : dictAttrs)
+      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+    return builder.getArrayAttr(attrs);
+  };
+  if (anyArgAttrs) {
+    SmallVector<DictionaryAttr> argAttrs;
+    for (auto &llvmArgAttrs : llvmArgAttrsSet)
+      argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
+    callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
+  }
+
+  llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
+  if (!llvmResAttr.hasAttributes())
+    return;
+  SmallVector<DictionaryAttr, 1> resAttrs;
+  resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder));
+  callOp.setResAttrsAttr(getArrayAttr(resAttrs));
+}
+
 template <typename Op>
 static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
   op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 4367100e3aca682..b2d2c1cddca318a 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1563,6 +1563,26 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func,
   }
 }
 
+static void convertParameterAttr(llvm::AttrBuilder &attrBuilder,
+                                 llvm::Attribute::AttrKind llvmKind,
+                                 NamedAttribute namedAttr,
+                                 ModuleTranslation &moduleTranslation) {
+  llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+      .Case<TypeAttr>([&](auto typeAttr) {
+        attrBuilder.addTypeAttr(
+            llvmKind, moduleTranslation.convertType(typeAttr.getValue()));
+      })
+      .Case<IntegerAttr>([&](auto intAttr) {
+        attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+      })
+      .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
+      .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
+        attrBuilder.addConstantRangeAttr(
+            llvmKind,
+            llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
+      });
+}
+
 FailureOr<llvm::AttrBuilder>
 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
                                          DictionaryAttr paramAttrs) {
@@ -1573,20 +1593,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
     auto it = attrNameToKindMapping.find(namedAttr.getName());
     if (it != attrNameToKindMapping.end()) {
       llvm::Attribute::AttrKind llvmKind = it->second;
-
-      llvm::TypeSwitch<Attribute>(namedAttr.getValue())
-          .Case<TypeAttr>([&](auto typeAttr) {
-            attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
-          })
-          .Case<IntegerAttr>([&](auto intAttr) {
-            attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
-          })
-          .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
-          .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
-            attrBuilder.addConstantRangeAttr(
-                llvmKind, llvm::ConstantRange(rangeAttr.getLower(),
-                                              rangeAttr.getUpper()));
-          });
+      convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
     } else if (namedAttr.getNameDialect()) {
       if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
         return failure();
@@ -1596,6 +1603,23 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
   return attrBuilder;
 }
 
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(CallOp, int argIdx,
+                                         DictionaryAttr paramAttrs) {
+  llvm::AttrBuilder attrBuilder(llvmModule->getContext());
+  auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+  for (auto namedAttr : paramAttrs) {
+    auto it = attrNameToKindMapping.find(namedAttr.getName());
+    if (it != attrNameToKindMapping.end()) {
+      llvm::Attribute::AttrKind llvmKind = it->second;
+      convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
+    }
+  }
+
+  return attrBuilder;
+}
+
 LogicalResult ModuleTranslation::convertFunctionSignatures() {
   // Declare all functions first because there may be function calls that form a
   // call graph with cycles, or global initializers that reference functions.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 5c939318fe3ed67..76c57e76f849355 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 func.func private @standard_func_callee()
 
 func.func @call_missing_ptr_type(%arg : i8) {
+  // expected-error at +2 {{expected '('}}
   // expected-error at +1 {{expected direct call to have 1 trailing type}}
   llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8)
   llvm.return
@@ -251,6 +252,7 @@ func.func @call_non_pointer_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 // -----
 
 func.func @call_non_function_type(%callee : !llvm.ptr, %arg : i8) {
+  // expected-error at +2 {{expected '('}}
   // expected-error at +1 {{expected trailing function type}}
   llvm.call %callee(%arg) : !llvm.ptr, !llvm.func<i8 (i8)>
   llvm.return
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 88660ce598f3c22..e565772f06b03c3 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -941,3 +941,23 @@ llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) {
   llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
   llvm.return
 }
+
+llvm.func @somefunc(i32, !llvm.ptr)
+
+// CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
+// CHECK-SAME:    %[[VAL_0:.*]]: i32,
+// CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr)
+llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
+  // CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
+// CHECK-SAME:    %[[VAL_0:.*]]: i16,
+// CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr
+llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
+  // CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  llvm.return %0 : i16
+}
diff --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
new file mode 100644
index 000000000000000..2c86ca6b03125e4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
@@ -0,0 +1,22 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK-LABEL: llvm.func @somefunc(i32, !llvm.ptr)
+declare void @somefunc(i32, ptr)
+
+; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
+; CHECK-SAME:    %[[VAL_0:.*]]: i32,
+; CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr)
+define void @test_call_arg_attrs_direct(i32 %0, ptr %1) {
+  ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  call void @somefunc(i32 %0, ptr byval(i64) %1)
+  ret void
+}
+
+; CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
+; CHECK-SAME:    %[[VAL_0:.*]]: i16,
+; CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr
+define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) {
+; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %3 = tail call signext i16 %1(i16 noundef signext %0)
+  ret i16 %3
+}
diff --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
new file mode 100644
index 000000000000000..89b1f29a68623b7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @somefunc(i32, !llvm.ptr)
+
+// CHECK-LABEL: define void @test_call_arg_attrs_direct
+llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
+  // CHECK: call void @somefunc(i32 %{{.*}}, ptr byval(i64) %{{.*}})
+  llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: define i16 @test_call_arg_attrs_indirec
+llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
+  // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}})
+  %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  llvm.return %0 : i16
+}

>From b9a67a93fc4e7a292e4be66d3eb3a5393e9da412 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 3 Feb 2025 07:20:43 -0800
Subject: [PATCH 2/3] remove argIdx arg + style nits

---
 mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h    |  2 +-
 .../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp  | 10 ++++++----
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp           |  3 +--
 mlir/test/Target/LLVMIR/call-argument-attributes.mlir  |  2 +-
 4 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 88fc17ca4fda248..25f17ba4f6a35ac 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -230,7 +230,7 @@ class ModuleTranslation {
 
   /// Translates parameter attributes of a call and adds them to the returned
   /// AttrBuilder. Returns failure if any of the translations failed.
-  FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp, int argIdx,
+  FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp,
                                                      DictionaryAttr paramAttrs);
 
   /// Gets the named metadata in the LLVM IR module being constructed, creating
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 52f42df60f0015f..822d6d7bb467b20 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -265,26 +265,28 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     if (callOp.getWillReturnAttr())
       call->addFnAttr(llvm::Attribute::WillReturn);
 
-    if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr())
+    if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
       for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
         if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
           FailureOr<llvm::AttrBuilder> attrBuilder =
-              moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs);
+              moduleTranslation.convertParameterAttrs(callOp, argAttrs);
           if (failed(attrBuilder))
             return failure();
           call->addParamAttrs(argIdx, *attrBuilder);
         }
       }
+    }
 
     ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
-    if (resAttrsArray && resAttrsArray.size() == 1)
+    if (resAttrsArray && resAttrsArray.size() == 1) {
       if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
         FailureOr<llvm::AttrBuilder> attrBuilder =
-            moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs);
+            moduleTranslation.convertParameterAttrs(callOp, resAttrs);
         if (failed(attrBuilder))
           return failure();
         call->addRetAttrs(*attrBuilder);
       }
+    }
 
     if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
       llvm::MemoryEffects memEffects =
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b2d2c1cddca318a..5cee1fc5c1cf443 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1604,8 +1604,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
 }
 
 FailureOr<llvm::AttrBuilder>
-ModuleTranslation::convertParameterAttrs(CallOp, int argIdx,
-                                         DictionaryAttr paramAttrs) {
+ModuleTranslation::convertParameterAttrs(CallOp, DictionaryAttr paramAttrs) {
   llvm::AttrBuilder attrBuilder(llvmModule->getContext());
   auto attrNameToKindMapping = getAttrNameToKindMapping();
 
diff --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
index 89b1f29a68623b7..b3d286dcda50456 100644
--- a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
+++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
@@ -9,7 +9,7 @@ llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
   llvm.return
 }
 
-// CHECK-LABEL: define i16 @test_call_arg_attrs_indirec
+// CHECK-LABEL: define i16 @test_call_arg_attrs_indirect
 llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
   // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}})
   %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})

>From 25a74a909bce87e19a0a842572e3865894f96f43 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 6 Feb 2025 01:17:06 -0800
Subject: [PATCH 3/3] handle invoke parameter attributes in import, export,
 pretty printing

---
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  3 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 17 ++++--
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        | 55 +++++++++++--------
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  3 +
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  2 +-
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       | 25 +++++++++
 .../Import/invoke-argument-attributes.ll      | 26 +++++++++
 .../LLVMIR/invoke-argument-attributes.mlir    | 25 +++++++++
 8 files changed, 126 insertions(+), 30 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll
 create mode 100644 mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir

diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 25f17ba4f6a35ac..3ad7f1e33f0a3f5 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -230,8 +230,7 @@ class ModuleTranslation {
 
   /// Translates parameter attributes of a call and adds them to the returned
   /// AttrBuilder. Returns failure if any of the translations failed.
-  FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp,
-                                                     DictionaryAttr paramAttrs);
+  FailureOr<llvm::AttrBuilder> convertParameterAttrs(DictionaryAttr paramAttrs);
 
   /// Gets the named metadata in the LLVM IR module being constructed, creating
   /// it if it does not exist.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 25d45f70b09ac54..bea90acc9364ff8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1660,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
                           {getCalleeAttrName(), getOperandSegmentSizeAttr(),
                            getCConvAttrName(), getVarCalleeTypeAttrName(),
                            getOpBundleSizesAttrName(),
-                           getOpBundleTagsAttrName()});
+                           getOpBundleTagsAttrName(), getArgAttrsAttrName(),
+                           getResAttrsAttrName()});
 
   p << " : ";
   if (!isDirect)
     p << getOperand(0).getType() << ", ";
-  p.printFunctionalType(
-      llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
-      getResultTypes());
+  call_interface_impl::printFunctionSignature(
+      p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
+      getArgAttrsAttr(),
+      /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
 }
 
 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
@@ -1676,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  ( `vararg(` var-callee-type `)` )?
 //                  ( `[` op-bundles-list `]` )?
-//                  attribute-dict? `:` (type `,`)? function-type
+//                  attribute-dict? `:` (type `,`)?
+//                  function-type-with-argument-attributes
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
@@ -1743,6 +1746,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
                                       argAttrs, resultAttrs))
     return failure();
+  call_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, argAttrs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+
   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
                               opBundleOperandTypes,
                               getOpBundleSizesAttrName(result.name)))
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 822d6d7bb467b20..fc295ab7e8e1de9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -224,6 +224,34 @@ static void convertLinkerOptionsOp(ArrayAttr options,
   linkerMDNode->addOperand(listMDNode);
 }
 
+static LogicalResult
+convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
+                               LLVM::ModuleTranslation &moduleTranslation) {
+  if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
+    for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
+      if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
+        FailureOr<llvm::AttrBuilder> attrBuilder =
+            moduleTranslation.convertParameterAttrs(argAttrs);
+        if (failed(attrBuilder))
+          return failure();
+        call->addParamAttrs(argIdx, *attrBuilder);
+      }
+    }
+  }
+
+  ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
+  if (resAttrsArray && resAttrsArray.size() == 1) {
+    if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
+      FailureOr<llvm::AttrBuilder> attrBuilder =
+          moduleTranslation.convertParameterAttrs(resAttrs);
+      if (failed(attrBuilder))
+        return failure();
+      call->addRetAttrs(*attrBuilder);
+    }
+  }
+  return success();
+}
+
 static LogicalResult
 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
@@ -265,28 +293,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     if (callOp.getWillReturnAttr())
       call->addFnAttr(llvm::Attribute::WillReturn);
 
-    if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
-      for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
-        if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
-          FailureOr<llvm::AttrBuilder> attrBuilder =
-              moduleTranslation.convertParameterAttrs(callOp, argAttrs);
-          if (failed(attrBuilder))
-            return failure();
-          call->addParamAttrs(argIdx, *attrBuilder);
-        }
-      }
-    }
-
-    ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
-    if (resAttrsArray && resAttrsArray.size() == 1) {
-      if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
-        FailureOr<llvm::AttrBuilder> attrBuilder =
-            moduleTranslation.convertParameterAttrs(callOp, resAttrs);
-        if (failed(attrBuilder))
-          return failure();
-        call->addRetAttrs(*attrBuilder);
-      }
-    }
+    if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
+      return failure();
 
     if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
       llvm::MemoryEffects memEffects =
@@ -395,6 +403,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           operandsRef.drop_front(), opBundles);
     }
     result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
+    if (failed(
+            convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
+      return failure();
     moduleTranslation.mapBranch(invOp, result);
     // InvokeOp can only have 0 or 1 result
     if (invOp->getNumResults() != 0) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8d779c5083eb601..a55a65f9067add8 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1788,6 +1788,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
       return failure();
 
+    // Handle parameter and result attributes.
+    convertParameterAttributes(invokeInst, invokeOp, builder);
+
     if (!invokeInst->getType()->isVoidTy())
       mapValue(inst, invokeOp.getResults().front());
     else
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 5cee1fc5c1cf443..21e15b5dbf96ddb 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1604,7 +1604,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
 }
 
 FailureOr<llvm::AttrBuilder>
-ModuleTranslation::convertParameterAttrs(CallOp, DictionaryAttr paramAttrs) {
+ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
   llvm::AttrBuilder attrBuilder(llvmModule->getContext());
   auto attrNameToKindMapping = getAttrNameToKindMapping();
 
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index e565772f06b03c3..09a0cd57e2675d8 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -961,3 +961,28 @@ llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
   %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
   llvm.return %0 : i16
 }
+
+// CHECK-LABEL:   llvm.func @test_invoke_arg_attrs(
+// CHECK-SAME:      %[[VAL_0:.*]]: i16) attributes {personality = @__gxx_personality_v0} {
+llvm.func @test_invoke_arg_attrs(%arg0: i16) attributes { personality = @__gxx_personality_v0 } {
+  // CHECK:  llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> ()
+  llvm.invoke @somefunc(%arg0) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> ()
+^bb1:
+  %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// CHECK-LABEL:   llvm.func @test_invoke_arg_attrs_indirect(
+// CHECK-SAME:      %[[VAL_0:.*]]: i16,
+// CHECK-SAME:      %[[VAL_1:.*]]: !llvm.ptr) -> i16 attributes {personality = @__gxx_personality_v0} {
+llvm.func @test_invoke_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 attributes { personality = @__gxx_personality_v0 } {
+  // CHECK: llvm.invoke %[[VAL_1]](%[[VAL_0]]) to ^bb2 unwind ^bb1 : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %0 = llvm.invoke %arg1(%arg0) to ^bb2 unwind ^bb1 : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+^bb1:
+  %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return %0 : i16
+^bb2:
+  llvm.return %0 : i16
+}
diff --git a/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll
new file mode 100644
index 000000000000000..e606961e1a252e5
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll
@@ -0,0 +1,26 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK-LABEL:   llvm.func @test(
+; CHECK-SAME:      %[[VAL_0:.*]]: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} {
+define signext i16 @test(i16 noundef signext %0) personality ptr @__gxx_personality_v0 {
+; CHECK:           %[[VAL_3:.*]] = llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %2 = invoke signext i16 @somefunc(i16 noundef signext %0)
+          to label %7 unwind label %3
+
+3:                                                ; preds = %1
+  %4 = landingpad { ptr, i32 }
+          catch ptr null
+  %5 = extractvalue { ptr, i32 } %4, 0
+  %6 = tail call ptr @__cxa_begin_catch(ptr %5) #2
+  tail call void @__cxa_end_catch()
+  br label %7
+
+7:                                                ; preds = %1, %3
+  %8 = phi i16 [ 0, %3 ], [ %2, %1 ]
+  ret i16 %8
+}
+
+declare noundef signext i16 @somefunc(i16 noundef signext)
+declare i32 @__gxx_personality_v0(...)
+declare ptr @__cxa_begin_catch(ptr)
+declare void @__cxa_end_catch()
diff --git a/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir
new file mode 100644
index 000000000000000..5d6e49bfe09e886
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @test
+llvm.func @test(%arg0: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} {
+  %0 = llvm.mlir.zero : !llvm.ptr
+  %1 = llvm.mlir.constant(0 : i16) : i16
+// CHECK:      invoke signext i16 @somefunc(i16 noundef signext %{{.*}})
+// CHECK-NEXT:   to label %{{.*}} unwind label %{{.*}}
+  %2 = llvm.invoke @somefunc(%arg0) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+^bb1:  // pred: ^bb0
+  %3 = llvm.landingpad (catch %0 : !llvm.ptr) : !llvm.struct<(ptr, i32)>
+  %4 = llvm.extractvalue %3[0] : !llvm.struct<(ptr, i32)>
+  %5 = llvm.call tail @__cxa_begin_catch(%4) : (!llvm.ptr) -> !llvm.ptr
+  llvm.call tail @__cxa_end_catch() : () -> ()
+  llvm.br ^bb3(%1 : i16)
+^bb2:  // pred: ^bb0
+  llvm.br ^bb3(%2 : i16)
+^bb3(%6: i16):  // 2 preds: ^bb1, ^bb2
+  llvm.return %6 : i16
+}
+
+llvm.func @somefunc(i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.noundef, llvm.signext})
+llvm.func @__gxx_personality_v0(...) -> i32
+llvm.func @__cxa_begin_catch(!llvm.ptr) -> !llvm.ptr
+llvm.func @__cxa_end_catch()



More information about the Mlir-commits mailing list