[Mlir-commits] [mlir] 1dfb104 - [mlir][LLVMIR] Add operand bundle support for llvm.intr.assume (#112143)
Sirui Mu
llvmlistbot at llvm.org
Wed Oct 16 05:50:03 PDT 2024
Author: Sirui Mu
Date: 2024-10-16T20:49:02+08:00
New Revision: 1dfb104eac73863b06751bea225ffa6ef589577f
URL: https://github.com/llvm/llvm-project/commit/1dfb104eac73863b06751bea225ffa6ef589577f
DIFF: https://github.com/llvm/llvm-project/commit/1dfb104eac73863b06751bea225ffa6ef589577f.diff
LOG: [mlir][LLVMIR] Add operand bundle support for llvm.intr.assume (#112143)
This patch adds operand bundle support for `llvm.intr.assume`.
This patch actually contains two parts:
- `llvm.intr.assume` now accepts operand bundle related attributes and
operands. `llvm.intr.assume` does not take constraint on the operand
bundles, but obviously only a few set of operand bundles are meaningful.
I plan to add some of those (e.g. `aligned` and `separate_storage` are
what interest me but other people may be interested in other operand
bundles as well) in future patches.
- The definitions of `llvm.call`, `llvm.invoke`, and
`llvm.call_intrinsic` actually define `op_bundle_tags` as an operation
property. It turns out this approach would introduce some unnecessary
burden if applied equally to the intrinsic operations because properties
are not available through `Operation *` but we have to operate on
`Operation *` during the import/export of intrinsics, so this PR changes
it from a property to an array attribute.
This patch relands commit d8fadad07c952c4aea967aefb0900e4e43ad0555.
Added:
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/include/mlir/Target/LLVMIR/ModuleImport.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Dialect/LLVMIR/inlining.mlir
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/LLVMIR/Import/intrinsic.ll
mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
mlir/test/Target/LLVMIR/llvmir-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index 0e38325f9891ac..e81db32bcaad03 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -71,6 +71,7 @@ class ArmSME_IntrOp<string mnemonic,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
+ /*bit requiresOpBundles=*/0,
/*list<int> immArgPositions=*/immArgPositions,
/*list<string> immArgAttrNames=*/immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index 27a2b418aadb2a..ea82f7f7b8e124 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -59,6 +59,8 @@ def LLVM_Dialect : Dialect {
static StringRef getStructRetAttrName() { return "llvm.sret"; }
static StringRef getWriteOnlyAttrName() { return "llvm.writeonly"; }
static StringRef getZExtAttrName() { return "llvm.zeroext"; }
+ static StringRef getOpBundleSizesAttrName() { return "op_bundle_sizes"; }
+ static StringRef getOpBundleTagsAttrName() { return "op_bundle_tags"; }
// TODO Restrict the usage of this to parameter attributes once there is an
// alternative way of modeling memory effects on FunctionOpInterface.
/// Name of the attribute that will cause the creation of a readnone memory
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index ab40c8ec4b6588..845c88b1be7750 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -120,7 +120,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">;
def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">;
def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
/*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"]
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3],
+ /*immArgAttrNames=*/["rw", "hint", "cache"]
> {
let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
}
@@ -176,7 +177,8 @@ class LLVM_MemcpyIntrOpBase<string name> :
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
+ /*immArgAttrNames=*/["isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
AnySignlessInteger:$len, I1Attr:$isVolatile);
@@ -206,7 +208,8 @@ def LLVM_MemcpyInlineOp :
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
+ /*immArgAttrNames=*/["len", "isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
APIntAttr:$len, I1Attr:$isVolatile);
@@ -232,7 +235,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
+ /*immArgAttrNames=*/["isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
// Append the alias attributes defined by LLVM_IntrOpBase.
@@ -286,7 +290,8 @@ def LLVM_NoAliasScopeDeclOp
class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1],
[DeclareOpInterfaceMethods<PromotableOpInterface>],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["size"]> {
let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
}
@@ -306,7 +311,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1],
def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
[DeclareOpInterfaceMethods<PromotableOpInterface>],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> {
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[1],
+ /*immArgAttrNames=*/["size"]> {
let arguments = (ins LLVM_DefaultPointer:$start,
I64Attr:$size,
LLVM_AnyPointer:$ptr);
@@ -368,7 +374,7 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands.take_front( }] # numArgs # [{),
+ llvmOperands.take_front( }] # numArgs # [{), {}, false,
{}, {}, mlirOperands, mlirAttrs))) {
return failure();
}
@@ -429,7 +435,26 @@ def LLVM_USHLSat : LLVM_BinarySameArgsIntrOpI<"ushl.sat">;
//
def LLVM_AssumeOp
- : LLVM_ZeroResultIntrOp<"assume", []>, Arguments<(ins I1:$cond)>;
+ : LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[],
+ /*requiresAccessGroup=*/0,
+ /*requiresAliasAnalysis=*/0,
+ /*requiresOpBundles=*/1> {
+ dag args = (ins I1:$cond);
+ let arguments = !con(args, opBundleArgs);
+
+ let assemblyFormat = [{
+ $cond
+ ( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
+ $op_bundle_tags)^ )?
+ `:` type($cond) attr-dict
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$cond)>
+ ];
+
+ let hasVerifier = 1;
+}
def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
[Pure, SameOperandsAndResultType]> {
@@ -992,7 +1017,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">;
def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap",
/*overloadedOperands=*/[], /*traits=*/[],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> {
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["failureKind"]> {
let arguments = (ins I8Attr:$failureKind);
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index c3d352d8d0dd48..a38dafa4d9cf34 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -291,7 +291,7 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
- bit requiresFastmath = 0,
+ bit requiresFastmath = 0, bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_OpBase<dialect, opName, !listconcat(
@@ -313,6 +313,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes,
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa),
(ins )));
+ dag opBundleArgs = !if(!gt(requiresOpBundles, 0),
+ (ins VariadicOfVariadic<LLVM_Type,
+ "op_bundle_sizes">:$op_bundle_operands,
+ DenseI32ArrayAttr:$op_bundle_sizes,
+ OptionalAttr<ArrayAttr>:$op_bundle_tags),
+ (ins ));
string llvmEnumName = enumName;
string overloadedResultsCpp = "{" # !interleave(overloadedResults, ", ") # "}";
string overloadedOperandsCpp = "{" # !interleave(overloadedOperands, ", ") # "}";
@@ -336,6 +342,8 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
llvmOperands,
+ llvmOpBundles,
+ }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
}] # immArgPositionsCpp # [{,
}] # immArgAttrNamesCpp # [{,
mlirOperands,
@@ -381,12 +389,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
int numResults, bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+ bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
overloadedResults, overloadedOperands, traits,
numResults, requiresAccessGroup, requiresAliasAnalysis,
- requiresFastmath, immArgPositions, immArgAttrNames>;
+ requiresFastmath, requiresOpBundles, immArgPositions,
+ immArgAttrNames>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -406,11 +416,13 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
list<Trait> traits = [],
bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0,
+ bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
requiresAccessGroup, requiresAliasAnalysis,
- /*requiresFastMath=*/0, immArgPositions, immArgAttrNames>;
+ /*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
+ immArgAttrNames>;
// Base class for LLVM intrinsic operations returning one result. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -422,11 +434,12 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> overloadedOperands = [],
list<Trait> traits = [],
bit requiresFastmath = 0,
- list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- requiresFastmath, immArgPositions, immArgAttrNames>;
+ requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
+ immArgAttrNames>;
def LLVM_OneResultOpBuilder :
OpBuilder<(ins "Type":$resultType, "ValueRange":$operands,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index bbca7bc7286acb..d5def510a904d3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -559,11 +559,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
- DefaultValuedProperty<
- ArrayProperty<StringProperty, "operand bundle tags">,
- "ArrayRef<std::string>{}",
- "SmallVector<std::string>{}"
- >:$op_bundle_tags);
+ OptionalAttr<ArrayAttr>:$op_bundle_tags);
let results = (outs Optional<LLVM_Type>:$result);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);
@@ -678,11 +674,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
- DefaultValuedProperty<
- ArrayProperty<StringProperty, "operand bundle tags">,
- "ArrayRef<std::string>{}",
- "SmallVector<std::string>{}"
- >:$op_bundle_tags);
+ OptionalAttr<ArrayAttr>:$op_bundle_tags);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
@@ -1930,11 +1922,7 @@ def LLVM_CallIntrinsicOp
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
- DefaultValuedProperty<
- ArrayProperty<StringProperty, "operand bundle tags">,
- "ArrayRef<std::string>{}",
- "SmallVector<std::string>{}"
- >:$op_bundle_tags);
+ OptionalAttr<ArrayAttr>:$op_bundle_tags);
let results = (outs Optional<LLVM_Type>:$results);
let llvmBuilder = [{
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index c40ae4b1016b49..3695708439d91f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -98,7 +98,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
overloadedOperands, traits, numResults, requiresAccessGroup,
- requiresAliasAnalysis, 0, immArgPositions, immArgAttrNames>;
+ requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
//===----------------------------------------------------------------------===//
// ROCDL special register op definitions
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 9f300bcafea537..bbb7af58d27393 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -243,6 +243,8 @@ class ModuleImport {
/// corresponding MLIR attribute names.
LogicalResult
convertIntrinsicArguments(ArrayRef<llvm::Value *> values,
+ ArrayRef<llvm::OperandBundleUse> opBundles,
+ bool requiresOpBundles,
ArrayRef<unsigned> immArgPositions,
ArrayRef<StringLiteral> immArgAttrNames,
SmallVectorImpl<Value> &valuesOut,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 12ed8cc88ae7b7..cc73878a64ff67 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -241,13 +241,18 @@ static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
static void printOpBundles(OpAsmPrinter &p, Operation *op,
OperandRangeRange opBundleOperands,
TypeRangeRange opBundleOperandTypes,
- ArrayRef<std::string> opBundleTags) {
+ std::optional<ArrayAttr> opBundleTags) {
+ if (opBundleOperands.empty())
+ return;
+ assert(opBundleTags && "expect operand bundle tags");
+
p << "[";
llvm::interleaveComma(
- llvm::zip(opBundleOperands, opBundleOperandTypes, opBundleTags), p,
+ llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
[&p](auto bundle) {
+ auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
- std::get<2>(bundle));
+ bundleTag);
});
p << "]";
}
@@ -256,7 +261,7 @@ static ParseResult parseOneOpBundle(
OpAsmParser &p,
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
SmallVector<SmallVector<Type>> &opBundleOperandTypes,
- SmallVector<std::string> &opBundleTags) {
+ SmallVector<Attribute> &opBundleTags) {
SMLoc currentParserLoc = p.getCurrentLocation();
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> types;
@@ -276,7 +281,7 @@ static ParseResult parseOneOpBundle(
opBundleOperands.push_back(std::move(operands));
opBundleOperandTypes.push_back(std::move(types));
- opBundleTags.push_back(std::move(tag));
+ opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
return success();
}
@@ -285,16 +290,17 @@ static std::optional<ParseResult> parseOpBundles(
OpAsmParser &p,
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
SmallVector<SmallVector<Type>> &opBundleOperandTypes,
- SmallVector<std::string> &opBundleTags) {
+ ArrayAttr &opBundleTags) {
if (p.parseOptionalLSquare())
return std::nullopt;
if (succeeded(p.parseOptionalRSquare()))
return success();
+ SmallVector<Attribute> opBundleTagAttrs;
auto bundleParser = [&] {
return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
- opBundleTags);
+ opBundleTagAttrs);
};
if (p.parseCommaSeparatedList(bundleParser))
return failure();
@@ -302,6 +308,8 @@ static std::optional<ParseResult> parseOpBundles(
if (p.parseRSquare())
return failure();
+ opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
+
return success();
}
@@ -1039,7 +1047,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
- /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1066,7 +1074,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
/*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr,
/*no_unwind=*/nullptr, /*will_return=*/nullptr,
- /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1079,7 +1087,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
- /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1092,7 +1100,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
- /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1192,12 +1200,20 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
template <typename OpType>
static LogicalResult verifyOperandBundles(OpType &op) {
OperandRangeRange opBundleOperands = op.getOpBundleOperands();
- ArrayRef<std::string> opBundleTags = op.getOpBundleTags();
+ std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
- if (opBundleTags.size() != opBundleOperands.size())
+ auto isStringAttr = [](Attribute tagAttr) {
+ return isa<StringAttr>(tagAttr);
+ };
+ if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
+ return op.emitError("operand bundle tag must be a StringAttr");
+
+ size_t numOpBundles = opBundleOperands.size();
+ size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
+ if (numOpBundles != numOpBundleTags)
return op.emitError("expected ")
- << opBundleOperands.size()
- << " operand bundle tags, but actually got " << opBundleTags.size();
+ << numOpBundles << " operand bundle tags, but actually got "
+ << numOpBundleTags;
return success();
}
@@ -1329,7 +1345,8 @@ void CallOp::print(OpAsmPrinter &p) {
{getCalleeAttrName(), getTailCallKindAttrName(),
getVarCalleeTypeAttrName(), getCConvAttrName(),
getOperandSegmentSizesAttrName(),
- getOpBundleSizesAttrName()});
+ getOpBundleSizesAttrName(),
+ getOpBundleTagsAttrName()});
p << " : ";
if (!isDirect)
@@ -1437,7 +1454,7 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
SmallVector<SmallVector<Type>> opBundleOperandTypes;
- SmallVector<std::string> opBundleTags;
+ ArrayAttr opBundleTags;
// Default to C Calling Convention if no keyword is provided.
result.addAttribute(
@@ -1483,9 +1500,9 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
result && failed(*result))
return failure();
- if (!opBundleTags.empty())
- result.getOrAddProperties<CallOp::Properties>().op_bundle_tags =
- std::move(opBundleTags);
+ if (opBundleTags && !opBundleTags.empty())
+ result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
+ opBundleTags);
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
@@ -1525,8 +1542,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
- normalOps, unwindOps, nullptr, nullptr, {}, std::nullopt, normal,
- unwind);
+ normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1535,7 +1551,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
ValueRange unwindOps) {
build(builder, state, tys,
/*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
- nullptr, {}, std::nullopt, normal, unwind);
+ nullptr, {}, {}, normal, unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1544,7 +1560,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
Block *unwind, ValueRange unwindOps) {
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
- nullptr, nullptr, {}, std::nullopt, normal, unwind);
+ nullptr, nullptr, {}, {}, normal, unwind);
}
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1634,7 +1650,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs(),
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
getCConvAttrName(), getVarCalleeTypeAttrName(),
- getOpBundleSizesAttrName()});
+ getOpBundleSizesAttrName(),
+ getOpBundleTagsAttrName()});
p << " : ";
if (!isDirect)
@@ -1657,7 +1674,7 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
TypeAttr varCalleeType;
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
SmallVector<SmallVector<Type>> opBundleOperandTypes;
- SmallVector<std::string> opBundleTags;
+ ArrayAttr opBundleTags;
Block *normalDest, *unwindDest;
SmallVector<Value, 4> normalOperands, unwindOperands;
Builder &builder = parser.getBuilder();
@@ -1703,9 +1720,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
result && failed(*result))
return failure();
- if (!opBundleTags.empty())
- result.getOrAddProperties<InvokeOp::Properties>().op_bundle_tags =
- std::move(opBundleTags);
+ if (opBundleTags && !opBundleTags.empty())
+ result.addAttribute(
+ InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
+ opBundleTags);
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
@@ -3333,7 +3351,7 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::StringAttr intrin, mlir::ValueRange args) {
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
FastmathFlagsAttr{},
- /*op_bundle_operands=*/{});
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
}
void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
@@ -3341,14 +3359,14 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
fastMathFlags,
- /*op_bundle_operands=*/{});
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
}
void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::Type resultType, mlir::StringAttr intrin,
mlir::ValueRange args) {
build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
- /*op_bundle_operands=*/{});
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
}
void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
@@ -3356,7 +3374,7 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::StringAttr intrin, mlir::ValueRange args,
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
build(builder, state, resultTypes, intrin, args, fastMathFlags,
- /*op_bundle_operands=*/{});
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
}
//===----------------------------------------------------------------------===//
@@ -3413,6 +3431,18 @@ void InlineAsmOp::getEffects(
}
}
+//===----------------------------------------------------------------------===//
+// AssumeOp (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
+ mlir::Value cond) {
+ return build(builder, state, cond, /*op_bundle_operands=*/{},
+ /*op_bundle_tags=*/{});
+}
+
+LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
+
//===----------------------------------------------------------------------===//
// masked_gather (intrinsic)
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index d034e576dfc579..4fd043c7c93e68 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -68,6 +68,12 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
if (isConvertibleIntrinsic(intrinsicID)) {
SmallVector<llvm::Value *> args(inst->args());
ArrayRef<llvm::Value *> llvmOperands(args);
+
+ SmallVector<llvm::OperandBundleUse> llvmOpBundles;
+ llvmOpBundles.reserve(inst->getNumOperandBundles());
+ for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i)
+ llvmOpBundles.push_back(inst->getOperandBundleAt(i));
+
#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index a8595d14ccf2e5..2084e527773ca8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -114,17 +114,27 @@ convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
}
static SmallVector<llvm::OperandBundleDef>
-convertOperandBundles(OperandRangeRange bundleOperands,
- ArrayRef<std::string> bundleTags,
+convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags,
LLVM::ModuleTranslation &moduleTranslation) {
SmallVector<llvm::OperandBundleDef> bundles;
bundles.reserve(bundleOperands.size());
- for (auto [operands, tag] : llvm::zip_equal(bundleOperands, bundleTags))
+ for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) {
+ StringRef tag = cast<StringAttr>(tagAttr).getValue();
bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
+ }
return bundles;
}
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(OperandRangeRange bundleOperands,
+ std::optional<ArrayAttr> bundleTags,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (!bundleTags)
+ return {};
+ return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
+}
+
/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index bc830a77f3c580..2c0b665ad0d833 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -50,6 +50,12 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
if (isConvertibleIntrinsic(intrinsicID)) {
SmallVector<llvm::Value *> args(inst->args());
ArrayRef<llvm::Value *> llvmOperands(args);
+
+ SmallVector<llvm::OperandBundleUse> llvmOpBundles;
+ llvmOpBundles.reserve(inst->getNumOperandBundles());
+ for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i)
+ llvmOpBundles.push_back(inst->getOperandBundleAt(i));
+
#include "mlir/Dialect/LLVMIR/NVVMFromLLVMIRConversions.inc"
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index bd861f3a69e53c..6e97b2a50af8a1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1311,7 +1311,8 @@ ModuleImport::convertValues(ArrayRef<llvm::Value *> values) {
}
LogicalResult ModuleImport::convertIntrinsicArguments(
- ArrayRef<llvm::Value *> values, ArrayRef<unsigned> immArgPositions,
+ ArrayRef<llvm::Value *> values, ArrayRef<llvm::OperandBundleUse> opBundles,
+ bool requiresOpBundles, ArrayRef<unsigned> immArgPositions,
ArrayRef<StringLiteral> immArgAttrNames, SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut) {
assert(immArgPositions.size() == immArgAttrNames.size() &&
@@ -1341,6 +1342,35 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
valuesOut.push_back(*mlirValue);
}
+ SmallVector<int> opBundleSizes;
+ SmallVector<Attribute> opBundleTagAttrs;
+ if (requiresOpBundles) {
+ opBundleSizes.reserve(opBundles.size());
+ opBundleTagAttrs.reserve(opBundles.size());
+
+ for (const llvm::OperandBundleUse &bundle : opBundles) {
+ opBundleSizes.push_back(bundle.Inputs.size());
+ opBundleTagAttrs.push_back(StringAttr::get(context, bundle.getTagName()));
+
+ for (const llvm::Use &opBundleOperand : bundle.Inputs) {
+ auto operandMlirValue = convertValue(opBundleOperand.get());
+ if (failed(operandMlirValue))
+ return failure();
+ valuesOut.push_back(*operandMlirValue);
+ }
+ }
+
+ auto opBundleSizesAttr = DenseI32ArrayAttr::get(context, opBundleSizes);
+ auto opBundleSizesAttrNameAttr =
+ StringAttr::get(context, LLVMDialect::getOpBundleSizesAttrName());
+ attrsOut.push_back({opBundleSizesAttrNameAttr, opBundleSizesAttr});
+
+ auto opBundleTagsAttr = ArrayAttr::get(context, opBundleTagAttrs);
+ auto opBundleTagsAttrNameAttr =
+ StringAttr::get(context, LLVMDialect::getOpBundleTagsAttrName());
+ attrsOut.push_back({opBundleTagsAttrNameAttr, opBundleTagsAttr});
+ }
+
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 6e005f9ec5df85..ceb8ba3b33818b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -55,6 +55,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <numeric>
#include <optional>
#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
@@ -854,8 +855,40 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
"LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
"length");
+ SmallVector<llvm::OperandBundleDef> opBundles;
+ size_t numOpBundleOperands = 0;
+ auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
+ intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
+ auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
+ intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
+
+ if (opBundleSizesAttr && opBundleTagsAttr) {
+ ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
+ assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
+ "operand bundles and tags do not match");
+
+ numOpBundleOperands =
+ std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
+ assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+ "operand bundle operands is more than the number of operands");
+
+ ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+ size_t nextOperandIdx = 0;
+ opBundles.reserve(opBundleSizesAttr.size());
+
+ for (auto [opBundleTagAttr, bundleSize] :
+ llvm::zip(opBundleTagsAttr, opBundleSizes)) {
+ auto bundleTag = cast<StringAttr>(opBundleTagAttr).str();
+ auto bundleOperands = moduleTranslation.lookupValues(
+ operands.slice(nextOperandIdx, bundleSize));
+ opBundles.emplace_back(std::move(bundleTag), std::move(bundleOperands));
+ nextOperandIdx += bundleSize;
+ }
+ }
+
// Map operands and attributes to LLVM values.
- auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
+ auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+ auto operands = moduleTranslation.lookupValues(opOperands);
SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
for (auto [immArgPos, immArgName] :
llvm::zip(immArgPositions, immArgAttrNames)) {
@@ -890,7 +923,7 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
module, intrinsic, overloadedTypes);
- return builder.CreateCall(llvmIntr, args);
+ return builder.CreateCall(llvmIntr, args, opBundles);
}
/// Given a single MLIR operation, create the corresponding LLVM IR operation
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index b86103422b0745..55b1bc9c545a85 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -684,7 +684,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
-// CHECK: "llvm.intr.assume"(%[[CMP]]) : (i1) -> ()
+// CHECK: llvm.intr.assume %[[CMP]] : i1
// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
// CHECK: return %[[VAL]] : f32
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 9dc22abf143bf0..48dc9079333d4f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -160,7 +160,7 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64
// CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
// CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
- // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> ()
+ // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
memref.assume_alignment %0, 16 : memref<4x4xf16>
return
}
@@ -177,7 +177,7 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset
// CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
// CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
// CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
- // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> ()
+ // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
return
}
diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index f9551e311df59f..0b7ca3f2bb048a 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -18,7 +18,7 @@ func.func @inner_func_inlinable(%ptr : !llvm.ptr) -> i32 {
"llvm.intr.memset"(%ptr, %byte, %0) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
"llvm.intr.memmove"(%ptr, %ptr, %0) <{isVolatile = true}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
"llvm.intr.memcpy"(%ptr, %ptr, %0) <{isVolatile = true}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- "llvm.intr.assume"(%true) : (i1) -> ()
+ llvm.intr.assume %true : i1
llvm.fence release
%2 = llvm.atomicrmw add %ptr, %0 monotonic : !llvm.ptr, i32
%3 = llvm.cmpxchg %ptr, %0, %1 acq_rel monotonic : !llvm.ptr, i32
@@ -44,7 +44,7 @@ func.func @inner_func_inlinable(%ptr : !llvm.ptr) -> i32 {
// CHECK: "llvm.intr.memset"(%[[PTR]]
// CHECK: "llvm.intr.memmove"(%[[PTR]], %[[PTR]]
// CHECK: "llvm.intr.memcpy"(%[[PTR]], %[[PTR]]
-// CHECK: "llvm.intr.assume"
+// CHECK: llvm.intr.assume
// CHECK: llvm.fence release
// CHECK: llvm.atomicrmw add %[[PTR]], %[[CST]] monotonic
// CHECK: llvm.cmpxchg %[[PTR]], %[[CST]], %[[RES]] acq_rel monotonic
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 3062cdc38c0abb..b8ce7db795a1d1 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -836,3 +836,30 @@ llvm.func @test_call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> ()
llvm.return
}
+
+// CHECK-LABEL: @test_assume_intr_no_opbundle
+llvm.func @test_assume_intr_no_opbundle(%arg0 : !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ // CHECK: llvm.intr.assume %0 : i1
+ llvm.intr.assume %0 : i1
+ llvm.return
+}
+
+// CHECK-LABEL: @test_assume_intr_empty_opbundle
+llvm.func @test_assume_intr_empty_opbundle(%arg0 : !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ // CHECK: llvm.intr.assume %0 : i1
+ llvm.intr.assume %0 [] : i1
+ llvm.return
+}
+
+// CHECK-LABEL: @test_assume_intr_with_opbundles
+llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ %1 = llvm.mlir.constant(2 : i32) : i32
+ %2 = llvm.mlir.constant(3 : i32) : i32
+ %3 = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
+ llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 28a1bd21c82a38..606b11175f572f 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -630,11 +630,21 @@ define void @va_intrinsics_test(ptr %0, ptr %1, ...) {
; CHECK-LABEL: @assume
; CHECK-SAME: %[[TRUE:[a-zA-Z0-9]+]]
define void @assume(i1 %true) {
- ; CHECK: "llvm.intr.assume"(%[[TRUE]]) : (i1) -> ()
+ ; CHECK: llvm.intr.assume %[[TRUE]] : i1
call void @llvm.assume(i1 %true)
ret void
}
+; CHECK-LABEL: @assume_with_opbundles
+; CHECK-SAME: %[[TRUE:[a-zA-Z0-9]+]]
+; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
+define void @assume_with_opbundles(i1 %true, ptr %p) {
+ ; CHECK: %[[ALIGN:.+]] = llvm.mlir.constant(8 : i32) : i32
+ ; CHECK: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i32)] : i1
+ call void @llvm.assume(i1 %true) ["align"(ptr %p, i32 8)]
+ ret void
+}
+
; CHECK-LABEL: @is_constant
; CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
define void @is_constant(i32 %0) {
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 0634a7ba907f1e..cb712eb4e1262d 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -363,6 +363,21 @@ llvm.func @umin_test(%arg0: i32, %arg1: i32, %arg2: vector<8xi32>, %arg3: vector
llvm.return
}
+// CHECK-LABEL: @assume_without_opbundles
+llvm.func @assume_without_opbundles(%cond: i1) {
+ // CHECK: call void @llvm.assume(i1 %{{.+}})
+ llvm.intr.assume %cond : i1
+ llvm.return
+}
+
+// CHECK-LABEL: @assume_with_opbundles
+llvm.func @assume_with_opbundles(%cond: i1, %p: !llvm.ptr) {
+ %0 = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: call void @llvm.assume(i1 %{{.+}}) [ "align"(ptr %{{.+}}, i32 8) ]
+ llvm.intr.assume %cond ["align"(%p, %0 : !llvm.ptr, i32)] : i1
+ llvm.return
+}
+
// CHECK-LABEL: @vector_reductions
llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi32>) {
// CHECK: call i32 @llvm.vector.reduce.add.v8i32
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index af0981440a1776..15658ea6068121 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -188,7 +188,7 @@ llvm.func @sadd_overflow_intr_wrong_type(%arg0 : i32, %arg1 : f32) -> !llvm.stru
llvm.func @assume_intr_wrong_type(%cond : i16) {
// expected-error @below{{op operand #0 must be 1-bit signless integer, but got 'i16'}}
- "llvm.intr.assume"(%cond) : (i16) -> ()
+ llvm.intr.assume %cond : i16
llvm.return
}
More information about the Mlir-commits
mailing list