[Mlir-commits] [mlir] 10fa277 - [mlir][llvm] Add branch weight op interface

Tobias Gysi llvmlistbot at llvm.org
Thu Jul 20 03:55:36 PDT 2023


Author: Tobias Gysi
Date: 2023-07-20T10:46:04Z
New Revision: 10fa27704b3165ddc4efbcf7964042b137e7fa7e

URL: https://github.com/llvm/llvm-project/commit/10fa27704b3165ddc4efbcf7964042b137e7fa7e
DIFF: https://github.com/llvm/llvm-project/commit/10fa27704b3165ddc4efbcf7964042b137e7fa7e.diff

LOG: [mlir][llvm] Add branch weight op interface

This revision adds a branch weight op interface for the call / branch
operations that support branch weights. It can be used in the LLVM IR
import and export to simplify the branch weight conversion. An
additional mapping between call operations and instructions ensures
the actual conversion can be done in the module translation itself,
rather than in the dialect translation interface. It also has the
benefit that downstream users can amend custom metadata to the call
operation during the export to LLVM IR.

Reviewed By: zero9178, definelicht

Differential Revision: https://reviews.llvm.org/D155702

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 9f230bf0be87ea..7b33ec8bb0c30e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -30,7 +30,7 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathFlagsAttr();
       }]
       >,
@@ -48,6 +48,42 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   ];
 }
 
+def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
+  let description = [{
+    An interface for operations that can carry branch weights metadata. It
+    provides setters and getters for the operation's branch weights attribute.
+    The default implementation of the interface methods expect the operation to
+    have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns the branch weights attribute or nullptr",
+      /*returnType=*/  "DenseI32ArrayAttr",
+      /*methodName=*/  "getBranchWeightsOrNull",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getBranchWeightsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Sets the branch weights attribute",
+      /*returnType=*/  "void",
+      /*methodName=*/  "setBranchWeights",
+      /*args=*/        (ins "DenseI32ArrayAttr":$attr),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setBranchWeightsAttr(attr);
+      }]
+      >
+  ];
+}
+
 def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
   let description = [{
     An interface for memory operations that can carry access groups metadata.
@@ -67,7 +103,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getAccessGroupsAttr();
       }]
       >,
@@ -78,7 +114,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setAccessGroupsAttr(attr);
       }]
       >
@@ -105,7 +141,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getAliasScopesAttr();
       }]
       >,
@@ -116,7 +152,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setAliasScopesAttr(attr);
       }]
       >,
@@ -127,7 +163,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getNoaliasScopesAttr();
       }]
       >,
@@ -138,7 +174,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setNoaliasScopesAttr(attr);
       }]
       >,
@@ -149,7 +185,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getTbaaAttr();
       }]
       >,
@@ -160,7 +196,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setTbaaAttr(attr);
       }]
       >

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 01070fe74bff9a..2d0ca913e8137b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -536,12 +536,14 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
 def LLVM_InvokeOp : LLVM_Op<"invoke", [
                       AttrSizedOperandSegments,
                       DeclareOpInterfaceMethods<BranchOpInterface>,
-                      DeclareOpInterfaceMethods<CallOpInterface>, Terminator]> {
+                      DeclareOpInterfaceMethods<CallOpInterface>,
+                      DeclareOpInterfaceMethods<BranchWeightOpInterface>,
+                      Terminator]> {
   let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
-                   OptionalAttr<ElementsAttr>:$branch_weights);
+                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
   let results = (outs Variadic<LLVM_Type>);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -582,7 +584,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
 def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                     [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
                      DeclareOpInterfaceMethods<CallOpInterface>,
-                     DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+                     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+                     DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
   let summary = "Call to an LLVM function.";
   let description = [{
     In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@@ -616,7 +619,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   Variadic<LLVM_Type>,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
-                  OptionalAttr<ElementsAttr>:$branch_weights);
+                  OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
@@ -847,12 +850,14 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
   ];
 }
 def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
-    [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<BranchOpInterface>,
+     DeclareOpInterfaceMethods<BranchWeightOpInterface>,
      Pure]> {
   let arguments = (ins I1:$condition,
                    Variadic<LLVM_Type>:$trueDestOperands,
                    Variadic<LLVM_Type>:$falseDestOperands,
-                   OptionalAttr<ElementsAttr>:$branch_weights,
+                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
                    OptionalAttr<LoopAnnotationAttr>:$loop_annotation);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
   let assemblyFormat = [{
@@ -874,7 +879,7 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
             falseOperands);
   }]>,
   OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands,
-    "ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
+    "DenseI32ArrayAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
   [{
       build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights,
       {}, trueDest, falseDest);
@@ -934,7 +939,9 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
 }
 
 def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
-    [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<BranchOpInterface>,
+     DeclareOpInterfaceMethods<BranchWeightOpInterface>,
      Pure]> {
   let arguments = (ins
     AnyInteger:$value,
@@ -942,7 +949,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
     VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
     OptionalAttr<AnyIntElementsAttr>:$case_values,
     DenseI32ArrayAttr:$case_operand_segments,
-    OptionalAttr<ElementsAttr>:$branch_weights
+    OptionalAttr<DenseI32ArrayAttr>:$branch_weights
   );
   let successors = (successor
     AnySuccessor:$defaultDestination,

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index da4d43ac9ac844..0d296aac055914 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -118,6 +118,20 @@ class ModuleTranslation {
     return branchMapping.lookup(op);
   }
 
+  /// Stores a mapping between an MLIR call operation and a corresponding LLVM
+  /// call instruction.
+  void mapCall(Operation *mlir, llvm::CallInst *llvm) {
+    auto result = callMapping.try_emplace(mlir, llvm);
+    (void)result;
+    assert(result.second && "attempting to map a call that is already mapped");
+  }
+
+  /// Finds an LLVM call instruction that corresponds to the given MLIR call
+  /// operation.
+  llvm::CallInst *lookupCall(Operation *op) const {
+    return callMapping.lookup(op);
+  }
+
   /// Removes the mapping for blocks contained in the region and values defined
   /// in these blocks.
   void forgetMapping(Region &region);
@@ -141,6 +155,9 @@ class ModuleTranslation {
   /// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
   void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
 
+  /// Sets LLVM profiling metadata for operations that have branch weights.
+  void setBranchWeightsMetadata(BranchWeightOpInterface op);
+
   /// Sets LLVM loop metadata for branch operations that have a loop annotation
   /// attribute.
   void setLoopMetadata(Operation *op, llvm::Instruction *inst);
@@ -328,6 +345,11 @@ class ModuleTranslation {
   /// values after all operations are converted.
   DenseMap<Operation *, llvm::Instruction *> branchMapping;
 
+  /// A mapping between MLIR LLVM dialect call operations and LLVM IR call
+  /// instructions. This allows for adding branch weights after the operations
+  /// have been converted.
+  DenseMap<Operation *, llvm::CallInst *> callMapping;
+
   /// Mapping from an alias scope metadata operation to its LLVM metadata.
   /// This map is populated on module entry.
   DenseMap<Attribute, llvm::MDNode *> aliasScopeMetadataMapping;

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 28e587a066e4ea..1d32e6e55f6ae4 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -553,10 +553,12 @@ class BranchConditionalConversionPattern
   matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // If branch weights exist, map them to 32-bit integer vector.
-    ElementsAttr branchWeights = nullptr;
+    DenseI32ArrayAttr branchWeights = nullptr;
     if (auto weights = op.getBranchWeights()) {
-      VectorType weightType = VectorType::get(2, rewriter.getI32Type());
-      branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
+      SmallVector<int32_t> weightValues;
+      for (auto weight : weights->getAsRange<IntegerAttr>())
+        weightValues.push_back(weight.getInt());
+      branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
     }
 
     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 8eee3b2afc8414..f4d9c95e4179fc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -310,11 +310,11 @@ void CondBrOp::build(OpBuilder &builder, OperationState &result,
                      Value condition, Block *trueDest, ValueRange trueOperands,
                      Block *falseDest, ValueRange falseOperands,
                      std::optional<std::pair<uint32_t, uint32_t>> weights) {
-  ElementsAttr weightsAttr;
+  DenseI32ArrayAttr weightsAttr;
   if (weights)
     weightsAttr =
-        builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
-                                  static_cast<int32_t>(weights->second)});
+        builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
+                                      static_cast<int32_t>(weights->second)});
 
   build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
         /*loop_annotation=*/{}, trueDest, falseDest);
@@ -330,9 +330,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
                      BlockRange caseDestinations,
                      ArrayRef<ValueRange> caseOperands,
                      ArrayRef<int32_t> branchWeights) {
-  ElementsAttr weightsAttr;
+  DenseI32ArrayAttr weightsAttr;
   if (!branchWeights.empty())
-    weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
+    weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
 
   build(builder, result, value, defaultOperands, caseOperands, caseValues,
         weightsAttr, defaultDestination, caseDestinations);

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index a6f0ebe54aac2f..40d8253d822f64 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -125,13 +125,11 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     branchWeights.push_back(branchWeight->getZExtValue());
   }
 
-  return TypeSwitch<Operation *, LogicalResult>(op)
-      .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
-        branchWeightOp.setBranchWeightsAttr(
-            builder.getI32VectorAttr(branchWeights));
-        return success();
-      })
-      .Default([](auto) { return failure(); });
+  if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
+    iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
+    return success();
+  }
+  return failure();
 }
 
 /// Searches for the attribute that maps to the given TBAA metadata `node` and

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index a044930a0cf8bd..8f7c5d8b799e27 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -124,21 +124,6 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Constructs branch weights metadata if the provided `weights` hold a value,
-/// otherwise returns nullptr.
-static llvm::MDNode *
-convertBranchWeights(std::optional<ElementsAttr> weights,
-                     LLVM::ModuleTranslation &moduleTranslation) {
-  if (!weights)
-    return nullptr;
-  SmallVector<uint32_t> weightValues;
-  weightValues.reserve(weights->size());
-  for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
-    weightValues.push_back(weight.getLimitedValue());
-  return llvm::MDBuilder(moduleTranslation.getLLVMContext())
-      .createBranchWeights(weightValues);
-}
-
 static LogicalResult
 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
@@ -182,10 +167,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                                                       callOp.getArgOperands()),
                                 operandsRef.front(), operandsRef.drop_front());
     }
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(callOp.getBranchWeights(), moduleTranslation);
-    if (branchWeights)
-      call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);
     moduleTranslation.setTBAAMetadata(callOp, call);
@@ -196,7 +177,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       return success();
     }
     // Check that LLVM call returns void for 0-result functions.
-    return success(call->getType()->isVoidTy());
+    if (!call->getType()->isVoidTy())
+      return failure();
+    moduleTranslation.mapCall(callOp, call);
+    return success();
   }
 
   if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
@@ -274,10 +258,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
           operandsRef.drop_front());
     }
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(invOp.getBranchWeights(), moduleTranslation);
-    if (branchWeights)
-      result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
     moduleTranslation.mapBranch(invOp, result);
     // InvokeOp can only have 0 or 1 result
     if (invOp->getNumResults() != 0) {
@@ -314,23 +294,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation);
     llvm::BranchInst *branch = builder.CreateCondBr(
         moduleTranslation.lookupValue(condbrOp.getOperand(0)),
         moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
-        moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
+        moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)));
     moduleTranslation.mapBranch(&opInst, branch);
     moduleTranslation.setLoopMetadata(&opInst, branch);
     return success();
   }
   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation);
     llvm::SwitchInst *switchInst = builder.CreateSwitch(
         moduleTranslation.lookupValue(switchOp.getValue()),
         moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
-        switchOp.getCaseDestinations().size(), branchWeights);
+        switchOp.getCaseDestinations().size());
 
     auto *ty = llvm::cast<llvm::IntegerType>(
         moduleTranslation.convertType(switchOp.getValue().getType()));

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d363fb8d91862d..cd3a645a18c686 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -664,6 +664,10 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
 
     if (failed(convertOperation(op, builder)))
       return failure();
+
+    // Set the branch weight metadata on the translated instruction.
+    if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
+      setBranchWeightsMetadata(iface);
   }
 
   return success();
@@ -1183,6 +1187,19 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
   inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
 }
 
+void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
+  DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
+  if (!weightsAttr)
+    return;
+
+  llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
+  assert(inst && "expected the operation to have a mapping to an instruction");
+  SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
+  inst->setMetadata(
+      llvm::LLVMContext::MD_prof,
+      llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
+}
+
 LogicalResult ModuleTranslation::createTBAAMetadata() {
   llvm::LLVMContext &ctx = llvmModule->getContext();
   llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
index 8c58d59e86d7e4..54ef71f75f528f 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
@@ -68,7 +68,7 @@ spirv.module Logical GLSL450 {
   }
 
   spirv.func @cond_branch_with_weights(%cond: i1) -> () "None" {
-    // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2
+    // CHECK: llvm.cond_br %{{.*}} weights([1, 2]), ^bb1, ^bb2
     spirv.BranchConditional %cond [1, 2], ^true, ^false
   // CHECK: ^bb1:
   ^true:

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index da4799d8a26392..09bbc5a4739657 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -874,7 +874,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) {
   // expected-error at +1 {{expects number of branch weights to match number of successors: 3 vs 2}}
   llvm.switch %arg0 : i32, ^bb1 [
     42: ^bb2(%arg0, %arg0 : i32, i32)
-  ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+  ] {branch_weights = array<i32: 13, 17, 19>}
 
 ^bb1: // pred: ^bb0
   llvm.return

diff  --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
index 688dd100f98235..cc3b47a54dfe9f 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
@@ -4,7 +4,7 @@
 define i64 @cond_br(i1 %arg1, i64 %arg2) {
 entry:
   ; CHECK: llvm.cond_br
-  ; CHECK-SAME: weights(dense<[0, 3]> : vector<2xi32>)
+  ; CHECK-SAME: weights([0, 3])
   br i1 %arg1, label %bb1, label %bb2, !prof !0
 bb1:
   ret i64 %arg2
@@ -19,7 +19,7 @@ bb2:
 ; CHECK-LABEL: @simple_switch(
 define i32 @simple_switch(i32 %arg1) {
   ; CHECK: llvm.switch
-  ; CHECK: {branch_weights = dense<[42, 3, 5]> : vector<3xi32>}
+  ; CHECK: {branch_weights = array<i32: 42, 3, 5>}
   switch i32 %arg1, label %bbd [
     i32 0, label %bb1
     i32 9, label %bb2
@@ -41,7 +41,7 @@ declare void @fn()
 
 ; CHECK-LABEL: @call_branch_weights
 define void @call_branch_weights() {
-  ; CHECK:  llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>}
+  ; CHECK:  llvm.call @fn() {branch_weights = array<i32: 42>}
   call void @fn(), !prof !0
   ret void
 }
@@ -55,7 +55,7 @@ declare i32 @__gxx_personality_v0(...)
 
 ; CHECK-LABEL: @invoke_branch_weights
 define i32 @invoke_branch_weights() personality ptr @__gxx_personality_v0 {
-  ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+  ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32: 42, 99>} : () -> ()
   invoke void @foo() to label %bb2 unwind label %bb1, !prof !0
 bb1:
   %1 = landingpad { ptr, i32 } cleanup

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 2500de25f49891..3f97ebd9aa363f 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1802,7 +1802,7 @@ llvm.func @foo() {
 // Check that branch weight attributes are exported properly as metadata.
 llvm.func @cond_br_weights(%cond : i1, %arg0 : i32,  %arg1 : i32) -> i32 {
   // CHECK: !prof ![[NODE:[0-9]+]]
-  llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
+  llvm.cond_br %cond weights([5, 10]), ^bb1, ^bb2
 ^bb1:  // pred: ^bb0
   llvm.return %arg0 : i32
 ^bb2:  // pred: ^bb0
@@ -1818,7 +1818,7 @@ llvm.func @fn()
 // CHECK-LABEL: @call_branch_weights
 llvm.func @call_branch_weights() {
   // CHECK: !prof ![[NODE:[0-9]+]]
-  llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> ()
+  llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
   llvm.return
 }
 
@@ -1833,7 +1833,7 @@ llvm.func @__gxx_personality_v0(...) -> i32
 llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: !prof ![[NODE:[0-9]+]]
-  llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+  llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32 : 42, 99>} : () -> ()
 ^bb1:  // pred: ^bb0
   %1 = llvm.landingpad cleanup : !llvm.struct<(ptr<i8>, i32)>
   llvm.br ^bb2
@@ -2062,7 +2062,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
   llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
     9: ^bb2(%1, %2 : i32, i32),
     99: ^bb3
-  ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+  ] {branch_weights = array<i32 : 13, 17, 19>}
 
 ^bb1(%3: i32):  // pred: ^bb0
   llvm.return %3 : i32


        


More information about the Mlir-commits mailing list