[Mlir-commits] [mlir] 99d03f0 - [MLIR][LLVMDialect] Added branch weights attribute to CondBrOp
George Mitenkov
llvmlistbot at llvm.org
Fri Jul 24 00:14:15 PDT 2020
Author: George Mitenkov
Date: 2020-07-24T10:11:13+03:00
New Revision: 99d03f03919498b688a8921b2ec669057772803f
URL: https://github.com/llvm/llvm-project/commit/99d03f03919498b688a8921b2ec669057772803f
DIFF: https://github.com/llvm/llvm-project/commit/99d03f03919498b688a8921b2ec669057772803f.diff
LOG: [MLIR][LLVMDialect] Added branch weights attribute to CondBrOp
This patch introduces branch weights metadata to `llvm.cond_br` op in
LLVM Dialect. It is modelled as optional `ElementsAttr`, for example:
```
llvm.cond_br %cond weights(dense<[1, 3]> : vector<2xi32>), ^bb1, ^bb2
```
When exporting to proper LLVM, this attribute is transformed into metadata
node. The test for metadata creation is added to `../Target/llvmir.mlir`.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D83658
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f421d2e46463..5322e243427a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -514,21 +514,29 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
NoSideEffect]> {
let arguments = (ins LLVMI1:$condition,
Variadic<LLVM_Type>:$trueDestOperands,
- Variadic<LLVM_Type>:$falseDestOperands);
+ Variadic<LLVM_Type>:$falseDestOperands,
+ OptionalAttr<ElementsAttr>:$branch_weights);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = [{
- $condition `,`
+ $condition ( `weights` `(` $branch_weights^ `)` )? `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
let builders = [OpBuilder<
- "OpBuilder &builder, OperationState &result, Value condition,"
- "Block *trueDest, ValueRange trueOperands,"
- "Block *falseDest, ValueRange falseOperands", [{
- build(builder, result, condition, trueOperands, falseOperands, trueDest,
- falseDest);
+ "OpBuilder &builder, OperationState &result, Value condition,"
+ "Block *trueDest, ValueRange trueOperands,"
+ "Block *falseDest, ValueRange falseOperands,"
+ "Optional<std::pair<uint32_t, uint32_t>> weights = {}", [{
+ ElementsAttr weightsAttr;
+ if (weights) {
+ weightsAttr =
+ builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
+ static_cast<int32_t>(weights->second)});
+ }
+ build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
+ trueDest, falseDest);
}]>, OpBuilder<
"OpBuilder &builder, OperationState &result, Value condition,"
"Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 12d3d4009bed..3a70dd3932e9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -30,6 +30,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
@@ -594,9 +595,22 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
+ auto weights = condbrOp.branch_weights();
+ llvm::MDNode *branchWeights = nullptr;
+ if (weights) {
+ // Map weight attributes to LLVM metadata.
+ auto trueWeight =
+ weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
+ auto falseWeight =
+ weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
+ branchWeights =
+ llvm::MDBuilder(llvmModule->getContext())
+ .createBranchWeights(static_cast<uint32_t>(trueWeight),
+ static_cast<uint32_t>(falseWeight));
+ }
builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
blockMapping[condbrOp.getSuccessor(0)],
- blockMapping[condbrOp.getSuccessor(1)]);
+ blockMapping[condbrOp.getSuccessor(1)], branchWeights);
return success();
}
diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 566212aa124f..954b5b134541 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1252,3 +1252,17 @@ llvm.mlir.global internal constant @taker_of_address() : !llvm<"void()*"> {
%0 = llvm.mlir.addressof @address_taken : !llvm<"void()*">
llvm.return %0 : !llvm<"void()*">
}
+
+// -----
+
+// Check that branch weight attributes are exported properly as metadata.
+llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32, %arg1 : !llvm.i32) -> !llvm.i32 {
+ // CHECK: !prof ![[NODE:[0-9]+]]
+ llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
+^bb1: // pred: ^bb0
+ llvm.return %arg0 : !llvm.i32
+^bb2: // pred: ^bb0
+ llvm.return %arg1 : !llvm.i32
+}
+
+// CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}
More information about the Mlir-commits
mailing list