[llvm] [ScalarizeMaskedMemIntr] Optimize splat non-constant masks (PR #104537)

Krzysztof Drewniak via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 18:53:34 PDT 2024


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/104537

In cases (like the ones added in the tests) where the condition of a masked load or store is a splat but not a constant (that is, a masked operation is being used to implement patterns like "load if the current lane is in-bounds, otherwise return 0"), optimize the 'scalarized' code to perform an aligned vector load/store if the splat constant is true.

Additionally, take a few steps to preserve aliasing information and names when nothing is scalarized while I'm here.

As motivation, some LLVM IR users will genatate masked load/store in cases that map to this kind of predicated operation (where either the vector is loaded/stored or it isn't) in order to take advantage of hardware primitives, but on AMDGPU, where we don't have a masked load or store, this pass would scalarize a load or store that was intended to be - and can be - vectorized while also introducing expensive branches.

Fixes #104520

Pre-commit tests at #104527

>From 04ae0cb77ccbaea3ab49098e09f61321ee698ac7 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 16 Aug 2024 01:38:18 +0000
Subject: [PATCH] [ScalarizeMaskedMemIntr] Optimize splat non-constant masks

In cases (like the ones added in the tests) where the condition of a
masked load or store is a splat but not a constant (that is, a masked
operation is being used to implement patterns like "load if the
current lane is in-bounds, otherwise return 0"), optimize the
'scalarized' code to perform an aligned vector load/store if the splat
constant is true.

Additionally, take a few steps to preserve aliasing information and
names when nothing is scalarized while I'm here.

As motivation, some LLVM IR users will genatate masked load/store in
cases that map to this kind of predicated operation (where either the
vector is loaded/stored or it isn't) in order to take advantage of
hardware primitives, but on AMDGPU, where we don't have a masked load
or store, this pass would scalarize a load or store that was intended
to be - and can be - vectorized while also introducing expensive branches.

Fixes #104520

Pre-commit tests at #104527
---
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       | 64 ++++++++++++++++++-
 .../X86/expand-masked-load.ll                 | 34 +++-------
 .../X86/expand-masked-store.ll                | 25 ++------
 3 files changed, 75 insertions(+), 48 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 8eadf8900020d9..9cb7bad94c20bc 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -17,6 +17,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -161,7 +162,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
 
   // Short-cut if the mask is all-true.
   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
-    Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
+    LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
+    NewI->copyMetadata(*CI);
+    NewI->takeName(CI);
     CI->replaceAllUsesWith(NewI);
     CI->eraseFromParent();
     return;
@@ -188,8 +191,39 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
     return;
   }
 
+  // Optimize the case where the "masked load" is a predicated load - that is,
+  // where the mask is the splat of a non-constant scalar boolean. In that case,
+  // use that splated value as the guard on a conditional vector load.
+  if (isSplatValue(Mask, /*Index=*/0)) {
+    Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
+                                                    Mask->getName() + ".first");
+    Instruction *ThenTerm =
+        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
+                                  /*BranchWeights=*/nullptr, DTU);
+
+    BasicBlock *CondBlock = ThenTerm->getParent();
+    CondBlock->setName("cond.load");
+    Builder.SetInsertPoint(CondBlock->getTerminator());
+    LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
+                                               CI->getName() + ".cond.load");
+    Load->copyMetadata(*CI);
+
+    BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
+    Builder.SetInsertPoint(PostLoad, PostLoad->begin());
+    PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
+    Phi->addIncoming(Load, CondBlock);
+    Phi->addIncoming(Src0, IfBlock);
+    Phi->takeName(CI);
+
+    CI->replaceAllUsesWith(Phi);
+    CI->eraseFromParent();
+    ModifiedDT = true;
+    return;
+  }
   // If the mask is not v1i1, use scalar bit test operations. This generates
   // better results on X86 at least.
+  // Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
+  // - what's a good way to detect this?
   Value *SclrMask;
   if (VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
@@ -297,7 +331,9 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
 
   // Short-cut if the mask is all-true.
   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
-    Builder.CreateAlignedStore(Src, Ptr, AlignVal);
+    StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
+    Store->takeName(CI);
+    Store->copyMetadata(*CI);
     CI->eraseFromParent();
     return;
   }
@@ -319,8 +355,31 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
     return;
   }
 
+  // Optimize the case where the "masked store" is a predicated store - that is,
+  // when the mask is the splat of a non-constant scalar boolean. In that case,
+  // optimize to a conditional store.
+  if (isSplatValue(Mask, /*Index=*/0)) {
+    Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
+                                                    Mask->getName() + ".first");
+    Instruction *ThenTerm =
+        SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
+                                  /*BranchWeights=*/nullptr, DTU);
+    BasicBlock *CondBlock = ThenTerm->getParent();
+    CondBlock->setName("cond.store");
+    Builder.SetInsertPoint(CondBlock->getTerminator());
+
+    StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
+    Store->takeName(CI);
+    Store->copyMetadata(*CI);
+
+    CI->eraseFromParent();
+    ModifiedDT = true;
+    return;
+  }
+
   // If the mask is not v1i1, use scalar bit test operations. This generates
   // better results on X86 at least.
+
   Value *SclrMask;
   if (VectorWidth != 1) {
     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
@@ -997,7 +1056,6 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
         any_of(II->args(),
                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
       return false;
-
     switch (II->getIntrinsicID()) {
     default:
       break;
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll
index 9b1c59829b9ffb..fffb5f021e52d4 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll
@@ -32,8 +32,8 @@ define <2 x i64> @scalarize_v2i64(ptr %p, <2 x i1> %mask, <2 x i64> %passthru) {
 
 define <2 x i64> @scalarize_v2i64_ones_mask(ptr %p, <2 x i64> %passthru) {
 ; CHECK-LABEL: @scalarize_v2i64_ones_mask(
-; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    ret <2 x i64> [[TMP1]]
+; CHECK-NEXT:    [[RET:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
+; CHECK-NEXT:    ret <2 x i64> [[RET]]
 ;
   %ret = call <2 x i64> @llvm.masked.load.v2i64.p0(ptr %p, i32 8, <2 x i1> <i1 true, i1 true>, <2 x i64> %passthru)
   ret <2 x i64> %ret
@@ -58,34 +58,18 @@ define <2 x i64> @scalarize_v2i64_const_mask(ptr %p, <2 x i64> %passthru) {
   ret <2 x i64> %ret
 }
 
-; To be fixed: If the mask is the splat/broadcast of a non-constant value, use a
-; vector load
 define <2 x i64> @scalarize_v2i64_splat_mask(ptr %p, i1 %mask, <2 x i64> %passthrough) {
 ; CHECK-LABEL: @scalarize_v2i64_splat_mask(
 ; CHECK-NEXT:    [[MASK_VEC:%.*]] = insertelement <2 x i1> poison, i1 [[MASK:%.*]], i32 0
 ; CHECK-NEXT:    [[MASK_SPLAT:%.*]] = shufflevector <2 x i1> [[MASK_VEC]], <2 x i1> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK_SPLAT]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label [[COND_LOAD:%.*]], label [[ELSE:%.*]]
+; CHECK-NEXT:    [[MASK_SPLAT_FIRST:%.*]] = extractelement <2 x i1> [[MASK_SPLAT]], i64 0
+; CHECK-NEXT:    br i1 [[MASK_SPLAT_FIRST]], label [[COND_LOAD:%.*]], label [[TMP1:%.*]]
 ; CHECK:       cond.load:
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i32 0
-; CHECK-NEXT:    [[TMP4:%.*]] = load i64, ptr [[TMP3]], align 8
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x i64> [[PASSTHROUGH:%.*]], i64 [[TMP4]], i64 0
-; CHECK-NEXT:    br label [[ELSE]]
-; CHECK:       else:
-; CHECK-NEXT:    [[RES_PHI_ELSE:%.*]] = phi <2 x i64> [ [[TMP5]], [[COND_LOAD]] ], [ [[PASSTHROUGH]], [[TMP0:%.*]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT:    br i1 [[TMP7]], label [[COND_LOAD1:%.*]], label [[ELSE2:%.*]]
-; CHECK:       cond.load1:
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[P]], i32 1
-; CHECK-NEXT:    [[TMP9:%.*]] = load i64, ptr [[TMP8]], align 8
-; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <2 x i64> [[RES_PHI_ELSE]], i64 [[TMP9]], i64 1
-; CHECK-NEXT:    br label [[ELSE2]]
-; CHECK:       else2:
-; CHECK-NEXT:    [[RES_PHI_ELSE3:%.*]] = phi <2 x i64> [ [[TMP10]], [[COND_LOAD1]] ], [ [[RES_PHI_ELSE]], [[ELSE]] ]
-; CHECK-NEXT:    ret <2 x i64> [[RES_PHI_ELSE3]]
+; CHECK-NEXT:    [[RET_COND_LOAD:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
+; CHECK-NEXT:    br label [[TMP1]]
+; CHECK:       1:
+; CHECK-NEXT:    [[RET:%.*]] = phi <2 x i64> [ [[RET_COND_LOAD]], [[COND_LOAD]] ], [ [[PASSTHROUGH:%.*]], [[TMP0:%.*]] ]
+; CHECK-NEXT:    ret <2 x i64> [[RET]]
 ;
   %mask.vec = insertelement <2 x i1> poison, i1 %mask, i32 0
   %mask.splat = shufflevector <2 x i1> %mask.vec, <2 x i1> poison, <2 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll
index cd2815e67e6720..4e3679dc5da99e 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll
@@ -56,31 +56,16 @@ define void @scalarize_v2i64_const_mask(ptr %p, <2 x i64> %data) {
   ret void
 }
 
-; To be fixed: If the mask is the splat/broadcast of a non-constant value, use a
-; vector store
 define void @scalarize_v2i64_splat_mask(ptr %p, <2 x i64> %data, i1 %mask) {
 ; CHECK-LABEL: @scalarize_v2i64_splat_mask(
 ; CHECK-NEXT:    [[MASK_VEC:%.*]] = insertelement <2 x i1> poison, i1 [[MASK:%.*]], i32 0
 ; CHECK-NEXT:    [[MASK_SPLAT:%.*]] = shufflevector <2 x i1> [[MASK_VEC]], <2 x i1> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK_SPLAT]] to i2
-; CHECK-NEXT:    [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT:    br i1 [[TMP2]], label [[COND_STORE:%.*]], label [[ELSE:%.*]]
+; CHECK-NEXT:    [[MASK_SPLAT_FIRST:%.*]] = extractelement <2 x i1> [[MASK_SPLAT]], i64 0
+; CHECK-NEXT:    br i1 [[MASK_SPLAT_FIRST]], label [[COND_STORE:%.*]], label [[TMP1:%.*]]
 ; CHECK:       cond.store:
-; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i64> [[DATA:%.*]], i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i32 0
-; CHECK-NEXT:    store i64 [[TMP3]], ptr [[TMP4]], align 8
-; CHECK-NEXT:    br label [[ELSE]]
-; CHECK:       else:
-; CHECK-NEXT:    [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label [[COND_STORE1:%.*]], label [[ELSE2:%.*]]
-; CHECK:       cond.store1:
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x i64> [[DATA]], i64 1
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[P]], i32 1
-; CHECK-NEXT:    store i64 [[TMP7]], ptr [[TMP8]], align 8
-; CHECK-NEXT:    br label [[ELSE2]]
-; CHECK:       else2:
+; CHECK-NEXT:    store <2 x i64> [[DATA:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT:    br label [[TMP1]]
+; CHECK:       1:
 ; CHECK-NEXT:    ret void
 ;
   %mask.vec = insertelement <2 x i1> poison, i1 %mask, i32 0



More information about the llvm-commits mailing list