[llvm] [ScalarizeMaskedMemIntr] Optimize splat non-constant masks (PR #104537)
Krzysztof Drewniak via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 15 18:53:34 PDT 2024
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/104537
In cases (like the ones added in the tests) where the condition of a masked load or store is a splat but not a constant (that is, a masked operation is being used to implement patterns like "load if the current lane is in-bounds, otherwise return 0"), optimize the 'scalarized' code to perform an aligned vector load/store if the splat constant is true.
Additionally, take a few steps to preserve aliasing information and names when nothing is scalarized while I'm here.
As motivation, some LLVM IR users will genatate masked load/store in cases that map to this kind of predicated operation (where either the vector is loaded/stored or it isn't) in order to take advantage of hardware primitives, but on AMDGPU, where we don't have a masked load or store, this pass would scalarize a load or store that was intended to be - and can be - vectorized while also introducing expensive branches.
Fixes #104520
Pre-commit tests at #104527
>From 04ae0cb77ccbaea3ab49098e09f61321ee698ac7 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 16 Aug 2024 01:38:18 +0000
Subject: [PATCH] [ScalarizeMaskedMemIntr] Optimize splat non-constant masks
In cases (like the ones added in the tests) where the condition of a
masked load or store is a splat but not a constant (that is, a masked
operation is being used to implement patterns like "load if the
current lane is in-bounds, otherwise return 0"), optimize the
'scalarized' code to perform an aligned vector load/store if the splat
constant is true.
Additionally, take a few steps to preserve aliasing information and
names when nothing is scalarized while I'm here.
As motivation, some LLVM IR users will genatate masked load/store in
cases that map to this kind of predicated operation (where either the
vector is loaded/stored or it isn't) in order to take advantage of
hardware primitives, but on AMDGPU, where we don't have a masked load
or store, this pass would scalarize a load or store that was intended
to be - and can be - vectorized while also introducing expensive branches.
Fixes #104520
Pre-commit tests at #104527
---
.../Scalar/ScalarizeMaskedMemIntrin.cpp | 64 ++++++++++++++++++-
.../X86/expand-masked-load.ll | 34 +++-------
.../X86/expand-masked-store.ll | 25 ++------
3 files changed, 75 insertions(+), 48 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 8eadf8900020d9..9cb7bad94c20bc 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -17,6 +17,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
@@ -161,7 +162,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
// Short-cut if the mask is all-true.
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
- Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
+ LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
+ NewI->copyMetadata(*CI);
+ NewI->takeName(CI);
CI->replaceAllUsesWith(NewI);
CI->eraseFromParent();
return;
@@ -188,8 +191,39 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
return;
}
+ // Optimize the case where the "masked load" is a predicated load - that is,
+ // where the mask is the splat of a non-constant scalar boolean. In that case,
+ // use that splated value as the guard on a conditional vector load.
+ if (isSplatValue(Mask, /*Index=*/0)) {
+ Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
+ Mask->getName() + ".first");
+ Instruction *ThenTerm =
+ SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
+ /*BranchWeights=*/nullptr, DTU);
+
+ BasicBlock *CondBlock = ThenTerm->getParent();
+ CondBlock->setName("cond.load");
+ Builder.SetInsertPoint(CondBlock->getTerminator());
+ LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
+ CI->getName() + ".cond.load");
+ Load->copyMetadata(*CI);
+
+ BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
+ Builder.SetInsertPoint(PostLoad, PostLoad->begin());
+ PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
+ Phi->addIncoming(Load, CondBlock);
+ Phi->addIncoming(Src0, IfBlock);
+ Phi->takeName(CI);
+
+ CI->replaceAllUsesWith(Phi);
+ CI->eraseFromParent();
+ ModifiedDT = true;
+ return;
+ }
// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
+ // Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
+ // - what's a good way to detect this?
Value *SclrMask;
if (VectorWidth != 1) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
@@ -297,7 +331,9 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
// Short-cut if the mask is all-true.
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
- Builder.CreateAlignedStore(Src, Ptr, AlignVal);
+ StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
+ Store->takeName(CI);
+ Store->copyMetadata(*CI);
CI->eraseFromParent();
return;
}
@@ -319,8 +355,31 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
return;
}
+ // Optimize the case where the "masked store" is a predicated store - that is,
+ // when the mask is the splat of a non-constant scalar boolean. In that case,
+ // optimize to a conditional store.
+ if (isSplatValue(Mask, /*Index=*/0)) {
+ Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
+ Mask->getName() + ".first");
+ Instruction *ThenTerm =
+ SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
+ /*BranchWeights=*/nullptr, DTU);
+ BasicBlock *CondBlock = ThenTerm->getParent();
+ CondBlock->setName("cond.store");
+ Builder.SetInsertPoint(CondBlock->getTerminator());
+
+ StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
+ Store->takeName(CI);
+ Store->copyMetadata(*CI);
+
+ CI->eraseFromParent();
+ ModifiedDT = true;
+ return;
+ }
+
// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
+
Value *SclrMask;
if (VectorWidth != 1) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
@@ -997,7 +1056,6 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
any_of(II->args(),
[](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
return false;
-
switch (II->getIntrinsicID()) {
default:
break;
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll
index 9b1c59829b9ffb..fffb5f021e52d4 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-load.ll
@@ -32,8 +32,8 @@ define <2 x i64> @scalarize_v2i64(ptr %p, <2 x i1> %mask, <2 x i64> %passthru) {
define <2 x i64> @scalarize_v2i64_ones_mask(ptr %p, <2 x i64> %passthru) {
; CHECK-LABEL: @scalarize_v2i64_ones_mask(
-; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
-; CHECK-NEXT: ret <2 x i64> [[TMP1]]
+; CHECK-NEXT: [[RET:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
+; CHECK-NEXT: ret <2 x i64> [[RET]]
;
%ret = call <2 x i64> @llvm.masked.load.v2i64.p0(ptr %p, i32 8, <2 x i1> <i1 true, i1 true>, <2 x i64> %passthru)
ret <2 x i64> %ret
@@ -58,34 +58,18 @@ define <2 x i64> @scalarize_v2i64_const_mask(ptr %p, <2 x i64> %passthru) {
ret <2 x i64> %ret
}
-; To be fixed: If the mask is the splat/broadcast of a non-constant value, use a
-; vector load
define <2 x i64> @scalarize_v2i64_splat_mask(ptr %p, i1 %mask, <2 x i64> %passthrough) {
; CHECK-LABEL: @scalarize_v2i64_splat_mask(
; CHECK-NEXT: [[MASK_VEC:%.*]] = insertelement <2 x i1> poison, i1 [[MASK:%.*]], i32 0
; CHECK-NEXT: [[MASK_SPLAT:%.*]] = shufflevector <2 x i1> [[MASK_VEC]], <2 x i1> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT: [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK_SPLAT]] to i2
-; CHECK-NEXT: [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT: br i1 [[TMP2]], label [[COND_LOAD:%.*]], label [[ELSE:%.*]]
+; CHECK-NEXT: [[MASK_SPLAT_FIRST:%.*]] = extractelement <2 x i1> [[MASK_SPLAT]], i64 0
+; CHECK-NEXT: br i1 [[MASK_SPLAT_FIRST]], label [[COND_LOAD:%.*]], label [[TMP1:%.*]]
; CHECK: cond.load:
-; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i32 0
-; CHECK-NEXT: [[TMP4:%.*]] = load i64, ptr [[TMP3]], align 8
-; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i64> [[PASSTHROUGH:%.*]], i64 [[TMP4]], i64 0
-; CHECK-NEXT: br label [[ELSE]]
-; CHECK: else:
-; CHECK-NEXT: [[RES_PHI_ELSE:%.*]] = phi <2 x i64> [ [[TMP5]], [[COND_LOAD]] ], [ [[PASSTHROUGH]], [[TMP0:%.*]] ]
-; CHECK-NEXT: [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT: [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
-; CHECK-NEXT: br i1 [[TMP7]], label [[COND_LOAD1:%.*]], label [[ELSE2:%.*]]
-; CHECK: cond.load1:
-; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[P]], i32 1
-; CHECK-NEXT: [[TMP9:%.*]] = load i64, ptr [[TMP8]], align 8
-; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x i64> [[RES_PHI_ELSE]], i64 [[TMP9]], i64 1
-; CHECK-NEXT: br label [[ELSE2]]
-; CHECK: else2:
-; CHECK-NEXT: [[RES_PHI_ELSE3:%.*]] = phi <2 x i64> [ [[TMP10]], [[COND_LOAD1]] ], [ [[RES_PHI_ELSE]], [[ELSE]] ]
-; CHECK-NEXT: ret <2 x i64> [[RES_PHI_ELSE3]]
+; CHECK-NEXT: [[RET_COND_LOAD:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
+; CHECK-NEXT: br label [[TMP1]]
+; CHECK: 1:
+; CHECK-NEXT: [[RET:%.*]] = phi <2 x i64> [ [[RET_COND_LOAD]], [[COND_LOAD]] ], [ [[PASSTHROUGH:%.*]], [[TMP0:%.*]] ]
+; CHECK-NEXT: ret <2 x i64> [[RET]]
;
%mask.vec = insertelement <2 x i1> poison, i1 %mask, i32 0
%mask.splat = shufflevector <2 x i1> %mask.vec, <2 x i1> poison, <2 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll
index cd2815e67e6720..4e3679dc5da99e 100644
--- a/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll
+++ b/llvm/test/Transforms/ScalarizeMaskedMemIntrin/X86/expand-masked-store.ll
@@ -56,31 +56,16 @@ define void @scalarize_v2i64_const_mask(ptr %p, <2 x i64> %data) {
ret void
}
-; To be fixed: If the mask is the splat/broadcast of a non-constant value, use a
-; vector store
define void @scalarize_v2i64_splat_mask(ptr %p, <2 x i64> %data, i1 %mask) {
; CHECK-LABEL: @scalarize_v2i64_splat_mask(
; CHECK-NEXT: [[MASK_VEC:%.*]] = insertelement <2 x i1> poison, i1 [[MASK:%.*]], i32 0
; CHECK-NEXT: [[MASK_SPLAT:%.*]] = shufflevector <2 x i1> [[MASK_VEC]], <2 x i1> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT: [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK_SPLAT]] to i2
-; CHECK-NEXT: [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
-; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
-; CHECK-NEXT: br i1 [[TMP2]], label [[COND_STORE:%.*]], label [[ELSE:%.*]]
+; CHECK-NEXT: [[MASK_SPLAT_FIRST:%.*]] = extractelement <2 x i1> [[MASK_SPLAT]], i64 0
+; CHECK-NEXT: br i1 [[MASK_SPLAT_FIRST]], label [[COND_STORE:%.*]], label [[TMP1:%.*]]
; CHECK: cond.store:
-; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i64> [[DATA:%.*]], i64 0
-; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i32 0
-; CHECK-NEXT: store i64 [[TMP3]], ptr [[TMP4]], align 8
-; CHECK-NEXT: br label [[ELSE]]
-; CHECK: else:
-; CHECK-NEXT: [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
-; CHECK-NEXT: [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
-; CHECK-NEXT: br i1 [[TMP6]], label [[COND_STORE1:%.*]], label [[ELSE2:%.*]]
-; CHECK: cond.store1:
-; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i64> [[DATA]], i64 1
-; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[P]], i32 1
-; CHECK-NEXT: store i64 [[TMP7]], ptr [[TMP8]], align 8
-; CHECK-NEXT: br label [[ELSE2]]
-; CHECK: else2:
+; CHECK-NEXT: store <2 x i64> [[DATA:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT: br label [[TMP1]]
+; CHECK: 1:
; CHECK-NEXT: ret void
;
%mask.vec = insertelement <2 x i1> poison, i1 %mask, i32 0
More information about the llvm-commits
mailing list