[Mlir-commits] [mlir] 99d03f0 - [MLIR][LLVMDialect] Added branch weights attribute to CondBrOp

George Mitenkov llvmlistbot at llvm.org
Fri Jul 24 00:14:15 PDT 2020


Author: George Mitenkov
Date: 2020-07-24T10:11:13+03:00
New Revision: 99d03f03919498b688a8921b2ec669057772803f

URL: https://github.com/llvm/llvm-project/commit/99d03f03919498b688a8921b2ec669057772803f
DIFF: https://github.com/llvm/llvm-project/commit/99d03f03919498b688a8921b2ec669057772803f.diff

LOG: [MLIR][LLVMDialect] Added branch weights attribute to CondBrOp

This patch introduces branch weights metadata to `llvm.cond_br` op in
LLVM Dialect. It is modelled as optional `ElementsAttr`, for example:
```
llvm.cond_br %cond weights(dense<[1, 3]> : vector<2xi32>), ^bb1, ^bb2
```
When exporting to proper LLVM, this attribute is transformed into metadata
node. The test for metadata creation is added to `../Target/llvmir.mlir`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D83658

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f421d2e46463..5322e243427a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -514,21 +514,29 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
      NoSideEffect]> {
   let arguments = (ins LLVMI1:$condition,
                    Variadic<LLVM_Type>:$trueDestOperands,
-                   Variadic<LLVM_Type>:$falseDestOperands);
+                   Variadic<LLVM_Type>:$falseDestOperands,
+                   OptionalAttr<ElementsAttr>:$branch_weights);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
   let assemblyFormat = [{
-    $condition `,`
+    $condition ( `weights` `(` $branch_weights^ `)` )? `,`
     $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
     $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
     attr-dict
   }];
 
   let builders = [OpBuilder<
-    "OpBuilder &builder, OperationState &result, Value condition,"
-    "Block *trueDest, ValueRange trueOperands,"
-    "Block *falseDest, ValueRange falseOperands", [{
-      build(builder, result, condition, trueOperands, falseOperands, trueDest,
-            falseDest);
+     "OpBuilder &builder, OperationState &result, Value condition,"
+     "Block *trueDest, ValueRange trueOperands,"
+     "Block *falseDest, ValueRange falseOperands,"
+     "Optional<std::pair<uint32_t, uint32_t>> weights = {}", [{
+        ElementsAttr weightsAttr;
+        if (weights) {
+          weightsAttr =
+              builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
+                                       static_cast<int32_t>(weights->second)});
+        }
+        build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
+              trueDest, falseDest);
   }]>, OpBuilder<
     "OpBuilder &builder, OperationState &result, Value condition,"
     "Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 12d3d4009bed..3a70dd3932e9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -30,6 +30,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
@@ -594,9 +595,22 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
+    auto weights = condbrOp.branch_weights();
+    llvm::MDNode *branchWeights = nullptr;
+    if (weights) {
+      // Map weight attributes to LLVM metadata.
+      auto trueWeight =
+          weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
+      auto falseWeight =
+          weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
+      branchWeights =
+          llvm::MDBuilder(llvmModule->getContext())
+              .createBranchWeights(static_cast<uint32_t>(trueWeight),
+                                   static_cast<uint32_t>(falseWeight));
+    }
     builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
                          blockMapping[condbrOp.getSuccessor(0)],
-                         blockMapping[condbrOp.getSuccessor(1)]);
+                         blockMapping[condbrOp.getSuccessor(1)], branchWeights);
     return success();
   }
 

diff  --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 566212aa124f..954b5b134541 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1252,3 +1252,17 @@ llvm.mlir.global internal constant @taker_of_address() : !llvm<"void()*"> {
   %0 = llvm.mlir.addressof @address_taken : !llvm<"void()*">
   llvm.return %0 : !llvm<"void()*">
 }
+
+// -----
+
+// Check that branch weight attributes are exported properly as metadata.
+llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32,  %arg1 : !llvm.i32) -> !llvm.i32 {
+  // CHECK: !prof ![[NODE:[0-9]+]]
+  llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
+^bb1:  // pred: ^bb0
+  llvm.return %arg0 : !llvm.i32
+^bb2:  // pred: ^bb0
+  llvm.return %arg1 : !llvm.i32
+}
+
+// CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}


        


More information about the Mlir-commits mailing list