[llvm] [ScalarizeMaskedMemIntr][ProfCheck] Correctly annotate branch weights (PR #181568)

Aiden Grossman via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 15 12:13:38 PST 2026


https://github.com/boomanaiden154 created https://github.com/llvm/llvm-project/pull/181568

There are two cases in ScalarizeMaskedMemIntr where conditional branches are created using conditionals derived from the mask. Given these are synthesized ad we do not have VP metadata for them, we need to mark them as unknown.

>From c639d3b8474f7481d90cfed8b98eb813c82932ee Mon Sep 17 00:00:00 2001
From: Aiden Grossman <aidengrossman at google.com>
Date: Sun, 15 Feb 2026 20:09:28 +0000
Subject: [PATCH] [ScalarizeMaskedMemIntr][ProfCheck] Correctly annotate branch
 weights

There are two cases in ScalarizeMaskedMemIntr where conditional branches
are created using conditionals derived from the mask. Given these are
synthesized ad we do not have VP metadata for them, we need to mark them
as unknown.
---
 .../Scalar/ScalarizeMaskedMemIntrin.cpp         | 13 +++++++++++--
 llvm/test/CodeGen/X86/masked_gather_scatter.ll  | 17 ++++++++++++-----
 2 files changed, 23 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index b7b08ae61ec52..da9ceb4f440e5 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -28,6 +28,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/InitializePasses.h"
@@ -540,9 +541,13 @@ static void scalarizeMaskedGather(const DataLayout &DL,
     //  %Elt = load i32* %EltAddr
     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
     //
+    // We mark the branch weights as explicitly unknown given they would only
+    // be derivable from the mask which we do not have VP information for.
     Instruction *ThenTerm =
         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
-                                  /*BranchWeights=*/nullptr, DTU);
+                                  getExplicitlyUnknownBranchWeightsIfProfiled(
+                                      *CI->getFunction(), DEBUG_TYPE),
+                                  DTU);
 
     BasicBlock *CondBlock = ThenTerm->getParent();
     CondBlock->setName("cond.load");
@@ -670,9 +675,13 @@ static void scalarizeMaskedScatter(const DataLayout &DL,
     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     //  %store i32 %Elt1, i32* %Ptr1
     //
+    // We mark the branch weights as explicitly unknown given they would only
+    // be derivable from the mask which we do not have VP information for.
     Instruction *ThenTerm =
         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
-                                  /*BranchWeights=*/nullptr, DTU);
+                                  getExplicitlyUnknownBranchWeightsIfProfiled(
+                                      *CI->getFunction(), DEBUG_TYPE),
+                                  DTU);
 
     BasicBlock *CondBlock = ThenTerm->getParent();
     CondBlock->setName("cond.store");
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index cf49ac1e4886b..2e57a5140694d 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -50,6 +50,7 @@ declare <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> , i32, <8 x i1> , <8
 
 
 ; SCALAR-LABEL: test2
+; SCALAR:      !prof !0
 ; SCALAR:      extractelement <16 x ptr>
 ; SCALAR-NEXT: load float
 ; SCALAR-NEXT: insertelement <16 x float>
@@ -58,9 +59,9 @@ declare <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> , i32, <8 x i1> , <8
 ; SCALAR-NEXT:  %res.phi.else = phi
 ; SCALAR-NEXT:  and i16 %{{.*}}, 2
 ; SCALAR-NEXT:  icmp ne i16 %{{.*}}, 0
-; SCALAR-NEXT:  br i1 %{{.*}}, label %cond.load1, label %else2
+; SCALAR-NEXT:  br i1 %{{.*}}, label %cond.load1, label %else2, !prof !1
 
-define <16 x float> @test2(ptr %base, <16 x i32> %ind, i16 %mask) {
+define <16 x float> @test2(ptr %base, <16 x i32> %ind, i16 %mask) !prof !0 {
 ; X64-LABEL: test2:
 ; X64:       # %bb.0:
 ; X64-NEXT:    kmovw %esi, %k1
@@ -151,9 +152,10 @@ define <16 x i32> @test4(ptr %base, <16 x i32> %ind, i16 %mask) {
 
 
 ; SCALAR-LABEL: test5
+; SCALAR:        !prof !0
 ; SCALAR:        and i16 %scalar_mask, 1
 ; SCALAR-NEXT:   icmp ne i16 %{{.*}}, 0
-; SCALAR-NEXT:   br i1 %{{.*}}, label %cond.store, label %else
+; SCALAR-NEXT:   br i1 %{{.*}}, label %cond.store, label %else, !prof !1
 ; SCALAR: cond.store:
 ; SCALAR-NEXT:  %Elt0 = extractelement <16 x i32> %val, i64 0
 ; SCALAR-NEXT:  %Ptr0 = extractelement <16 x ptr> %gep.random, i64 0
@@ -162,9 +164,9 @@ define <16 x i32> @test4(ptr %base, <16 x i32> %ind, i16 %mask) {
 ; SCALAR: else:
 ; SCALAR-NEXT:   and i16 %scalar_mask, 2
 ; SCALAR-NEXT:   icmp ne i16 %{{.*}}, 0
-; SCALAR-NEXT:  br i1 %{{.*}}, label %cond.store1, label %else2
+; SCALAR-NEXT:  br i1 %{{.*}}, label %cond.store1, label %else2, !prof !1
 
-define void @test5(ptr %base, <16 x i32> %ind, i16 %mask, <16 x i32>%val) {
+define void @test5(ptr %base, <16 x i32> %ind, i16 %mask, <16 x i32>%val) !prof !0 {
 ; X64-LABEL: test5:
 ; X64:       # %bb.0:
 ; X64-NEXT:    kmovw %esi, %k1
@@ -5565,3 +5567,8 @@ define {<16 x float>, <16 x float>} @test_gather_structpt2_16f32_mask_index_pair
   %pair2 = insertvalue {<16 x float>, <16 x float>} %pair1, <16 x float> %res, 1
   ret {<16 x float>, <16 x float>} %pair2
 }
+
+!0 = !{!"function_entry_count", i64 1000}
+
+; SCALAR:      !0 = !{!"function_entry_count", i64 1000}
+; SCALAR-NEXT: !1 = !{!"unknown", !"scalarize-masked-mem-intrin"}



More information about the llvm-commits mailing list