[Mlir-commits] [mlir] 99e1308 - [mlir][LLVM] handle argument and result attributes in llvm.call and llvm.invoke (#123177)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 00:39:54 PST 2025


Author: jeanPerier
Date: 2025-02-11T09:39:51+01:00
New Revision: 99e1308c41b24e2422324d68be28e5370196e5d6

URL: https://github.com/llvm/llvm-project/commit/99e1308c41b24e2422324d68be28e5370196e5d6
DIFF: https://github.com/llvm/llvm-project/commit/99e1308c41b24e2422324d68be28e5370196e5d6.diff

LOG: [mlir][LLVM] handle argument and result attributes in llvm.call and llvm.invoke (#123177)

Update llvm.call/llvm.invoke pretty printer/parser and the llvm ir import/export
to deal with the argument and result attributes.

This patch is made on top of PR 123176 that modified the
CallOpInterface and added the argument and result attributes to
llvm.call and llvm.invoke without doing anything with them.

RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107

Added: 
    mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
    mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll
    mlir/test/Target/LLVMIR/call-argument-attributes.mlir
    mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir

Modified: 
    llvm/include/llvm/IR/InstrTypes.h
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 26be02d4b193de6..90fe864d4ae71bf 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
     Attrs = Attrs.addRetAttribute(getContext(), Attr);
   }
 
+  /// Adds attributes to the return value.
+  void addRetAttrs(const AttrBuilder &B) {
+    Attrs = Attrs.addRetAttributes(getContext(), B);
+  }
+
   /// Adds the attribute to the indicated argument
   void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
     assert(ArgNo < arg_size() && "Out of bounds");
@@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
     Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
   }
 
+  /// Adds attributes to the indicated argument
+  void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
+    assert(ArgNo < arg_size() && "Out of bounds");
+    Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
+  }
+
   /// removes the attribute from the list of attributes.
   void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
     Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index d4032c6bc4356b5..4642d58760ca862 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -341,14 +341,18 @@ class ModuleImport {
   FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
   /// Returns the callee name, or an empty symbol if the call is not direct.
   FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
-  /// Converts the parameter attributes attached to `func` and adds them to
-  /// the `funcOp`.
+  /// Converts the parameter and result attributes attached to `func` and adds
+  /// them to the `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
                                   OpBuilder &builder);
   /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
   /// DictionaryAttr for the LLVM dialect.
   DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
                                            OpBuilder &builder);
+  /// Converts the parameter and result attributes attached to `call` and adds
+  /// them to the `callOp`.
+  void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
+                                  OpBuilder &builder);
   /// Converts the attributes attached to `inst` and adds them to the `op`.
   LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
   /// Converts the attributes attached to `inst` and adds them to the `op`.

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 9de1d29fa8ec8fb..e86d576bdb24119 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -235,6 +235,11 @@ class ModuleTranslation {
                             /*recordInsertions=*/false);
   }
 
+  /// Translates parameter attributes of a call and adds them to the returned
+  /// AttrBuilder. Returns failure if any of the translations failed.
+  FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOpInterface callOp,
+                                                     DictionaryAttr paramAttrs);
+
   /// Gets the named metadata in the LLVM IR module being constructed, creating
   /// it if it does not exist.
   llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
@@ -359,8 +364,8 @@ class ModuleTranslation {
   convertDialectAttributes(Operation *op,
                            ArrayRef<llvm::Instruction *> instructions);
 
-  /// Translates parameter attributes and adds them to the returned AttrBuilder.
-  /// Returns failure if any of the translations failed.
+  /// Translates parameter attributes of a function and adds them to the
+  /// returned AttrBuilder. Returns failure if any of the translations failed.
   FailureOr<llvm::AttrBuilder>
   convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2c93b6ed08cf8b6..bfcba40555a7c60 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1335,42 +1335,53 @@ void CallOp::print(OpAsmPrinter &p) {
                            getVarCalleeTypeAttrName(), getCConvAttrName(),
                            getOperandSegmentSizesAttrName(),
                            getOpBundleSizesAttrName(),
-                           getOpBundleTagsAttrName()});
+                           getOpBundleTagsAttrName(), getArgAttrsAttrName(),
+                           getResAttrsAttrName()});
 
   p << " : ";
   if (!isDirect)
     p << getOperand(0).getType() << ", ";
 
-  // Reconstruct the function MLIR function type from operand and result types.
-  p.printFunctionalType(args.getTypes(), getResultTypes());
+  // Reconstruct the MLIR function type from operand and result types.
+  call_interface_impl::printFunctionSignature(
+      p, args.getTypes(), getArgAttrsAttr(),
+      /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
 }
 
 /// Parses the type of a call operation and resolves the operands if the parsing
 /// succeeds. Returns failure otherwise.
 static ParseResult parseCallTypeAndResolveOperands(
     OpAsmParser &parser, OperationState &result, bool isDirect,
-    ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+    ArrayRef<OpAsmParser::UnresolvedOperand> operands,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
   SMLoc trailingTypesLoc = parser.getCurrentLocation();
   SmallVector<Type> types;
-  if (parser.parseColonTypeList(types))
+  if (parser.parseColon())
     return failure();
-
-  if (isDirect && types.size() != 1)
-    return parser.emitError(trailingTypesLoc,
-                            "expected direct call to have 1 trailing type");
-  if (!isDirect && types.size() != 2)
-    return parser.emitError(trailingTypesLoc,
-                            "expected indirect call to have 2 trailing types");
-
-  auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
-  if (!funcType)
+  if (!isDirect) {
+    types.emplace_back();
+    if (parser.parseType(types.back()))
+      return failure();
+    if (parser.parseOptionalComma())
+      return parser.emitError(
+          trailingTypesLoc, "expected indirect call to have 2 trailing types");
+  }
+  SmallVector<Type> argTypes;
+  SmallVector<Type> resTypes;
+  if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
+                                                  resTypes, resultAttrs)) {
+    if (isDirect)
+      return parser.emitError(trailingTypesLoc,
+                              "expected direct call to have 1 trailing types");
     return parser.emitError(trailingTypesLoc,
                             "expected trailing function type");
-  if (funcType.getNumResults() > 1)
+  }
+
+  if (resTypes.size() > 1)
     return parser.emitError(trailingTypesLoc,
                             "expected function with 0 or 1 result");
-  if (funcType.getNumResults() == 1 &&
-      llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
+  if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
     return parser.emitError(trailingTypesLoc,
                             "expected a non-void result type");
 
@@ -1378,12 +1389,12 @@ static ParseResult parseCallTypeAndResolveOperands(
   // indirect calls, while the types list is emtpy for direct calls.
   // Append the function input types to resolve the call operation
   // operands.
-  llvm::append_range(types, funcType.getInputs());
+  llvm::append_range(types, argTypes);
   if (parser.resolveOperands(operands, types, parser.getNameLoc(),
                              result.operands))
     return failure();
-  if (funcType.getNumResults() != 0)
-    result.addTypes(funcType.getResults());
+  if (resTypes.size() != 0)
+    result.addTypes(resTypes);
 
   return success();
 }
@@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  SmallVector<DictionaryAttr> argAttrs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      argAttrs, resultAttrs))
     return failure();
+  call_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, argAttrs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
                               opBundleOperandTypes,
                               getOpBundleSizesAttrName(result.name)))
@@ -1643,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
                           {getCalleeAttrName(), getOperandSegmentSizeAttr(),
                            getCConvAttrName(), getVarCalleeTypeAttrName(),
                            getOpBundleSizesAttrName(),
-                           getOpBundleTagsAttrName()});
+                           getOpBundleTagsAttrName(), getArgAttrsAttrName(),
+                           getResAttrsAttrName()});
 
   p << " : ";
   if (!isDirect)
     p << getOperand(0).getType() << ", ";
-  p.printFunctionalType(
-      llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
-      getResultTypes());
+  call_interface_impl::printFunctionSignature(
+      p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
+      getArgAttrsAttr(),
+      /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
 }
 
 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
@@ -1659,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  ( `vararg(` var-callee-type `)` )?
 //                  ( `[` op-bundles-list `]` )?
-//                  attribute-dict? `:` (type `,`)? function-type
+//                  attribute-dict? `:` (type `,`)?
+//                  function-type-with-argument-attributes
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
@@ -1721,8 +1741,15 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  SmallVector<DictionaryAttr> argAttrs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      argAttrs, resultAttrs))
     return failure();
+  call_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, argAttrs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+
   if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
                               opBundleOperandTypes,
                               getOpBundleSizesAttrName(result.name)))

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 3afea87ca92c12f..4f89ee703ebb92c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -224,6 +224,39 @@ static void convertLinkerOptionsOp(ArrayAttr options,
   linkerMDNode->addOperand(listMDNode);
 }
 
+static LogicalResult
+convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
+                               LLVM::ModuleTranslation &moduleTranslation) {
+  if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
+    for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
+      if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
+          !argAttrs.empty()) {
+        FailureOr<llvm::AttrBuilder> attrBuilder =
+            moduleTranslation.convertParameterAttrs(callOp, argAttrs);
+        if (failed(attrBuilder))
+          return failure();
+        call->addParamAttrs(argIdx, *attrBuilder);
+      }
+    }
+  }
+
+  ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
+  if (resAttrsArray && resAttrsArray.size() > 0) {
+    if (resAttrsArray.size() != 1)
+      return mlir::emitError(callOp.getLoc(),
+                             "llvm.func cannot have multiple results");
+    if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
+        !resAttrs.empty()) {
+      FailureOr<llvm::AttrBuilder> attrBuilder =
+          moduleTranslation.convertParameterAttrs(callOp, resAttrs);
+      if (failed(attrBuilder))
+        return failure();
+      call->addRetAttrs(*attrBuilder);
+    }
+  }
+  return success();
+}
+
 static LogicalResult
 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
@@ -265,6 +298,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     if (callOp.getWillReturnAttr())
       call->addFnAttr(llvm::Attribute::WillReturn);
 
+    if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
+      return failure();
+
     if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
       llvm::MemoryEffects memEffects =
           llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
@@ -372,6 +408,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           operandsRef.drop_front(), opBundles);
     }
     result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
+    if (failed(
+            convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
+      return failure();
     moduleTranslation.mapBranch(invOp, result);
     // InvokeOp can only have 0 or 1 result
     if (invOp->getNumResults() != 0) {

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 5c61ecaeeefed7b..fd0283b856b6b60 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1756,6 +1756,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
       if (failed(convertCallAttributes(callInst, callOp)))
         return failure();
+      // Handle parameter and result attributes.
+      convertParameterAttributes(callInst, callOp, builder);
       return callOp.getOperation();
     }();
 
@@ -1836,6 +1838,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
       return failure();
 
+    // Handle parameter and result attributes.
+    convertParameterAttributes(invokeInst, invokeOp, builder);
+
     if (!invokeInst->getType()->isVoidTy())
       mapValue(inst, invokeOp.getResults().front());
     else
@@ -2199,6 +2204,37 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
+void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
+                                              CallOpInterface callOp,
+                                              OpBuilder &builder) {
+  llvm::AttributeList llvmAttrs = call->getAttributes();
+  SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
+  bool anyArgAttrs = false;
+  for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+    llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
+    if (llvmArgAttrsSet.back().hasAttributes())
+      anyArgAttrs = true;
+  }
+  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+    SmallVector<Attribute> attrs;
+    for (auto &dict : dictAttrs)
+      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+    return builder.getArrayAttr(attrs);
+  };
+  if (anyArgAttrs) {
+    SmallVector<DictionaryAttr> argAttrs;
+    for (auto &llvmArgAttrs : llvmArgAttrsSet)
+      argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
+    callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
+  }
+
+  llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
+  if (!llvmResAttr.hasAttributes())
+    return;
+  DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
+  callOp.setResAttrsAttr(getArrayAttr({resAttrs}));
+}
+
 template <typename Op>
 static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
   op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ed61cb255be8fa7..3da47de6ac24b89 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1617,30 +1617,50 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func,
   }
 }
 
+static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder,
+                                          llvm::Attribute::AttrKind llvmKind,
+                                          NamedAttribute namedAttr,
+                                          ModuleTranslation &moduleTranslation,
+                                          Location loc) {
+  return llvm::TypeSwitch<Attribute, LogicalResult>(namedAttr.getValue())
+      .Case<TypeAttr>([&](auto typeAttr) {
+        attrBuilder.addTypeAttr(
+            llvmKind, moduleTranslation.convertType(typeAttr.getValue()));
+        return success();
+      })
+      .Case<IntegerAttr>([&](auto intAttr) {
+        attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+        return success();
+      })
+      .Case<UnitAttr>([&](auto) {
+        attrBuilder.addAttribute(llvmKind);
+        return success();
+      })
+      .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
+        attrBuilder.addConstantRangeAttr(
+            llvmKind,
+            llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
+        return success();
+      })
+      .Default([loc](auto) {
+        return emitError(loc, "unsupported parameter attribute type");
+      });
+}
+
 FailureOr<llvm::AttrBuilder>
 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
                                          DictionaryAttr paramAttrs) {
   llvm::AttrBuilder attrBuilder(llvmModule->getContext());
   auto attrNameToKindMapping = getAttrNameToKindMapping();
+  Location loc = func.getLoc();
 
   for (auto namedAttr : paramAttrs) {
     auto it = attrNameToKindMapping.find(namedAttr.getName());
     if (it != attrNameToKindMapping.end()) {
       llvm::Attribute::AttrKind llvmKind = it->second;
-
-      llvm::TypeSwitch<Attribute>(namedAttr.getValue())
-          .Case<TypeAttr>([&](auto typeAttr) {
-            attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
-          })
-          .Case<IntegerAttr>([&](auto intAttr) {
-            attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
-          })
-          .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
-          .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
-            attrBuilder.addConstantRangeAttr(
-                llvmKind, llvm::ConstantRange(rangeAttr.getLower(),
-                                              rangeAttr.getUpper()));
-          });
+      if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this,
+                                      loc)))
+        return failure();
     } else if (namedAttr.getNameDialect()) {
       if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
         return failure();
@@ -1650,6 +1670,26 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
   return attrBuilder;
 }
 
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(CallOpInterface callOp,
+                                         DictionaryAttr paramAttrs) {
+  llvm::AttrBuilder attrBuilder(llvmModule->getContext());
+  Location loc = callOp.getLoc();
+  auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+  for (auto namedAttr : paramAttrs) {
+    auto it = attrNameToKindMapping.find(namedAttr.getName());
+    if (it != attrNameToKindMapping.end()) {
+      llvm::Attribute::AttrKind llvmKind = it->second;
+      if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this,
+                                      loc)))
+        return failure();
+    }
+  }
+
+  return attrBuilder;
+}
+
 LogicalResult ModuleTranslation::convertFunctionSignatures() {
   // Declare all functions first because there may be function calls that form a
   // call graph with cycles, or global initializers that reference functions.

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 08bc5a2d65a8aab..0415ab00bdb0500 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 func.func private @standard_func_callee()
 
 func.func @call_missing_ptr_type(%arg : i8) {
+  // expected-error at +2 {{expected '('}}
   // expected-error at +1 {{expected direct call to have 1 trailing type}}
   llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8)
   llvm.return
@@ -251,6 +252,7 @@ func.func @call_non_pointer_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 // -----
 
 func.func @call_non_function_type(%callee : !llvm.ptr, %arg : i8) {
+  // expected-error at +2 {{expected '('}}
   // expected-error at +1 {{expected trailing function type}}
   llvm.call %callee(%arg) : !llvm.ptr, !llvm.func<i8 (i8)>
   llvm.return

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 88660ce598f3c22..09a0cd57e2675d8 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -941,3 +941,48 @@ llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) {
   llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
   llvm.return
 }
+
+llvm.func @somefunc(i32, !llvm.ptr)
+
+// CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
+// CHECK-SAME:    %[[VAL_0:.*]]: i32,
+// CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr)
+llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
+  // CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
+// CHECK-SAME:    %[[VAL_0:.*]]: i16,
+// CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr
+llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
+  // CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  llvm.return %0 : i16
+}
+
+// CHECK-LABEL:   llvm.func @test_invoke_arg_attrs(
+// CHECK-SAME:      %[[VAL_0:.*]]: i16) attributes {personality = @__gxx_personality_v0} {
+llvm.func @test_invoke_arg_attrs(%arg0: i16) attributes { personality = @__gxx_personality_v0 } {
+  // CHECK:  llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> ()
+  llvm.invoke @somefunc(%arg0) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> ()
+^bb1:
+  %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// CHECK-LABEL:   llvm.func @test_invoke_arg_attrs_indirect(
+// CHECK-SAME:      %[[VAL_0:.*]]: i16,
+// CHECK-SAME:      %[[VAL_1:.*]]: !llvm.ptr) -> i16 attributes {personality = @__gxx_personality_v0} {
+llvm.func @test_invoke_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 attributes { personality = @__gxx_personality_v0 } {
+  // CHECK: llvm.invoke %[[VAL_1]](%[[VAL_0]]) to ^bb2 unwind ^bb1 : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %0 = llvm.invoke %arg1(%arg0) to ^bb2 unwind ^bb1 : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+^bb1:
+  %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return %0 : i16
+^bb2:
+  llvm.return %0 : i16
+}

diff  --git a/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
new file mode 100644
index 000000000000000..fa39c79bf085953
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/call-argument-attributes.ll
@@ -0,0 +1,22 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK-LABEL: llvm.func @somefunc(i32, !llvm.ptr)
+declare void @somefunc(i32, ptr)
+
+; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
+; CHECK-SAME:    %[[VAL_0:.*]]: i32,
+; CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr)
+define void @test_call_arg_attrs_direct(i32 %0, ptr %1) {
+  ; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  call void @somefunc(i32 %0, ptr byval(i64) %1)
+  ret void
+}
+
+; CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
+; CHECK-SAME:    %[[VAL_0:.*]]: i16,
+; CHECK-SAME:    %[[VAL_1:.*]]: !llvm.ptr
+define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) {
+  ; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %3 = tail call signext i16 %1(i16 noundef signext %0)
+  ret i16 %3
+}

diff  --git a/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll
new file mode 100644
index 000000000000000..42489832fd18470
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/invoke-argument-attributes.ll
@@ -0,0 +1,26 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK-LABEL:   llvm.func @test(
+; CHECK-SAME:      %[[VAL_0:.*]]: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} {
+define signext i16 @test(i16 noundef signext %0) personality ptr @__gxx_personality_v0 {
+  ; CHECK:           %[[VAL_3:.*]] = llvm.invoke @somefunc(%[[VAL_0]]) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %2 = invoke signext i16 @somefunc(i16 noundef signext %0)
+          to label %7 unwind label %3
+
+3:                                                ; preds = %1
+  %4 = landingpad { ptr, i32 }
+          catch ptr null
+  %5 = extractvalue { ptr, i32 } %4, 0
+  %6 = tail call ptr @__cxa_begin_catch(ptr %5) #2
+  tail call void @__cxa_end_catch()
+  br label %7
+
+7:                                                ; preds = %1, %3
+  %8 = phi i16 [ 0, %3 ], [ %2, %1 ]
+  ret i16 %8
+}
+
+declare noundef signext i16 @somefunc(i16 noundef signext)
+declare i32 @__gxx_personality_v0(...)
+declare ptr @__cxa_begin_catch(ptr)
+declare void @__cxa_end_catch()

diff  --git a/mlir/test/Target/LLVMIR/call-argument-attributes.mlir b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
new file mode 100644
index 000000000000000..b3d286dcda50456
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/call-argument-attributes.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @somefunc(i32, !llvm.ptr)
+
+// CHECK-LABEL: define void @test_call_arg_attrs_direct
+llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
+  // CHECK: call void @somefunc(i32 %{{.*}}, ptr byval(i64) %{{.*}})
+  llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: define i16 @test_call_arg_attrs_indirect
+llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
+  // CHECK: tail call signext i16 %{{.*}}(i16 noundef signext %{{.*}})
+  %0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  llvm.return %0 : i16
+}

diff  --git a/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir
new file mode 100644
index 000000000000000..ea8ed1d416435d3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/invoke-argument-attributes.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @test
+llvm.func @test(%arg0: i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext}) attributes {personality = @__gxx_personality_v0} {
+  %0 = llvm.mlir.zero : !llvm.ptr
+  %1 = llvm.mlir.constant(0 : i16) : i16
+  // CHECK:      invoke signext i16 @somefunc(i16 noundef signext %{{.*}})
+  // CHECK-NEXT:   to label %{{.*}} unwind label %{{.*}}
+  %2 = llvm.invoke @somefunc(%arg0) to ^bb2 unwind ^bb1 : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+^bb1:  // pred: ^bb0
+  %3 = llvm.landingpad (catch %0 : !llvm.ptr) : !llvm.struct<(ptr, i32)>
+  %4 = llvm.extractvalue %3[0] : !llvm.struct<(ptr, i32)>
+  %5 = llvm.call tail @__cxa_begin_catch(%4) : (!llvm.ptr) -> !llvm.ptr
+  llvm.call tail @__cxa_end_catch() : () -> ()
+  llvm.br ^bb3(%1 : i16)
+^bb2:  // pred: ^bb0
+  llvm.br ^bb3(%2 : i16)
+^bb3(%6: i16):  // 2 preds: ^bb1, ^bb2
+  llvm.return %6 : i16
+}
+
+llvm.func @somefunc(i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.noundef, llvm.signext})
+llvm.func @__gxx_personality_v0(...) -> i32
+llvm.func @__cxa_begin_catch(!llvm.ptr) -> !llvm.ptr
+llvm.func @__cxa_end_catch()


        


More information about the Mlir-commits mailing list