[Mlir-commits] [mlir] 51e36f2 - [mlir][llvm] Add branch weights to call and invoke
Tobias Gysi
llvmlistbot at llvm.org
Mon Jan 9 01:27:09 PST 2023
Author: Christian Ulmann
Date: 2023-01-09T10:25:07+01:00
New Revision: 51e36f217f868183059aa72e2f69d1f8492330e3
URL: https://github.com/llvm/llvm-project/commit/51e36f217f868183059aa72e2f69d1f8492330e3
DIFF: https://github.com/llvm/llvm-project/commit/51e36f217f868183059aa72e2f69d1f8492330e3.diff
LOG: [mlir][llvm] Add branch weights to call and invoke
This commit introduces branch weight attributes to the LLVM::CallOp and
LLVM::InvokeOp and adds both import and export of them.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D141122
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/Import/profiling-metadata.ll
mlir/test/Target/LLVMIR/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ca4cc4853543d..bfb222625e25f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -483,7 +483,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
Variadic<LLVM_Type>:$normalDestOperands,
- Variadic<LLVM_Type>:$unwindDestOperands);
+ Variadic<LLVM_Type>:$unwindDestOperands,
+ OptionalAttr<ElementsAttr>:$branch_weights);
let results = (outs Variadic<LLVM_Type>);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);
@@ -500,7 +501,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
"ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps),
[{
build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
- unwindOps, normal, unwind);
+ unwindOps, nullptr, normal, unwind);
}]>];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
@@ -553,13 +554,16 @@ def LLVM_CallOp : LLVM_Op<"call",
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
- "{}">:$fastmathFlags);
+ "{}">:$fastmathFlags,
+ OptionalAttr<ElementsAttr>:$branch_weights);
let results = (outs Optional<LLVM_Type>:$result);
let builders = [
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
+ OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
+ CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>
];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 58abf02d7a3ed..c32ca2b8f5a0b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1156,7 +1156,13 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringAttr callee, ValueRange args) {
- build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr);
+ build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr,
+ nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
+ FlatSymbolRefAttr callee, ValueRange args) {
+ build(builder, state, results, callee, args, nullptr, nullptr);
}
void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
@@ -1165,7 +1171,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
Type resultType = func.getFunctionType().getReturnType();
if (!resultType.isa<LLVM::LLVMVoidType>())
results.push_back(resultType);
- build(builder, state, results, SymbolRefAttr::get(func), args, nullptr);
+ build(builder, state, results, SymbolRefAttr::get(func), args, nullptr,
+ nullptr);
}
CallInterfaceCallable CallOp::getCallableForCallee() {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index aed832173ef02..24e5ab37f8d0e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
@@ -116,15 +117,13 @@ static LogicalResult setProfilingAttrs(OpBuilder &builder, llvm::MDNode *node,
}
// Attach the branch weights to the operations that support it.
- if (auto condBrOp = dyn_cast<CondBrOp>(op)) {
- condBrOp.setBranchWeightsAttr(builder.getI32VectorAttr(branchWeights));
- return success();
- }
- if (auto switchOp = dyn_cast<SwitchOp>(op)) {
- switchOp.setBranchWeightsAttr(builder.getI32VectorAttr(branchWeights));
- return success();
- }
- return failure();
+ return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
+ branchWeightOp.setBranchWeightsAttr(
+ builder.getI32VectorAttr(branchWeights));
+ return success();
+ })
+ .Default([](auto) { return failure(); });
}
/// Attaches the given TBAA metadata `node` to the imported operation.
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 697121e380c00..7f44db50f820f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -322,6 +322,21 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp &op, llvm::IRBuilderBase &builder,
return success();
}
+/// Constructs branch weights metadata if the provided `weights` hold a value,
+/// otherwise returns nullptr.
+static llvm::MDNode *
+convertBranchWeights(std::optional<ElementsAttr> weights,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (!weights)
+ return nullptr;
+ SmallVector<uint32_t> weightValues;
+ weightValues.reserve(weights->size());
+ for (APInt weight : weights->cast<DenseIntElementsAttr>())
+ weightValues.push_back(weight.getLimitedValue());
+ return llvm::MDBuilder(moduleTranslation.getLLVMContext())
+ .createBranchWeights(weightValues);
+}
+
static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -336,32 +351,34 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
// Emit function calls. If the "callee" attribute is present, this is a
// direct function call and we also need to look up the remapped function
// itself. Otherwise, this is an indirect call and the callee is the first
- // operand, look it up as a normal value. Return the llvm::Value
- // representing the function result, which may be of llvm::VoidTy type.
- auto convertCall = [&](Operation &op) -> llvm::Value * {
- auto operands = moduleTranslation.lookupValues(op.getOperands());
+ // operand, look it up as a normal value.
+ if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
+ auto operands = moduleTranslation.lookupValues(callOp.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
- if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
- return builder.CreateCall(
+ llvm::CallInst *call;
+ if (auto attr = callOp.getCalleeAttr()) {
+ call = builder.CreateCall(
moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
- auto calleeType =
- op.getOperands().front().getType().cast<LLVMPointerType>();
- auto *calleeFunctionType = cast<llvm::FunctionType>(
- moduleTranslation.convertType(calleeType.getElementType()));
- return builder.CreateCall(calleeFunctionType, operandsRef.front(),
- operandsRef.drop_front());
- };
-
- // Emit calls. If the called function has a result, remap the corresponding
- // value. Note that LLVM IR dialect CallOp has either 0 or 1 result.
- if (isa<LLVM::CallOp>(opInst)) {
- llvm::Value *result = convertCall(opInst);
+ } else {
+ auto calleeType =
+ callOp->getOperands().front().getType().cast<LLVMPointerType>();
+ auto *calleeFunctionType = cast<llvm::FunctionType>(
+ moduleTranslation.convertType(calleeType.getElementType()));
+ call = builder.CreateCall(calleeFunctionType, operandsRef.front(),
+ operandsRef.drop_front());
+ }
+ llvm::MDNode *branchWeights =
+ convertBranchWeights(callOp.getBranchWeights(), moduleTranslation);
+ if (branchWeights)
+ call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
+ // If the called function has a result, remap the corresponding value. Note
+ // that LLVM IR dialect CallOp has either 0 or 1 result.
if (opInst.getNumResults() != 0) {
- moduleTranslation.mapValue(opInst.getResult(0), result);
+ moduleTranslation.mapValue(opInst.getResult(0), call);
return success();
}
// Check that LLVM call returns void for 0-result functions.
- return success(result->getType()->isVoidTy());
+ return success(call->getType()->isVoidTy());
}
if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
@@ -442,6 +459,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
operandsRef.drop_front());
}
+ llvm::MDNode *branchWeights =
+ convertBranchWeights(invOp.getBranchWeights(), moduleTranslation);
+ if (branchWeights)
+ result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
if (invOp->getNumResults() != 0) {
@@ -478,17 +499,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
- llvm::MDNode *branchWeights = nullptr;
- if (auto weights = condbrOp.getBranchWeights()) {
- // Map weight attributes to LLVM metadata.
- auto weightValues = weights->getValues<APInt>();
- auto trueWeight = weightValues[0].getSExtValue();
- auto falseWeight = weightValues[1].getSExtValue();
- branchWeights =
- llvm::MDBuilder(moduleTranslation.getLLVMContext())
- .createBranchWeights(static_cast<uint32_t>(trueWeight),
- static_cast<uint32_t>(falseWeight));
- }
+ llvm::MDNode *branchWeights =
+ convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation);
llvm::BranchInst *branch = builder.CreateCondBr(
moduleTranslation.lookupValue(condbrOp.getOperand(0)),
moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
@@ -498,16 +510,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
- llvm::MDNode *branchWeights = nullptr;
- if (auto weights = switchOp.getBranchWeights()) {
- llvm::SmallVector<uint32_t> weightValues;
- weightValues.reserve(weights->size());
- for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
- weightValues.push_back(weight.getLimitedValue());
- branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext())
- .createBranchWeights(weightValues);
- }
-
+ llvm::MDNode *branchWeights =
+ convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation);
llvm::SwitchInst *switchInst = builder.CreateSwitch(
moduleTranslation.lookupValue(switchOp.getValue()),
moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
diff --git a/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll b/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll
index 402271a6a2b51..70a66b645e5a7 100644
--- a/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll
+++ b/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll
@@ -33,3 +33,36 @@ bbd:
}
!0 = !{!"branch_weights", i32 42, i32 3, i32 5}
+
+; // -----
+
+; CHECK: llvm.func @fn()
+declare void @fn()
+
+; CHECK-LABEL: @call_branch_weights
+define void @call_branch_weights() {
+ ; CHECK: llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>}
+ call void @fn(), !prof !0
+ ret void
+}
+
+!0 = !{!"branch_weights", i32 42}
+
+; // -----
+
+declare void @foo()
+declare i32 @__gxx_personality_v0(...)
+
+; CHECK-LABEL: @invoke_branch_weights
+define i32 @invoke_branch_weights() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
+ ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+ invoke void @foo() to label %bb2 unwind label %bb1, !prof !0
+bb1:
+ %1 = landingpad { i8*, i32 } cleanup
+ br label %bb2
+bb2:
+ ret i32 1
+
+}
+
+!0 = !{!"branch_weights", i32 42, i32 99}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7e65b879323aa..eb6738e1497e4 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1631,6 +1631,38 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
// -----
+llvm.func @fn()
+
+// CHECK-LABEL: @call_branch_weights
+llvm.func @call_branch_weights() {
+ // CHECK: !prof ![[NODE:[0-9]+]]
+ llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> ()
+ llvm.return
+}
+
+// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
+
+// -----
+
+llvm.func @foo()
+llvm.func @__gxx_personality_v0(...) -> i32
+
+// CHECK-LABEL: @invoke_branch_weights
+llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: !prof ![[NODE:[0-9]+]]
+ llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+^bb1: // pred: ^bb0
+ %1 = llvm.landingpad cleanup : !llvm.struct<(ptr<i8>, i32)>
+ llvm.br ^bb2
+^bb2: // 2 preds: ^bb0, ^bb1
+ llvm.return %0 : i32
+}
+
+// CHECK: ![[NODE]] = !{!"branch_weights", i32 42, i32 99}
+
+// -----
+
llvm.func @volatile_store_and_load() {
%val = llvm.mlir.constant(5 : i32) : i32
%size = llvm.mlir.constant(1 : i64) : i64
More information about the Mlir-commits
mailing list