[llvm] [ScalarizeMaskedMemIntr] Don't use a scalar mask on GPUs (PR #104842)

Krzysztof Drewniak via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 12:53:59 PDT 2024


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/104842

ScalarizedMaskedMemIntr contains an optimization where the <N x i1> mask is bitcast into an iN and then bit-tests with powers of two are used to determine whether to load/store/... or not.

However, on machines with branch divergence (mainly GPUs), this is a mis-optimization, since each i1 in the mask will be stored in a condition register - that is, ecah of these "i1"s is likely to be a word or two wide, making these bit operations counterproductive.

Therefore, amend this pass to skip the optimizaiton on targets that it pessimizes.

Pre-commit tests #104645

>From be404168b927119c303c2ab814ca424b0e44634c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 16 Aug 2024 22:42:17 +0000
Subject: [PATCH] [ScalarizeMaskedMemIntr] Don't use a scalar mask on GPUs

ScalarizedMaskedMemIntr contains an optimization where the <N x i1>
mask is bitcast into an iN and then bit-tests with powers of two are
used to determine whether to load/store/... or not.

However, on machines with branch divergence (mainly GPUs), this is a
mis-optimization, since each i1 in the mask will be stored in a
condition register - that is, ecah of these "i1"s is likely to be a
word or two wide, making these bit operations counterproductive.

Therefore, amend this pass to skip the optimizaiton on targets that it
pessimizes.

Pre-commit tests #104645
---
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       | 120 +++++++++++-------
 .../AMDGPU/expamd-masked-load.ll              |  33 ++---
 .../AMDGPU/expand-masked-gather.ll            |  11 +-
 .../AMDGPU/expand-masked-scatter.ll           |  11 +-
 .../AMDGPU/expand-masked-store.ll             |  33 ++---
 5 files changed, 105 insertions(+), 103 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 9cb7bad94c20bc..ffa415ebd37048 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -141,7 +141,8 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
 //  %10 = extractelement <16 x i1> %mask, i32 2
 //  br i1 %10, label %cond.load4, label %else5
 //
-static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedLoad(const DataLayout &DL,
+                                const TargetTransformInfo &TTI, CallInst *CI,
                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Ptr = CI->getArgOperand(0);
   Value *Alignment = CI->getArgOperand(1);
@@ -221,11 +222,10 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
     return;
   }
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  // Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
-  // - what's a good way to detect this?
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs and other
+  // machines with divergence, as there each i1 needs a vector register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -233,13 +233,15 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     // Fill the "else" block, created in the previous iteration
     //
-    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
-    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
-    //  %cond = icmp ne i16 %mask_1, 0
-    //  br i1 %mask_1, label %cond.load, label %else
+    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
+    //  %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
+    //  %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -312,7 +314,8 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
 //   store i32 %6, i32* %7
 //   br label %else2
 //   . . .
-static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedStore(const DataLayout &DL,
+                                 const TargetTransformInfo &TTI, CallInst *CI,
                                  DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptr = CI->getArgOperand(1);
@@ -378,10 +381,10 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -393,8 +396,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
     //  %cond = icmp ne i16 %mask_1, 0
     //  br i1 %mask_1, label %cond.store, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -461,7 +467,8 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
 // . . .
 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
 // ret <16 x i32> %Result
-static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedGather(const DataLayout &DL,
+                                  const TargetTransformInfo &TTI, CallInst *CI,
                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Ptrs = CI->getArgOperand(0);
   Value *Alignment = CI->getArgOperand(1);
@@ -500,9 +507,10 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there, each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -514,9 +522,12 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
     //  %cond = icmp ne i16 %mask_1, 0
     //  br i1 %Mask1, label %cond.load, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
 
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -591,7 +602,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
 // store i32 %Elt1, i32* %Ptr1, align 4
 // br label %else2
 //   . . .
-static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
+static void scalarizeMaskedScatter(const DataLayout &DL,
+                                   const TargetTransformInfo &TTI, CallInst *CI,
                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptrs = CI->getArgOperand(1);
@@ -629,8 +641,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
   // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -642,8 +654,11 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
     //  %cond = icmp ne i16 %mask_1, 0
     //  br i1 %Mask1, label %cond.store, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -681,8 +696,10 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
   ModifiedDT = true;
 }
 
-static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
-                                      DomTreeUpdater *DTU, bool &ModifiedDT) {
+static void scalarizeMaskedExpandLoad(const DataLayout &DL,
+                                      const TargetTransformInfo &TTI,
+                                      CallInst *CI, DomTreeUpdater *DTU,
+                                      bool &ModifiedDT) {
   Value *Ptr = CI->getArgOperand(0);
   Value *Mask = CI->getArgOperand(1);
   Value *PassThru = CI->getArgOperand(2);
@@ -738,9 +755,10 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there, each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -748,13 +766,16 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     // Fill the "else" block, created in the previous iteration
     //
-    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
-    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
-    //  br i1 %mask_1, label %cond.load, label %else
+    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
+    //  %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
+    //  label %cond.load, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
 
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -813,8 +834,9 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
   ModifiedDT = true;
 }
 
-static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
-                                         DomTreeUpdater *DTU,
+static void scalarizeMaskedCompressStore(const DataLayout &DL,
+                                         const TargetTransformInfo &TTI,
+                                         CallInst *CI, DomTreeUpdater *DTU,
                                          bool &ModifiedDT) {
   Value *Src = CI->getArgOperand(0);
   Value *Ptr = CI->getArgOperand(1);
@@ -855,9 +877,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
   }
 
   // If the mask is not v1i1, use scalar bit test operations. This generates
-  // better results on X86 at least.
-  Value *SclrMask;
-  if (VectorWidth != 1) {
+  // better results on X86 at least. However, don't do this on GPUs or other
+  // machines with branch divergence, as there, each i1 takes up a register.
+  Value *SclrMask = nullptr;
+  if (!TTI.hasBranchDivergence() && VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
   }
@@ -868,8 +891,11 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     //  br i1 %mask_1, label %cond.store, label %else
     //
+    // On GPUs, use
+    //  %cond = extrectelement %mask, Idx
+    // instead
     Value *Predicate;
-    if (VectorWidth != 1) {
+    if (SclrMask != nullptr) {
       Value *Mask = Builder.getInt(APInt::getOneBitSet(
           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
@@ -1071,14 +1097,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(),
               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
         return false;
-      scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedLoad(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_store:
       if (TTI.isLegalMaskedStore(
               CI->getArgOperand(0)->getType(),
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
         return false;
-      scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedStore(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_gather: {
       MaybeAlign MA =
@@ -1089,7 +1115,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
           !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
         return false;
-      scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedGather(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     }
     case Intrinsic::masked_scatter: {
@@ -1102,7 +1128,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
           !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
                                            Alignment))
         return false;
-      scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedScatter(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     }
     case Intrinsic::masked_expandload:
@@ -1110,14 +1136,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(),
               CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
         return false;
-      scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedExpandLoad(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_compressstore:
       if (TTI.isLegalMaskedCompressStore(
               CI->getArgOperand(0)->getType(),
               CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
         return false;
-      scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
+      scalarizeMaskedCompressStore(DL, TTI, CI, DTU, ModifiedDT);
       return true;
     }
   }
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll
index 35e5bcde4c0dbd..faee9f95ebdac0 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expamd-masked-load.ll
@@ -8,10 +8,8 @@
 define <2 x i32> @scalarize_v2i32(ptr %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-LABEL: define <2 x i32> @scalarize_v2i32(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr [[TMP3]], align 4
@@ -19,9 +17,8 @@ define <2 x i32> @scalarize_v2i32(ptr %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x i32> [ [[TMP5]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT:    br i1 [[TMP7]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 1
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr [[TMP8]], align 4
@@ -58,10 +55,8 @@ define <2 x i32> @scalarize_v2i32_splat_mask(ptr %p, i1 %mask, <2 x i32> %passth
 define <2 x half> @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %passthru) {
 ; CHECK-LABEL: define <2 x half> @scalarize_v2f16(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x half> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds half, ptr [[P]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = load half, ptr [[TMP3]], align 2
@@ -69,9 +64,8 @@ define <2 x half> @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %passthru)
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x half> [ [[TMP5]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT:    br i1 [[TMP7]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds half, ptr [[P]], i32 1
 ; CHECK-NEXT:    [[TMP9:%.*]] = load half, ptr [[TMP8]], align 2
@@ -88,10 +82,8 @@ define <2 x half> @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %passthru)
 define <2 x i32> @scalarize_v2i32_p3(ptr addrspace(3) %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-LABEL: define <2 x i32> @scalarize_v2i32_p3(
 ; CHECK-SAME: ptr addrspace(3) [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr addrspace(3) [[TMP3]], align 4
@@ -99,9 +91,8 @@ define <2 x i32> @scalarize_v2i32_p3(ptr addrspace(3) %p, <2 x i1> %mask, <2 x i
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x i32> [ [[TMP5]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT:    br i1 [[TMP7]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 1
 ; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr addrspace(3) [[TMP8]], align 4
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll
index 94d0e2943d9366..8c4408bfa527b7 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-gather.ll
@@ -8,10 +8,8 @@
 define <2 x i32> @scalarize_v2i32(<2 x ptr> %p, <2 x i1> %mask, <2 x i32> %passthru) {
 ; CHECK-LABEL: define <2 x i32> @scalarize_v2i32(
 ; CHECK-SAME: <2 x ptr> [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[PASSTHRU:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[MASK0:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[MASK0]], label %[[COND_LOAD:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_LOAD]]:
 ; CHECK-NEXT:    [[PTR0:%.*]] = extractelement <2 x ptr> [[P]], i64 0
 ; CHECK-NEXT:    [[LOAD0:%.*]] = load i32, ptr [[PTR0]], align 8
@@ -19,9 +17,8 @@ define <2 x i32> @scalarize_v2i32(<2 x ptr> %p, <2 x i1> %mask, <2 x i32> %passt
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
 ; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x i32> [ [[RES0]], %[[COND_LOAD]] ], [ [[PASSTHRU]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP3:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne i2 [[TMP3]], 0
-; CHECK-NEXT:    br i1 [[TMP4]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[MASK1:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[MASK1]], label %[[COND_LOAD1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_LOAD1]]:
 ; CHECK-NEXT:    [[PTR1:%.*]] = extractelement <2 x ptr> [[P]], i64 1
 ; CHECK-NEXT:    [[LOAD1:%.*]] = load i32, ptr [[PTR1]], align 8
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll
index 45debf35d06e4f..448ee18e9b4dec 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-scatter.ll
@@ -8,19 +8,16 @@
 define void @scalarize_v2i32(<2 x ptr> %p, <2 x i1> %mask, <2 x i32> %value) {
 ; CHECK-LABEL: define void @scalarize_v2i32(
 ; CHECK-SAME: <2 x ptr> [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[VALUE:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[MASK0:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[MASK0]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[ELT0:%.*]] = extractelement <2 x i32> [[VALUE]], i64 0
 ; CHECK-NEXT:    [[PTR0:%.*]] = extractelement <2 x ptr> [[P]], i64 0
 ; CHECK-NEXT:    store i32 [[ELT0]], ptr [[PTR0]], align 8
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP3:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne i2 [[TMP3]], 0
-; CHECK-NEXT:    br i1 [[TMP4]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[MASK1:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[MASK1]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[ELT1:%.*]] = extractelement <2 x i32> [[VALUE]], i64 1
 ; CHECK-NEXT:    [[PTR1:%.*]] = extractelement <2 x ptr> [[P]], i64 1
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll
index 1efd008b77e1c0..2eb86f20374d87 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/AMDGPU/expand-masked-store.ll
@@ -8,19 +8,16 @@
 define void @scalarize_v2i32(ptr %p, <2 x i1> %mask, <2 x i32> %data) {
 ; CHECK-LABEL: define void @scalarize_v2i32(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[DATA:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i32> [[DATA]], i64 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 0
 ; CHECK-NEXT:    store i32 [[TMP3]], ptr [[TMP4]], align 4
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i32> [[DATA]], i64 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 1
@@ -55,19 +52,16 @@ define void @scalarize_v2i32_splat_mask(ptr %p, <2 x i32> %data, i1 %mask) {
 define void @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %data) {
 ; CHECK-LABEL: define void @scalarize_v2f16(
 ; CHECK-SAME: ptr [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x half> [[DATA:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x half> [[DATA]], i64 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds half, ptr [[P]], i32 0
 ; CHECK-NEXT:    store half [[TMP3]], ptr [[TMP4]], align 2
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x half> [[DATA]], i64 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds half, ptr [[P]], i32 1
@@ -83,19 +77,16 @@ define void @scalarize_v2f16(ptr %p, <2 x i1> %mask, <2 x half> %data) {
 define void @scalarize_v2i32_p3(ptr addrspace(3) %p, <2 x i1> %mask, <2 x i32> %data) {
 ; CHECK-LABEL: define void @scalarize_v2i32_p3(
 ; CHECK-SAME: ptr addrspace(3) [[P:%.*]], <2 x i1> [[MASK:%.*]], <2 x i32> [[DATA:%.*]]) {
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i1> [[MASK]], i64 0
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[COND_STORE:.*]], label %[[ELSE:.*]]
 ; CHECK:       [[COND_STORE]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i32> [[DATA]], i64 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 0
 ; CHECK-NEXT:    store i32 [[TMP3]], ptr addrspace(3) [[TMP4]], align 4
 ; CHECK-NEXT:    br label %[[ELSE]]
 ; CHECK:       [[ELSE]]:
-; CHECK-NEXT:    [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i1> [[MASK]], i64 1
+; CHECK-NEXT:    br i1 [[TMP5]], label %[[COND_STORE1:.*]], label %[[ELSE2:.*]]
 ; CHECK:       [[COND_STORE1]]:
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i32> [[DATA]], i64 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 1



More information about the llvm-commits mailing list