[llvm] b3cb950 - [PGO]Implement metadata combine for 'branch_weights' of direct

Mingming Liu via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 27 13:05:27 PDT 2023


Author: Mingming Liu
Date: 2023-04-27T13:04:17-07:00
New Revision: b3cb950cf3d162459a31a9ab533af0f403f344eb

URL: https://github.com/llvm/llvm-project/commit/b3cb950cf3d162459a31a9ab533af0f403f344eb
DIFF: https://github.com/llvm/llvm-project/commit/b3cb950cf3d162459a31a9ab533af0f403f344eb.diff

LOG: [PGO]Implement metadata combine for 'branch_weights' of direct
callsites when none of the instructions folds the rest away.

- Merge cases are added for simplify-cfg {sink,hoist}, based on https://gcc.godbolt.org/z/avGvc38W7 and https://gcc.godbolt.org/z/dbWbjGhaE
- When one instruction folds the others in, do not update branch_weights
  with sum (see test/Transforms/GVN/calls-readonly.ll)

Differential Revision: https://reviews.llvm.org/D148877

Added: 
    llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll
    llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll

Modified: 
    llvm/include/llvm/IR/Metadata.h
    llvm/lib/IR/Metadata.cpp
    llvm/lib/Transforms/Utils/Local.cpp
    llvm/test/Transforms/GVN/calls-readonly.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Metadata.h b/llvm/include/llvm/IR/Metadata.h
index 262a148039a94..1a584f2ba599b 100644
--- a/llvm/include/llvm/IR/Metadata.h
+++ b/llvm/include/llvm/IR/Metadata.h
@@ -1274,6 +1274,11 @@ class MDNode : public Metadata {
   template <class NodeTy>
   static void dispatchResetHash(NodeTy *, std::false_type) {}
 
+  /// Merge branch weights from two direct callsites.
+  static MDNode *mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
+                                             const Instruction *AInstr,
+                                             const Instruction *BInstr);
+
 public:
   using op_iterator = const MDOperand *;
   using op_range = iterator_range<op_iterator>;
@@ -1319,6 +1324,11 @@ class MDNode : public Metadata {
   static MDNode *getMostGenericRange(MDNode *A, MDNode *B);
   static MDNode *getMostGenericAliasScope(MDNode *A, MDNode *B);
   static MDNode *getMostGenericAlignmentOrDereferenceable(MDNode *A, MDNode *B);
+  /// Merge !prof metadata from two instructions.
+  /// Currently only implemented with direct callsites with branch weights.
+  static MDNode *getMergedProfMetadata(MDNode *A, MDNode *B,
+                                       const Instruction *AInstr,
+                                       const Instruction *BInstr);
 };
 
 /// Tuple of metadata.

diff  --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index cfcfcd762fdc3..6ffeec1f21d33 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -1072,6 +1072,70 @@ MDNode *MDNode::getMostGenericFPMath(MDNode *A, MDNode *B) {
   return B;
 }
 
+// Call instructions with branch weights are only used in SamplePGO as
+// documented in
+/// https://llvm.org/docs/BranchWeightMetadata.html#callinst).
+MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
+                                            const Instruction *AInstr,
+                                            const Instruction *BInstr) {
+  assert(A && B && AInstr && BInstr && "Caller should guarantee");
+  auto &Ctx = AInstr->getContext();
+  MDBuilder MDHelper(Ctx);
+
+  // LLVM IR verifier verifies !prof metadata has at least 2 operands.
+  assert(A->getNumOperands() >= 2 && B->getNumOperands() >= 2 &&
+         "!prof annotations should have no less than 2 operands");
+  MDString *AMDS = dyn_cast<MDString>(A->getOperand(0));
+  MDString *BMDS = dyn_cast<MDString>(B->getOperand(0));
+  // LLVM IR verfier verifies first operand is MDString.
+  assert(AMDS != nullptr && BMDS != nullptr &&
+         "first operand should be a non-null MDString");
+  StringRef AProfName = AMDS->getString();
+  StringRef BProfName = BMDS->getString();
+  if (AProfName.equals("branch_weights") &&
+      BProfName.equals("branch_weights")) {
+    ConstantInt *AInstrWeight =
+        mdconst::dyn_extract<ConstantInt>(A->getOperand(1));
+    ConstantInt *BInstrWeight =
+        mdconst::dyn_extract<ConstantInt>(B->getOperand(1));
+    assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier");
+    return MDNode::get(Ctx,
+                       {MDHelper.createString("branch_weights"),
+                        MDHelper.createConstant(ConstantInt::get(
+                            Type::getInt64Ty(Ctx),
+                            SaturatingAdd(AInstrWeight->getZExtValue(),
+                                          BInstrWeight->getZExtValue())))});
+  }
+  return nullptr;
+}
+
+// Pass in both instructions and nodes. Instruction information (e.g.,
+// instruction type) helps interpret profiles and make implementation clearer.
+MDNode *MDNode::getMergedProfMetadata(MDNode *A, MDNode *B,
+                                      const Instruction *AInstr,
+                                      const Instruction *BInstr) {
+  if (!(A && B)) {
+    return A ? A : B;
+  }
+
+  assert(AInstr->getMetadata(LLVMContext::MD_prof) == A &&
+         "Caller should guarantee");
+  assert(BInstr->getMetadata(LLVMContext::MD_prof) == B &&
+         "Caller should guarantee");
+
+  const CallInst *ACall = dyn_cast<CallInst>(AInstr);
+  const CallInst *BCall = dyn_cast<CallInst>(BInstr);
+
+  // Both ACall and BCall are direct callsites.
+  if (ACall && BCall && ACall->getCalledFunction() &&
+      BCall->getCalledFunction())
+    return mergeDirectCallProfMetadata(A, B, AInstr, BInstr);
+
+  // The rest of the cases are not implemented but could be added
+  // when there are use cases.
+  return nullptr;
+}
+
 static bool isContiguous(const ConstantRange &A, const ConstantRange &B) {
   return A.getUpper() == B.getLower() || A.getLower() == B.getUpper();
 }

diff  --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 53d1f8b62d1b8..0f5d2ce841f19 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -2709,6 +2709,10 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
         // Preserve !nontemporal if it is present on both instructions.
         K->setMetadata(Kind, JMD);
         break;
+      case LLVMContext::MD_prof:
+        if (DoesKMove)
+          K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J));
+        break;
     }
   }
   // Set !invariant.group from J if J has it. If both instructions have it
@@ -2737,6 +2741,7 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J,
                          LLVMContext::MD_dereferenceable_or_null,
                          LLVMContext::MD_access_group,
                          LLVMContext::MD_preserve_access_index,
+                         LLVMContext::MD_prof,
                          LLVMContext::MD_nontemporal,
                          LLVMContext::MD_noundef};
   combineMetadata(K, J, KnownIDs, KDominatesJ);

diff  --git a/llvm/test/Transforms/GVN/calls-readonly.ll b/llvm/test/Transforms/GVN/calls-readonly.ll
index 5c24740c881b4..b4855e41a64f5 100644
--- a/llvm/test/Transforms/GVN/calls-readonly.ll
+++ b/llvm/test/Transforms/GVN/calls-readonly.ll
@@ -6,7 +6,7 @@ target triple = "i386-apple-darwin7"
 
 define ptr @test(ptr %P, ptr %Q, i32 %x, i32 %y) nounwind readonly {
 entry:
-  %0 = tail call i32 @strlen(ptr %P)              ; <i32> [#uses=2]
+  %0 = tail call i32 @strlen(ptr %P), !prof !0    ; <i32> [#uses=2]
   %1 = icmp eq i32 %0, 0                          ; <i1> [#uses=1]
   br i1 %1, label %bb, label %bb1
 
@@ -17,7 +17,7 @@ bb:                                               ; preds = %entry
 bb1:                                              ; preds = %bb, %entry
   %x_addr.0 = phi i32 [ %2, %bb ], [ %x, %entry ] ; <i32> [#uses=1]
   %3 = tail call ptr @strchr(ptr %Q, i32 97)      ; <ptr> [#uses=1]
-  %4 = tail call i32 @strlen(ptr %P)              ; <i32> [#uses=1]
+  %4 = tail call i32 @strlen(ptr %P) , !prof !1   ; <i32> [#uses=1]
   %5 = add i32 %x_addr.0, %0                      ; <i32> [#uses=1]
   %.sum = sub i32 %5, %4                          ; <i32> [#uses=1]
   %6 = getelementptr i8, ptr %3, i32 %.sum            ; <ptr> [#uses=1]
@@ -26,7 +26,7 @@ bb1:                                              ; preds = %bb, %entry
 
 ; CHECK: define ptr @test(ptr %P, ptr %Q, i32 %x, i32 %y) #0 {
 ; CHECK: entry:
-; CHECK-NEXT:   %0 = tail call i32 @strlen(ptr %P)
+; CHECK-NEXT:   %0 = tail call i32 @strlen(ptr %P), !prof !0
 ; CHECK-NEXT:   %1 = icmp eq i32 %0, 0
 ; CHECK-NEXT:   br i1 %1, label %bb, label %bb1
 ; CHECK: bb:
@@ -43,3 +43,6 @@ bb1:                                              ; preds = %bb, %entry
 declare i32 @strlen(ptr) nounwind readonly
 
 declare ptr @strchr(ptr, i32) nounwind readonly
+
+!0 = !{!"branch_weights", i32 95}
+!1 = !{!"branch_weights", i32 95}

diff  --git a/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll
new file mode 100644
index 0000000000000..e57033b345384
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll
@@ -0,0 +1,62 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2
+; RUN: opt < %s -passes='simplifycfg<no-sink-common-insts;hoist-common-insts>' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST
+
+; Test case based on C++ code with manualy annotated !prof metadata.
+; This is to test that when calls to 'func1' from 'if.then' block
+; and 'if.else' block are hoisted, the branch_weights are merged and
+; attached to merged call rather than dropped.
+;
+; int func1(int a, int b) ;
+; int func2(int a, int b) ;
+
+; int func(int a, int b, bool c) {
+;    int sum= 0;
+;    if(c) {
+;        sum += func1(a, b);
+;    } else {
+;        sum += func1(a, b);
+;        sum -= func2(a, b);
+;    }
+;    return sum;
+; }
+define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) {
+; HOIST-LABEL: define i32 @_Z4funciib
+; HOIST-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) {
+; HOIST-NEXT:  entry:
+; HOIST-NEXT:    [[CALL:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[B]]), !prof [[PROF0:![0-9]+]]
+; HOIST-NEXT:    br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]]
+; HOIST:       if.else:
+; HOIST-NEXT:    [[CALL3:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]])
+; HOIST-NEXT:    [[SUB:%.*]] = sub i32 [[CALL]], [[CALL3]]
+; HOIST-NEXT:    br label [[IF_END]]
+; HOIST:       if.end:
+; HOIST-NEXT:    [[SUM_0:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[CALL]], [[ENTRY:%.*]] ]
+; HOIST-NEXT:    ret i32 [[SUM_0]]
+;
+entry:
+  br i1 %c, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %call1 = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !1
+  %call3 = tail call i32 @_Z5func2ii(i32 %a, i32 %b)
+  %sub = sub i32 %call1, %call3
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %sum.0 = phi i32 [ %call, %if.then ], [ %sub, %if.else ]
+  ret i32 %sum.0
+}
+
+declare i32 @_Z5func1ii(i32, i32)
+
+declare i32 @_Z5func2ii(i32, i32)
+
+!0 = !{!"branch_weights", i32 10}
+!1 = !{!"branch_weights", i32 90}
+;.
+; HOIST: [[PROF0]] = !{!"branch_weights", i64 100}
+;.

diff  --git a/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll
new file mode 100644
index 0000000000000..3206746b13a33
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2
+; RUN: opt < %s -passes='simplifycfg<sink-common-insts;no-hoist-common-insts>' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=SINK
+
+
+; Test case based on the following C++ code with manualy annotated !prof metadata.
+; This is to test that when calls to 'func1' from 'if.then' and 'if.else' are
+; sinked, the branch weights are merged and attached to sinked call.
+;
+; int func1(int a, int b) ;
+; int func2(int a, int b) ;
+
+; int func(int a, int b, bool c) {
+;    int sum = 0;
+;    if (c) {
+;        sum += func1(a,b);
+;    } else {
+;        b -= func2(a,b);
+;        sum += func1(a,b);
+;    }
+;    return sum;
+; }
+
+define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) {
+; SINK-LABEL: define i32 @_Z4funciib
+; SINK-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) {
+; SINK-NEXT:  entry:
+; SINK-NEXT:    br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]]
+; SINK:       if.else:
+; SINK-NEXT:    [[CALL1:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]])
+; SINK-NEXT:    [[SUB:%.*]] = sub i32 [[B]], [[CALL1]]
+; SINK-NEXT:    br label [[IF_END]]
+; SINK:       if.end:
+; SINK-NEXT:    [[SUB_SINK:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[B]], [[ENTRY:%.*]] ]
+; SINK-NEXT:    [[CALL2:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[SUB_SINK]]), !prof [[PROF0:![0-9]+]]
+; SINK-NEXT:    ret i32 [[CALL2]]
+;
+entry:
+  br i1 %c, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %call1 = tail call i32 @_Z5func2ii(i32 %a, i32 %b)
+  %sub = sub i32 %b, %call1
+  %call2 = tail call i32 @_Z5func1ii(i32 %a, i32 %sub), !prof !1
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %sum.0 = phi i32 [ %call, %if.then ], [ %call2, %if.else ]
+  ret i32 %sum.0
+}
+
+declare i32 @_Z5func1ii(i32, i32)
+
+declare i32 @_Z5func2ii(i32, i32)
+
+!0 = !{!"branch_weights", i32 10}
+!1 = !{!"branch_weights", i32 90}
+;.
+; SINK: [[PROF0]] = !{!"branch_weights", i64 100}
+;.


        


More information about the llvm-commits mailing list