[Mlir-commits] [mlir] 1f4aa30 - [MLIR][SPIRVToLLVM] Branch weights support for BranchConditional conversion

George Mitenkov llvmlistbot at llvm.org
Wed Jul 29 00:12:13 PDT 2020


Author: George Mitenkov
Date: 2020-07-29T10:11:10+03:00
New Revision: 1f4aa30a4f8a3fb869baa741662b9f6b3c73a0e3

URL: https://github.com/llvm/llvm-project/commit/1f4aa30a4f8a3fb869baa741662b9f6b3c73a0e3
DIFF: https://github.com/llvm/llvm-project/commit/1f4aa30a4f8a3fb869baa741662b9f6b3c73a0e3.diff

LOG: [MLIR][SPIRVToLLVM] Branch weights support for BranchConditional conversion

Conversion of `spv.BranchConditional` now supports branch weights
that are mapped to weights vector in `llvm.cond_br`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D84657

Added: 
    

Modified: 
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
    mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 161fda0fc353..25a3ac07d5f4 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -459,13 +459,18 @@ class BranchConditionalConversionPattern
   LogicalResult
   matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    // There is no support of branch weights in LLVM dialect at the moment.
-    if (auto weights = op.branch_weights())
-      return failure();
+    // If branch weights exist, map them to 32-bit integer vector.
+    ElementsAttr branchWeights = nullptr;
+    if (auto weights = op.branch_weights()) {
+      VectorType weightType = VectorType::get(2, rewriter.getI32Type());
+      branchWeights =
+          DenseElementsAttr::get(weightType, weights.getValue().getValue());
+    }
 
     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
-        op, op.condition(), op.getTrueBlock(), op.getTrueBlockArguments(),
-        op.getFalseBlock(), op.getFalseBlockArguments());
+        op, op.condition(), op.getTrueBlockArguments(),
+        op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
+        op.getFalseBlock());
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
index 0a2f6d681769..3c92040a17ed 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
@@ -66,16 +66,14 @@ spv.module Logical GLSL450 {
   ^inner_false(%arg3: i32, %arg4: i32):
     spv.Return
   }
-}
-
-// -----
 
-spv.module Logical GLSL450 {
   spv.func @cond_branch_with_weights(%cond: i1) -> () "None" {
-    // expected-error at +1 {{failed to legalize operation 'spv.BranchConditional' that was explicitly marked illegal}}
+    // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2
     spv.BranchConditional %cond [1, 2], ^true, ^false
+  // CHECK: ^bb1:
   ^true:
     spv.Return
+  // CHECK: ^bb2:
   ^false:
     spv.Return
   }


        


More information about the Mlir-commits mailing list