[llvm] [ScalarizeMaskedMemIntr] Don't use a scalar mask on GPUs (PR #104842)

Krzysztof Drewniak via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 10:58:44 PDT 2024


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/104842

>From be404168b927119c303c2ab814ca424b0e44634c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 16 Aug 2024 22:42:17 +0000
Subject: [PATCH 1/2] [ScalarizeMaskedMemIntr] Don't use a scalar mask on GPUs

ScalarizedMaskedMemIntr contains an optimization where the <N x i1>
mask is bitcast into an iN and then bit-tests with powers of two are
used to determine whether to load/store/... or not.

However, on machines with branch divergence (mainly GPUs), this is a
mis-optimization, since each i1 in the mask will be stored in a
condition register - that is, ecah of these "i1"s is likely to be a
word or two wide, making these bit operations counterproductive.

Therefore, amend this pass to skip the optimizaiton on targets that it
pessimizes.

Pre-commit tests #104645
---
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       | 120 +++++++++++-------
 .../AMDGPU/expamd-masked-load.ll              |  33 ++---
 .../AMDGPU/expand-masked-gather.ll            |  11 +-
 .../AMDGPU/expand-masked-scatter.ll           |  11 +-
 .../AMDGPU/expand-masked-store.ll             |  33 ++---
 5 files changed, 105 insertions(+), 103 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 9cb7bad94c20bc..ffa415ebd37048 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -141,7 +141,8 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
 //  %10 = extractelement <16 x i1> %mask, i32 2
 //  br i1 %10, label %cond.load4, label %else5
 //
-static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedLoad(const DataLayout &DL,
+                                const TargetTransformInfo &TTI, CallInst *CI,
                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Ptr = CI->getArgOperand(0);
   Value *Alignment = CI->getArgOperand(1);
@@ -221,11 +222,10 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
     return;
   }
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  // Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
-  // - what's a good way to detect this?
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs and other
+  // machines with divergence, as there each i1 needs a vector register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -233,13 +233,15 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     // Fill the "else" block, created in the previous iteration
     //
-    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
-    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
-    //  %cond = icmp ne i16 %mask_1, 0
-    //  br i1 %mask_1, label %cond.load, label %else
+    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
+    //  %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
+    //  %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -312,7 +314,8 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
 //   store i32 %6, i32* %7
 //   br label %else2
 //   . . .
-static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedStore(const DataLayout &DL,
+                                 const TargetTransformInfo &TTI, CallInst *CI,
                                  DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptr = CI->getArgOperand(1);
@@ -378,10 +381,10 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -393,8 +396,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
     //  %cond = icmp ne i16 %mask_1, 0
     //  br i1 %mask_1, label %cond.store, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -461,7 +467,8 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
 // . . .
 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
 // ret <16 x i32> %Result
-static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedGather(const DataLayout &DL,
+                                  const TargetTransformInfo &TTI, CallInst *CI,
                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Ptrs = CI->getArgOperand(0);
   Value *Alignment = CI->getArgOperand(1);
@@ -500,9 +507,10 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there, each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -514,9 +522,12 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
     //  %cond = icmp ne i16 %mask_1, 0
     //  br i1 %Mask1, label %cond.load, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
 
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -591,7 +602,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
 // store i32 %Elt1, i32* %Ptr1, align 4
 // br label %else2
 //   . . .
-static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedScatter(const DataLayout &DL,
+                                   const TargetTransformInfo &TTI, CallInst *CI,
                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptrs = CI->getArgOperand(1);
@@ -629,8 +641,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
   // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -642,8 +654,11 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
     //  %cond = icmp ne i16 %mask_1, 0
     //  br i1 %Mask1, label %cond.store, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -681,8 +696,10 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
   ModifiedDT = true;
 }
 
-static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
-                                      DomTreeUpdater *DTU, bool &ModifiedDT) {
+static void scalarizeMaskedExpandLoad(const DataLayout &DL,
+                                      const TargetTransformInfo &TTI,
+                                      CallInst *CI, DomTreeUpdater *DTU,
+                                      bool &ModifiedDT) {
   Value *Ptr = CI->getArgOperand(0);
   Value *Mask = CI->getArgOperand(1);
   Value *PassThru = CI->getArgOperand(2);
@@ -738,9 +755,10 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there, each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -748,13 +766,16 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     // Fill the "else" block, created in the previous iteration
     //
-    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
-    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
-    //  br i1 %mask_1, label %cond.load, label %else
+    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
+    //  %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
+    //  label %cond.load, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
 
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -813,8 +834,9 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
   ModifiedDT = true;
 }
 
-static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
-                                         DomTreeUpdater *DTU,
+static void scalarizeMaskedCompressStore(const DataLayout &DL,
+                                         const TargetTransformInfo &TTI,
+                                         CallInst *CI, DomTreeUpdater *DTU,
                                          bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptr = CI->getArgOperand(1);
@@ -855,9 +877,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there, each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -868,8 +891,11 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     //  br i1 %mask_1, label %cond.store, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -1071,14 +1097,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(),
               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
         return false;
-      scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedLoad(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_store:
       if (TTI.isLegalMaskedStore(
               CI->getArgOperand(0)->getType(),
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
         return false;
-      scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedStore(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_gather: {
       MaybeAlign MA =
@@ -1089,7 +1115,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
           !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
         return false;
-      scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedGather(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     }
     case Intrinsic::masked_scatter: {
@@ -1102,7 +1128,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
           !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
                                            Alignment))
         return false;
-      scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedScatter(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     }
     case Intrinsic::masked_expandload:
@@ -1110,14 +1136,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(),
               CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
         return false;
-      scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedExpandLoad(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_compressstore:
       if (TTI.isLegalMaskedCompressStore(
               CI->getArgOperand(0)->getType(),
               CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
         return false;
-      scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedCompressStore(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     }
   }
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll
index 35e5bcde4c0dbd..faee9f95ebdac0 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll
@@ -8,10 +8,8 @@
 define <2 x i32> @scalarize_v2i32(ptr %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-LABEL: define <2 x i32> @scalarize_v2i32(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr [[TMP3]], align 4
@@ -19,9 +17,8 @@ define <2 x i32> @scalarize_v2i32(ptr %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x i32> [ [[TMP5]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT:    br i1 [[TMP7]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 1
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr [[TMP8]], align 4
@@ -58,10 +55,8 @@ define <2 x i32> @scalarize_v2i32_splat_mask(ptr %p, i1 %mask, <2 x i32> %passth
 define <2 x half> @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %passthru) {
 ; CHECK-LABEL: define <2 x half> @scalarize_v2f16(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x half> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds half, ptr [[P]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = load half, ptr [[TMP3]], align 2
@@ -69,9 +64,8 @@ define <2 x half> @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %passthru)
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x half> [ [[TMP5]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT:    br i1 [[TMP7]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds half, ptr [[P]], i32 1
 ; CHECK-NEXT:    [[TMP9:%.*]] = load half, ptr [[TMP8]], align 2
@@ -88,10 +82,8 @@ define <2 x half> @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %passthru)
 define <2 x i32> @scalarize_v2i32_p3(ptr addrspace(3) %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-LABEL: define <2 x i32> @scalarize_v2i32_p3(
 ; CHECK-SAME: ptr addrspace(3) [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr addrspace(3) [[TMP3]], align 4
@@ -99,9 +91,8 @@ define <2 x i32> @scalarize_v2i32_p3(ptr addrspace(3) %p, <2 x i1> %mask, <2 x i
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x i32> [ [[TMP5]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT:    br i1 [[TMP7]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 1
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr addrspace(3) [[TMP8]], align 4
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll
index 94d0e2943d9366..8c4408bfa527b7 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll
@@ -8,10 +8,8 @@
 define <2 x i32> @scalarize_v2i32(<2 x ptr> %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-LABEL: define <2 x i32> @scalarize_v2i32(
 ; CHECK-SAME: <2 x ptr> [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[MASK0:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[MASK0]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[PTR0:%.*]] = extractelement <2 x ptr> [[P]], i64 0
 ; CHECK-NEXT:    [[LOAD0:%.*]] = load i32, ptr [[PTR0]], align 8
@@ -19,9 +17,8 @@ define <2 x i32> @scalarize_v2i32(<2 x ptr> %p, <2 x i1> %mask, <2 x i32> %passt
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x i32> [ [[RES0]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP3:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne i2 [[TMP3]], 0
-; CHECK-NEXT:    br i1 [[TMP4]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[MASK1:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[MASK1]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[PTR1:%.*]] = extractelement <2 x ptr> [[P]], i64 1
 ; CHECK-NEXT:    [[LOAD1:%.*]] = load i32, ptr [[PTR1]], align 8
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll
index 45debf35d06e4f..448ee18e9b4dec 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll
@@ -8,19 +8,16 @@
 define void @scalarize_v2i32(<2 x ptr> %p, <2 x i1> %mask, <2 x i32> %value) {
 ; CHECK-LABEL: define void @scalarize_v2i32(
 ; CHECK-SAME: <2 x ptr> [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[VALUE:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[MASK0:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[MASK0]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[ELT0:%.*]] = extractelement <2 x i32> [[VALUE]], i64 0
 ; CHECK-NEXT:    [[PTR0:%.*]] = extractelement <2 x ptr> [[P]], i64 0
 ; CHECK-NEXT:    store i32 [[ELT0]], ptr [[PTR0]], align 8
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP3:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne i2 [[TMP3]], 0
-; CHECK-NEXT:    br i1 [[TMP4]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[MASK1:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[MASK1]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[ELT1:%.*]] = extractelement <2 x i32> [[VALUE]], i64 1
 ; CHECK-NEXT:    [[PTR1:%.*]] = extractelement <2 x ptr> [[P]], i64 1
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll
index 1efd008b77e1c0..2eb86f20374d87 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll
@@ -8,19 +8,16 @@
 define void @scalarize_v2i32(ptr %p, <2 x i1> %mask, <2 x i32> %data) {
 ; CHECK-LABEL: define void @scalarize_v2i32(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[DATA:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i32> [[DATA]], i64 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 0
 ; CHECK-NEXT:    store i32 [[TMP3]], ptr [[TMP4]], align 4
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i32> [[DATA]], i64 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 1
@@ -55,19 +52,16 @@ define void @scalarize_v2i32_splat_mask(ptr %p, <2 x i32> %data, i1 %mask) {
 define void @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %data) {
 ; CHECK-LABEL: define void @scalarize_v2f16(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x half> [[DATA:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x half> [[DATA]], i64 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds half, ptr [[P]], i32 0
 ; CHECK-NEXT:    store half [[TMP3]], ptr [[TMP4]], align 2
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x half> [[DATA]], i64 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds half, ptr [[P]], i32 1
@@ -83,19 +77,16 @@ define void @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %data) {
 define void @scalarize_v2i32_p3(ptr addrspace(3) %p, <2 x i1> %mask, <2 x i32> %data) {
 ; CHECK-LABEL: define void @scalarize_v2i32_p3(
 ; CHECK-SAME: ptr addrspace(3) [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[DATA:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i32> [[DATA]], i64 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 0
 ; CHECK-NEXT:    store i32 [[TMP3]], ptr addrspace(3) [[TMP4]], align 4
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i32> [[DATA]], i64 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 1

>From 97810880789edd8aca061911290cb9cc45d6756b Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 22 Aug 2024 17:58:34 +0000
Subject: [PATCH 2/2] Address review comments

---
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       | 66 ++++++++++---------
 1 file changed, 35 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index ffa415ebd37048..63fcc1760ccafe 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -69,10 +69,11 @@ class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
 
 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
                           const TargetTransformInfo &TTI, const DataLayout &DL,
-                          DomTreeUpdater *DTU);
+                          bool HasBranchDivergence, DomTreeUpdater *DTU);
 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
                              const TargetTransformInfo &TTI,
-                             const DataLayout &DL, DomTreeUpdater *DTU);
+                             const DataLayout &DL, bool HasBranchDivergence,
+                             DomTreeUpdater *DTU);
 
 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
 
@@ -141,9 +142,9 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
 //  %10 = extractelement <16 x i1> %mask, i32 2
 //  br i1 %10, label %cond.load4, label %else5
 //
-static void scalarizeMaskedLoad(const DataLayout &DL,
-                                const TargetTransformInfo &TTI, CallInst *CI,
-                                DomTreeUpdater *DTU, bool &ModifiedDT) {
+static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
+                                CallInst *CI, DomTreeUpdater *DTU,
+                                bool &ModifiedDT) {
   Value *Ptr = CI->getArgOperand(0);
   Value *Alignment = CI->getArgOperand(1);
   Value *Mask = CI->getArgOperand(2);
@@ -225,7 +226,7 @@ static void scalarizeMaskedLoad(const DataLayout &DL,
   // better results on X86 at least. However, don't do this on GPUs and other
   // machines with divergence, as there each i1 needs a vector register.
   Value *SclrMask = nullptr;
-  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
+  if (VectorWidth != 1 && !HasBranchDivergence) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -314,9 +315,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL,
 //   store i32 %6, i32* %7
 //   br label %else2
 //   . . .
-static void scalarizeMaskedStore(const DataLayout &DL,
-                                 const TargetTransformInfo &TTI, CallInst *CI,
-                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
+static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
+                                 CallInst *CI, DomTreeUpdater *DTU,
+                                 bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptr = CI->getArgOperand(1);
   Value *Alignment = CI->getArgOperand(2);
@@ -384,7 +385,7 @@ static void scalarizeMaskedStore(const DataLayout &DL,
   // better results on X86 at least. However, don't do this on GPUs or other
   // machines with branch divergence, as there each i1 takes up a register.
   Value *SclrMask = nullptr;
-  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
+  if (VectorWidth != 1 && !HasBranchDivergence) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -468,7 +469,7 @@ static void scalarizeMaskedStore(const DataLayout &DL,
 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
 // ret <16 x i32> %Result
 static void scalarizeMaskedGather(const DataLayout &DL,
-                                  const TargetTransformInfo &TTI, CallInst *CI,
+                                  bool HasBranchDivergence, CallInst *CI,
                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Ptrs = CI->getArgOperand(0);
   Value *Alignment = CI->getArgOperand(1);
@@ -510,7 +511,7 @@ static void scalarizeMaskedGather(const DataLayout &DL,
   // better results on X86 at least. However, don't do this on GPUs or other
   // machines with branch divergence, as there, each i1 takes up a register.
   Value *SclrMask = nullptr;
-  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
+  if (VectorWidth != 1 && !HasBranchDivergence) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -603,7 +604,7 @@ static void scalarizeMaskedGather(const DataLayout &DL,
 // br label %else2
 //   . . .
 static void scalarizeMaskedScatter(const DataLayout &DL,
-                                   const TargetTransformInfo &TTI, CallInst *CI,
+                                   bool HasBranchDivergence, CallInst *CI,
                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptrs = CI->getArgOperand(1);
@@ -642,7 +643,7 @@ static void scalarizeMaskedScatter(const DataLayout &DL,
   // If the mask is not v1i1, use scalar bit test operations. This generates
   // better results on X86 at least.
   Value *SclrMask = nullptr;
-  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
+  if (VectorWidth != 1 && !HasBranchDivergence) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -697,9 +698,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL,
 }
 
 static void scalarizeMaskedExpandLoad(const DataLayout &DL,
-                                      const TargetTransformInfo &TTI,
-                                      CallInst *CI, DomTreeUpdater *DTU,
-                                      bool &ModifiedDT) {
+                                      bool HasBranchDivergence, CallInst *CI,
+                                      DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Ptr = CI->getArgOperand(0);
   Value *Mask = CI->getArgOperand(1);
   Value *PassThru = CI->getArgOperand(2);
@@ -758,7 +758,7 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL,
   // better results on X86 at least. However, don't do this on GPUs or other
   // machines with branch divergence, as there, each i1 takes up a register.
   Value *SclrMask = nullptr;
-  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
+  if (VectorWidth != 1 && !HasBranchDivergence) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -835,8 +835,8 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL,
 }
 
 static void scalarizeMaskedCompressStore(const DataLayout &DL,
-                                         const TargetTransformInfo &TTI,
-                                         CallInst *CI, DomTreeUpdater *DTU,
+                                         bool HasBranchDivergence, CallInst *CI,
+                                         DomTreeUpdater *DTU,
                                          bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptr = CI->getArgOperand(1);
@@ -880,7 +880,7 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL,
   // better results on X86 at least. However, don't do this on GPUs or other
   // machines with branch divergence, as there, each i1 takes up a register.
   Value *SclrMask = nullptr;
-  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
+  if (VectorWidth != 1 && !HasBranchDivergence) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -1019,12 +1019,13 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI,
   bool EverMadeChange = false;
   bool MadeChange = true;
   auto &DL = F.getDataLayout();
+  bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
   while (MadeChange) {
     MadeChange = false;
     for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
       bool ModifiedDTOnIteration = false;
       MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
-                                  DTU ? &*DTU : nullptr);
+                                  HasBranchDivergence, DTU ? &*DTU : nullptr);
 
       // Restart BB iteration if the dominator tree of the Function was changed
       if (ModifiedDTOnIteration)
@@ -1058,13 +1059,14 @@ ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
 
 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
                           const TargetTransformInfo &TTI, const DataLayout &DL,
-                          DomTreeUpdater *DTU) {
+                          bool HasBranchDivergence, DomTreeUpdater *DTU) {
   bool MadeChange = false;
 
   BasicBlock::iterator CurInstIterator = BB.begin();
   while (CurInstIterator != BB.end()) {
     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
-      MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
+      MadeChange |=
+          optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
     if (ModifiedDT)
       return true;
   }
@@ -1074,7 +1076,8 @@ static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
 
 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
                              const TargetTransformInfo &TTI,
-                             const DataLayout &DL, DomTreeUpdater *DTU) {
+                             const DataLayout &DL, bool HasBranchDivergence,
+                             DomTreeUpdater *DTU) {
   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
   if (II) {
     // The scalarization code below does not work for scalable vectors.
@@ -1097,14 +1100,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(),
               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
         return false;
-      scalarizeMaskedLoad(DL, TTI, CI, DTU, ModifiedDT);
+      scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_store:
       if (TTI.isLegalMaskedStore(
               CI->getArgOperand(0)->getType(),
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
         return false;
-      scalarizeMaskedStore(DL, TTI, CI, DTU, ModifiedDT);
+      scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_gather: {
       MaybeAlign MA =
@@ -1115,7 +1118,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
           !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
         return false;
-      scalarizeMaskedGather(DL, TTI, CI, DTU, ModifiedDT);
+      scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
     }
     case Intrinsic::masked_scatter: {
@@ -1128,7 +1131,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
           !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
                                            Alignment))
         return false;
-      scalarizeMaskedScatter(DL, TTI, CI, DTU, ModifiedDT);
+      scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
     }
     case Intrinsic::masked_expandload:
@@ -1136,14 +1139,15 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(),
               CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
         return false;
-      scalarizeMaskedExpandLoad(DL, TTI, CI, DTU, ModifiedDT);
+      scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_compressstore:
       if (TTI.isLegalMaskedCompressStore(
               CI->getArgOperand(0)->getType(),
               CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
         return false;
-      scalarizeMaskedCompressStore(DL, TTI, CI, DTU, ModifiedDT);
+      scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
+                                   ModifiedDT);
       return true;
     }
   }



More information about the llvm-commits mailing list