[llvm] [X86][CodeGen] Add IsStore parameter to hasConditionalLoadStoreForType (PR #132153)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 19 23:31:56 PDT 2025


https://github.com/phoebewang updated https://github.com/llvm/llvm-project/pull/132153

>From f6cd49693df779e43a82cc8022f27a63f82461eb Mon Sep 17 00:00:00 2001
From: "Wang, Phoebe" <phoebe.wang at intel.com>
Date: Thu, 20 Mar 2025 14:24:33 +0800
Subject: [PATCH] [X86][CodeGen] Add IsStore parameter to
 hasConditionalLoadStoreForType

Address https://github.com/llvm/llvm-project/pull/132032#issuecomment-2736936769
---
 .../llvm/Analysis/TargetTransformInfo.h       |  8 ++---
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  6 ++--
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  5 +--
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  8 ++---
 .../lib/Target/X86/X86TargetTransformInfo.cpp | 34 +++++++++++++------
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |  2 +-
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     |  4 ++-
 7 files changed, 43 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7ae4913d7c037..99e21aca97631 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1162,7 +1162,7 @@ class TargetTransformInfo {
 
   /// \return true if the target supports load/store that enables fault
   /// suppression of memory operands when the source condition is false.
-  bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
+  bool hasConditionalLoadStoreForType(Type *Ty, bool IsStore) const;
 
   /// \return the target-provided register class ID for the provided type,
   /// accounting for type promotion and other type-legalization techniques that
@@ -2107,7 +2107,7 @@ class TargetTransformInfo::Concept {
   virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
                                              const Function &Fn) const = 0;
   virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
-  virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
+  virtual bool hasConditionalLoadStoreForType(Type *Ty, bool IsStore) const = 0;
   virtual unsigned getRegisterClassForType(bool Vector,
                                            Type *Ty = nullptr) const = 0;
   virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2770,8 +2770,8 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getNumberOfRegisters(unsigned ClassID) const override {
     return Impl.getNumberOfRegisters(ClassID);
   }
-  bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
-    return Impl.hasConditionalLoadStoreForType(Ty);
+  bool hasConditionalLoadStoreForType(Type *Ty, bool IsStore) const override {
+    return Impl.hasConditionalLoadStoreForType(Ty, IsStore);
   }
   unsigned getRegisterClassForType(bool Vector,
                                    Type *Ty = nullptr) const override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 31ab919080744..745758426c714 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -504,11 +504,13 @@ class TargetTransformInfoImplBase {
   }
 
   unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
-  bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }
+  bool hasConditionalLoadStoreForType(Type *Ty, bool IsStore) const {
+    return false;
+  }
 
   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
     return Vector ? 1 : 0;
-  };
+  }
 
   const char *getRegisterClassName(unsigned ClassID) const {
     switch (ClassID) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index bd1312d8c2d0b..b57624ba811d3 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -759,8 +759,9 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
   return TTIImpl->getNumberOfRegisters(ClassID);
 }
 
-bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
-  return TTIImpl->hasConditionalLoadStoreForType(Ty);
+bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty,
+                                                         bool IsStore) const {
+  return TTIImpl->hasConditionalLoadStoreForType(Ty, IsStore);
 }
 
 unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 14bb1d943d2d6..988efa54ecade 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4790,8 +4790,8 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
   const auto &TTI =
       TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
   SDValue StoreNode =
-      !IsCompressing &&
-              TTI.hasConditionalLoadStoreForType(I.getArgOperand(0)->getType())
+      !IsCompressing && TTI.hasConditionalLoadStoreForType(
+                            I.getArgOperand(0)->getType(), /*IsStore=*/true)
           ? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
                                  Mask)
           : DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
@@ -4976,8 +4976,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
   // variables.
   SDValue Load;
   SDValue Res;
-  if (!IsExpanding &&
-      TTI.hasConditionalLoadStoreForType(Src0Operand->getType()))
+  if (!IsExpanding && TTI.hasConditionalLoadStoreForType(Src0Operand->getType(),
+                                                         /*IsStore=*/false))
     Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
   else
     Res = Load =
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index fb0d6d1ffa414..8bee87a22db16 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -175,7 +175,7 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
   return 8;
 }
 
-bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
+bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty, bool IsStore) const {
   if (!ST->hasCF())
     return false;
   if (!Ty)
@@ -6229,13 +6229,7 @@ bool X86TTIImpl::canMacroFuseCmp() {
   return ST->hasMacroFusion() || ST->hasBranchFusion();
 }
 
-bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
-  Type *ScalarTy = DataTy->getScalarType();
-
-  // The backend can't handle a single element vector w/o CFCMOV.
-  if (isa<VectorType>(DataTy) && cast<FixedVectorType>(DataTy)->getNumElements() == 1)
-    return ST->hasCF() && hasConditionalLoadStoreForType(ScalarTy);
-
+static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) {
   if (!ST->hasAVX())
     return false;
 
@@ -6259,8 +6253,28 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
          ((IntWidth == 8 || IntWidth == 16) && ST->hasBWI());
 }
 
-bool X86TTIImpl::isLegalMaskedStore(Type *DataType, Align Alignment) {
-  return isLegalMaskedLoad(DataType, Alignment);
+bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+  Type *ScalarTy = DataTy->getScalarType();
+
+  // The backend can't handle a single element vector w/o CFCMOV.
+  if (isa<VectorType>(DataTy) &&
+      cast<FixedVectorType>(DataTy)->getNumElements() == 1)
+    return ST->hasCF() &&
+           hasConditionalLoadStoreForType(ScalarTy, /*IsStore=*/false);
+
+  return isLegalMaskedLoadStore(ScalarTy, ST);
+}
+
+bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment) {
+  Type *ScalarTy = DataTy->getScalarType();
+
+  // The backend can't handle a single element vector w/o CFCMOV.
+  if (isa<VectorType>(DataTy) &&
+      cast<FixedVectorType>(DataTy)->getNumElements() == 1)
+    return ST->hasCF() &&
+           hasConditionalLoadStoreForType(ScalarTy, /*IsStore=*/true);
+
+  return isLegalMaskedLoadStore(ScalarTy, ST);
 }
 
 bool X86TTIImpl::isLegalNTLoad(Type *DataType, Align Alignment) {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index c916da7f275d7..9a427d4388d0b 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -132,7 +132,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   /// @{
 
   unsigned getNumberOfRegisters(unsigned ClassID) const;
-  bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
+  bool hasConditionalLoadStoreForType(Type *Ty, bool IsStore) const;
   TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
   unsigned getLoadStoreVecRegBitWidth(unsigned AS) const;
   unsigned getMaxInterleaveFactor(ElementCount VF);
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 1e4342c65494d..2de966e00542d 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1768,19 +1768,21 @@ static void hoistConditionalLoadsStores(
 static bool isSafeCheapLoadStore(const Instruction *I,
                                  const TargetTransformInfo &TTI) {
   // Not handle volatile or atomic.
+  bool IsStore = false;
   if (auto *L = dyn_cast<LoadInst>(I)) {
     if (!L->isSimple())
       return false;
   } else if (auto *S = dyn_cast<StoreInst>(I)) {
     if (!S->isSimple())
       return false;
+    IsStore = true;
   } else
     return false;
 
   // llvm.masked.load/store use i32 for alignment while load/store use i64.
   // That's why we have the alignment limitation.
   // FIXME: Update the prototype of the intrinsics?
-  return TTI.hasConditionalLoadStoreForType(getLoadStoreType(I)) &&
+  return TTI.hasConditionalLoadStoreForType(getLoadStoreType(I), IsStore) &&
          getLoadStoreAlignment(I) < Value::MaximumAlignment;
 }
 



More information about the llvm-commits mailing list