[Mlir-commits] [mlir] 1dfb104 - [mlir][LLVMIR] Add operand bundle support for llvm.intr.assume (#112143)

Sirui Mu llvmlistbot at llvm.org
Wed Oct 16 05:50:03 PDT 2024


Author: Sirui Mu
Date: 2024-10-16T20:49:02+08:00
New Revision: 1dfb104eac73863b06751bea225ffa6ef589577f

URL: https://github.com/llvm/llvm-project/commit/1dfb104eac73863b06751bea225ffa6ef589577f
DIFF: https://github.com/llvm/llvm-project/commit/1dfb104eac73863b06751bea225ffa6ef589577f.diff

LOG: [mlir][LLVMIR] Add operand bundle support for llvm.intr.assume (#112143)

This patch adds operand bundle support for `llvm.intr.assume`.

This patch actually contains two parts:

- `llvm.intr.assume` now accepts operand bundle related attributes and
operands. `llvm.intr.assume` does not take constraint on the operand
bundles, but obviously only a few set of operand bundles are meaningful.
I plan to add some of those (e.g. `aligned` and `separate_storage` are
what interest me but other people may be interested in other operand
bundles as well) in future patches.

- The definitions of `llvm.call`, `llvm.invoke`, and
`llvm.call_intrinsic` actually define `op_bundle_tags` as an operation
property. It turns out this approach would introduce some unnecessary
burden if applied equally to the intrinsic operations because properties
are not available through `Operation *` but we have to operate on
`Operation *` during the import/export of intrinsics, so this PR changes
it from a property to an array attribute.

This patch relands commit d8fadad07c952c4aea967aefb0900e4e43ad0555.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/inlining.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/LLVMIR/Import/intrinsic.ll
    mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index 0e38325f9891ac..e81db32bcaad03 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -71,6 +71,7 @@ class ArmSME_IntrOp<string mnemonic,
           /*bit requiresAccessGroup=*/0,
           /*bit requiresAliasAnalysis=*/0,
           /*bit requiresFastmath=*/0,
+          /*bit requiresOpBundles=*/0,
           /*list<int> immArgPositions=*/immArgPositions,
           /*list<string> immArgAttrNames=*/immArgAttrNames>;
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index 27a2b418aadb2a..ea82f7f7b8e124 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -59,6 +59,8 @@ def LLVM_Dialect : Dialect {
     static StringRef getStructRetAttrName() { return "llvm.sret"; }
     static StringRef getWriteOnlyAttrName() { return "llvm.writeonly"; }
     static StringRef getZExtAttrName() { return "llvm.zeroext"; }
+    static StringRef getOpBundleSizesAttrName() { return "op_bundle_sizes"; }
+    static StringRef getOpBundleTagsAttrName() { return "op_bundle_tags"; }
     // TODO Restrict the usage of this to parameter attributes once there is an
     // alternative way of modeling memory effects on FunctionOpInterface.
     /// Name of the attribute that will cause the creation of a readnone memory

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index ab40c8ec4b6588..845c88b1be7750 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -120,7 +120,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">;
 def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">;
 def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
   /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
-  /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"]
+  /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3],
+  /*immArgAttrNames=*/["rw", "hint", "cache"]
 > {
   let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
 }
@@ -176,7 +177,8 @@ class LLVM_MemcpyIntrOpBase<string name> :
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
     /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
-    /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
+    /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
+    /*immArgAttrNames=*/["isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
                   AnySignlessInteger:$len, I1Attr:$isVolatile);
@@ -206,7 +208,8 @@ def LLVM_MemcpyInlineOp :
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
     /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
-    /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
+    /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
+    /*immArgAttrNames=*/["len", "isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
                   APIntAttr:$len, I1Attr:$isVolatile);
@@ -232,7 +235,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
      DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
     /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
-    /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
+    /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
+    /*immArgAttrNames=*/["isVolatile"]> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
   // Append the alias attributes defined by LLVM_IntrOpBase.
@@ -286,7 +290,8 @@ def LLVM_NoAliasScopeDeclOp
 class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1],
     [DeclareOpInterfaceMethods<PromotableOpInterface>],
     /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
-    /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
+    /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
+    /*immArgAttrNames=*/["size"]> {
   let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
   let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
 }
@@ -306,7 +311,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1],
 def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
     [DeclareOpInterfaceMethods<PromotableOpInterface>],
     /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
-    /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> {
+    /*requiresOpBundles=*/0, /*immArgPositions=*/[1],
+    /*immArgAttrNames=*/["size"]> {
   let arguments = (ins LLVM_DefaultPointer:$start,
                        I64Attr:$size,
                        LLVM_AnyPointer:$ptr);
@@ -368,7 +374,7 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
     SmallVector<Value> mlirOperands;
     SmallVector<NamedAttribute> mlirAttrs;
     if (failed(moduleImport.convertIntrinsicArguments(
-        llvmOperands.take_front( }] # numArgs # [{),
+        llvmOperands.take_front( }] # numArgs # [{), {}, false,
         {}, {}, mlirOperands, mlirAttrs))) {
       return failure();
     }
@@ -429,7 +435,26 @@ def LLVM_USHLSat : LLVM_BinarySameArgsIntrOpI<"ushl.sat">;
 //
 
 def LLVM_AssumeOp
-  : LLVM_ZeroResultIntrOp<"assume", []>, Arguments<(ins I1:$cond)>;
+    : LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[],
+                            /*requiresAccessGroup=*/0,
+                            /*requiresAliasAnalysis=*/0,
+                            /*requiresOpBundles=*/1> {
+  dag args = (ins I1:$cond);
+  let arguments = !con(args, opBundleArgs);
+
+  let assemblyFormat = [{
+    $cond
+    ( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
+                        $op_bundle_tags)^ )?
+    `:` type($cond) attr-dict
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$cond)>
+  ];
+
+  let hasVerifier = 1;
+}
 
 def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
                                             [Pure, SameOperandsAndResultType]> {
@@ -992,7 +1017,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">;
 def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap",
   /*overloadedOperands=*/[], /*traits=*/[],
   /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
-  /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> {
+  /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
+  /*immArgAttrNames=*/["failureKind"]> {
   let arguments = (ins I8Attr:$failureKind);
 }
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index c3d352d8d0dd48..a38dafa4d9cf34 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -291,7 +291,7 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
                       list<Trait> traits, int numResults,
                       bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
-                      bit requiresFastmath = 0,
+                      bit requiresFastmath = 0, bit requiresOpBundles = 0,
                       list<int> immArgPositions = [],
                       list<string> immArgAttrNames = []>
     : LLVM_OpBase<dialect, opName, !listconcat(
@@ -313,6 +313,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                  OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes,
                  OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa),
             (ins )));
+  dag opBundleArgs = !if(!gt(requiresOpBundles, 0),
+                         (ins VariadicOfVariadic<LLVM_Type,
+                                "op_bundle_sizes">:$op_bundle_operands,
+                              DenseI32ArrayAttr:$op_bundle_sizes,
+                              OptionalAttr<ArrayAttr>:$op_bundle_tags),
+                         (ins ));
   string llvmEnumName = enumName;
   string overloadedResultsCpp =  "{" # !interleave(overloadedResults, ", ") # "}";
   string overloadedOperandsCpp =  "{" # !interleave(overloadedOperands, ", ") # "}";
@@ -336,6 +342,8 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
     SmallVector<NamedAttribute> mlirAttrs;
     if (failed(moduleImport.convertIntrinsicArguments(
       llvmOperands,
+      llvmOpBundles,
+      }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
       }] # immArgPositionsCpp # [{,
       }] # immArgAttrNamesCpp # [{,
       mlirOperands,
@@ -381,12 +389,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<Trait> traits,
                   int numResults, bit requiresAccessGroup = 0,
                   bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+                  bit requiresOpBundles = 0,
                   list<int> immArgPositions = [],
                   list<string> immArgAttrNames = []>
     : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
                       overloadedResults, overloadedOperands, traits,
                       numResults, requiresAccessGroup, requiresAliasAnalysis,
-                      requiresFastmath, immArgPositions, immArgAttrNames>;
+                      requiresFastmath, requiresOpBundles, immArgPositions,
+                      immArgAttrNames>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -406,11 +416,13 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
                             list<Trait> traits = [],
                             bit requiresAccessGroup = 0,
                             bit requiresAliasAnalysis = 0,
+                            bit requiresOpBundles = 0,
                             list<int> immArgPositions = [],
                             list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
                   requiresAccessGroup, requiresAliasAnalysis,
-                  /*requiresFastMath=*/0, immArgPositions, immArgAttrNames>;
+                  /*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
+                  immArgAttrNames>;
 
 // Base class for LLVM intrinsic operations returning one result. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -422,11 +434,12 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
                            list<int> overloadedOperands = [],
                            list<Trait> traits = [],
                            bit requiresFastmath = 0,
-                          list<int> immArgPositions = [],
-                          list<string> immArgAttrNames = []>
+                           list<int> immArgPositions = [],
+                           list<string> immArgAttrNames = []>
     : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
                   /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
-                  requiresFastmath, immArgPositions, immArgAttrNames>;
+                  requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
+                  immArgAttrNames>;
 
 def LLVM_OneResultOpBuilder :
   OpBuilder<(ins "Type":$resultType, "ValueRange":$operands,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index bbca7bc7286acb..d5def510a904d3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -559,11 +559,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                    VariadicOfVariadic<LLVM_Type,
                                       "op_bundle_sizes">:$op_bundle_operands,
                    DenseI32ArrayAttr:$op_bundle_sizes,
-                   DefaultValuedProperty<
-                     ArrayProperty<StringProperty, "operand bundle tags">,
-                     "ArrayRef<std::string>{}",
-                     "SmallVector<std::string>{}"
-                   >:$op_bundle_tags);
+                   OptionalAttr<ArrayAttr>:$op_bundle_tags);
   let results = (outs Optional<LLVM_Type>:$result);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -678,11 +674,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   VariadicOfVariadic<LLVM_Type,
                                      "op_bundle_sizes">:$op_bundle_operands,
                   DenseI32ArrayAttr:$op_bundle_sizes,
-                  DefaultValuedProperty<
-                    ArrayProperty<StringProperty, "operand bundle tags">,
-                    "ArrayRef<std::string>{}",
-                    "SmallVector<std::string>{}"
-                  >:$op_bundle_tags);
+                  OptionalAttr<ArrayAttr>:$op_bundle_tags);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
@@ -1930,11 +1922,7 @@ def LLVM_CallIntrinsicOp
                        VariadicOfVariadic<LLVM_Type,
                                           "op_bundle_sizes">:$op_bundle_operands,
                        DenseI32ArrayAttr:$op_bundle_sizes,
-                       DefaultValuedProperty<
-                         ArrayProperty<StringProperty, "operand bundle tags">,
-                         "ArrayRef<std::string>{}",
-                         "SmallVector<std::string>{}"
-                       >:$op_bundle_tags);
+                       OptionalAttr<ArrayAttr>:$op_bundle_tags);
   let results = (outs Optional<LLVM_Type>:$results);
   let llvmBuilder = [{
     return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index c40ae4b1016b49..3695708439d91f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -98,7 +98,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
   LLVM_IntrOpBase<ROCDL_Dialect,  mnemonic,
     "amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
     overloadedOperands, traits, numResults, requiresAccessGroup,
-    requiresAliasAnalysis, 0, immArgPositions, immArgAttrNames>;
+    requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
 
 //===----------------------------------------------------------------------===//
 // ROCDL special register op definitions

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 9f300bcafea537..bbb7af58d27393 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -243,6 +243,8 @@ class ModuleImport {
   /// corresponding MLIR attribute names.
   LogicalResult
   convertIntrinsicArguments(ArrayRef<llvm::Value *> values,
+                            ArrayRef<llvm::OperandBundleUse> opBundles,
+                            bool requiresOpBundles,
                             ArrayRef<unsigned> immArgPositions,
                             ArrayRef<StringLiteral> immArgAttrNames,
                             SmallVectorImpl<Value> &valuesOut,

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 12ed8cc88ae7b7..cc73878a64ff67 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -241,13 +241,18 @@ static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
 static void printOpBundles(OpAsmPrinter &p, Operation *op,
                            OperandRangeRange opBundleOperands,
                            TypeRangeRange opBundleOperandTypes,
-                           ArrayRef<std::string> opBundleTags) {
+                           std::optional<ArrayAttr> opBundleTags) {
+  if (opBundleOperands.empty())
+    return;
+  assert(opBundleTags && "expect operand bundle tags");
+
   p << "[";
   llvm::interleaveComma(
-      llvm::zip(opBundleOperands, opBundleOperandTypes, opBundleTags), p,
+      llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
       [&p](auto bundle) {
+        auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
         printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
-                         std::get<2>(bundle));
+                         bundleTag);
       });
   p << "]";
 }
@@ -256,7 +261,7 @@ static ParseResult parseOneOpBundle(
     OpAsmParser &p,
     SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
     SmallVector<SmallVector<Type>> &opBundleOperandTypes,
-    SmallVector<std::string> &opBundleTags) {
+    SmallVector<Attribute> &opBundleTags) {
   SMLoc currentParserLoc = p.getCurrentLocation();
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
   SmallVector<Type> types;
@@ -276,7 +281,7 @@ static ParseResult parseOneOpBundle(
 
   opBundleOperands.push_back(std::move(operands));
   opBundleOperandTypes.push_back(std::move(types));
-  opBundleTags.push_back(std::move(tag));
+  opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
 
   return success();
 }
@@ -285,16 +290,17 @@ static std::optional<ParseResult> parseOpBundles(
     OpAsmParser &p,
     SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
     SmallVector<SmallVector<Type>> &opBundleOperandTypes,
-    SmallVector<std::string> &opBundleTags) {
+    ArrayAttr &opBundleTags) {
   if (p.parseOptionalLSquare())
     return std::nullopt;
 
   if (succeeded(p.parseOptionalRSquare()))
     return success();
 
+  SmallVector<Attribute> opBundleTagAttrs;
   auto bundleParser = [&] {
     return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
-                            opBundleTags);
+                            opBundleTagAttrs);
   };
   if (p.parseCommaSeparatedList(bundleParser))
     return failure();
@@ -302,6 +308,8 @@ static std::optional<ParseResult> parseOpBundles(
   if (p.parseRSquare())
     return failure();
 
+  opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
+
   return success();
 }
 
@@ -1039,7 +1047,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1066,7 +1074,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr,
         /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
         /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1079,7 +1087,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1092,7 +1100,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1192,12 +1200,20 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
 template <typename OpType>
 static LogicalResult verifyOperandBundles(OpType &op) {
   OperandRangeRange opBundleOperands = op.getOpBundleOperands();
-  ArrayRef<std::string> opBundleTags = op.getOpBundleTags();
+  std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
 
-  if (opBundleTags.size() != opBundleOperands.size())
+  auto isStringAttr = [](Attribute tagAttr) {
+    return isa<StringAttr>(tagAttr);
+  };
+  if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
+    return op.emitError("operand bundle tag must be a StringAttr");
+
+  size_t numOpBundles = opBundleOperands.size();
+  size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
+  if (numOpBundles != numOpBundleTags)
     return op.emitError("expected ")
-           << opBundleOperands.size()
-           << " operand bundle tags, but actually got " << opBundleTags.size();
+           << numOpBundles << " operand bundle tags, but actually got "
+           << numOpBundleTags;
 
   return success();
 }
@@ -1329,7 +1345,8 @@ void CallOp::print(OpAsmPrinter &p) {
                           {getCalleeAttrName(), getTailCallKindAttrName(),
                            getVarCalleeTypeAttrName(), getCConvAttrName(),
                            getOperandSegmentSizesAttrName(),
-                           getOpBundleSizesAttrName()});
+                           getOpBundleSizesAttrName(),
+                           getOpBundleTagsAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1437,7 +1454,7 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
   SmallVector<SmallVector<Type>> opBundleOperandTypes;
-  SmallVector<std::string> opBundleTags;
+  ArrayAttr opBundleTags;
 
   // Default to C Calling Convention if no keyword is provided.
   result.addAttribute(
@@ -1483,9 +1500,9 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
           parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
       result && failed(*result))
     return failure();
-  if (!opBundleTags.empty())
-    result.getOrAddProperties<CallOp::Properties>().op_bundle_tags =
-        std::move(opBundleTags);
+  if (opBundleTags && !opBundleTags.empty())
+    result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
+                        opBundleTags);
 
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
@@ -1525,8 +1542,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
-        normalOps, unwindOps, nullptr, nullptr, {}, std::nullopt, normal,
-        unwind);
+        normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1535,7 +1551,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange unwindOps) {
   build(builder, state, tys,
         /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, {}, std::nullopt, normal, unwind);
+        nullptr, {}, {}, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1544,7 +1560,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
-        nullptr, nullptr, {}, std::nullopt, normal, unwind);
+        nullptr, nullptr, {}, {}, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1634,7 +1650,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {getCalleeAttrName(), getOperandSegmentSizeAttr(),
                            getCConvAttrName(), getVarCalleeTypeAttrName(),
-                           getOpBundleSizesAttrName()});
+                           getOpBundleSizesAttrName(),
+                           getOpBundleTagsAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1657,7 +1674,7 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   TypeAttr varCalleeType;
   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
   SmallVector<SmallVector<Type>> opBundleOperandTypes;
-  SmallVector<std::string> opBundleTags;
+  ArrayAttr opBundleTags;
   Block *normalDest, *unwindDest;
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
@@ -1703,9 +1720,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
           parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
       result && failed(*result))
     return failure();
-  if (!opBundleTags.empty())
-    result.getOrAddProperties<InvokeOp::Properties>().op_bundle_tags =
-        std::move(opBundleTags);
+  if (opBundleTags && !opBundleTags.empty())
+    result.addAttribute(
+        InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
+        opBundleTags);
 
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
@@ -3333,7 +3351,7 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
                             mlir::StringAttr intrin, mlir::ValueRange args) {
   build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
         FastmathFlagsAttr{},
-        /*op_bundle_operands=*/{});
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
 }
 
 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
@@ -3341,14 +3359,14 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
                             mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
   build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
         fastMathFlags,
-        /*op_bundle_operands=*/{});
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
 }
 
 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
                             mlir::Type resultType, mlir::StringAttr intrin,
                             mlir::ValueRange args) {
   build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
-        /*op_bundle_operands=*/{});
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
 }
 
 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
@@ -3356,7 +3374,7 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
                             mlir::StringAttr intrin, mlir::ValueRange args,
                             mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
   build(builder, state, resultTypes, intrin, args, fastMathFlags,
-        /*op_bundle_operands=*/{});
+        /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
 }
 
 //===----------------------------------------------------------------------===//
@@ -3413,6 +3431,18 @@ void InlineAsmOp::getEffects(
   }
 }
 
+//===----------------------------------------------------------------------===//
+// AssumeOp (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
+                           mlir::Value cond) {
+  return build(builder, state, cond, /*op_bundle_operands=*/{},
+               /*op_bundle_tags=*/{});
+}
+
+LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
+
 //===----------------------------------------------------------------------===//
 // masked_gather (intrinsic)
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index d034e576dfc579..4fd043c7c93e68 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -68,6 +68,12 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
   if (isConvertibleIntrinsic(intrinsicID)) {
     SmallVector<llvm::Value *> args(inst->args());
     ArrayRef<llvm::Value *> llvmOperands(args);
+
+    SmallVector<llvm::OperandBundleUse> llvmOpBundles;
+    llvmOpBundles.reserve(inst->getNumOperandBundles());
+    for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i)
+      llvmOpBundles.push_back(inst->getOperandBundleAt(i));
+
 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"
   }
 

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index a8595d14ccf2e5..2084e527773ca8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -114,17 +114,27 @@ convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
 }
 
 static SmallVector<llvm::OperandBundleDef>
-convertOperandBundles(OperandRangeRange bundleOperands,
-                      ArrayRef<std::string> bundleTags,
+convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags,
                       LLVM::ModuleTranslation &moduleTranslation) {
   SmallVector<llvm::OperandBundleDef> bundles;
   bundles.reserve(bundleOperands.size());
 
-  for (auto [operands, tag] : llvm::zip_equal(bundleOperands, bundleTags))
+  for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) {
+    StringRef tag = cast<StringAttr>(tagAttr).getValue();
     bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
+  }
   return bundles;
 }
 
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(OperandRangeRange bundleOperands,
+                      std::optional<ArrayAttr> bundleTags,
+                      LLVM::ModuleTranslation &moduleTranslation) {
+  if (!bundleTags)
+    return {};
+  return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
+}
+
 /// Builder for LLVM_CallIntrinsicOp
 static LogicalResult
 convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index bc830a77f3c580..2c0b665ad0d833 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -50,6 +50,12 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
   if (isConvertibleIntrinsic(intrinsicID)) {
     SmallVector<llvm::Value *> args(inst->args());
     ArrayRef<llvm::Value *> llvmOperands(args);
+
+    SmallVector<llvm::OperandBundleUse> llvmOpBundles;
+    llvmOpBundles.reserve(inst->getNumOperandBundles());
+    for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i)
+      llvmOpBundles.push_back(inst->getOperandBundleAt(i));
+
 #include "mlir/Dialect/LLVMIR/NVVMFromLLVMIRConversions.inc"
   }
 

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index bd861f3a69e53c..6e97b2a50af8a1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1311,7 +1311,8 @@ ModuleImport::convertValues(ArrayRef<llvm::Value *> values) {
 }
 
 LogicalResult ModuleImport::convertIntrinsicArguments(
-    ArrayRef<llvm::Value *> values, ArrayRef<unsigned> immArgPositions,
+    ArrayRef<llvm::Value *> values, ArrayRef<llvm::OperandBundleUse> opBundles,
+    bool requiresOpBundles, ArrayRef<unsigned> immArgPositions,
     ArrayRef<StringLiteral> immArgAttrNames, SmallVectorImpl<Value> &valuesOut,
     SmallVectorImpl<NamedAttribute> &attrsOut) {
   assert(immArgPositions.size() == immArgAttrNames.size() &&
@@ -1341,6 +1342,35 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
     valuesOut.push_back(*mlirValue);
   }
 
+  SmallVector<int> opBundleSizes;
+  SmallVector<Attribute> opBundleTagAttrs;
+  if (requiresOpBundles) {
+    opBundleSizes.reserve(opBundles.size());
+    opBundleTagAttrs.reserve(opBundles.size());
+
+    for (const llvm::OperandBundleUse &bundle : opBundles) {
+      opBundleSizes.push_back(bundle.Inputs.size());
+      opBundleTagAttrs.push_back(StringAttr::get(context, bundle.getTagName()));
+
+      for (const llvm::Use &opBundleOperand : bundle.Inputs) {
+        auto operandMlirValue = convertValue(opBundleOperand.get());
+        if (failed(operandMlirValue))
+          return failure();
+        valuesOut.push_back(*operandMlirValue);
+      }
+    }
+
+    auto opBundleSizesAttr = DenseI32ArrayAttr::get(context, opBundleSizes);
+    auto opBundleSizesAttrNameAttr =
+        StringAttr::get(context, LLVMDialect::getOpBundleSizesAttrName());
+    attrsOut.push_back({opBundleSizesAttrNameAttr, opBundleSizesAttr});
+
+    auto opBundleTagsAttr = ArrayAttr::get(context, opBundleTagAttrs);
+    auto opBundleTagsAttrNameAttr =
+        StringAttr::get(context, LLVMDialect::getOpBundleTagsAttrName());
+    attrsOut.push_back({opBundleTagsAttrNameAttr, opBundleTagsAttr});
+  }
+
   return success();
 }
 

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 6e005f9ec5df85..ceb8ba3b33818b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -55,6 +55,7 @@
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <numeric>
 #include <optional>
 
 #define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
@@ -854,8 +855,40 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
          "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
          "length");
 
+  SmallVector<llvm::OperandBundleDef> opBundles;
+  size_t numOpBundleOperands = 0;
+  auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
+      intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
+  auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
+      intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
+
+  if (opBundleSizesAttr && opBundleTagsAttr) {
+    ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
+    assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
+           "operand bundles and tags do not match");
+
+    numOpBundleOperands =
+        std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
+    assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+           "operand bundle operands is more than the number of operands");
+
+    ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+    size_t nextOperandIdx = 0;
+    opBundles.reserve(opBundleSizesAttr.size());
+
+    for (auto [opBundleTagAttr, bundleSize] :
+         llvm::zip(opBundleTagsAttr, opBundleSizes)) {
+      auto bundleTag = cast<StringAttr>(opBundleTagAttr).str();
+      auto bundleOperands = moduleTranslation.lookupValues(
+          operands.slice(nextOperandIdx, bundleSize));
+      opBundles.emplace_back(std::move(bundleTag), std::move(bundleOperands));
+      nextOperandIdx += bundleSize;
+    }
+  }
+
   // Map operands and attributes to LLVM values.
-  auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
+  auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+  auto operands = moduleTranslation.lookupValues(opOperands);
   SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
   for (auto [immArgPos, immArgName] :
        llvm::zip(immArgPositions, immArgAttrNames)) {
@@ -890,7 +923,7 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
       module, intrinsic, overloadedTypes);
 
-  return builder.CreateCall(llvmIntr, args);
+  return builder.CreateCall(llvmIntr, args, opBundles);
 }
 
 /// Given a single MLIR operation, create the corresponding LLVM IR operation

diff  --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index b86103422b0745..55b1bc9c545a85 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -684,7 +684,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
 // CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
 // CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}}  : i64
 // CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
-// CHECK: "llvm.intr.assume"(%[[CMP]]) : (i1) -> ()
+// CHECK: llvm.intr.assume %[[CMP]] : i1
 // CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 // CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
 // CHECK: return %[[VAL]] : f32

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 9dc22abf143bf0..48dc9079333d4f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -160,7 +160,7 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64
   // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
   // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
-  // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> ()
+  // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
   memref.assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
@@ -177,7 +177,7 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset
   // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
   // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
   // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
-  // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> ()
+  // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
   memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
   return
 }

diff  --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index f9551e311df59f..0b7ca3f2bb048a 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -18,7 +18,7 @@ func.func @inner_func_inlinable(%ptr : !llvm.ptr) -> i32 {
   "llvm.intr.memset"(%ptr, %byte, %0) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
   "llvm.intr.memmove"(%ptr, %ptr, %0) <{isVolatile = true}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
   "llvm.intr.memcpy"(%ptr, %ptr, %0) <{isVolatile = true}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
-  "llvm.intr.assume"(%true) : (i1) -> ()
+  llvm.intr.assume %true : i1
   llvm.fence release
   %2 = llvm.atomicrmw add %ptr, %0 monotonic : !llvm.ptr, i32
   %3 = llvm.cmpxchg %ptr, %0, %1 acq_rel monotonic : !llvm.ptr, i32
@@ -44,7 +44,7 @@ func.func @inner_func_inlinable(%ptr : !llvm.ptr) -> i32 {
 // CHECK: "llvm.intr.memset"(%[[PTR]]
 // CHECK: "llvm.intr.memmove"(%[[PTR]], %[[PTR]]
 // CHECK: "llvm.intr.memcpy"(%[[PTR]], %[[PTR]]
-// CHECK: "llvm.intr.assume"
+// CHECK: llvm.intr.assume
 // CHECK: llvm.fence release
 // CHECK: llvm.atomicrmw add %[[PTR]], %[[CST]] monotonic
 // CHECK: llvm.cmpxchg %[[PTR]], %[[CST]], %[[RES]] acq_rel monotonic

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 3062cdc38c0abb..b8ce7db795a1d1 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -836,3 +836,30 @@ llvm.func @test_call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
   llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> ()
   llvm.return
 }
+
+// CHECK-LABEL: @test_assume_intr_no_opbundle
+llvm.func @test_assume_intr_no_opbundle(%arg0 : !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  // CHECK: llvm.intr.assume %0 : i1
+  llvm.intr.assume %0 : i1
+  llvm.return
+}
+
+// CHECK-LABEL: @test_assume_intr_empty_opbundle
+llvm.func @test_assume_intr_empty_opbundle(%arg0 : !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  // CHECK: llvm.intr.assume %0 : i1
+  llvm.intr.assume %0 [] : i1
+  llvm.return
+}
+
+// CHECK-LABEL: @test_assume_intr_with_opbundles
+llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) {
+  %0 = llvm.mlir.constant(1 : i1) : i1
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  %2 = llvm.mlir.constant(3 : i32) : i32
+  %3 = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
+  llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 28a1bd21c82a38..606b11175f572f 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -630,11 +630,21 @@ define void @va_intrinsics_test(ptr %0, ptr %1, ...) {
 ; CHECK-LABEL: @assume
 ; CHECK-SAME:  %[[TRUE:[a-zA-Z0-9]+]]
 define void @assume(i1 %true) {
-  ; CHECK:  "llvm.intr.assume"(%[[TRUE]]) : (i1) -> ()
+  ; CHECK:  llvm.intr.assume %[[TRUE]] : i1
   call void @llvm.assume(i1 %true)
   ret void
 }
 
+; CHECK-LABEL: @assume_with_opbundles
+; CHECK-SAME:  %[[TRUE:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
+define void @assume_with_opbundles(i1 %true, ptr %p) {
+  ; CHECK: %[[ALIGN:.+]] = llvm.mlir.constant(8 : i32) : i32
+  ; CHECK:  llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i32)] : i1
+  call void @llvm.assume(i1 %true) ["align"(ptr %p, i32 8)]
+  ret void
+}
+
 ; CHECK-LABEL: @is_constant
 ; CHECK-SAME:  %[[VAL:[a-zA-Z0-9]+]]
 define void @is_constant(i32 %0) {

diff  --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 0634a7ba907f1e..cb712eb4e1262d 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -363,6 +363,21 @@ llvm.func @umin_test(%arg0: i32, %arg1: i32, %arg2: vector<8xi32>, %arg3: vector
   llvm.return
 }
 
+// CHECK-LABEL: @assume_without_opbundles
+llvm.func @assume_without_opbundles(%cond: i1) {
+  // CHECK: call void @llvm.assume(i1 %{{.+}})
+  llvm.intr.assume %cond : i1
+  llvm.return
+}
+
+// CHECK-LABEL: @assume_with_opbundles
+llvm.func @assume_with_opbundles(%cond: i1, %p: !llvm.ptr) {
+  %0 = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: call void @llvm.assume(i1 %{{.+}}) [ "align"(ptr %{{.+}}, i32 8) ]
+  llvm.intr.assume %cond ["align"(%p, %0 : !llvm.ptr, i32)] : i1
+  llvm.return
+}
+
 // CHECK-LABEL: @vector_reductions
 llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi32>) {
   // CHECK: call i32 @llvm.vector.reduce.add.v8i32

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index af0981440a1776..15658ea6068121 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -188,7 +188,7 @@ llvm.func @sadd_overflow_intr_wrong_type(%arg0 : i32, %arg1 : f32) -> !llvm.stru
 
 llvm.func @assume_intr_wrong_type(%cond : i16) {
   // expected-error @below{{op operand #0 must be 1-bit signless integer, but got 'i16'}}
-  "llvm.intr.assume"(%cond) : (i16) -> ()
+  llvm.intr.assume %cond : i16
   llvm.return
 }
 


        


More information about the Mlir-commits mailing list