[Mlir-commits] [flang] [mlir] [mlir][LLVM] Add operand bundle support (PR #108933)
Sirui Mu
llvmlistbot at llvm.org
Wed Sep 25 09:38:59 PDT 2024
https://github.com/Lancern updated https://github.com/llvm/llvm-project/pull/108933
>From 668dfb2d335f3446bb0ae2e9529d1e84f5eac8ac Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Mon, 23 Sep 2024 23:19:47 +0800
Subject: [PATCH 1/8] [mlir][LLVM] Add operand bundle support
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 44 ++-
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 4 +
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 10 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 251 ++++++++++++++++--
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 68 ++++-
mlir/test/Dialect/LLVMIR/invalid.mlir | 18 +-
mlir/test/Target/LLVMIR/llvmir.mlir | 49 ++++
7 files changed, 401 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8584a25f8b3d6c..030160821bd823 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -551,7 +551,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
- DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+ DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+ VariadicOfVariadic<LLVM_Type,
+ "op_bundle_sizes">:$op_bundle_operands,
+ DenseI32ArrayAttr:$op_bundle_sizes,
+ DefaultValuedProperty<
+ ArrayProperty<StringProperty, "operand bundle tags">,
+ "ArrayRef<std::string>{}",
+ "SmallVector<std::string>{}"
+ >:$op_bundle_tags);
let results = (outs Optional<LLVM_Type>:$result);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);
@@ -607,7 +615,8 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
//===----------------------------------------------------------------------===//
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
- [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -661,8 +670,15 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
OptionalAttr<UnitAttr>:$convergent,
OptionalAttr<UnitAttr>:$no_unwind,
- OptionalAttr<UnitAttr>:$will_return
- );
+ OptionalAttr<UnitAttr>:$will_return,
+ VariadicOfVariadic<LLVM_Type,
+ "op_bundle_sizes">:$op_bundle_operands,
+ DenseI32ArrayAttr:$op_bundle_sizes,
+ DefaultValuedProperty<
+ ArrayProperty<StringProperty, "operand bundle tags">,
+ "ArrayRef<std::string>{}",
+ "SmallVector<std::string>{}"
+ >:$op_bundle_tags);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
@@ -682,6 +698,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>
];
+ let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
/// Returns the callee function type.
@@ -1895,7 +1912,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
def LLVM_CallIntrinsicOp
: LLVM_Op<"call_intrinsic",
- [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
let summary = "Call to an LLVM intrinsic function.";
let description = [{
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
@@ -1903,13 +1921,25 @@ def LLVM_CallIntrinsicOp
}];
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
- "{}">:$fastmathFlags);
+ "{}">:$fastmathFlags,
+ VariadicOfVariadic<LLVM_Type,
+ "op_bundle_sizes">:$op_bundle_operands,
+ DenseI32ArrayAttr:$op_bundle_sizes,
+ DefaultValuedProperty<
+ ArrayProperty<StringProperty, "operand bundle tags">,
+ "ArrayRef<std::string>{}",
+ "SmallVector<std::string>{}"
+ >:$op_bundle_tags);
let results = (outs Optional<LLVM_Type>:$results);
let llvmBuilder = [{
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
}];
let assemblyFormat = [{
- $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
+ $intrin `(` $args `)`
+ ( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
+ $op_bundle_tags)^ )?
+ `:` functional-type($args, $results)
+ attr-dict
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4c2e8682285c52..2cc77e8fd41b9a 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promoted, callOp->getAttrs());
+ newOp.getProperties().operandSegmentSizes = {
+ static_cast<int32_t>(promoted.size()), 0};
+ newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
+
SmallVector<Value, 4> results;
if (numResults < 2) {
// If < 2 results, packing did not do anything and we can just return.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ca786316324198..6ae607f75adbd5 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -837,8 +837,11 @@ class FunctionCallPattern
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (callOp.getNumResults() == 0) {
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
+ newOp.getProperties().operandSegmentSizes = {
+ static_cast<int32_t>(adaptor.getOperands().size()), 0};
+ newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
return success();
}
@@ -846,8 +849,11 @@ class FunctionCallPattern
auto dstType = typeConverter.convertType(callOp.getType(0));
if (!dstType)
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
+ newOp.getProperties().operandSegmentSizes = {
+ static_cast<int32_t>(adaptor.getOperands().size()), 0};
+ newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 205d7494d4378c..837e0e41800d81 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -220,6 +220,88 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
return static_cast<RetTy>(index);
}
+//===----------------------------------------------------------------------===//
+// Operand bundle helpers.
+//===----------------------------------------------------------------------===//
+
+static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
+ TypeRange operandTypes, StringRef tag) {
+ p.printString(tag);
+ p << "(";
+ p.printOperands(operands);
+ p << " : ";
+ llvm::interleaveComma(operandTypes, p);
+ p << ")";
+}
+
+static void printOpBundles(OpAsmPrinter &p, Operation *op,
+ OperandRangeRange opBundleOperands,
+ TypeRangeRange opBundleOperandTypes,
+ ArrayRef<std::string> opBundleTags) {
+ p << "[";
+ llvm::interleaveComma(
+ llvm::zip(opBundleOperands, opBundleOperandTypes, opBundleTags), p,
+ [&p](auto bundle) {
+ printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
+ std::get<2>(bundle));
+ });
+ p << "]";
+}
+
+static ParseResult parseOneOpBundle(
+ OpAsmParser &p,
+ SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
+ SmallVector<SmallVector<Type>> &opBundleOperandTypes,
+ SmallVector<std::string> &opBundleTags) {
+ auto currentParserLoc = p.getCurrentLocation();
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ SmallVector<Type> types;
+ std::string tag;
+
+ if (p.parseString(&tag))
+ return p.emitError(currentParserLoc, "expect operand bundle tag");
+
+ if (p.parseLParen())
+ return failure();
+
+ if (p.parseOperandList(operands))
+ return failure();
+ if (p.parseColon())
+ return failure();
+ if (p.parseTypeList(types))
+ return failure();
+
+ if (p.parseRParen())
+ return failure();
+
+ opBundleOperands.push_back(std::move(operands));
+ opBundleOperandTypes.push_back(std::move(types));
+ opBundleTags.push_back(std::move(tag));
+
+ return success();
+}
+
+static std::optional<ParseResult> parseOpBundles(
+ OpAsmParser &p,
+ SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
+ SmallVector<SmallVector<Type>> &opBundleOperandTypes,
+ SmallVector<std::string> &opBundleTags) {
+ if (p.parseOptionalLSquare())
+ return std::nullopt;
+
+ auto bundleParser = [&] {
+ return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
+ opBundleTags);
+ };
+ if (p.parseCommaSeparatedList(bundleParser))
+ return failure();
+
+ if (p.parseRSquare())
+ return failure();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Printing, parsing, folding and builder for LLVM::CmpOp.
//===----------------------------------------------------------------------===//
@@ -954,6 +1036,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -980,6 +1063,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
/*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr,
/*no_unwind=*/nullptr, /*will_return=*/nullptr,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -992,6 +1076,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1004,6 +1089,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+ /*op_bundle_operands=*/{}, /*op_bundle_tags=*/std::nullopt,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -1027,7 +1113,7 @@ void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
}
Operation::operand_range CallOp::getArgOperands() {
- return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+ return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
MutableOperandRange CallOp::getArgOperandsMutable() {
@@ -1100,6 +1186,27 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
return success();
}
+template <typename OpType>
+static LogicalResult verifyOperandBundles(OpType &op) {
+ OperandRangeRange opBundleOperands = op.getOpBundleOperands();
+ std::optional<ArrayRef<std::string>> opBundleTags = op.getOpBundleTags();
+
+ if (!opBundleTags.has_value()) {
+ if (!opBundleOperands.empty())
+ return op.emitError("expected operand bundle tags");
+ return success();
+ }
+
+ if (opBundleTags->size() != opBundleOperands.size())
+ return op.emitError("expected ")
+ << opBundleOperands.size()
+ << " operand bundle tags, but actually got " << opBundleTags->size();
+
+ return success();
+}
+
+LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
+
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(verifyCallOpVarCalleeType(*this)))
return failure();
@@ -1150,15 +1257,15 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Verify that the operand and result types match the callee.
if (!funcType.isVarArg() &&
- funcType.getNumParams() != (getNumOperands() - isIndirect))
+ funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
return emitOpError() << "incorrect number of operands ("
- << (getNumOperands() - isIndirect)
+ << (getCalleeOperands().size() - isIndirect)
<< ") for callee (expecting: "
<< funcType.getNumParams() << ")";
- if (funcType.getNumParams() > (getNumOperands() - isIndirect))
+ if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
return emitOpError() << "incorrect number of operands ("
- << (getNumOperands() - isIndirect)
+ << (getCalleeOperands().size() - isIndirect)
<< ") for varargs callee (expecting at least: "
<< funcType.getNumParams() << ")";
@@ -1208,16 +1315,24 @@ void CallOp::print(OpAsmPrinter &p) {
else
p << getOperand(0);
- auto args = getOperands().drop_front(isDirect ? 0 : 1);
+ auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
// Print the variadic callee type if the call is variadic.
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
p << " vararg(" << *varCalleeType << ")";
+ if (!getOpBundleOperands().empty()) {
+ p << " ";
+ printOpBundles(p, *this, getOpBundleOperands(),
+ getOpBundleOperands().getTypes(), getOpBundleTags());
+ }
+
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{getCalleeAttrName(), getTailCallKindAttrName(),
- getVarCalleeTypeAttrName(), getCConvAttrName()});
+ getVarCalleeTypeAttrName(), getCConvAttrName(),
+ getOperandSegmentSizesAttrName(),
+ getOpBundleSizesAttrName()});
p << " : ";
if (!isDirect)
@@ -1285,14 +1400,53 @@ static ParseResult parseOptionalCallFuncPtr(
return success();
}
+static ParseResult resolveOpBundleOperands(
+ OpAsmParser &parser, SMLoc loc, OperationState &state,
+ ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands,
+ ArrayRef<SmallVector<Type>> opBundleOperandTypes,
+ StringAttr opBundleSizesAttrName) {
+ assert(opBundleOperands.size() == opBundleOperandTypes.size() &&
+ "operand bundle operand groups and type groups should match");
+
+ unsigned opBundleIndex = 0;
+ for (const auto &[operands, types] :
+ llvm::zip(opBundleOperands, opBundleOperandTypes)) {
+ if (operands.size() != types.size())
+ return parser.emitError(loc, "expected ")
+ << operands.size()
+ << " types for operand bundle operands for operand bundle #"
+ << opBundleIndex << ", but actually got " << types.size();
+ if (parser.resolveOperands(operands, types, loc, state.operands))
+ return failure();
+ }
+
+ SmallVector<int32_t> opBundleSizes;
+ opBundleSizes.reserve(opBundleOperands.size());
+ for (const auto &operands : opBundleOperands) {
+ opBundleSizes.push_back(operands.size());
+ }
+
+ state.addAttribute(
+ opBundleSizesAttrName,
+ DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
+
+ return success();
+}
+
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
// ( `vararg(` var-callee-type `)` )?
+// ( `bundlearg(` ssa-use-list-list `)` )?
+// ( `bundletags(` str-elements-attr `) )
// attribute-dict? `:` (type `,`)? function-type
+// (`,` `bundletype(` type-list-list `)`)?
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
TypeAttr varCalleeType;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
+ SmallVector<SmallVector<Type>> opBundleOperandTypes;
+ SmallVector<std::string> opBundleTags;
// Default to C Calling Convention if no keyword is provided.
result.addAttribute(
@@ -1333,11 +1487,35 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
+ auto opBundlesLoc = parser.getCurrentLocation();
+ if (auto result = parseOpBundles(parser, opBundleOperands,
+ opBundleOperandTypes, opBundleTags);
+ result.has_value() && failed(*result))
+ return failure();
+ if (!opBundleTags.empty())
+ result.getOrAddProperties<CallOp::Properties>().op_bundle_tags =
+ std::move(opBundleTags);
+
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the trailing type list and resolve the operands.
- return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+ if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+ return failure();
+ if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
+ opBundleOperandTypes,
+ getOpBundleSizesAttrName(result.name)))
+ return failure();
+
+ int32_t numOpBundleOperands = 0;
+ for (const auto &operands : opBundleOperands)
+ numOpBundleOperands += operands.size();
+
+ result.addAttribute(
+ CallOp::getOperandSegmentSizeAttr(),
+ parser.getBuilder().getDenseI32ArrayAttr(
+ {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
+ return success();
}
LLVMFunctionType CallOp::getCalleeFunctionType() {
@@ -1356,7 +1534,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
- normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+ normalOps, unwindOps, nullptr, nullptr, {}, std::nullopt, normal,
+ unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1365,7 +1544,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
ValueRange unwindOps) {
build(builder, state, tys,
/*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
- nullptr, normal, unwind);
+ nullptr, {}, std::nullopt, normal, unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1374,7 +1553,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
Block *unwind, ValueRange unwindOps) {
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
- nullptr, nullptr, normal, unwind);
+ nullptr, nullptr, {}, std::nullopt, normal, unwind);
}
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1402,7 +1581,7 @@ void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
}
Operation::operand_range InvokeOp::getArgOperands() {
- return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+ return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
MutableOperandRange InvokeOp::getArgOperandsMutable() {
@@ -1423,6 +1602,9 @@ LogicalResult InvokeOp::verify() {
return emitError("first operation in unwind destination should be a "
"llvm.landingpad operation");
+ if (failed(verifyOperandBundles(*this)))
+ return failure();
+
return success();
}
@@ -1452,9 +1634,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
p << " vararg(" << *varCalleeType << ")";
+ if (!getOpBundleOperands().empty()) {
+ p << " ";
+ printOpBundles(p, *this, getOpBundleOperands(),
+ getOpBundleOperands().getTypes(), getOpBundleTags());
+ }
+
p.printOptionalAttrDict((*this)->getAttrs(),
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
- getCConvAttrName(), getVarCalleeTypeAttrName()});
+ getCConvAttrName(), getVarCalleeTypeAttrName(),
+ getOpBundleSizesAttrName()});
p << " : ";
if (!isDirect)
@@ -1468,11 +1657,17 @@ void InvokeOp::print(OpAsmPrinter &p) {
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
// ( `vararg(` var-callee-type `)` )?
+// ( `bundlearg(` ssa-use-list-list `)` )?
+// ( `bundletags(` str-elements-attr `) )
// attribute-dict? `:` (type `,`)? function-type
+// (`,` `bundletype(` type-list-list `)`)?
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
TypeAttr varCalleeType;
+ SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
+ SmallVector<SmallVector<Type>> opBundleOperandTypes;
+ SmallVector<std::string> opBundleTags;
Block *normalDest, *unwindDest;
SmallVector<Value, 4> normalOperands, unwindOperands;
Builder &builder = parser.getBuilder();
@@ -1513,22 +1708,40 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
+ auto opBundlesLoc = parser.getCurrentLocation();
+ if (auto result = parseOpBundles(parser, opBundleOperands,
+ opBundleOperandTypes, opBundleTags);
+ result.has_value() && failed(*result))
+ return failure();
+ if (!opBundleTags.empty())
+ result.getOrAddProperties<InvokeOp::Properties>().op_bundle_tags =
+ std::move(opBundleTags);
+
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the trailing type list and resolve the function operands.
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
return failure();
+ if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
+ opBundleOperandTypes,
+ getOpBundleSizesAttrName(result.name)))
+ return failure();
result.addSuccessors({normalDest, unwindDest});
result.addOperands(normalOperands);
result.addOperands(unwindOperands);
- result.addAttribute(InvokeOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr(
- {static_cast<int32_t>(operands.size()),
- static_cast<int32_t>(normalOperands.size()),
- static_cast<int32_t>(unwindOperands.size())}));
+ int32_t numOpBundleOperands = 0;
+ for (const auto &operands : opBundleOperands)
+ numOpBundleOperands += operands.size();
+
+ result.addAttribute(
+ InvokeOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
+ static_cast<int32_t>(normalOperands.size()),
+ static_cast<int32_t>(unwindOperands.size()),
+ numOpBundleOperands}));
return success();
}
@@ -3108,6 +3321,8 @@ OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
LogicalResult CallIntrinsicOp::verify() {
if (!getIntrin().starts_with("llvm."))
return emitOpError() << "intrinsic name must start with 'llvm.'";
+ if (failed(verifyOperandBundles(*this)))
+ return failure();
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index d948ff5eaf1769..53ca302518a90e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -102,6 +102,40 @@ getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id,
return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef);
}
+static llvm::OperandBundleDef
+convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ std::vector<llvm::Value *> operands;
+ operands.reserve(bundleOperands.size());
+ for (auto bundleArg : bundleOperands)
+ operands.push_back(moduleTranslation.lookupValue(bundleArg));
+ return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
+}
+
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(OperandRangeRange bundleOperands,
+ ArrayRef<std::string> bundleTags,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ assert(bundleOperands.size() == bundleTags.size() &&
+ "operand bundles and tags do not match");
+
+ SmallVector<llvm::OperandBundleDef> bundles;
+ bundles.reserve(bundleOperands.size());
+
+ for (auto [operands, tag] : llvm::zip(bundleOperands, bundleTags))
+ bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
+ return bundles;
+}
+
+static SmallVector<llvm::OperandBundleDef>
+convertOperandBundles(OperandRangeRange bundleOperands,
+ std::optional<ArrayRef<std::string>> bundleTags,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (!bundleTags.has_value())
+ bundleTags.emplace();
+ return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
+}
+
/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -138,15 +172,15 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
// Check the argument types of the call. If the function is variadic, check
// the subrange of required arguments.
if (!fn->getFunctionType()->isVarArg() &&
- op.getNumOperands() != fn->arg_size()) {
+ op.getArgs().size() != fn->arg_size()) {
return mlir::emitError(op.getLoc(), "intrinsic call has ")
- << op.getNumOperands() << " operands but " << op.getIntrinAttr()
+ << op.getArgs().size() << " operands but " << op.getIntrinAttr()
<< " expects " << fn->arg_size();
}
if (fn->getFunctionType()->isVarArg() &&
- op.getNumOperands() < fn->arg_size()) {
+ op.getArgs().size() < fn->arg_size()) {
return mlir::emitError(op.getLoc(), "intrinsic call has ")
- << op.getNumOperands() << " operands but variadic "
+ << op.getArgs().size() << " operands but variadic "
<< op.getIntrinAttr() << " expects at least " << fn->arg_size();
}
// Check the arguments up to the number the function requires.
@@ -164,8 +198,10 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
FastmathFlagsInterface itf = op;
builder.setFastMathFlags(getFastmathFlags(itf));
- auto *inst =
- builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands()));
+ auto *inst = builder.CreateCall(
+ fn, moduleTranslation.lookupValues(op.getArgs()),
+ convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
+ moduleTranslation));
if (op.getNumResults() == 1)
moduleTranslation.mapValue(op->getResults().front()) = inst;
return success();
@@ -205,17 +241,21 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
// itself. Otherwise, this is an indirect call and the callee is the first
// operand, look it up as a normal value.
if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
- auto operands = moduleTranslation.lookupValues(callOp.getOperands());
+ auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands());
+ SmallVector<llvm::OperandBundleDef> opBundles =
+ convertOperandBundles(callOp.getOpBundleOperands(),
+ callOp.getOpBundleTags(), moduleTranslation);
ArrayRef<llvm::Value *> operandsRef(operands);
llvm::CallInst *call;
if (auto attr = callOp.getCalleeAttr()) {
- call = builder.CreateCall(
- moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
+ call =
+ builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
+ operandsRef, opBundles);
} else {
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
call = builder.CreateCall(calleeType, operandsRef.front(),
- operandsRef.drop_front());
+ operandsRef.drop_front(), opBundles);
}
call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
@@ -312,13 +352,17 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
+ SmallVector<llvm::OperandBundleDef> opBundles =
+ convertOperandBundles(invOp.getOpBundleOperands(),
+ invOp.getOpBundleTags(), moduleTranslation);
ArrayRef<llvm::Value *> operandsRef(operands);
llvm::InvokeInst *result;
if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
result = builder.CreateInvoke(
moduleTranslation.lookupFunction(attr.getValue()),
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
- moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef);
+ moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
+ opBundles);
} else {
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
moduleTranslation.convertType(invOp.getCalleeFunctionType()));
@@ -326,7 +370,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
calleeType, operandsRef.front(),
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
- operandsRef.drop_front());
+ operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
moduleTranslation.mapBranch(invOp, result);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6670e4b186c397..1121691133108f 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -218,7 +218,7 @@ func.func @store_unaligned_atomic(%val : f32, %ptr : !llvm.ptr) {
func.func @invalid_call() {
// expected-error at +1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}}
- "llvm.call"() : () -> ()
+ "llvm.call"() {op_bundle_sizes = array<i32>} : () -> ()
llvm.return
}
@@ -286,7 +286,7 @@ func.func @call_non_llvm() {
func.func @call_non_llvm_arg(%arg0 : tensor<*xi32>) {
// expected-error at +1 {{'llvm.call' op operand #0 must be variadic of LLVM dialect-compatible type}}
- "llvm.call"(%arg0) : (tensor<*xi32>) -> ()
+ "llvm.call"(%arg0) {operandSegmentSizes = array<i32: 1, 0>, op_bundle_sizes = array<i32>} : (tensor<*xi32>) -> ()
llvm.return
}
@@ -1588,7 +1588,7 @@ llvm.func @variadic(...)
llvm.func @invalid_variadic_call(%arg: i32) {
// expected-error at +1 {{missing var_callee_type attribute for vararg call}}
- "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+ "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array<i32: 1, 0>, op_bundle_sizes = array<i32>} : (i32) -> ()
llvm.return
}
@@ -1598,7 +1598,7 @@ llvm.func @variadic(...)
llvm.func @invalid_variadic_call(%arg: i32) {
// expected-error at +1 {{missing var_callee_type attribute for vararg call}}
- "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+ "llvm.call"(%arg) <{callee = @variadic}> {operandSegmentSizes = array<i32: 1, 0>, op_bundle_sizes = array<i32>} : (i32) -> ()
llvm.return
}
@@ -1655,3 +1655,13 @@ llvm.func @alwaysinline_noinline() attributes { always_inline, no_inline } {
llvm.func @optnone_requires_noinline() attributes { optimize_none } {
llvm.return
}
+
+// -----
+
+llvm.func @foo()
+llvm.func @wrong_number_of_bundle_types() {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ // expected-error at +1 {{expected 1 types for operand bundle operands for operand bundle #0, but actually got 2}}
+ llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () bundletype((i32, i32))
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 966a00f9e3c675..025ff4a35a5522 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2626,3 +2626,52 @@ llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32:
llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
// CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}
+
+// -----
+
+llvm.func @foo()
+
+llvm.func @call_with_opbundle() {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.mlir.constant(2 : i32) : i32
+ %2 = llvm.mlir.constant(3 : i32) : i32
+ llvm.call @foo() ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+ llvm.return
+}
+
+// CHECK: define void @call_with_opbundle() {
+// CHECK-NEXT: call void @foo() [ "tag1"(i32 1, i32 2), "tag2"(i32 3) ]
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+
+llvm.func @__gxx_personality_v0(...) -> i32
+llvm.func @invoke_with_opbundle() attributes { personality = @__gxx_personality_v0 } {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.mlir.constant(2 : i32) : i32
+ %2 = llvm.mlir.constant(3 : i32) : i32
+ llvm.invoke @foo() to ^bb2 unwind ^bb1 ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+
+^bb1:
+ %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+ llvm.return
+
+^bb2:
+ llvm.return
+}
+
+// CHECK: define void @invoke_with_opbundle() personality ptr @__gxx_personality_v0 {
+// CHECK-NEXT: invoke void @foo() [ "tag1"(i32 1, i32 2), "tag2"(i32 3) ]
+// CHECK-NEXT: to label %{{.+}} unwind label %{{.+}}
+// CHECK: }
+
+llvm.func @call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ %1 = llvm.mlir.constant(16 : i32) : i32
+ llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> ()
+ llvm.return
+}
+
+// CHECK: define void @call_intrin_with_opbundle(ptr %0) {
+// CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr %0, i32 16) ]
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
>From 3a8fc34d90d4e5071bfc7e47bda55e78c3c0a8e2 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 20:28:32 +0800
Subject: [PATCH 2/8] resolve nits before adding more tests
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 46 +++++++------------
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 4 +-
mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +-
3 files changed, 19 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 837e0e41800d81..e7550b653d27be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -253,7 +253,7 @@ static ParseResult parseOneOpBundle(
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
SmallVector<SmallVector<Type>> &opBundleOperandTypes,
SmallVector<std::string> &opBundleTags) {
- auto currentParserLoc = p.getCurrentLocation();
+ SMLoc currentParserLoc = p.getCurrentLocation();
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> types;
std::string tag;
@@ -1189,18 +1189,12 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
template <typename OpType>
static LogicalResult verifyOperandBundles(OpType &op) {
OperandRangeRange opBundleOperands = op.getOpBundleOperands();
- std::optional<ArrayRef<std::string>> opBundleTags = op.getOpBundleTags();
+ ArrayRef<std::string> opBundleTags = op.getOpBundleTags();
- if (!opBundleTags.has_value()) {
- if (!opBundleOperands.empty())
- return op.emitError("expected operand bundle tags");
- return success();
- }
-
- if (opBundleTags->size() != opBundleOperands.size())
+ if (opBundleTags.size() != opBundleOperands.size())
return op.emitError("expected ")
<< opBundleOperands.size()
- << " operand bundle tags, but actually got " << opBundleTags->size();
+ << " operand bundle tags, but actually got " << opBundleTags.size();
return success();
}
@@ -1405,12 +1399,9 @@ static ParseResult resolveOpBundleOperands(
ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands,
ArrayRef<SmallVector<Type>> opBundleOperandTypes,
StringAttr opBundleSizesAttrName) {
- assert(opBundleOperands.size() == opBundleOperandTypes.size() &&
- "operand bundle operand groups and type groups should match");
-
unsigned opBundleIndex = 0;
for (const auto &[operands, types] :
- llvm::zip(opBundleOperands, opBundleOperandTypes)) {
+ llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) {
if (operands.size() != types.size())
return parser.emitError(loc, "expected ")
<< operands.size()
@@ -1422,9 +1413,8 @@ static ParseResult resolveOpBundleOperands(
SmallVector<int32_t> opBundleSizes;
opBundleSizes.reserve(opBundleOperands.size());
- for (const auto &operands : opBundleOperands) {
+ for (const auto &operands : opBundleOperands)
opBundleSizes.push_back(operands.size());
- }
state.addAttribute(
opBundleSizesAttrName,
@@ -1436,10 +1426,8 @@ static ParseResult resolveOpBundleOperands(
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
// ( `vararg(` var-callee-type `)` )?
-// ( `bundlearg(` ssa-use-list-list `)` )?
-// ( `bundletags(` str-elements-attr `) )
+// ( `[` op-bundles-list `]` )?
// attribute-dict? `:` (type `,`)? function-type
-// (`,` `bundletype(` type-list-list `)`)?
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
TypeAttr varCalleeType;
@@ -1487,10 +1475,10 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
- auto opBundlesLoc = parser.getCurrentLocation();
- if (auto result = parseOpBundles(parser, opBundleOperands,
- opBundleOperandTypes, opBundleTags);
- result.has_value() && failed(*result))
+ SMLoc opBundlesLoc = parser.getCurrentLocation();
+ if (std::optional<ParseResult> result = parseOpBundles(
+ parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
+ result && failed(*result))
return failure();
if (!opBundleTags.empty())
result.getOrAddProperties<CallOp::Properties>().op_bundle_tags =
@@ -1657,10 +1645,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
// ( `vararg(` var-callee-type `)` )?
-// ( `bundlearg(` ssa-use-list-list `)` )?
-// ( `bundletags(` str-elements-attr `) )
+// ( `[` op-bundles-list `]` )?
// attribute-dict? `:` (type `,`)? function-type
-// (`,` `bundletype(` type-list-list `)`)?
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
@@ -1708,10 +1694,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
- auto opBundlesLoc = parser.getCurrentLocation();
- if (auto result = parseOpBundles(parser, opBundleOperands,
- opBundleOperandTypes, opBundleTags);
- result.has_value() && failed(*result))
+ SMLoc opBundlesLoc = parser.getCurrentLocation();
+ if (std::optional<ParseResult> result = parseOpBundles(
+ parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
+ result && failed(*result))
return failure();
if (!opBundleTags.empty())
result.getOrAddProperties<InvokeOp::Properties>().op_bundle_tags =
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 53ca302518a90e..cd4e760c3b4bcf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -107,7 +107,7 @@ convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
LLVM::ModuleTranslation &moduleTranslation) {
std::vector<llvm::Value *> operands;
operands.reserve(bundleOperands.size());
- for (auto bundleArg : bundleOperands)
+ for (Value bundleArg : bundleOperands)
operands.push_back(moduleTranslation.lookupValue(bundleArg));
return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
}
@@ -131,7 +131,7 @@ static SmallVector<llvm::OperandBundleDef>
convertOperandBundles(OperandRangeRange bundleOperands,
std::optional<ArrayRef<std::string>> bundleTags,
LLVM::ModuleTranslation &moduleTranslation) {
- if (!bundleTags.has_value())
+ if (!bundleTags)
bundleTags.emplace();
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 1121691133108f..afe01d3ff89d68 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1662,6 +1662,6 @@ llvm.func @foo()
llvm.func @wrong_number_of_bundle_types() {
%0 = llvm.mlir.constant(0 : i32) : i32
// expected-error at +1 {{expected 1 types for operand bundle operands for operand bundle #0, but actually got 2}}
- llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> () bundletype((i32, i32))
+ llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> ()
llvm.return
}
>From ac42d461064eea3a428dd22410baeb4cb79e5ae0 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 20:51:38 +0800
Subject: [PATCH 3/8] parse empty operand bundles list
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +++
mlir/test/Target/LLVMIR/llvmir.mlir | 10 ++++++++++
2 files changed, 13 insertions(+)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e7550b653d27be..80f6ae7a224c8b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -289,6 +289,9 @@ static std::optional<ParseResult> parseOpBundles(
if (p.parseOptionalLSquare())
return std::nullopt;
+ if (succeeded(p.parseOptionalRSquare()))
+ return success();
+
auto bundleParser = [&] {
return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
opBundleTags);
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 025ff4a35a5522..189e541e5fc334 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2631,6 +2631,16 @@ llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 3
llvm.func @foo()
+llvm.func @call_with_empty_opbundle() {
+ llvm.call @foo() [] : () -> ()
+ llvm.return
+}
+
+// CHECK: define void @call_with_empty_opbundle() {
+// CHECK-NEXT: call void @foo()
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+
llvm.func @call_with_opbundle() {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.mlir.constant(2 : i32) : i32
>From 9acfb951dc47bffc9057b74b6ce4669e6e30a712 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:02:35 +0800
Subject: [PATCH 4/8] add test for operand bundle verifier
---
mlir/test/Dialect/LLVMIR/invalid.mlir | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index afe01d3ff89d68..9388d7ef24936e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1665,3 +1665,18 @@ llvm.func @wrong_number_of_bundle_types() {
llvm.call @foo() ["tag"(%0 : i32, i32)] : () -> ()
llvm.return
}
+
+// -----
+
+llvm.func @foo()
+llvm.func @wrong_number_of_bundle_tags() {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %1 = llvm.mlir.constant(1 : i32) : i32
+ // expected-error at +1 {{expected 2 operand bundle tags, but actually got 1}}
+ "llvm.call"(%0, %1) <{ op_bundle_tags = ["tag"] }> {
+ callee = @foo,
+ operandSegmentSizes = array<i32: 0, 2>,
+ op_bundle_sizes = array<i32: 1, 1>
+ } : (i32, i32) -> ()
+ llvm.return
+}
>From cc0e132be3b4face9a1f3d9fd9062f878bc548e6 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:29:34 +0800
Subject: [PATCH 5/8] parse empty operands within a bundle
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 +++++---------
mlir/test/Target/LLVMIR/llvmir.mlir | 10 ++++++++++
2 files changed, 15 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 80f6ae7a224c8b..4b95e5486e74ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -264,15 +264,11 @@ static ParseResult parseOneOpBundle(
if (p.parseLParen())
return failure();
- if (p.parseOperandList(operands))
- return failure();
- if (p.parseColon())
- return failure();
- if (p.parseTypeList(types))
- return failure();
-
- if (p.parseRParen())
- return failure();
+ if (p.parseOptionalRParen()) {
+ if (p.parseOperandList(operands) || p.parseColon() ||
+ p.parseTypeList(types) || p.parseRParen())
+ return failure();
+ }
opBundleOperands.push_back(std::move(operands));
opBundleOperandTypes.push_back(std::move(types));
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 189e541e5fc334..007284d0ca4435 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2641,6 +2641,16 @@ llvm.func @call_with_empty_opbundle() {
// CHECK-NEXT: ret void
// CHECK-NEXT: }
+llvm.func @call_with_empty_opbundle_operands() {
+ llvm.call @foo() ["tag"()] : () -> ()
+ llvm.return
+}
+
+// CHECK: define void @call_with_empty_opbundle_operands() {
+// CHECK-NEXT: call void @foo() [ "tag"() ]
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+
llvm.func @call_with_opbundle() {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.mlir.constant(2 : i32) : i32
>From 2290580a6edf301b92c418509ce5aaa380d7d7d4 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:32:43 +0800
Subject: [PATCH 6/8] nit: replace llvm::zip with llvm::zip_equal
---
.../Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index cd4e760c3b4bcf..78a3f1809aec31 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -116,13 +116,10 @@ static SmallVector<llvm::OperandBundleDef>
convertOperandBundles(OperandRangeRange bundleOperands,
ArrayRef<std::string> bundleTags,
LLVM::ModuleTranslation &moduleTranslation) {
- assert(bundleOperands.size() == bundleTags.size() &&
- "operand bundles and tags do not match");
-
SmallVector<llvm::OperandBundleDef> bundles;
bundles.reserve(bundleOperands.size());
- for (auto [operands, tag] : llvm::zip(bundleOperands, bundleTags))
+ for (auto [operands, tag] : llvm::zip_equal(bundleOperands, bundleTags))
bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
return bundles;
}
>From 43513981c203ef56eddfa6b10e82cb7835086ee3 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Tue, 24 Sep 2024 21:58:14 +0800
Subject: [PATCH 7/8] add roundtrip test for operand bundle syntax
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 17 +++--
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 83 ++++++++++++++++++++++
2 files changed, 94 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 4b95e5486e74ac..0561c364c7d591 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -228,9 +228,13 @@ static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
TypeRange operandTypes, StringRef tag) {
p.printString(tag);
p << "(";
- p.printOperands(operands);
- p << " : ";
- llvm::interleaveComma(operandTypes, p);
+
+ if (!operands.empty()) {
+ p.printOperands(operands);
+ p << " : ";
+ llvm::interleaveComma(operandTypes, p);
+ }
+
p << ")";
}
@@ -1611,7 +1615,7 @@ void InvokeOp::print(OpAsmPrinter &p) {
else
p << getOperand(0);
- p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
+ p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
p << " to ";
p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
p << " unwind ";
@@ -1635,8 +1639,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";
- p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
- getResultTypes());
+ p.printFunctionalType(
+ llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
+ getResultTypes());
}
// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 89d303fcac8ff2..62f1de2b7fe7d4 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -751,3 +751,86 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>,
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
llvm.return
}
+
+llvm.func @op_bundle_target()
+
+// CHECK-LABEL: @test_call_with_empty_opbundle
+llvm.func @test_call_with_empty_opbundle() {
+ // CHECK: llvm.call @op_bundle_target() : () -> ()
+ llvm.call @op_bundle_target() [] : () -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: @test_call_with_empty_opbundle_operands
+llvm.func @test_call_with_empty_opbundle_operands() {
+ // CHECK: llvm.call @op_bundle_target() ["tag"()] : () -> ()
+ llvm.call @op_bundle_target() ["tag"()] : () -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: @test_call_with_opbundle
+llvm.func @test_call_with_opbundle() {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %1 = llvm.mlir.constant(1 : i32) : i32
+ %2 = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: llvm.call @op_bundle_target() ["tag1"(%{{.+}}, %{{.+}} : i32, i32), "tag2"(%{{.+}} : i32)] : () -> ()
+ llvm.call @op_bundle_target() ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: @test_invoke_with_empty_opbundle
+llvm.func @test_invoke_with_empty_opbundle() attributes { personality = @__gxx_personality_v0 } {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.mlir.constant(2 : i32) : i32
+ %2 = llvm.mlir.constant(3 : i32) : i32
+ // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} : () -> ()
+ llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 [] : () -> ()
+
+^bb1:
+ %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+ llvm.return
+
+^bb2:
+ llvm.return
+}
+
+// CHECK-LABEL: @test_invoke_with_empty_opbundle_operands
+llvm.func @test_invoke_with_empty_opbundle_operands() attributes { personality = @__gxx_personality_v0 } {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.mlir.constant(2 : i32) : i32
+ %2 = llvm.mlir.constant(3 : i32) : i32
+ // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} ["tag"()] : () -> ()
+ llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 ["tag"()] : () -> ()
+
+^bb1:
+ %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+ llvm.return
+
+^bb2:
+ llvm.return
+}
+
+// CHECK-LABEL: @test_invoke_with_opbundle
+llvm.func @test_invoke_with_opbundle() attributes { personality = @__gxx_personality_v0 } {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.mlir.constant(2 : i32) : i32
+ %2 = llvm.mlir.constant(3 : i32) : i32
+ // CHECK: llvm.invoke @op_bundle_target() to ^{{.+}} unwind ^{{.+}} ["tag1"(%{{.+}}, %{{.+}} : i32, i32), "tag2"(%{{.+}} : i32)] : () -> ()
+ llvm.invoke @op_bundle_target() to ^bb2 unwind ^bb1 ["tag1"(%0, %1 : i32, i32), "tag2"(%2 : i32)] : () -> ()
+
+^bb1:
+ %3 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+ llvm.return
+
+^bb2:
+ llvm.return
+}
+
+// CHECK-LABEL: @test_call_intrin_with_opbundle
+llvm.func @test_call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ %1 = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: llvm.call_intrinsic "llvm.assume"(%{{.+}}) ["align"(%{{.+}}, %{{.+}} : !llvm.ptr, i32)] : (i1) -> ()
+ llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> ()
+ llvm.return
+}
>From bf4da2a7c95882011489cb39a0fe31c095ec5cc6 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Thu, 26 Sep 2024 00:38:29 +0800
Subject: [PATCH 8/8] fix flang regressions
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 39 ++++++++++++++++++++-----
1 file changed, 32 insertions(+), 7 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 88293bcf36a780..efc28e9708e197 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -110,6 +110,26 @@ static unsigned getLenParamFieldId(mlir::Type ty) {
return getTypeDescFieldId(ty) + 1;
}
+static llvm::SmallVector<mlir::NamedAttribute>
+addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
+ llvm::ArrayRef<mlir::NamedAttribute> attrs,
+ int32_t numCallOperands) {
+ llvm::SmallVector<mlir::NamedAttribute> newAttrs;
+ newAttrs.reserve(attrs.size() + 2);
+
+ for (mlir::NamedAttribute attr : attrs) {
+ if (attr.getName() != "operandSegmentSizes")
+ newAttrs.push_back(attr);
+ }
+
+ newAttrs.push_back(rewriter.getNamedAttr(
+ "operandSegmentSizes",
+ rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
+ newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
+ rewriter.getDenseI32ArrayAttr({})));
+ return newAttrs;
+}
+
namespace {
/// Lower `fir.address_of` operation to `llvm.address_of` operation.
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
@@ -229,7 +249,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
mlir::NamedAttribute attr = rewriter.getNamedAttr(
"callee", mlir::SymbolRefAttr::get(memSizeFn));
auto call = rewriter.create<mlir::LLVM::CallOp>(
- loc, ity, lenParams, llvm::ArrayRef<mlir::NamedAttribute>{attr});
+ loc, ity, lenParams,
+ addLLVMOpBundleAttrs(rewriter, {attr}, lenParams.size()));
size = call.getResult();
llvmObjectType = ::getI8Type(alloc.getContext());
} else {
@@ -559,7 +580,9 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- call, resultTys, adaptor.getOperands(), attrConvert.getAttrs());
+ call, resultTys, adaptor.getOperands(),
+ addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
+ adaptor.getOperands().size()));
return mlir::success();
}
};
@@ -980,7 +1003,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
heap->setAttr("callee", getMalloc(heap, rewriter));
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs());
+ heap, ::getLlvmPtrType(heap.getContext()), size,
+ addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 1));
return mlir::success();
}
@@ -1037,9 +1061,9 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = freemem.getLoc();
freemem->setAttr("callee", getFree(freemem, rewriter));
- rewriter.create<mlir::LLVM::CallOp>(loc, mlir::TypeRange{},
- mlir::ValueRange{adaptor.getHeapref()},
- freemem->getAttrs());
+ rewriter.create<mlir::LLVM::CallOp>(
+ loc, mlir::TypeRange{}, mlir::ValueRange{adaptor.getHeapref()},
+ addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 1));
rewriter.eraseOp(freemem);
return mlir::success();
}
@@ -2671,7 +2695,8 @@ struct FieldIndexOpConversion : public fir::FIROpConversion<fir::FieldIndexOp> {
"field", mlir::IntegerAttr::get(lowerTy().indexType(), index));
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
field, lowerTy().offsetType(), adaptor.getOperands(),
- llvm::ArrayRef<mlir::NamedAttribute>{callAttr, fieldAttr});
+ addLLVMOpBundleAttrs(rewriter, {callAttr, fieldAttr},
+ adaptor.getOperands().size()));
return mlir::success();
}
More information about the Mlir-commits
mailing list