[Mlir-commits] [mlir] [mlir][LLVM] Add operand bundle support (PR #108933)

Sirui Mu llvmlistbot at llvm.org
Tue Sep 24 06:58:50 PDT 2024


https://github.com/Lancern updated https://github.com/llvm/llvm-project/pull/108933

>From 7a417985258517a2b1a1a860e78c1599422cd87b Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Mon, 23 Sep 2024 23:19:47 +0800
Subject: [PATCH 1/7] [mlir][LLVM] Add operand bundle support

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  44 ++-
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp |   4 +
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    |  10 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 251 ++++++++++++++++--
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  68 ++++-
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  18 +-
 mlir/test/Target/LLVMIR/llvmir.mlir           |  49 ++++
 7 files changed, 401 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d956d7f27f784d..4ede00d4fa4320 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -551,7 +551,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
                    OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
-                   DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+                   DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+                   VariadicOfVariadic<LLVM_Type,
+                                      "op_bundle_sizes">:$op_bundle_operands,
+                   DenseI32ArrayAttr:$op_bundle_sizes,
+                   DefaultValuedProperty<
+                     ArrayProperty<StringProperty, "operand bundle tags">,
+                     "ArrayRef<std::string>{}",
+                     "SmallVector<std::string>{}"
+                   >:$op_bundle_tags);
   let results = (outs Optional<LLVM_Type>:$result);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -587,7 +595,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
 //===----------------------------------------------------------------------===//
 
 def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
-                    [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+                    [AttrSizedOperandSegments,
+                     DeclareOpInterfaceMethods<FastmathFlagsInterface>,
                      DeclareOpInterfaceMethods<CallOpInterface>,
                      DeclareOpInterfaceMethods<SymbolUserOpInterface>,
                      DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -641,8 +650,15 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
                   OptionalAttr<UnitAttr>:$convergent,
                   OptionalAttr<UnitAttr>:$no_unwind,
-                  OptionalAttr<UnitAttr>:$will_return
-                  );
+                  OptionalAttr<UnitAttr>:$will_return,
+                  VariadicOfVariadic<LLVM_Type,
+                                     "op_bundle_sizes">:$op_bundle_operands,
+                  DenseI32ArrayAttr:$op_bundle_sizes,
+                  DefaultValuedProperty<
+                    ArrayProperty<StringProperty, "operand bundle tags">,
+                    "ArrayRef<std::string>{}",
+                    "SmallVector<std::string>{}"
+                  >:$op_bundle_tags);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
@@ -662,6 +678,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
                    CArg<"ValueRange", "{}">:$args)>
   ];
+  let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Returns the callee function type.
@@ -1875,7 +1892,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
 
 def LLVM_CallIntrinsicOp
     : LLVM_Op<"call_intrinsic",
-              [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
+              [AttrSizedOperandSegments,
+               DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
   let summary = "Call to an LLVM intrinsic function.";
   let description = [{
     Call the specified llvm intrinsic. If the intrinsic is overloaded, use
@@ -1883,13 +1901,25 @@ def LLVM_CallIntrinsicOp
   }];
   let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
                        DefaultValuedAttr<LLVM_FastmathFlagsAttr,
-                                         "{}">:$fastmathFlags);
+                                         "{}">:$fastmathFlags,
+                       VariadicOfVariadic<LLVM_Type,
+                                          "op_bundle_sizes">:$op_bundle_operands,
+                       DenseI32ArrayAttr:$op_bundle_sizes,
+                       DefaultValuedProperty<
+                         ArrayProperty<StringProperty, "operand bundle tags">,
+                         "ArrayRef<std::string>{}",
+                         "SmallVector<std::string>{}"
+                       >:$op_bundle_tags);
   let results = (outs Optional<LLVM_Type>:$results);
   let llvmBuilder = [{
     return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
   }];
   let assemblyFormat = [{
-    $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
+    $intrin `(` $args `)`
+    ( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
+                        $op_bundle_tags)^ )?
+    `:` functional-type($args, $results)
+    attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4c2e8682285c52..2cc77e8fd41b9a 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
         promoted, callOp->getAttrs());
 
+    newOp.getProperties().operandSegmentSizes = {
+        static_cast<int32_t>(promoted.size()), 0};
+    newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
+
     SmallVector<Value, 4> results;
     if (numResults < 2) {
       // If < 2 results, packing did not do anything and we can just return.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ca786316324198..6ae607f75adbd5 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -837,8 +837,11 @@ class FunctionCallPattern
   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (callOp.getNumResults() == 0) {
-      rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+      auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
           callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
+      newOp.getProperties().operandSegmentSizes = {
+          static_cast<int32_t>(adaptor.getOperands().size()), 0};
+      newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
       return success();
     }
 
@@ -846,8 +849,11 @@ class FunctionCallPattern
     auto dstType = typeConverter.convertType(callOp.getType(0));
     if (!dstType)
       return rewriter.notifyMatchFailure(callOp, "type conversion failed");
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+    auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
+    newOp.getProperties().operandSegmentSizes = {
+        static_cast<int32_t>(adaptor.getOperands().size()), 0};
+    newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 205d7494d4378c..837e0e41800d81 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -220,6 +220,88 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
   return static_cast<RetTy>(index);
 }
 
+//===----------------------------------------------------------------------===//
+// Operand bundle helpers.
+//===----------------------------------------------------------------------===//
+
+static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
+                             TypeRange operandTypes, StringRef tag) {
+  p.printString(tag);
+  p << "(";
+  p.printOperands(operands);
+  p << " : ";
+  llvm::interleaveComma(operandTypes, p);
+  p << ")";
+}
+
+static void printOpBundles(OpAsmPrinter &p, Operation *op,
+                           OperandRangeRange opBundleOperands,
+                           TypeRangeRange opBundleOperandTypes,
+                           ArrayRef<std::string> opBundleTags) {
+  p << "[";
+  llvm::interleaveComma(
+      llvm::zip(opBundleOperands, opBundleOperandTypes, opBundleTags), p,
+      [&p](auto bundle) {
+        printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
+                         std::get<2>(bundle));
+      });
+  p << "]";
+}
+
+static ParseResult parseOneOpBundle(
+    OpAsmParser &p,
+    SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
+    SmallVector<SmallVector<Type>> &opBundleOperandTypes,
+    SmallVector<std::string> &opBundleTags) {
+  auto currentParserLoc = p.getCurrentLocation();
+  SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  SmallVector<Type> types;
+  std::string tag;
+
+  if (p.parseString(&tag))
+    return p.emitError(currentParserLoc, "expect operand bundle tag");
+
+  if (p.parseLParen())
+    return failure();
+
+  if (p.parseOperandList(operands))
+    return failure();
+  if (p.parseColon())
+    return failure();
+  if (p.parseTypeList(types))
+    return failure();
+
+  if (p.parseRParen())
+    return failure();
+
+  opBundleOperands.push_back(std::move(operands));
+  opBundleOperandTypes.push_back(std::move(types));
+  opBundleTags.push_back(std::move(tag));
+
+  return success();
+}
+
+static std::optional<ParseResult> parseOpBundles(
+    OpAsmParser &p,
+    SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
+    SmallVector<SmallVector<Type>> &opBundleOperandTypes,
+    SmallVector<std::string> &opBundleTags) {
+  if (p.parseOptionalLSquare())
+    return std::nullopt;
+
+  auto bundleParser = [&] {
+    return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
+                            opBundleTags);
+  };
+  if (p.parseCommaSeparatedList(bundleParser))
+    return failure();
+
+  if (p.parseRSquare())
+    return failure();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Printing, parsing, folding and builder for LLVM::CmpOp.
 //===----------------------------------------------------------------------===//
@@ -954,6 +1036,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -980,6 +1063,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr,
         /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
         /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -992,6 +1076,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1004,6 +1089,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1027,7 +1113,7 @@ void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 }
 
 Operation::operand_range CallOp::getArgOperands() {
-  return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+  return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
 MutableOperandRange CallOp::getArgOperandsMutable() {
@@ -1100,6 +1186,27 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
   return success();
 }
 
+template <typename OpType>
+static LogicalResult verifyOperandBundles(OpType &op) {
+  OperandRangeRange opBundleOperands = op.getOpBundleOperands();
+  std::optional<ArrayRef<std::string>> opBundleTags = op.getOpBundleTags();
+
+  if (!opBundleTags.has_value()) {
+    if (!opBundleOperands.empty())
+      return op.emitError("expected operand bundle tags");
+    return success();
+  }
+
+  if (opBundleTags->size() != opBundleOperands.size())
+    return op.emitError("expected ")
+           << opBundleOperands.size()
+           << " operand bundle tags, but actually got " << opBundleTags->size();
+
+  return success();
+}
+
+LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
+
 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(verifyCallOpVarCalleeType(*this)))
     return failure();
@@ -1150,15 +1257,15 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   // Verify that the operand and result types match the callee.
 
   if (!funcType.isVarArg() &&
-      funcType.getNumParams() != (getNumOperands() - isIndirect))
+      funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
     return emitOpError() << "incorrect number of operands ("
-                         << (getNumOperands() - isIndirect)
+                         << (getCalleeOperands().size() - isIndirect)
                          << ") for callee (expecting: "
                          << funcType.getNumParams() << ")";
 
-  if (funcType.getNumParams() > (getNumOperands() - isIndirect))
+  if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
     return emitOpError() << "incorrect number of operands ("
-                         << (getNumOperands() - isIndirect)
+                         << (getCalleeOperands().size() - isIndirect)
                          << ") for varargs callee (expecting at least: "
                          << funcType.getNumParams() << ")";
 
@@ -1208,16 +1315,24 @@ void CallOp::print(OpAsmPrinter &p) {
   else
     p << getOperand(0);
 
-  auto args = getOperands().drop_front(isDirect ? 0 : 1);
+  auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
   // Print the variadic callee type if the call is variadic.
   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
     p << " vararg(" << *varCalleeType << ")";
 
+  if (!getOpBundleOperands().empty()) {
+    p << " ";
+    printOpBundles(p, *this, getOpBundleOperands(),
+                   getOpBundleOperands().getTypes(), getOpBundleTags());
+  }
+
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
                           {getCalleeAttrName(), getTailCallKindAttrName(),
-                           getVarCalleeTypeAttrName(), getCConvAttrName()});
+                           getVarCalleeTypeAttrName(), getCConvAttrName(),
+                           getOperandSegmentSizesAttrName(),
+                           getOpBundleSizesAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1285,14 +1400,53 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
+static ParseResult resolveOpBundleOperands(
+    OpAsmParser &parser, SMLoc loc, OperationState &state,
+    ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands,
+    ArrayRef<SmallVector<Type>> opBundleOperandTypes,
+    StringAttr opBundleSizesAttrName) {
+  assert(opBundleOperands.size() == opBundleOperandTypes.size() &&
+         "operand bundle operand groups and type groups should match");
+
+  unsigned opBundleIndex = 0;
+  for (const auto &[operands, types] :
+       llvm::zip(opBundleOperands, opBundleOperandTypes)) {
+    if (operands.size() != types.size())
+      return parser.emitError(loc, "expected ")
+             << operands.size()
+             << " types for operand bundle operands for operand bundle #"
+             << opBundleIndex << ", but actually got " << types.size();
+    if (parser.resolveOperands(operands, types, loc, state.operands))
+      return failure();
+  }
+
+  SmallVector<int32_t> opBundleSizes;
+  opBundleSizes.reserve(opBundleOperands.size());
+  for (const auto &operands : opBundleOperands) {
+    opBundleSizes.push_back(operands.size());
+  }
+
+  state.addAttribute(
+      opBundleSizesAttrName,
+      DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
+
+  return success();
+}
+
 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-callee-type `)` )?
+//                             ( `bundlearg(` ssa-use-list-list `)` )?
+//                             ( `bundletags(` str-elements-attr `) )
 //                             attribute-dict? `:` (type `,`)? function-type
+//                             (`,` `bundletype(` type-list-list `)`)?
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
   TypeAttr varCalleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
+  SmallVector<SmallVector<Type>> opBundleOperandTypes;
+  SmallVector<std::string> opBundleTags;
 
   // Default to C Calling Convention if no keyword is provided.
   result.addAttribute(
@@ -1333,11 +1487,35 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
+  auto opBundlesLoc = parser.getCurrentLocation();
+  if (auto result = parseOpBundles(parser, opBundleOperands,
+                                   opBundleOperandTypes, opBundleTags);
+      result.has_value() && failed(*result))
+    return failure();
+  if (!opBundleTags.empty())
+    result.getOrAddProperties<CallOp::Properties>().op_bundle_tags =
+        std::move(opBundleTags);
+
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+    return failure();
+  if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
+                              opBundleOperandTypes,
+                              getOpBundleSizesAttrName(result.name)))
+    return failure();
+
+  int32_t numOpBundleOperands = 0;
+  for (const auto &operands : opBundleOperands)
+    numOpBundleOperands += operands.size();
+
+  result.addAttribute(
+      CallOp::getOperandSegmentSizeAttr(),
+      parser.getBuilder().getDenseI32ArrayAttr(
+          {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
+  return success();
 }
 
 LLVMFunctionType CallOp::getCalleeFunctionType() {
@@ -1356,7 +1534,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
-        normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+        normalOps, unwindOps, nullptr, nullptr, {}, std::nullopt, normal,
+        unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1365,7 +1544,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange unwindOps) {
   build(builder, state, tys,
         /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, normal, unwind);
+        nullptr, {}, std::nullopt, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1374,7 +1553,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
-        nullptr, nullptr, normal, unwind);
+        nullptr, nullptr, {}, std::nullopt, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1402,7 +1581,7 @@ void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 }
 
 Operation::operand_range InvokeOp::getArgOperands() {
-  return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+  return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
 MutableOperandRange InvokeOp::getArgOperandsMutable() {
@@ -1423,6 +1602,9 @@ LogicalResult InvokeOp::verify() {
     return emitError("first operation in unwind destination should be a "
                      "llvm.landingpad operation");
 
+  if (failed(verifyOperandBundles(*this)))
+    return failure();
+
   return success();
 }
 
@@ -1452,9 +1634,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
     p << " vararg(" << *varCalleeType << ")";
 
+  if (!getOpBundleOperands().empty()) {
+    p << " ";
+    printOpBundles(p, *this, getOpBundleOperands(),
+                   getOpBundleOperands().getTypes(), getOpBundleTags());
+  }
+
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {getCalleeAttrName(), getOperandSegmentSizeAttr(),
-                           getCConvAttrName(), getVarCalleeTypeAttrName()});
+                           getCConvAttrName(), getVarCalleeTypeAttrName(),
+                           getOpBundleSizesAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1468,11 +1657,17 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  ( `vararg(` var-callee-type `)` )?
+//                  ( `bundlearg(` ssa-use-list-list `)` )?
+//                  ( `bundletags(` str-elements-attr `) )
 //                  attribute-dict? `:` (type `,`)? function-type
+//                  (`,` `bundletype(` type-list-list `)`)?
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
   TypeAttr varCalleeType;
+  SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
+  SmallVector<SmallVector<Type>> opBundleOperandTypes;
+  SmallVector<std::string> opBundleTags;
   Block *normalDest, *unwindDest;
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
@@ -1513,22 +1708,40 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
+  auto opBundlesLoc = parser.getCurrentLocation();
+  if (auto result = parseOpBundles(parser, opBundleOperands,
+                                   opBundleOperandTypes, opBundleTags);
+      result.has_value() && failed(*result))
+    return failure();
+  if (!opBundleTags.empty())
+    result.getOrAddProperties<InvokeOp::Properties>().op_bundle_tags =
+        std::move(opBundleTags);
+
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
   if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
     return failure();
+  if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
+                              opBundleOperandTypes,
+                              getOpBundleSizesAttrName(result.name)))
+    return failure();
 
   result.addSuccessors({normalDest, unwindDest});
   result.addOperands(normalOperands);
   result.addOperands(unwindOperands);
 
-  result.addAttribute(InvokeOp::getOperandSegmentSizeAttr(),
-                      builder.getDenseI32ArrayAttr(
-                          {static_cast<int32_t>(operands.size()),
-                           static_cast<int32_t>(normalOperands.size()),
-                           static_cast<int32_t>(unwindOperands.size())}));
+  int32_t numOpBundleOperands = 0;
+  for (const auto &operands : opBundleOperands)
+    numOpBundleOperands += operands.size();
+
+  result.addAttribute(
+      InvokeOp::getOperandSegmentSizeAttr(),
+      builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
+                                    static_cast<int32_t>(normalOperands.size()),
+                                    static_cast<int32_t>(unwindOperands.size()),
+                                    numOpBundleOperands}));
   return success();
 }
 
@@ -3108,6 +3321,8 @@ OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
 LogicalResult CallIntrinsicOp::verify() {
   if (!getIntrin().starts_with("llvm."))
     return emitOpError() << "intrinsic name must start with 'llvm.'";
+  if (failed(verifyOperandBundles(*this)))
+    return failure();
   return success();
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index d948ff5eaf1769..53ca302518a90e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -102,6 +102,40 @@ getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id,
   return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef);
 }
 
+static llvm::OperandBundleDef
+convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
+                     LLVM::ModuleTranslation &moduleTranslation) {
+  std::vector<llvm::Value *> operands;
+  operands.reserve(bundleOperands.size());
+  for (auto bundleArg : bundleOperands)
+    operands.push_back(moduleTranslation.lookupValue(bundleArg));
+  return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
+}
+
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(OperandRangeRange bundleOperands,
+                      ArrayRef<std::string> bundleTags,
+                      LLVM::ModuleTranslation &moduleTranslation) {
+  assert(bundleOperands.size() == bundleTags.size() &&
+         "operand bundles and tags do not match");
+
+  SmallVector<llvm::OperandBundleDef> bundles;
+  bundles.reserve(bundleOperands.size());
+
+  for (auto [operands, tag] : llvm::zip(bundleOperands, bundleTags))
+    bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
+  return bundles;
+}
+
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(OperandRangeRange bundleOperands,
+                      std::optional<ArrayRef<std::string>> bundleTags,
+                      LLVM::ModuleTranslation &moduleTranslation) {
+  if (!bundleTags.has_value())
+    bundleTags.emplace();
+  return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
+}
+
 /// Builder for LLVM_CallIntrinsicOp
 static LogicalResult
 convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -138,15 +172,15 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   // Check the argument types of the call. If the function is variadic, check
   // the subrange of required arguments.
   if (!fn->getFunctionType()->isVarArg() &&
-      op.getNumOperands() != fn->arg_size()) {
+      op.getArgs().size() != fn->arg_size()) {
     return mlir::emitError(op.getLoc(), "intrinsic call has ")
-           << op.getNumOperands() << " operands but " << op.getIntrinAttr()
+           << op.getArgs().size() << " operands but " << op.getIntrinAttr()
            << " expects " << fn->arg_size();
   }
   if (fn->getFunctionType()->isVarArg() &&
-      op.getNumOperands() < fn->arg_size()) {
+      op.getArgs().size() < fn->arg_size()) {
     return mlir::emitError(op.getLoc(), "intrinsic call has ")
-           << op.getNumOperands() << " operands but variadic "
+           << op.getArgs().size() << " operands but variadic "
            << op.getIntrinAttr() << " expects at least " << fn->arg_size();
   }
   // Check the arguments up to the number the function requires.
@@ -164,8 +198,10 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   FastmathFlagsInterface itf = op;
   builder.setFastMathFlags(getFastmathFlags(itf));
 
-  auto *inst =
-      builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands()));
+  auto *inst = builder.CreateCall(
+      fn, moduleTranslation.lookupValues(op.getArgs()),
+      convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
+                            moduleTranslation));
   if (op.getNumResults() == 1)
     moduleTranslation.mapValue(op->getResults().front()) = inst;
   return success();
@@ -205,17 +241,21 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
   // itself.  Otherwise, this is an indirect call and the callee is the first
   // operand, look it up as a normal value.
   if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
-    auto operands = moduleTranslation.lookupValues(callOp.getOperands());
+    auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands());
+    SmallVector<llvm::OperandBundleDef> opBundles =
+        convertOperandBundles(callOp.getOpBundleOperands(),
+                              callOp.getOpBundleTags(), moduleTranslation);
     ArrayRef<llvm::Value *> operandsRef(operands);
     llvm::CallInst *call;
     if (auto attr = callOp.getCalleeAttr()) {
-      call = builder.CreateCall(
-          moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
+      call =
+          builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
+                             operandsRef, opBundles);
     } else {
       llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
           moduleTranslation.convertType(callOp.getCalleeFunctionType()));
       call = builder.CreateCall(calleeType, operandsRef.front(),
-                                operandsRef.drop_front());
+                                operandsRef.drop_front(), opBundles);
     }
     call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
     call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
@@ -312,13 +352,17 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
 
   if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
     auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
+    SmallVector<llvm::OperandBundleDef> opBundles =
+        convertOperandBundles(invOp.getOpBundleOperands(),
+                              invOp.getOpBundleTags(), moduleTranslation);
     ArrayRef<llvm::Value *> operandsRef(operands);
     llvm::InvokeInst *result;
     if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
       result = builder.CreateInvoke(
           moduleTranslation.lookupFunction(attr.getValue()),
           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
-          moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef);
+          moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
+          opBundles);
     } else {
       llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
           moduleTranslation.convertType(invOp.getCalleeFunctionType()));
@@ -326,7 +370,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           calleeType, operandsRef.front(),
           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
-          operandsRef.drop_front());
+          operandsRef.drop_front(), opBundles);
     }
     result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
     moduleTranslation.mapBranch(invOp, result);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6670e4b186c397..1121691133108f 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -218,7 +218,7 @@ func.func @store_unaligned_atomic(%val : f32, %ptr : !llvm.ptr) {
 
 func.func @invalid_call() {
   // expected-error at +1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}}
-  "llvm.call"() : () -> ()
+  "llvm.call"() {op_bundle_sizes = array<i32>} : () -> ()
   llvm.return
 }
 
@@ -286,7 +286,7 @@ func.func @call_non_llvm() {
 
 func.func @call_non_llvm_arg(%arg0 : tensor<*xi32>) {
   // expected-error at +1 {{'llvm.call' op operand #0 must be variadic of LLVM dialect-compatible type}}
-  "llvm.call"(%arg0) : (tensor<*xi32>) -> ()
+  "llvm.call"(%arg0) {operandSegmentSizes = array<i32: 1, 0>, op_bundle_sizes = array<i32>} : (tensor<*xi32>) -> ()
   llvm.return
 }
 
@@ -1588,7 +1588,7 @@ llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
   // expected-error at +1 {{missing var_callee_type attribute for vararg call}}
-  "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+  "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array<i32: 1, 0>, op_bundle_sizes = array<i32>} : (i32) -> ()
   llvm.return
 }
 
@@ -1598,7 +1598,7 @@ llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
   // expected-error at +1 {{missing var_callee_type attribute for vararg call}}
-  "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+  "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array<i32: 1, 0>, op_bundle_sizes = array<i32>} : (i32) -> ()
   llvm.return
 }
 
@@ -1655,3 +1655,13 @@ llvm.func @alwaysinline_noinline() attributes { always_inline, no_inline } {
 llvm.func @optnone_requires_noinline() attributes { optimize_none } {
   llvm.return
 }
+
+// -----
+
+llvm.func @foo()
+llvm.func @wrong_number_of_bundle_types() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  // expected-error at +1 {{expected 1 types for operand bundle operands for operand bundle #0, but actually got 2}}
+  llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () bundletype((i32, i32))
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7eca1a40373054..69c16fb369f3bd 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2621,3 +2621,52 @@ llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32:
 llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
 
 // CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}
+
+// -----
+
+llvm.func @foo()
+
+llvm.func @call_with_opbundle() {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  %2 = llvm.mlir.constant(3 : i32) : i32
+  llvm.call @foo() ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+  llvm.return
+}
+
+//      CHECK: define void @call_with_opbundle() {
+// CHECK-NEXT:   call void @foo() [ "tag1"(i32 1, i32 2), "tag2"(i32 3) ]
+// CHECK-NEXT:   ret void
+// CHECK-NEXT: }
+
+llvm.func @__gxx_personality_v0(...) -> i32
+llvm.func @invoke_with_opbundle() attributes { personality = @__gxx_personality_v0 } {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  %2 = llvm.mlir.constant(3 : i32) : i32
+  llvm.invoke @foo() to ^bb2 unwind ^bb1 ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+
+^bb1:
+  %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+
+^bb2:
+  llvm.return
+}
+
+//      CHECK: define void @invoke_with_opbundle() personality ptr @__gxx_personality_v0 {
+// CHECK-NEXT:   invoke void @foo() [ "tag1"(i32 1, i32 2), "tag2"(i32 3) ]
+// CHECK-NEXT:           to label %{{.+}} unwind label %{{.+}}
+//      CHECK: }
+
+llvm.func @call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  %1 = llvm.mlir.constant(16 : i32) : i32
+  llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> ()
+  llvm.return
+}
+
+//      CHECK: define void @call_intrin_with_opbundle(ptr %0) {
+// CHECK-NEXT:   call void @llvm.assume(i1 true) [ "align"(ptr %0, i32 16) ]
+// CHECK-NEXT:   ret void
+// CHECK-NEXT: }

>From ec456b7b63ade66b0a206a1281e4b55264fd148f Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 20:28:32 +0800
Subject: [PATCH 2/7] resolve nits before adding more tests

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 46 +++++++------------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  4 +-
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  2 +-
 3 files changed, 19 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 837e0e41800d81..e7550b653d27be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -253,7 +253,7 @@ static ParseResult parseOneOpBundle(
     SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
     SmallVector<SmallVector<Type>> &opBundleOperandTypes,
     SmallVector<std::string> &opBundleTags) {
-  auto currentParserLoc = p.getCurrentLocation();
+  SMLoc currentParserLoc = p.getCurrentLocation();
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
   SmallVector<Type> types;
   std::string tag;
@@ -1189,18 +1189,12 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
 template <typename OpType>
 static LogicalResult verifyOperandBundles(OpType &op) {
   OperandRangeRange opBundleOperands = op.getOpBundleOperands();
-  std::optional<ArrayRef<std::string>> opBundleTags = op.getOpBundleTags();
+  ArrayRef<std::string> opBundleTags = op.getOpBundleTags();
 
-  if (!opBundleTags.has_value()) {
-    if (!opBundleOperands.empty())
-      return op.emitError("expected operand bundle tags");
-    return success();
-  }
-
-  if (opBundleTags->size() != opBundleOperands.size())
+  if (opBundleTags.size() != opBundleOperands.size())
     return op.emitError("expected ")
            << opBundleOperands.size()
-           << " operand bundle tags, but actually got " << opBundleTags->size();
+           << " operand bundle tags, but actually got " << opBundleTags.size();
 
   return success();
 }
@@ -1405,12 +1399,9 @@ static ParseResult resolveOpBundleOperands(
     ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands,
     ArrayRef<SmallVector<Type>> opBundleOperandTypes,
     StringAttr opBundleSizesAttrName) {
-  assert(opBundleOperands.size() == opBundleOperandTypes.size() &&
-         "operand bundle operand groups and type groups should match");
-
   unsigned opBundleIndex = 0;
   for (const auto &[operands, types] :
-       llvm::zip(opBundleOperands, opBundleOperandTypes)) {
+       llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) {
     if (operands.size() != types.size())
       return parser.emitError(loc, "expected ")
              << operands.size()
@@ -1422,9 +1413,8 @@ static ParseResult resolveOpBundleOperands(
 
   SmallVector<int32_t> opBundleSizes;
   opBundleSizes.reserve(opBundleOperands.size());
-  for (const auto &operands : opBundleOperands) {
+  for (const auto &operands : opBundleOperands)
     opBundleSizes.push_back(operands.size());
-  }
 
   state.addAttribute(
       opBundleSizesAttrName,
@@ -1436,10 +1426,8 @@ static ParseResult resolveOpBundleOperands(
 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-callee-type `)` )?
-//                             ( `bundlearg(` ssa-use-list-list `)` )?
-//                             ( `bundletags(` str-elements-attr `) )
+//                             ( `[` op-bundles-list `]` )?
 //                             attribute-dict? `:` (type `,`)? function-type
-//                             (`,` `bundletype(` type-list-list `)`)?
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
   TypeAttr varCalleeType;
@@ -1487,10 +1475,10 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
-  auto opBundlesLoc = parser.getCurrentLocation();
-  if (auto result = parseOpBundles(parser, opBundleOperands,
-                                   opBundleOperandTypes, opBundleTags);
-      result.has_value() && failed(*result))
+  SMLoc opBundlesLoc = parser.getCurrentLocation();
+  if (std::optional<ParseResult> result = parseOpBundles(
+          parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
+      result && failed(*result))
     return failure();
   if (!opBundleTags.empty())
     result.getOrAddProperties<CallOp::Properties>().op_bundle_tags =
@@ -1657,10 +1645,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  ( `vararg(` var-callee-type `)` )?
-//                  ( `bundlearg(` ssa-use-list-list `)` )?
-//                  ( `bundletags(` str-elements-attr `) )
+//                  ( `[` op-bundles-list `]` )?
 //                  attribute-dict? `:` (type `,`)? function-type
-//                  (`,` `bundletype(` type-list-list `)`)?
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
@@ -1708,10 +1694,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
-  auto opBundlesLoc = parser.getCurrentLocation();
-  if (auto result = parseOpBundles(parser, opBundleOperands,
-                                   opBundleOperandTypes, opBundleTags);
-      result.has_value() && failed(*result))
+  SMLoc opBundlesLoc = parser.getCurrentLocation();
+  if (std::optional<ParseResult> result = parseOpBundles(
+          parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
+      result && failed(*result))
     return failure();
   if (!opBundleTags.empty())
     result.getOrAddProperties<InvokeOp::Properties>().op_bundle_tags =
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 53ca302518a90e..cd4e760c3b4bcf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -107,7 +107,7 @@ convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
                      LLVM::ModuleTranslation &moduleTranslation) {
   std::vector<llvm::Value *> operands;
   operands.reserve(bundleOperands.size());
-  for (auto bundleArg : bundleOperands)
+  for (Value bundleArg : bundleOperands)
     operands.push_back(moduleTranslation.lookupValue(bundleArg));
   return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
 }
@@ -131,7 +131,7 @@ static SmallVector<llvm::OperandBundleDef>
 convertOperandBundles(OperandRangeRange bundleOperands,
                       std::optional<ArrayRef<std::string>> bundleTags,
                       LLVM::ModuleTranslation &moduleTranslation) {
-  if (!bundleTags.has_value())
+  if (!bundleTags)
     bundleTags.emplace();
   return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
 }
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 1121691133108f..afe01d3ff89d68 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1662,6 +1662,6 @@ llvm.func @foo()
 llvm.func @wrong_number_of_bundle_types() {
   %0 = llvm.mlir.constant(0 : i32) : i32
   // expected-error at +1 {{expected 1 types for operand bundle operands for operand bundle #0, but actually got 2}}
-  llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () bundletype((i32, i32))
+  llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> ()
   llvm.return
 }

>From cda14f5fb63bbdd5a7d398e7f494a6b29c5d4edb Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 20:51:38 +0800
Subject: [PATCH 3/7] parse empty operand bundles list

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp |  3 +++
 mlir/test/Target/LLVMIR/llvmir.mlir        | 10 ++++++++++
 2 files changed, 13 insertions(+)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e7550b653d27be..80f6ae7a224c8b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -289,6 +289,9 @@ static std::optional<ParseResult> parseOpBundles(
   if (p.parseOptionalLSquare())
     return std::nullopt;
 
+  if (succeeded(p.parseOptionalRSquare()))
+    return success();
+
   auto bundleParser = [&] {
     return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
                             opBundleTags);
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 69c16fb369f3bd..c74b4b641eda05 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2626,6 +2626,16 @@ llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 3
 
 llvm.func @foo()
 
+llvm.func @call_with_empty_opbundle() {
+  llvm.call @foo() [] : () -> ()
+  llvm.return
+}
+
+//      CHECK: define void @call_with_empty_opbundle() {
+// CHECK-NEXT:   call void @foo()
+// CHECK-NEXT:   ret void
+// CHECK-NEXT: }
+
 llvm.func @call_with_opbundle() {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.mlir.constant(2 : i32) : i32

>From 02df51c63e50a11fc0ceb3c3946d62e1116f2231 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:02:35 +0800
Subject: [PATCH 4/7] add test for operand bundle verifier

---
 mlir/test/Dialect/LLVMIR/invalid.mlir | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index afe01d3ff89d68..9388d7ef24936e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1665,3 +1665,18 @@ llvm.func @wrong_number_of_bundle_types() {
   llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @foo()
+llvm.func @wrong_number_of_bundle_tags() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  // expected-error at +1 {{expected 2 operand bundle tags, but actually got 1}}
+  "llvm.call"(%0, %1) <{ op_bundle_tags = ["tag"] }> {
+    callee = @foo,
+    operandSegmentSizes = array<i32: 0, 2>,
+    op_bundle_sizes = array<i32: 1, 1>
+  } : (i32, i32) -> ()
+  llvm.return
+}

>From 2b79ea9ddcc97422d35906a22a5d7bf3b46788a3 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:29:34 +0800
Subject: [PATCH 5/7] parse empty operands within a bundle

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 +++++---------
 mlir/test/Target/LLVMIR/llvmir.mlir        | 10 ++++++++++
 2 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 80f6ae7a224c8b..4b95e5486e74ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -264,15 +264,11 @@ static ParseResult parseOneOpBundle(
   if (p.parseLParen())
     return failure();
 
-  if (p.parseOperandList(operands))
-    return failure();
-  if (p.parseColon())
-    return failure();
-  if (p.parseTypeList(types))
-    return failure();
-
-  if (p.parseRParen())
-    return failure();
+  if (p.parseOptionalRParen()) {
+    if (p.parseOperandList(operands) || p.parseColon() ||
+        p.parseTypeList(types) || p.parseRParen())
+      return failure();
+  }
 
   opBundleOperands.push_back(std::move(operands));
   opBundleOperandTypes.push_back(std::move(types));
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index c74b4b641eda05..89f9c15cf84e04 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2636,6 +2636,16 @@ llvm.func @call_with_empty_opbundle() {
 // CHECK-NEXT:   ret void
 // CHECK-NEXT: }
 
+llvm.func @call_with_empty_opbundle_operands() {
+  llvm.call @foo() ["tag"()] : () -> ()
+  llvm.return
+}
+
+//      CHECK: define void @call_with_empty_opbundle_operands() {
+// CHECK-NEXT:   call void @foo() [ "tag"() ]
+// CHECK-NEXT:   ret void
+// CHECK-NEXT: }
+
 llvm.func @call_with_opbundle() {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.mlir.constant(2 : i32) : i32

>From 1c88e5924004a99a46566e268bef7de924bb7d78 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:32:43 +0800
Subject: [PATCH 6/7] nit: replace llvm::zip with llvm::zip_equal

---
 .../Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index cd4e760c3b4bcf..78a3f1809aec31 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -116,13 +116,10 @@ static SmallVector<llvm::OperandBundleDef>
 convertOperandBundles(OperandRangeRange bundleOperands,
                       ArrayRef<std::string> bundleTags,
                       LLVM::ModuleTranslation &moduleTranslation) {
-  assert(bundleOperands.size() == bundleTags.size() &&
-         "operand bundles and tags do not match");
-
   SmallVector<llvm::OperandBundleDef> bundles;
   bundles.reserve(bundleOperands.size());
 
-  for (auto [operands, tag] : llvm::zip(bundleOperands, bundleTags))
+  for (auto [operands, tag] : llvm::zip_equal(bundleOperands, bundleTags))
     bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
   return bundles;
 }

>From 7a54369eb3c9e103924b429527363728d1d14d5c Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:58:14 +0800
Subject: [PATCH 7/7] add roundtrip test for operand bundle syntax

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 17 +++--
 mlir/test/Dialect/LLVMIR/roundtrip.mlir    | 83 ++++++++++++++++++++++
 2 files changed, 94 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 4b95e5486e74ac..0561c364c7d591 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -228,9 +228,13 @@ static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
                              TypeRange operandTypes, StringRef tag) {
   p.printString(tag);
   p << "(";
-  p.printOperands(operands);
-  p << " : ";
-  llvm::interleaveComma(operandTypes, p);
+
+  if (!operands.empty()) {
+    p.printOperands(operands);
+    p << " : ";
+    llvm::interleaveComma(operandTypes, p);
+  }
+
   p << ")";
 }
 
@@ -1611,7 +1615,7 @@ void InvokeOp::print(OpAsmPrinter &p) {
   else
     p << getOperand(0);
 
-  p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
+  p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
   p << " to ";
   p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
   p << " unwind ";
@@ -1635,8 +1639,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p << " : ";
   if (!isDirect)
     p << getOperand(0).getType() << ", ";
-  p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
-                        getResultTypes());
+  p.printFunctionalType(
+      llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
+      getResultTypes());
 }
 
 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ff16bb0f857dda..cbefa64d55b405 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -729,3 +729,86 @@ llvm.func @test_notail() -> i32 {
   %0 = llvm.call notail @tail_call_target() : () -> i32
   llvm.return %0 : i32
 }
+
+llvm.func @op_bundle_target()
+
+// CHECK-LABEL: @test_call_with_empty_opbundle
+llvm.func @test_call_with_empty_opbundle() {
+  // CHECK: llvm.call @op_bundle_target() : () -> ()
+  llvm.call @op_bundle_target() [] : () -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: @test_call_with_empty_opbundle_operands
+llvm.func @test_call_with_empty_opbundle_operands() {
+  // CHECK: llvm.call @op_bundle_target() ["tag"()] : () -> ()
+  llvm.call @op_bundle_target() ["tag"()] : () -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: @test_call_with_opbundle
+llvm.func @test_call_with_opbundle() {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %2 = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: llvm.call @op_bundle_target() ["tag1"(%{{.+}}, %{{.+}} : i32, i32), "tag2"(%{{.+}} : i32)] : () -> ()
+  llvm.call @op_bundle_target() ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: @test_invoke_with_empty_opbundle
+llvm.func @test_invoke_with_empty_opbundle() attributes { personality = @__gxx_personality_v0 } {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  %2 = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} : () -> ()
+  llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 [] : () -> ()
+
+^bb1:
+  %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+
+^bb2:
+  llvm.return
+}
+
+// CHECK-LABEL: @test_invoke_with_empty_opbundle_operands
+llvm.func @test_invoke_with_empty_opbundle_operands() attributes { personality = @__gxx_personality_v0 } {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  %2 = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} ["tag"()] : () -> ()
+  llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 ["tag"()] : () -> ()
+
+^bb1:
+  %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+
+^bb2:
+  llvm.return
+}
+
+// CHECK-LABEL: @test_invoke_with_opbundle
+llvm.func @test_invoke_with_opbundle() attributes { personality = @__gxx_personality_v0 } {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  %2 = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} ["tag1"(%{{.+}}, %{{.+}} : i32, i32), "tag2"(%{{.+}} : i32)] : () -> ()
+  llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+
+^bb1:
+  %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+
+^bb2:
+  llvm.return
+}
+
+// CHECK-LABEL: @test_call_intrin_with_opbundle
+llvm.func @test_call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  %1 = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: llvm.call_intrinsic "llvm.assume"(%{{.+}}) ["align"(%{{.+}}, %{{.+}} : !llvm.ptr, i32)] : (i1) -> ()
+  llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> ()
+  llvm.return
+}



More information about the Mlir-commits mailing list