[llvm] AMDGPU: Don't bitcast float typed atomic store in IR (PR #90116)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 13:29:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Matt Arsenault (arsenm)

<details>
<summary>Changes</summary>

Implement the promotion in the DAG.

Depends #<!-- -->90113


---
Full diff: https://github.com/llvm/llvm-project/pull/90116.diff


5 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+20) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+20-9) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+34) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+2) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+13) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index aa746f1c7b7b3b..f115a39a6953ce 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -530,6 +530,7 @@ namespace {
     bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
 
     SDValue visitSTORE(SDNode *N);
+    SDValue visitATOMIC_STORE(SDNode *N);
     SDValue visitLIFETIME_END(SDNode *N);
     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
@@ -1909,6 +1910,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::BR_CC:              return visitBR_CC(N);
   case ISD::LOAD:               return visitLOAD(N);
   case ISD::STORE:              return visitSTORE(N);
+  case ISD::ATOMIC_STORE:       return visitATOMIC_STORE(N);
   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
@@ -21096,6 +21098,24 @@ SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
                       ST->getMemOperand()->getFlags());
 }
 
+SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
+  AtomicSDNode *ST = cast<AtomicSDNode>(N);
+  SDValue Val = ST->getVal();
+  EVT VT = Val.getValueType();
+  EVT MemVT = ST->getMemoryVT();
+
+  if (MemVT.bitsLT(VT)) { // Is truncating store
+    APInt TruncDemandedBits = APInt::getLowBitsSet(VT.getScalarSizeInBits(),
+                                                   MemVT.getScalarSizeInBits());
+    // See if we can simplify the operation with SimplifyDemandedBits, which
+    // only works if the value has a single use.
+    if (SimplifyDemandedBits(Val, TruncDemandedBits))
+      return SDValue(N, 0);
+  }
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitSTORE(SDNode *N) {
   StoreSDNode *ST  = cast<StoreSDNode>(N);
   SDValue Chain = ST->getChain();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 24f69ea1b742a6..e2a22cbeb18c53 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4946,7 +4946,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
       Node->getOpcode() == ISD::INSERT_VECTOR_ELT) {
     OVT = Node->getOperand(0).getSimpleValueType();
   }
-  if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
+  if (Node->getOpcode() == ISD::ATOMIC_STORE ||
+      Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
       Node->getOpcode() == ISD::STRICT_SINT_TO_FP ||
       Node->getOpcode() == ISD::STRICT_FSETCC ||
       Node->getOpcode() == ISD::STRICT_FSETCCS)
@@ -5557,7 +5558,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Results.push_back(CvtVec);
     break;
   }
-  case ISD::ATOMIC_SWAP: {
+  case ISD::ATOMIC_SWAP:
+  case ISD::ATOMIC_STORE: {
     AtomicSDNode *AM = cast<AtomicSDNode>(Node);
     SDLoc SL(Node);
     SDValue CastVal = DAG.getNode(ISD::BITCAST, SL, NVT, AM->getVal());
@@ -5566,13 +5568,22 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     assert(AM->getMemoryVT().getSizeInBits() == NVT.getSizeInBits() &&
            "unexpected atomic_swap with illegal type");
 
-    SDValue NewAtomic
-      = DAG.getAtomic(ISD::ATOMIC_SWAP, SL, NVT,
-                      DAG.getVTList(NVT, MVT::Other),
-                      { AM->getChain(), AM->getBasePtr(), CastVal },
-                      AM->getMemOperand());
-    Results.push_back(DAG.getNode(ISD::BITCAST, SL, OVT, NewAtomic));
-    Results.push_back(NewAtomic.getValue(1));
+    SDValue Op0 = AM->getBasePtr();
+    SDValue Op1 = CastVal;
+
+    // ATOMIC_STORE uses a swapped operand order from every other AtomicSDNode,
+    // but really it should merge with ISD::STORE.
+    if (AM->getOpcode() == ISD::ATOMIC_STORE)
+      std::swap(Op0, Op1);
+
+    SDValue NewAtomic = DAG.getAtomic(AM->getOpcode(), SL, NVT, AM->getChain(),
+                                      Op0, Op1, AM->getMemOperand());
+
+    if (AM->getOpcode() != ISD::ATOMIC_STORE) {
+      Results.push_back(DAG.getNode(ISD::BITCAST, SL, OVT, NewAtomic));
+      Results.push_back(NewAtomic.getValue(1));
+    } else
+      Results.push_back(NewAtomic);
     break;
   }
   case ISD::SPLAT_VECTOR: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 7685bc73cf9652..5f21f65cab6273 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2249,6 +2249,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
     case ISD::STORE:      R = PromoteFloatOp_STORE(N, OpNo); break;
+    case ISD::ATOMIC_STORE: R = PromoteFloatOp_ATOMIC_STORE(N, OpNo); break;
   }
   // clang-format on
 
@@ -2371,6 +2372,23 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_STORE(SDNode *N, unsigned OpNo) {
                       ST->getMemOperand());
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatOp_ATOMIC_STORE(SDNode *N,
+                                                      unsigned OpNo) {
+  AtomicSDNode *ST = cast<AtomicSDNode>(N);
+  SDValue Val = ST->getVal();
+  SDLoc DL(N);
+
+  SDValue Promoted = GetPromotedFloat(Val);
+  EVT VT = ST->getOperand(1).getValueType();
+  EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+
+  SDValue NewVal = DAG.getNode(GetPromotionOpcode(Promoted.getValueType(), VT),
+                               DL, IVT, Promoted);
+
+  return DAG.getAtomic(ISD::ATOMIC_STORE, DL, IVT, ST->getChain(), NewVal,
+                       ST->getBasePtr(), ST->getMemOperand());
+}
+
 //===----------------------------------------------------------------------===//
 //  Float Result Promotion
 //===----------------------------------------------------------------------===//
@@ -3154,6 +3172,9 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {
   case ISD::SELECT_CC:  Res = SoftPromoteHalfOp_SELECT_CC(N, OpNo); break;
   case ISD::SETCC:      Res = SoftPromoteHalfOp_SETCC(N); break;
   case ISD::STORE:      Res = SoftPromoteHalfOp_STORE(N, OpNo); break;
+  case ISD::ATOMIC_STORE:
+    Res = SoftPromoteHalfOp_ATOMIC_STORE(N, OpNo);
+    break;
   case ISD::STACKMAP:
     Res = SoftPromoteHalfOp_STACKMAP(N, OpNo);
     break;
@@ -3307,6 +3328,19 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_STORE(SDNode *N, unsigned OpNo) {
                       ST->getMemOperand());
 }
 
+SDValue DAGTypeLegalizer::SoftPromoteHalfOp_ATOMIC_STORE(SDNode *N,
+                                                         unsigned OpNo) {
+  assert(OpNo == 1 && "Can only soften the stored value!");
+  AtomicSDNode *ST = cast<AtomicSDNode>(N);
+  SDValue Val = ST->getVal();
+  SDLoc dl(N);
+
+  SDValue Promoted = GetSoftPromotedHalf(Val);
+  return DAG.getAtomic(ISD::ATOMIC_STORE, dl, Promoted.getValueType(),
+                       ST->getChain(), Promoted, ST->getBasePtr(),
+                       ST->getMemOperand());
+}
+
 SDValue DAGTypeLegalizer::SoftPromoteHalfOp_STACKMAP(SDNode *N, unsigned OpNo) {
   assert(OpNo > 1); // Because the first two arguments are guaranteed legal.
   SmallVector<SDValue> NewOps(N->ops().begin(), N->ops().end());
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 9c855e55855312..fad0067564bf4d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -708,6 +708,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_ATOMIC_STORE(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_SELECT_CC(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_SETCC(SDNode *N, unsigned OpNo);
 
@@ -751,6 +752,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SoftPromoteHalfOp_SETCC(SDNode *N);
   SDValue SoftPromoteHalfOp_SELECT_CC(SDNode *N, unsigned OpNo);
   SDValue SoftPromoteHalfOp_STORE(SDNode *N, unsigned OpNo);
+  SDValue SoftPromoteHalfOp_ATOMIC_STORE(SDNode *N, unsigned OpNo);
   SDValue SoftPromoteHalfOp_STACKMAP(SDNode *N, unsigned OpNo);
   SDValue SoftPromoteHalfOp_PATCHPOINT(SDNode *N, unsigned OpNo);
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index f4a747784d1fd2..9e78b5e367d864 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -148,6 +148,19 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::LOAD, MVT::i128, Promote);
   AddPromotedToType(ISD::LOAD, MVT::i128, MVT::v4i32);
 
+  setOperationAction(ISD::ATOMIC_STORE, MVT::f32, Promote);
+  AddPromotedToType(ISD::ATOMIC_STORE, MVT::f32, MVT::i32);
+
+  setOperationAction(ISD::ATOMIC_STORE, MVT::f64, Promote);
+  AddPromotedToType(ISD::ATOMIC_STORE, MVT::f64, MVT::i64);
+
+  setOperationAction(ISD::ATOMIC_STORE, MVT::f16, Promote);
+  AddPromotedToType(ISD::ATOMIC_STORE, MVT::f16, MVT::i16);
+
+  setOperationAction(ISD::ATOMIC_STORE, MVT::bf16, Promote);
+  AddPromotedToType(ISD::ATOMIC_STORE, MVT::bf16, MVT::i16);
+
+
   // There are no 64-bit extloads. These should be done as a 32-bit extload and
   // an extension to 64-bit.
   for (MVT VT : MVT::integer_valuetypes())

``````````

</details>


https://github.com/llvm/llvm-project/pull/90116


More information about the llvm-commits mailing list