[flang-commits] [flang] 70343c8 - [mlir][flang] Added Weighted[Region]BranchOpInterface's. (#142079)
via flang-commits
flang-commits at lists.llvm.org
Tue Jun 17 16:14:18 PDT 2025
Author: Slava Zakharin
Date: 2025-06-17T16:14:13-07:00
New Revision: 70343c8d44273c187e3f7fa5e2037fbc41307077
URL: https://github.com/llvm/llvm-project/commit/70343c8d44273c187e3f7fa5e2037fbc41307077
DIFF: https://github.com/llvm/llvm-project/commit/70343c8d44273c187e3f7fa5e2037fbc41307077.diff
LOG: [mlir][flang] Added Weighted[Region]BranchOpInterface's. (#142079)
The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.
These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.
The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.
Added:
flang/test/Fir/cfg-conversion-if.fir
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
flang/test/Fir/fir-ops.fir
flang/test/Fir/invalid.fir
mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
mlir/test/Dialect/ControlFlow/invalid.mlir
mlir/test/Dialect/ControlFlow/ops.mlir
mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
mlir/test/Target/LLVMIR/llvmir-invalid.mlir
mlir/test/Target/LLVMIR/llvmir.mlir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 90e05ce3d5ca6..27a6ca4ebdb4e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
}];
}
-def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
- NoRegionArguments]> {
+def fir_IfOp
+ : region_Op<
+ "if", [DeclareOpInterfaceMethods<
+ RegionBranchOpInterface, ["getRegionInvocationBounds",
+ "getEntrySuccessorRegions"]>,
+ RecursiveMemoryEffects, NoRegionArguments,
+ WeightedRegionBranchOpInterface]> {
let summary = "if-then-else conditional operation";
let description = [{
Used to conditionally execute operations. This operation is the FIR
@@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
```
}];
- let arguments = (ins I1:$condition);
+ let arguments = (ins I1:$condition,
+ OptionalAttr<DenseI32ArrayAttr>:$region_weights);
let results = (outs Variadic<AnyType>:$results);
let regions = (region
@@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
unsigned resultNum);
+
+ /// Returns the display name string for the region_weights attribute.
+ static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
+ return "weights";
+ }
}];
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 6181e1fad4240..ecfa2939e96a6 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
parser.resolveOperand(cond, i1Type, result.operands))
return mlir::failure();
+ if (mlir::succeeded(
+ parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
+ if (parser.parseLParen())
+ return mlir::failure();
+ mlir::DenseI32ArrayAttr weights;
+ if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
+ return mlir::failure();
+ if (weights)
+ result.addAttribute(getRegionWeightsAttrName(result.name), weights);
+ if (parser.parseRParen())
+ return mlir::failure();
+ }
+
if (parser.parseOptionalArrowTypeList(result.types))
return mlir::failure();
@@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
void fir::IfOp::print(mlir::OpAsmPrinter &p) {
bool printBlockTerminators = false;
p << ' ' << getCondition();
+ if (auto weights = getRegionWeightsAttr()) {
+ p << ' ' << getWeightsAttrAssemblyName() << '(';
+ p.printStrippedAttrOrType(weights);
+ p << ')';
+ }
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ')';
printBlockTerminators = true;
@@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
printBlockTerminators);
}
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elideAttrs=*/{getRegionWeightsAttrName()});
}
void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index 8a9e9b80134b8..3d35803e6a2d3 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -212,9 +212,12 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
}
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<mlir::cf::CondBranchOp>(
+ auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
+ llvm::ArrayRef<int32_t> weights = ifOp.getWeights();
+ if (!weights.empty())
+ branchOp.setWeights(weights);
rewriter.replaceOp(ifOp, continueBlock->getArguments());
return success();
}
diff --git a/flang/test/Fir/cfg-conversion-if.fir b/flang/test/Fir/cfg-conversion-if.fir
new file mode 100644
index 0000000000000..1e30ee8e64f02
--- /dev/null
+++ b/flang/test/Fir/cfg-conversion-if.fir
@@ -0,0 +1,46 @@
+// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
+
+func.func private @callee() -> none
+
+// CHECK-LABEL: func.func @if_then(
+// CHECK-SAME: %[[ARG0:.*]]: i1) {
+// CHECK: cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: %[[VAL_0:.*]] = fir.call @callee() : () -> none
+// CHECK: cf.br ^bb2
+// CHECK: ^bb2:
+// CHECK: return
+// CHECK: }
+func.func @if_then(%cond: i1) {
+ fir.if %cond weights([10, 90]) {
+ fir.call @callee() : () -> none
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @if_then_else(
+// CHECK-SAME: %[[ARG0:.*]]: i1) -> i32 {
+// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK: cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: cf.br ^bb3(%[[VAL_0]] : i32)
+// CHECK: ^bb2:
+// CHECK: cf.br ^bb3(%[[VAL_1]] : i32)
+// CHECK: ^bb3(%[[VAL_2:.*]]: i32):
+// CHECK: cf.br ^bb4
+// CHECK: ^bb4:
+// CHECK: return %[[VAL_2]] : i32
+// CHECK: }
+func.func @if_then_else(%cond: i1) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %result = fir.if %cond weights([90, 10]) -> i32 {
+ fir.result %c0 : i32
+ } else {
+ fir.result %c1 : i32
+ }
+ return %result : i32
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 9c444d2f4e0bc..3585bf9efca3e 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
%6 = arith.addi %2, %5 : index
return %6 : index
}
+
+// CHECK-LABEL: func.func @test_if_weights(
+// CHECK-SAME: %[[ARG0:.*]]: i1) {
+func.func @test_if_weights(%cond: i1) {
+// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK: }
+ fir.if %cond weights([99, 1]) {
+ }
+// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK: } else {
+// CHECK: }
+ fir.if %cond weights ([99,1]) {
+ } else {
+ }
+ return
+}
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index 45cae1f82cb8e..aca0ecc1abdc1 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1393,3 +1393,31 @@ fir.local {type = local_init} @x.localizer : f32 init {
^bb0(%arg0: f32, %arg1: f32):
fir.yield(%arg0 : f32)
}
+
+// -----
+
+func.func @wrong_weights_number_in_if_then(%cond: i1) {
+// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}}
+ fir.if %cond weights([50]) {
+ }
+ return
+}
+
+// -----
+
+func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
+// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}}
+ fir.if %cond weights([50, 40, 10]) {
+ } else {
+ }
+ return
+}
+
+// -----
+
+func.func @negative_weight_in_if_then(%cond: i1) {
+// expected-error @below {{weight #0 must be non-negative}}
+ fir.if %cond weights([-1, 101]) {
+ }
+ return
+}
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index 48f12b46a57f1..79da81ba049dd 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
// CondBranchOp
//===----------------------------------------------------------------------===//
-def CondBranchOp : CF_Op<"cond_br",
- [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
- Pure, Terminator]> {
+def CondBranchOp
+ : CF_Op<"cond_br", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<
+ BranchOpInterface, ["getSuccessorForOperands"]>,
+ WeightedBranchOpInterface, Pure, Terminator]> {
let summary = "Conditional branch operation";
let description = [{
The `cf.cond_br` terminator operation represents a conditional branch on a
@@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
```
}];
- let arguments = (ins I1:$condition,
- Variadic<AnyType>:$trueDestOperands,
- Variadic<AnyType>:$falseDestOperands);
+ let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
+ Variadic<AnyType>:$falseDestOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
- let builders = [
- OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
- "ValueRange":$trueOperands, "Block *":$falseDest,
- "ValueRange":$falseOperands), [{
- build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
+ let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+ "ValueRange":$trueOperands,
+ "Block *":$falseDest,
+ "ValueRange":$falseOperands),
+ [{
+ build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
falseDest);
}]>,
- OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
- "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
+ OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+ "Block *":$falseDest,
+ CArg<"ValueRange", "{}">:$falseOperands),
+ [{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];
@@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
let hasCanonicalizer = 1;
let assemblyFormat = [{
- $condition `,`
+ $condition (`weights` `(` $branch_weights^ `)` )? `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 2824f09dab6ce..138170f8c8762 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
];
}
-def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
- let description = [{
- An interface for operations that can carry branch weights metadata. It
- provides setters and getters for the operation's branch weights attribute.
- The default implementation of the interface methods expect the operation to
- have an attribute of type DenseI32ArrayAttr named branch_weights.
- }];
-
- let cppNamespace = "::mlir::LLVM";
-
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns the branch weights attribute or nullptr",
- /*returnType=*/ "::mlir::DenseI32ArrayAttr",
- /*methodName=*/ "getBranchWeightsOrNull",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
- auto op = cast<ConcreteOp>(this->getOperation());
- return op.getBranchWeightsAttr();
- }]
- >,
- InterfaceMethod<
- /*desc=*/ "Sets the branch weights attribute",
- /*returnType=*/ "void",
- /*methodName=*/ "setBranchWeights",
- /*args=*/ (ins "::mlir::DenseI32ArrayAttr":$attr),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
- auto op = cast<ConcreteOp>(this->getOperation());
- op.setBranchWeightsAttr(attr);
- }]
- >
- ];
-}
-
def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
let description = [{
An interface for memory operations that can carry access groups metadata.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 68fa620d239b9..939e7a09a73ad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
// Call-related operations.
-def LLVM_InvokeOp : LLVM_Op<"invoke", [
- AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<BranchOpInterface>,
- DeclareOpInterfaceMethods<CallOpInterface>,
- DeclareOpInterfaceMethods<BranchWeightOpInterface>,
- Terminator]> {
+def LLVM_InvokeOp
+ : LLVM_Op<"invoke", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ DeclareOpInterfaceMethods<CallOpInterface>,
+ DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
+ Terminator]> {
let arguments = (ins
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
@@ -734,12 +734,12 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
// CallOp
//===----------------------------------------------------------------------===//
-def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
- [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<FastmathFlagsInterface>,
- DeclareOpInterfaceMethods<CallOpInterface>,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>,
- DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
+def LLVM_CallOp
+ : LLVM_MemAccessOpBase<
+ "call", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+ DeclareOpInterfaceMethods<CallOpInterface>,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@@ -788,21 +788,16 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags,
- OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind,
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
- UnitAttr:$convergent,
- UnitAttr:$no_unwind,
- UnitAttr:$will_return,
+ UnitAttr:$convergent, UnitAttr:$no_unwind, UnitAttr:$will_return,
VariadicOfVariadic<LLVM_Type, "op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
OptionalAttr<ArrayAttr>:$op_bundle_tags,
OptionalAttr<DictArrayAttr>:$arg_attrs,
- OptionalAttr<DictArrayAttr>:$res_attrs,
- UnitAttr:$no_inline,
- UnitAttr:$always_inline,
- UnitAttr:$inline_hint);
+ OptionalAttr<DictArrayAttr>:$res_attrs, UnitAttr:$no_inline,
+ UnitAttr:$always_inline, UnitAttr:$inline_hint);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
@@ -1047,11 +1042,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
LLVM_TerminatorPassthroughOpBuilder
];
}
-def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
- [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<BranchOpInterface>,
- DeclareOpInterfaceMethods<BranchWeightOpInterface>,
- Pure]> {
+def LLVM_CondBrOp
+ : LLVM_TerminatorOp<
+ "cond_br", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
+ Pure]> {
let arguments = (ins I1:$condition,
Variadic<LLVM_Type>:$trueDestOperands,
Variadic<LLVM_Type>:$falseDestOperands,
@@ -1136,11 +1132,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
}];
}
-def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
- [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<BranchOpInterface>,
- DeclareOpInterfaceMethods<BranchWeightOpInterface>,
- Pure]> {
+def LLVM_SwitchOp
+ : LLVM_TerminatorOp<
+ "switch", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
+ Pure]> {
let arguments = (ins
AnySignlessInteger:$value,
Variadic<AnyType>:$defaultOperands,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 7f6967f11444f..d63800c12d132 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands);
} // namespace detail
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the branch weights attached to an operation
+/// implementing WeightedBranchOpInterface are correct.
+LogicalResult verifyBranchWeights(Operation *op);
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// WeightedRegiobBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the region weights attached to an operation
+/// implementing WeightedRegiobBranchOpInterface are correct.
+LogicalResult verifyRegionBranchWeights(Operation *op);
+} // namespace detail
+
//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 69bce78e946c8..46ab0b9ebbc6b 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -375,6 +375,118 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
];
}
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
+ let description = [{
+ This interface provides weight information for branching terminator
+ operations, i.e. terminator operations with successors.
+
+ This interface provides methods for getting/setting integer non-negative
+ weight of each branch. The probability of executing a branch
+ is computed as the ratio between the branch's weight and the total
+ sum of the weights (which cannot be zero).
+ The weights are optional. If they are provided, then their number
+ must match the number of successors of the operation.
+
+ The default implementations of the methods expect the operation
+ to have an attribute of type DenseI32ArrayAttr named branch_weights.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [InterfaceMethod<
+ /*desc=*/"Returns the branch weights",
+ /*returnType=*/"::llvm::ArrayRef<int32_t>",
+ /*methodName=*/"getWeights",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ if (auto attr = op.getBranchWeightsAttr())
+ return attr.asArrayRef();
+ return {};
+ }]>,
+ InterfaceMethod<
+ /*desc=*/"Sets the branch weights",
+ /*returnType=*/"void",
+ /*methodName=*/"setWeights",
+ /*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ op.setBranchWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
+ }]>,
+ ];
+
+ let verify = [{
+ return ::mlir::detail::verifyBranchWeights($_op);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// WeightedRegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+// TODO: the probabilities of entering a particular region seem
+// to correlate with the values returned by
+// RegionBranchOpInterface::invocationBounds(), and we should probably
+// verify that the values are consistent. In that case, should
+// WeightedRegionBranchOpInterface extend RegionBranchOpInterface?
+def WeightedRegionBranchOpInterface
+ : OpInterface<"WeightedRegionBranchOpInterface"> {
+ let description = [{
+ This interface provides weight information for region operations
+ that exhibit branching behavior between held regions.
+
+ This interface provides methods for getting/setting integer non-negative
+ weight of each branch. The probability of executing a region is computed
+ as the ratio between the region branch's weight and the total sum
+ of the weights (which cannot be zero).
+ The weights are optional. If they are provided, then their number
+ must match the number of regions held by the operation
+ (including empty regions).
+
+ The weights specify the probability of branching to a particular
+ region when first executing the operation.
+ For example, for loop-like operations with a single region
+ the weight specifies the probability of entering the loop.
+
+ The default implementations of the methods expect the operation
+ to have an attribute of type DenseI32ArrayAttr named branch_weights.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [InterfaceMethod<
+ /*desc=*/"Returns the region weights",
+ /*returnType=*/"::llvm::ArrayRef<int32_t>",
+ /*methodName=*/"getWeights",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ if (auto attr = op.getRegionWeightsAttr())
+ return attr.asArrayRef();
+ return {};
+ }]>,
+ InterfaceMethod<
+ /*desc=*/"Sets the region weights",
+ /*returnType=*/"void",
+ /*methodName=*/"setWeights",
+ /*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ op.setRegionWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
+ }]>,
+ ];
+
+ let verify = [{
+ return ::mlir::detail::verifyRegionBranchWeights($_op);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 97ae14aa0d6af..0f136c5c46d79 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -189,7 +189,7 @@ class ModuleTranslation {
llvm::Instruction *inst);
/// Sets LLVM profiling metadata for operations that have branch weights.
- void setBranchWeightsMetadata(BranchWeightOpInterface op);
+ void setBranchWeightsMetadata(WeightedBranchOpInterface op);
/// Sets LLVM loop metadata for branch operations that have a loop annotation
/// attribute.
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd003bd5b5..d31d7d801e149 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -166,10 +166,15 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
TypeRange(adaptor.getFalseDestOperands()));
if (failed(convertedFalseBlock))
return failure();
- Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
+ auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getCondition(), *convertedTrueBlock,
adaptor.getTrueDestOperands(), *convertedFalseBlock,
adaptor.getFalseDestOperands());
+ ArrayRef<int32_t> weights = op.getWeights();
+ if (!weights.empty()) {
+ newOp.setWeights(weights);
+ op.removeBranchWeightsAttr();
+ }
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(op->getAttrDictionary());
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c7528c970a4ba..a12aef0dfad38 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -589,10 +589,6 @@ LogicalResult SwitchOp::verify() {
static_cast<int64_t>(getCaseDestinations().size())))
return emitOpError("expects number of case values to match number of "
"case destinations");
- if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
- return emitError("expects number of branch weights to match number of "
- "successors: ")
- << getBranchWeights()->size() << " vs " << getNumSuccessors();
if (getCaseValues() &&
getValue().getType() != getCaseValues()->getElementType())
return emitError("expects case value type to match condition value type");
@@ -962,7 +958,6 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
assert(callee && "expected non-null callee in direct call builder");
build(builder, state, results,
/*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
- /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
@@ -992,7 +987,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), callee, args,
/*fastmathFlags=*/nullptr,
- /*branch_weights=*/nullptr, /*CConv=*/nullptr,
+ /*CConv=*/nullptr,
/*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr,
/*no_unwind=*/nullptr, /*will_return=*/nullptr,
@@ -1009,7 +1004,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType),
/*callee=*/nullptr, args,
- /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+ /*fastmathFlags=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
@@ -1025,7 +1020,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
- /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+ /*fastmathFlags=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2ae334b517a31..3a63db35eec0f 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,6 +9,7 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -80,6 +81,51 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
return success();
}
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyWeights(Operation *op,
+ llvm::ArrayRef<int32_t> weights,
+ std::size_t expectedWeightsNum,
+ llvm::StringRef weightAnchorName,
+ llvm::StringRef weightRefName) {
+ if (weights.empty())
+ return success();
+
+ if (weights.size() != expectedWeightsNum)
+ return op->emitError() << "expects number of " << weightAnchorName
+ << " weights to match number of " << weightRefName
+ << ": " << weights.size() << " vs "
+ << expectedWeightsNum;
+
+ for (auto [index, weight] : llvm::enumerate(weights))
+ if (weight < 0)
+ return op->emitError() << "weight #" << index << " must be non-negative";
+
+ if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
+ return op->emitError() << "branch weights cannot all be zero";
+
+ return success();
+}
+
+LogicalResult detail::verifyBranchWeights(Operation *op) {
+ llvm::ArrayRef<int32_t> weights =
+ cast<WeightedBranchOpInterface>(op).getWeights();
+ return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
+ "successors");
+}
+
+//===----------------------------------------------------------------------===//
+// WeightedRegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
+ llvm::ArrayRef<int32_t> weights =
+ cast<WeightedRegionBranchOpInterface>(op).getWeights();
+ return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
+}
+
//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 1b5ce868b5c77..e67aa892afe09 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -146,8 +146,15 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
branchWeights.push_back(branchWeight->getZExtValue());
}
- if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
- iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
+ if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
+ // LLVM allows attaching a single weight to call instructions.
+ // This is used for carrying the execution count information
+ // in PGO modes. MLIR WeightedBranchOpInterface does not allow this,
+ // so we drop the metadata in this case.
+ // LLVM should probably use the VP form of MD_prof metadata
+ // for such cases.
+ if (op->getNumSuccessors() != 0)
+ iface.setWeights(branchWeights);
return success();
}
return failure();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index e5ca147ea98f8..3eaa24eb5c95b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1055,7 +1055,7 @@ LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
return failure();
// Set the branch weight metadata on the translated instruction.
- if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
+ if (auto iface = dyn_cast<WeightedBranchOpInterface>(op))
setBranchWeightsMetadata(iface);
}
@@ -2026,14 +2026,15 @@ void ModuleTranslation::setDereferenceableMetadata(
inst->setMetadata(kindId, derefSizeNode);
}
-void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
- DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
- if (!weightsAttr)
+void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
+ SmallVector<uint32_t> weights;
+ llvm::transform(op.getWeights(), std::back_inserter(weights),
+ [](int32_t value) { return static_cast<uint32_t>(value); });
+ if (weights.empty())
return;
llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
assert(inst && "expected the operation to have a mapping to an instruction");
- SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
inst->setMetadata(
llvm::LLVMContext::MD_prof,
llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
index 9a0f2b7714544..7c78211d59010 100644
--- a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
+++ b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
@@ -67,3 +67,17 @@ func.func @unreachable_block() {
^bb1(%arg0: index):
cf.br ^bb1(%arg0 : index)
}
+
+// -----
+
+// Test case for cf.cond_br with weights.
+
+// CHECK-LABEL: func.func @cf_cond_br_with_weights(
+func.func @cf_cond_br_with_weights(%cond: i1, %a: index, %b: index) -> index {
+// CHECK: llvm.cond_br %{{.*}} weights([90, 10]), ^bb1(%{{.*}} : i64), ^bb2(%{{.*}} : i64)
+ cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index) {branch_weights = array<i32: 90, 10>}
+^bb1(%arg1: index):
+ return %arg1 : index
+^bb2(%arg2: index):
+ return %arg2 : index
+}
diff --git a/mlir/test/Dialect/ControlFlow/invalid.mlir b/mlir/test/Dialect/ControlFlow/invalid.mlir
index b51d8095c9974..1b8de22a9ff9f 100644
--- a/mlir/test/Dialect/ControlFlow/invalid.mlir
+++ b/mlir/test/Dialect/ControlFlow/invalid.mlir
@@ -67,3 +67,39 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) {
^bb3(%bb3arg : i32):
return
}
+
+// -----
+
+// CHECK-LABEL: func @wrong_weights_number
+func.func @wrong_weights_number(%cond: i1) {
+ // expected-error at +1 {{expects number of branch weights to match number of successors: 1 vs 2}}
+ cf.cond_br %cond weights([100]), ^bb1, ^bb2
+ ^bb1:
+ return
+ ^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_weight
+func.func @wrong_total_weight(%cond: i1) {
+ // expected-error at +1 {{weight #0 must be non-negative}}
+ cf.cond_br %cond weights([-1, 101]), ^bb1, ^bb2
+ ^bb1:
+ return
+ ^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @zero_weights
+func.func @wrong_total_weight(%cond: i1) {
+ // expected-error at +1 {{branch weights cannot all be zero}}
+ cf.cond_br %cond weights([0, 0]), ^bb1, ^bb2
+ ^bb1:
+ return
+ ^bb2:
+ return
+}
diff --git a/mlir/test/Dialect/ControlFlow/ops.mlir b/mlir/test/Dialect/ControlFlow/ops.mlir
index c9317c7613972..160534240e0fa 100644
--- a/mlir/test/Dialect/ControlFlow/ops.mlir
+++ b/mlir/test/Dialect/ControlFlow/ops.mlir
@@ -51,3 +51,13 @@ func.func @switch_result_number(%arg0: i32) {
^bb2:
return
}
+
+// CHECK-LABEL: func @cond_weights
+func.func @cond_weights(%cond: i1) {
+// CHECK: cf.cond_br %{{.*}} weights([60, 40]), ^{{.*}}, ^{{.*}}
+ cf.cond_br %cond weights([60, 40]), ^bb1, ^bb2
+ ^bb1:
+ return
+ ^bb2:
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
index cc3b47a54dfe9..c623df0b605b2 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
@@ -36,14 +36,17 @@ bbd:
; // -----
+; Verify that a single weight attached to a call is not translated.
+; The MLIR WeightedBranchOpInterface does not support this case.
+
; CHECK: llvm.func @fn()
-declare void @fn()
+declare i32 @fn()
; CHECK-LABEL: @call_branch_weights
-define void @call_branch_weights() {
- ; CHECK: llvm.call @fn() {branch_weights = array<i32: 42>}
- call void @fn(), !prof !0
- ret void
+define i32 @call_branch_weights() {
+ ; CHECK: llvm.call @fn() : () -> i32
+ %1 = call i32 @fn(), !prof !0
+ ret i32 %1
}
!0 = !{!"branch_weights", i32 42}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 24a7b42557278..a8ef401fff27e 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -448,3 +448,19 @@ llvm.mlir.global external constant @const() {addr_space = 0 : i32, dso_local} :
}
llvm.func extern_weak @extern_func()
+
+// -----
+
+llvm.func @invoke_branch_weights_callee()
+llvm.func @__gxx_personality_v0(...) -> i32
+
+llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // expected-error @below{{expects number of branch weights to match number of successors: 1 vs 2}}
+ llvm.invoke @invoke_branch_weights_callee() to ^bb2 unwind ^bb1 {branch_weights = array<i32 : 42>} : () -> ()
+^bb1: // pred: ^bb0
+ %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+ llvm.br ^bb2
+^bb2: // 2 preds: ^bb0, ^bb1
+ llvm.return %0 : i32
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7742259e7a478..fc1993b50ba2d 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1906,32 +1906,6 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
// -----
-llvm.func @fn()
-
-// CHECK-LABEL: @call_branch_weights
-llvm.func @call_branch_weights() {
- // CHECK: !prof ![[NODE:[0-9]+]]
- llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
- llvm.return
-}
-
-// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
-
-// -----
-
-llvm.func @fn() -> i32
-
-// CHECK-LABEL: @call_branch_weights
-llvm.func @call_branch_weights() {
- // CHECK: !prof ![[NODE:[0-9]+]]
- %res = llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> i32
- llvm.return
-}
-
-// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
-
-// -----
-
llvm.func @foo()
llvm.func @__gxx_personality_v0(...) -> i32
More information about the flang-commits
mailing list