[Mlir-commits] [mlir] [mlir][LLVM] Add operand bundle support (PR #108933)

Sirui Mu llvmlistbot at llvm.org
Mon Sep 16 23:33:39 PDT 2024


https://github.com/Lancern created https://github.com/llvm/llvm-project/pull/108933

This PR adds LLVM [operand bundle](https://llvm.org/docs/LangRef.html#operand-bundles) support to MLIR LLVM dialect. It affects these 3 operations related to making function calls: `llvm.call`, `llvm.invoke`, and `llvm.call_intrinsic`.

This PR adds two new parameters to each of the 3 operations. The first parameter is a variadic operand `bundle_operands` that contains the SSA values for operand bundles. The second parameter is a `#llvm.opbundles` attribute, which is basically a list of `#llvm.opbundle` attribute. A `#llvm.opbundle` attribute provides information about a single operand bundle. It includes a string tag, and a list of integers which index into the `bundle_operands` parameter to indicate the SSA values included in the operand bundle.

>From bbb2bc21bde88059172eea5a51e1b17ec15e2f85 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 17 Sep 2024 14:21:01 +0800
Subject: [PATCH] [mlir][LLVM] Add operand bundle support

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  29 +++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  25 ++-
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp |   4 +
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    |  12 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 186 +++++++++++++++---
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  62 ++++--
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  52 ++++-
 mlir/test/Target/LLVMIR/llvmir.mlir           |  47 +++++
 8 files changed, 363 insertions(+), 54 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2da45eba77655b..67d43e4d2e0657 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1210,4 +1210,33 @@ def WorkgroupAttributionAttr
   let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// OperandBundleAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_OperandBundleAttr : LLVM_Attr<"OperandBundle", "opbundle"> {
+  let summary = "Operand bundle information";
+  let description = [{
+    Provide information about a single operand bundle. Each operand bundle has a
+    string tag together with various number of SSA value uses. The SSA values
+    are specified through indices into the operation's operand bundle operands.
+  }];
+
+  let parameters = (ins "StringAttr":$tag,
+                        OptionalArrayRefParameter<"uint32_t">:$argIndices);
+  let assemblyFormat = [{
+    `<` $tag (`,` $argIndices^)? `>`
+  }];
+}
+
+def LLVM_OperandBundlesAttr : LLVM_Attr<"OperandBundles", "opbundles"> {
+  let summary = "A list of operand bundle attributes";
+  let description = "A list of operand bundle attributes";
+
+  let parameters = (ins ArrayRefParameter<"OperandBundleAttr">:$bundles);
+  let assemblyFormat = [{
+    `<` $bundles `>`
+  }];
+}
+
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d956d7f27f784d..b10a7bbad8eea2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -550,8 +550,10 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
+                   Variadic<LLVM_Type>:$bundle_operands,
                    OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
-                   DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+                   DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+                   OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
   let results = (outs Optional<LLVM_Type>:$result);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -587,7 +589,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
 //===----------------------------------------------------------------------===//
 
 def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
-                    [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+                    [AttrSizedOperandSegments,
+                     DeclareOpInterfaceMethods<FastmathFlagsInterface>,
                      DeclareOpInterfaceMethods<CallOpInterface>,
                      DeclareOpInterfaceMethods<SymbolUserOpInterface>,
                      DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -633,6 +636,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
   dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                   OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
+                  Variadic<LLVM_Type>:$bundle_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
@@ -641,7 +645,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
                   OptionalAttr<UnitAttr>:$convergent,
                   OptionalAttr<UnitAttr>:$no_unwind,
-                  OptionalAttr<UnitAttr>:$will_return
+                  OptionalAttr<UnitAttr>:$will_return,
+                  OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles
                   );
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
@@ -662,6 +667,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
                    CArg<"ValueRange", "{}">:$args)>
   ];
+  let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Returns the callee function type.
@@ -1875,21 +1881,28 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
 
 def LLVM_CallIntrinsicOp
     : LLVM_Op<"call_intrinsic",
-              [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
+              [AttrSizedOperandSegments,
+               DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
   let summary = "Call to an LLVM intrinsic function.";
   let description = [{
     Call the specified llvm intrinsic. If the intrinsic is overloaded, use
     the MLIR function type of this op to determine which intrinsic to call.
   }];
   let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
+                       Variadic<LLVM_Type>:$bundle_operands,
                        DefaultValuedAttr<LLVM_FastmathFlagsAttr,
-                                         "{}">:$fastmathFlags);
+                                         "{}">:$fastmathFlags,
+                       OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
   let results = (outs Optional<LLVM_Type>:$results);
   let llvmBuilder = [{
     return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
   }];
   let assemblyFormat = [{
-    $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
+    $intrin `(` $args `)`
+    ( `bundlearg` `(` $bundle_operands^ `)` )?
+    `:` functional-type($args, $results)
+    ( `,` `tuple` `<` type($bundle_operands)^ `>` )?
+    attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4c2e8682285c52..85ec6031dfee70 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
         promoted, callOp->getAttrs());
 
+    newOp->setAttr(newOp.getOperandSegmentSizesAttrName(),
+                   rewriter.getDenseI32ArrayAttr(
+                       {static_cast<int32_t>(promoted.size()), 0}));
+
     SmallVector<Value, 4> results;
     if (numResults < 2) {
       // If < 2 results, packing did not do anything and we can just return.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ca786316324198..6bb8203da1898a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -837,8 +837,12 @@ class FunctionCallPattern
   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (callOp.getNumResults() == 0) {
-      rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+      auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
           callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
+      newOp->setAttr(
+          newOp.getOperandSegmentSizesAttrName(),
+          rewriter.getDenseI32ArrayAttr(
+              {static_cast<int32_t>(adaptor.getOperands().size()), 0}));
       return success();
     }
 
@@ -846,8 +850,12 @@ class FunctionCallPattern
     auto dstType = typeConverter.convertType(callOp.getType(0));
     if (!dstType)
       return rewriter.notifyMatchFailure(callOp, "type conversion failed");
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+    auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
+    newOp->setAttr(
+        newOp.getOperandSegmentSizesAttrName(),
+        rewriter.getDenseI32ArrayAttr(
+            {static_cast<int32_t>(adaptor.getOperands().size()), 0}));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 205d7494d4378c..bb7df718ccad61 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -949,12 +949,14 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
   assert(callee && "expected non-null callee in direct call builder");
   build(builder, state, results,
-        /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+        /*var_callee_type=*/nullptr, callee, args, /*bundle_operands=*/{},
+        /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -975,11 +977,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), callee, args,
+        /*bundle_operands=*/{},
         /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
         /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr,
-        /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+        /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundles=*/nullptr,
         /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -988,12 +991,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType),
-        /*callee=*/nullptr, args,
+        /*callee=*/nullptr, args, /*bundle_operands=*/{},
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
-        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+        /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
@@ -1001,11 +1004,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
+        /*bundle_operands=*/{},
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
-        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+        /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
 CallInterfaceCallable CallOp::getCallableForCallee() {
@@ -1027,7 +1031,7 @@ void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 }
 
 Operation::operand_range CallOp::getArgOperands() {
-  return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+  return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
 MutableOperandRange CallOp::getArgOperandsMutable() {
@@ -1100,6 +1104,38 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
   return success();
 }
 
+template <typename OpType>
+static LogicalResult verifyOperandBundleOperands(OpType &op) {
+  ValueRange opBundleOperands = op.getBundleOperands();
+  OperandBundlesAttr opBundles = op.getOpBundlesAttr();
+
+  if (!opBundles) {
+    if (!opBundleOperands.empty())
+      return op.emitError("expected operand bundles attribute");
+    return success();
+  }
+
+  DenseSet<uint32_t> seenOperandIdx;
+  for (OperandBundleAttr bundle : opBundles.getBundles()) {
+    for (uint32_t bundleOperandIdx : bundle.getArgIndices()) {
+      if (bundleOperandIdx >= opBundleOperands.size())
+        return op.emitError("operand bundle argument index ")
+               << bundleOperandIdx << " is out of range";
+      seenOperandIdx.insert(bundleOperandIdx);
+    }
+  }
+
+  for (uint32_t idx = 0; idx < opBundleOperands.size(); ++idx) {
+    if (!seenOperandIdx.contains(idx))
+      return op.emitError("operand bundle argument at index ")
+             << idx << " is not included in any operand bundles";
+  }
+
+  return success();
+}
+
+LogicalResult CallOp::verify() { return verifyOperandBundleOperands(*this); }
+
 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(verifyCallOpVarCalleeType(*this)))
     return failure();
@@ -1150,15 +1186,15 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   // Verify that the operand and result types match the callee.
 
   if (!funcType.isVarArg() &&
-      funcType.getNumParams() != (getNumOperands() - isIndirect))
+      funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
     return emitOpError() << "incorrect number of operands ("
-                         << (getNumOperands() - isIndirect)
+                         << (getCalleeOperands().size() - isIndirect)
                          << ") for callee (expecting: "
                          << funcType.getNumParams() << ")";
 
-  if (funcType.getNumParams() > (getNumOperands() - isIndirect))
+  if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
     return emitOpError() << "incorrect number of operands ("
-                         << (getNumOperands() - isIndirect)
+                         << (getCalleeOperands().size() - isIndirect)
                          << ") for varargs callee (expecting at least: "
                          << funcType.getNumParams() << ")";
 
@@ -1208,16 +1244,24 @@ void CallOp::print(OpAsmPrinter &p) {
   else
     p << getOperand(0);
 
-  auto args = getOperands().drop_front(isDirect ? 0 : 1);
+  auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
   // Print the variadic callee type if the call is variadic.
   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
     p << " vararg(" << *varCalleeType << ")";
 
+  // Print the operand bundles, if any.
+  if (!getBundleOperands().empty()) {
+    p << " bundlearg(";
+    p.printOperands(getBundleOperands());
+    p << ")";
+  }
+
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
                           {getCalleeAttrName(), getTailCallKindAttrName(),
-                           getVarCalleeTypeAttrName(), getCConvAttrName()});
+                           getVarCalleeTypeAttrName(), getCConvAttrName(),
+                           getOperandSegmentSizesAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1225,24 +1269,53 @@ void CallOp::print(OpAsmPrinter &p) {
 
   // Reconstruct the function MLIR function type from operand and result types.
   p.printFunctionalType(args.getTypes(), getResultTypes());
+
+  if (!getBundleOperands().empty()) {
+    SmallVector<Type> opBundleArgTypes;
+    opBundleArgTypes.reserve(getBundleOperands().size());
+    for (auto opBundleArg : getBundleOperands())
+      opBundleArgTypes.push_back(opBundleArg.getType());
+
+    p << ", tuple<";
+    llvm::interleaveComma(opBundleArgTypes, p);
+    p << ">";
+  }
 }
 
 /// Parses the type of a call operation and resolves the operands if the parsing
 /// succeeds. Returns failure otherwise.
 static ParseResult parseCallTypeAndResolveOperands(
     OpAsmParser &parser, OperationState &result, bool isDirect,
-    ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+    ArrayRef<OpAsmParser::UnresolvedOperand> operands,
+    ArrayRef<OpAsmParser::UnresolvedOperand> opBundleOperands) {
   SMLoc trailingTypesLoc = parser.getCurrentLocation();
   SmallVector<Type> types;
   if (parser.parseColonTypeList(types))
     return failure();
 
-  if (isDirect && types.size() != 1)
-    return parser.emitError(trailingTypesLoc,
-                            "expected direct call to have 1 trailing type");
-  if (!isDirect && types.size() != 2)
+  if (isDirect && opBundleOperands.empty() && types.size() != 1)
+    return parser.emitError(
+        trailingTypesLoc,
+        "expected direct call without operand bundles to have 1 trailing type");
+  if (isDirect && !opBundleOperands.empty() && types.size() != 2)
+    return parser.emitError(
+        trailingTypesLoc,
+        "expected direct call with operand bundles to have 2 trailing types");
+  if (!isDirect && opBundleOperands.empty() && types.size() != 2)
     return parser.emitError(trailingTypesLoc,
-                            "expected indirect call to have 2 trailing types");
+                            "expected indirect call without operand bundles to "
+                            "have 2 trailing types");
+  if (!isDirect && !opBundleOperands.empty() && types.size() != 3)
+    return parser.emitError(
+        trailingTypesLoc,
+        "expected indirect call with operand bundles to have 3 trailing types");
+
+  TupleType opBundleTypes;
+  if (!opBundleOperands.empty()) {
+    opBundleTypes = llvm::dyn_cast<TupleType>(types.pop_back_val());
+    if (!opBundleTypes)
+      return parser.emitError(trailingTypesLoc, "expected trailing tuple type");
+  }
 
   auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
   if (!funcType)
@@ -1267,6 +1340,12 @@ static ParseResult parseCallTypeAndResolveOperands(
   if (funcType.getNumResults() != 0)
     result.addTypes(funcType.getResults());
 
+  if (!opBundleOperands.empty()) {
+    if (parser.resolveOperands(opBundleOperands, opBundleTypes.getTypes(),
+                               parser.getNameLoc(), result.operands))
+      return failure();
+  }
+
   return success();
 }
 
@@ -1288,7 +1367,9 @@ static ParseResult parseOptionalCallFuncPtr(
 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-callee-type `)` )?
+//                             ( `bundlearg(` ssa-use-list `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
+//                             (`,` `tuple` `<` type-list `>`)?
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
   TypeAttr varCalleeType;
@@ -1333,11 +1414,25 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
+  SmallVector<OpAsmParser::UnresolvedOperand> opBundleOperands;
+  bool hasOpBundles = parser.parseOptionalKeyword("bundlearg").succeeded();
+  if (hasOpBundles &&
+      parser.parseOperandList(opBundleOperands, OpAsmParser::Delimiter::Paren))
+    return failure();
+
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      opBundleOperands))
+    return failure();
+
+  result.addAttribute(CallOp::getOperandSegmentSizeAttr(),
+                      parser.getBuilder().getDenseI32ArrayAttr(
+                          {static_cast<int32_t>(operands.size()),
+                           static_cast<int32_t>(opBundleOperands.size())}));
+  return success();
 }
 
 LLVMFunctionType CallOp::getCalleeFunctionType() {
@@ -1356,7 +1451,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
-        normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+        normalOps, unwindOps, {}, nullptr, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1364,8 +1459,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange normalOps, Block *unwind,
                      ValueRange unwindOps) {
   build(builder, state, tys,
-        /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, normal, unwind);
+        /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, {},
+        nullptr, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1374,7 +1469,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
-        nullptr, nullptr, normal, unwind);
+        {}, nullptr, nullptr, nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1402,7 +1497,7 @@ void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 }
 
 Operation::operand_range InvokeOp::getArgOperands() {
-  return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+  return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
 MutableOperandRange InvokeOp::getArgOperandsMutable() {
@@ -1423,6 +1518,9 @@ LogicalResult InvokeOp::verify() {
     return emitError("first operation in unwind destination should be a "
                      "llvm.landingpad operation");
 
+  if (failed(verifyOperandBundleOperands(*this)))
+    return failure();
+
   return success();
 }
 
@@ -1452,6 +1550,13 @@ void InvokeOp::print(OpAsmPrinter &p) {
   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
     p << " vararg(" << *varCalleeType << ")";
 
+  // Print the operand bundles, if any.
+  if (!getBundleOperands().empty()) {
+    p << " bundlearg(";
+    p.printOperands(getBundleOperands());
+    p << ")";
+  }
+
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {getCalleeAttrName(), getOperandSegmentSizeAttr(),
                            getCConvAttrName(), getVarCalleeTypeAttrName()});
@@ -1461,6 +1566,17 @@ void InvokeOp::print(OpAsmPrinter &p) {
     p << getOperand(0).getType() << ", ";
   p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
                         getResultTypes());
+
+  if (!getBundleOperands().empty()) {
+    SmallVector<Type> opBundleArgTypes;
+    opBundleArgTypes.reserve(getBundleOperands().size());
+    for (auto opBundleArg : getBundleOperands())
+      opBundleArgTypes.push_back(opBundleArg.getType());
+
+    p << ", tuple<";
+    llvm::interleaveComma(opBundleArgTypes, p);
+    p << ">";
+  }
 }
 
 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
@@ -1468,7 +1584,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  ( `vararg(` var-callee-type `)` )?
+//                  ( `bundlearg(` ssa-use-list `)` )?
 //                  attribute-dict? `:` (type `,`)? function-type
+//                  (`,` `tuple` `<` type-list `>`)?
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
@@ -1513,11 +1631,18 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
+  SmallVector<OpAsmParser::UnresolvedOperand> opBundleOperands;
+  bool hasOpBundles = parser.parseOptionalKeyword("bundlearg").succeeded();
+  if (hasOpBundles &&
+      parser.parseOperandList(opBundleOperands, OpAsmParser::Delimiter::Paren))
+    return failure();
+
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      opBundleOperands))
     return failure();
 
   result.addSuccessors({normalDest, unwindDest});
@@ -1528,7 +1653,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
                       builder.getDenseI32ArrayAttr(
                           {static_cast<int32_t>(operands.size()),
                            static_cast<int32_t>(normalOperands.size()),
-                           static_cast<int32_t>(unwindOperands.size())}));
+                           static_cast<int32_t>(unwindOperands.size()),
+                           static_cast<int32_t>(opBundleOperands.size())}));
   return success();
 }
 
@@ -3108,6 +3234,8 @@ OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
 LogicalResult CallIntrinsicOp::verify() {
   if (!getIntrin().starts_with("llvm."))
     return emitOpError() << "intrinsic name must start with 'llvm.'";
+  if (failed(verifyOperandBundleOperands(*this)))
+    return failure();
   return success();
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index d948ff5eaf1769..ab81f24a9b74f6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -102,6 +102,36 @@ getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id,
   return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef);
 }
 
+static llvm::OperandBundleDef
+convertOperandBundle(ValueRange bundleOperands, OperandBundleAttr bundleAttr,
+                     LLVM::ModuleTranslation &moduleTranslation) {
+  std::vector<llvm::Value *> operands;
+  operands.reserve(bundleAttr.getArgIndices().size());
+  for (uint32_t idx : bundleAttr.getArgIndices())
+    operands.push_back(moduleTranslation.lookupValue(bundleOperands[idx]));
+  return llvm::OperandBundleDef(bundleAttr.getTag().str(), std::move(operands));
+}
+
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(ValueRange bundleOperands, OperandBundlesAttr bundleAttrs,
+                      LLVM::ModuleTranslation &moduleTranslation) {
+  SmallVector<llvm::OperandBundleDef> bundles;
+  bundles.reserve(bundleAttrs.getBundles().size());
+  for (OperandBundleAttr bundle : bundleAttrs.getBundles())
+    bundles.push_back(
+        convertOperandBundle(bundleOperands, bundle, moduleTranslation));
+  return bundles;
+}
+
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(ValueRange bundleOperands,
+                      std::optional<OperandBundlesAttr> bundleAttrs,
+                      LLVM::ModuleTranslation &moduleTranslation) {
+  if (!bundleAttrs.has_value())
+    return {};
+  return convertOperandBundles(bundleOperands, *bundleAttrs, moduleTranslation);
+}
+
 /// Builder for LLVM_CallIntrinsicOp
 static LogicalResult
 convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -138,15 +168,15 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   // Check the argument types of the call. If the function is variadic, check
   // the subrange of required arguments.
   if (!fn->getFunctionType()->isVarArg() &&
-      op.getNumOperands() != fn->arg_size()) {
+      op.getArgs().size() != fn->arg_size()) {
     return mlir::emitError(op.getLoc(), "intrinsic call has ")
-           << op.getNumOperands() << " operands but " << op.getIntrinAttr()
+           << op.getArgs().size() << " operands but " << op.getIntrinAttr()
            << " expects " << fn->arg_size();
   }
   if (fn->getFunctionType()->isVarArg() &&
-      op.getNumOperands() < fn->arg_size()) {
+      op.getArgs().size() < fn->arg_size()) {
     return mlir::emitError(op.getLoc(), "intrinsic call has ")
-           << op.getNumOperands() << " operands but variadic "
+           << op.getArgs().size() << " operands but variadic "
            << op.getIntrinAttr() << " expects at least " << fn->arg_size();
   }
   // Check the arguments up to the number the function requires.
@@ -164,8 +194,10 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   FastmathFlagsInterface itf = op;
   builder.setFastMathFlags(getFastmathFlags(itf));
 
-  auto *inst =
-      builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands()));
+  auto *inst = builder.CreateCall(
+      fn, moduleTranslation.lookupValues(op.getArgs()),
+      convertOperandBundles(op.getBundleOperands(), op.getOpBundles(),
+                            moduleTranslation));
   if (op.getNumResults() == 1)
     moduleTranslation.mapValue(op->getResults().front()) = inst;
   return success();
@@ -205,17 +237,20 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
   // itself.  Otherwise, this is an indirect call and the callee is the first
   // operand, look it up as a normal value.
   if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
-    auto operands = moduleTranslation.lookupValues(callOp.getOperands());
+    auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands());
+    SmallVector<llvm::OperandBundleDef> opBundles = convertOperandBundles(
+        callOp.getBundleOperands(), callOp.getOpBundles(), moduleTranslation);
     ArrayRef<llvm::Value *> operandsRef(operands);
     llvm::CallInst *call;
     if (auto attr = callOp.getCalleeAttr()) {
-      call = builder.CreateCall(
-          moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
+      call =
+          builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
+                             operandsRef, opBundles);
     } else {
       llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
           moduleTranslation.convertType(callOp.getCalleeFunctionType()));
       call = builder.CreateCall(calleeType, operandsRef.front(),
-                                operandsRef.drop_front());
+                                operandsRef.drop_front(), opBundles);
     }
     call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
     call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
@@ -312,13 +347,16 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
 
   if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
     auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
+    SmallVector<llvm::OperandBundleDef> opBundles = convertOperandBundles(
+        invOp.getBundleOperands(), invOp.getOpBundles(), moduleTranslation);
     ArrayRef<llvm::Value *> operandsRef(operands);
     llvm::InvokeInst *result;
     if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
       result = builder.CreateInvoke(
           moduleTranslation.lookupFunction(attr.getValue()),
           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
-          moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef);
+          moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
+          opBundles);
     } else {
       llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
           moduleTranslation.convertType(invOp.getCalleeFunctionType()));
@@ -326,7 +364,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           calleeType, operandsRef.front(),
           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
-          operandsRef.drop_front());
+          operandsRef.drop_front(), opBundles);
     }
     result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
     moduleTranslation.mapBranch(invOp, result);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6670e4b186c397..e4227151d371c7 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -225,7 +225,7 @@ func.func @invalid_call() {
 // -----
 
 func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
-  // expected-error at +1 {{expected indirect call to have 2 trailing types}}
+  // expected-error at +1 {{expected indirect call without operand bundles to have 2 trailing types}}
   llvm.call %callee(%arg) : (i8) -> (i8)
   llvm.return
 }
@@ -235,7 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 func.func private @standard_func_callee()
 
 func.func @call_missing_ptr_type(%arg : i8) {
-  // expected-error at +1 {{expected direct call to have 1 trailing type}}
+  // expected-error at +1 {{expected direct call without operand bundles to have 1 trailing type}}
   llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8)
   llvm.return
 }
@@ -286,7 +286,7 @@ func.func @call_non_llvm() {
 
 func.func @call_non_llvm_arg(%arg0 : tensor<*xi32>) {
   // expected-error at +1 {{'llvm.call' op operand #0 must be variadic of LLVM dialect-compatible type}}
-  "llvm.call"(%arg0) : (tensor<*xi32>) -> ()
+  "llvm.call"(%arg0) {operandSegmentSizes = array<i32: 1, 0>} : (tensor<*xi32>) -> ()
   llvm.return
 }
 
@@ -1588,7 +1588,7 @@ llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
   // expected-error at +1 {{missing var_callee_type attribute for vararg call}}
-  "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+  "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array<i32: 1, 0>} : (i32) -> ()
   llvm.return
 }
 
@@ -1598,7 +1598,7 @@ llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
   // expected-error at +1 {{missing var_callee_type attribute for vararg call}}
-  "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+  "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array<i32: 1, 0>} : (i32) -> ()
   llvm.return
 }
 
@@ -1655,3 +1655,45 @@ llvm.func @alwaysinline_noinline() attributes { always_inline, no_inline } {
 llvm.func @optnone_requires_noinline() attributes { optimize_none } {
   llvm.return
 }
+
+// -----
+llvm.func @foo()
+llvm.func @opbundle_no_types() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  // expected-error at +1 {{expected direct call with operand bundles to have 2 trailing types}}
+  llvm.call @foo() bundlearg(%0) {op_bundles = #llvm.opbundles<#llvm.opbundle<"tag", 0>>} : () -> ()
+}
+
+// -----
+llvm.func @foo()
+llvm.func @opbundle_no_attr() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  // expected-error at +1 {{expected operand bundles attribute}}
+  llvm.call @foo() bundlearg(%0) : () -> (), tuple<i32>
+}
+
+// -----
+llvm.func @foo()
+llvm.func @opbundle_types_mismatch() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  // expected-error at +1 {{2 operands present, but expected 1}}
+  llvm.call @foo() bundlearg(%0, %1) : () -> (), tuple<i32>
+}
+
+// -----
+llvm.func @foo()
+llvm.func @opbundle_arg_idx_out_of_range() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  // expected-error at +1 {{operand bundle argument index 1 is out of range}}
+  llvm.call @foo() bundlearg(%0) {op_bundles = #llvm.opbundles<#llvm.opbundle<"tag", 1>>} : () -> (), tuple<i32>
+}
+
+// -----
+llvm.func @foo()
+llvm.func @opbundle_arg_not_used() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  // expected-error at +1 {{operand bundle argument at index 0 is not included in any operand bundles}}
+  llvm.call @foo() bundlearg(%0, %1) {op_bundles = #llvm.opbundles<#llvm.opbundle<"tag", 1>>} : () -> (), tuple<i32, i32>
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7eca1a40373054..6bf9a00f3d3739 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2621,3 +2621,50 @@ llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32:
 llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
 
 // CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}
+
+// -----
+
+llvm.func @foo()
+
+llvm.func @call_with_opbundle() {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  llvm.call @foo() bundlearg(%0, %1) {op_bundles = #llvm.opbundles<#llvm.opbundle<"tag", 0, 1>>} : () -> (), tuple<i32, i32>
+  llvm.return
+}
+
+//      CHECK: define void @call_with_opbundle() {
+// CHECK-NEXT:   call void @foo() [ "tag"(i32 1, i32 2) ]
+// CHECK-NEXT:   ret void
+// CHECK-NEXT: }
+
+llvm.func @__gxx_personality_v0(...) -> i32
+llvm.func @invoke_with_opbundle() attributes { personality = @__gxx_personality_v0 } {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  llvm.invoke @foo() to ^bb2 unwind ^bb1 bundlearg(%0, %1) {op_bundles = #llvm.opbundles<#llvm.opbundle<"tag", 0, 1>>} : () -> (), tuple<i32, i32>
+
+^bb1:
+  %2 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+
+^bb2:
+  llvm.return
+}
+
+//      CHECK: define void @invoke_with_opbundle() personality ptr @__gxx_personality_v0 {
+// CHECK-NEXT:   invoke void @foo() [ "tag"(i32 1, i32 2) ]
+// CHECK-NEXT:           to label %{{.+}} unwind label %{{.+}}
+//      CHECK: }
+
+llvm.func @call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  %1 = llvm.mlir.constant(16 : i32) : i32
+  llvm.call_intrinsic "llvm.assume"(%0) bundlearg(%arg0, %1) : (i1) -> (), tuple<!llvm.ptr, i32> {op_bundles = #llvm.opbundles<#llvm.opbundle<"align", 0, 1>>}
+  llvm.return
+}
+
+//      CHECK: define void @call_intrin_with_opbundle(ptr %0) {
+// CHECK-NEXT:   call void @llvm.assume(i1 true) [ "align"(ptr %0, i32 16) ]
+// CHECK-NEXT:   ret void
+// CHECK-NEXT: }



More information about the Mlir-commits mailing list