[llvm] 0be6fd4 - [SDAG] Use MMO flags in MemSDNode folding

Stanislav Mekhanoshin via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 9 14:44:11 PST 2022


Author: Stanislav Mekhanoshin
Date: 2022-03-09T14:25:22-08:00
New Revision: 0be6fd44f363d5756a3c796e1875567a037a0920

URL: https://github.com/llvm/llvm-project/commit/0be6fd44f363d5756a3c796e1875567a037a0920
DIFF: https://github.com/llvm/llvm-project/commit/0be6fd44f363d5756a3c796e1875567a037a0920.diff

LOG: [SDAG] Use MMO flags in MemSDNode folding

SDNodes with different target flags may now be folded together
rightfully resulting in the assertion in the refineAlignment.
Folding nodes with different target flags may result in the
wrong load instructions produced at least on the AMDGPU.

Fixes: SWDEV-326805

Differential Revision: https://reviews.llvm.org/D121335

Added: 
    llvm/test/CodeGen/AMDGPU/mmo-target-flags-folding.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2f8fd513e586e..59fc217713614 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -714,6 +714,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(LD->getMemoryVT().getRawBits());
     ID.AddInteger(LD->getRawSubclassData());
     ID.AddInteger(LD->getPointerInfo().getAddrSpace());
+    ID.AddInteger(LD->getMemOperand()->getFlags());
     break;
   }
   case ISD::STORE: {
@@ -721,6 +722,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(ST->getMemoryVT().getRawBits());
     ID.AddInteger(ST->getRawSubclassData());
     ID.AddInteger(ST->getPointerInfo().getAddrSpace());
+    ID.AddInteger(ST->getMemOperand()->getFlags());
     break;
   }
   case ISD::VP_LOAD: {
@@ -728,6 +730,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(ELD->getMemoryVT().getRawBits());
     ID.AddInteger(ELD->getRawSubclassData());
     ID.AddInteger(ELD->getPointerInfo().getAddrSpace());
+    ID.AddInteger(ELD->getMemOperand()->getFlags());
     break;
   }
   case ISD::VP_STORE: {
@@ -735,6 +738,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(EST->getMemoryVT().getRawBits());
     ID.AddInteger(EST->getRawSubclassData());
     ID.AddInteger(EST->getPointerInfo().getAddrSpace());
+    ID.AddInteger(EST->getMemOperand()->getFlags());
     break;
   }
   case ISD::VP_GATHER: {
@@ -742,6 +746,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(EG->getMemoryVT().getRawBits());
     ID.AddInteger(EG->getRawSubclassData());
     ID.AddInteger(EG->getPointerInfo().getAddrSpace());
+    ID.AddInteger(EG->getMemOperand()->getFlags());
     break;
   }
   case ISD::VP_SCATTER: {
@@ -749,6 +754,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(ES->getMemoryVT().getRawBits());
     ID.AddInteger(ES->getRawSubclassData());
     ID.AddInteger(ES->getPointerInfo().getAddrSpace());
+    ID.AddInteger(ES->getMemOperand()->getFlags());
     break;
   }
   case ISD::MLOAD: {
@@ -756,6 +762,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(MLD->getMemoryVT().getRawBits());
     ID.AddInteger(MLD->getRawSubclassData());
     ID.AddInteger(MLD->getPointerInfo().getAddrSpace());
+    ID.AddInteger(MLD->getMemOperand()->getFlags());
     break;
   }
   case ISD::MSTORE: {
@@ -763,6 +770,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(MST->getMemoryVT().getRawBits());
     ID.AddInteger(MST->getRawSubclassData());
     ID.AddInteger(MST->getPointerInfo().getAddrSpace());
+    ID.AddInteger(MST->getMemOperand()->getFlags());
     break;
   }
   case ISD::MGATHER: {
@@ -770,6 +778,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(MG->getMemoryVT().getRawBits());
     ID.AddInteger(MG->getRawSubclassData());
     ID.AddInteger(MG->getPointerInfo().getAddrSpace());
+    ID.AddInteger(MG->getMemOperand()->getFlags());
     break;
   }
   case ISD::MSCATTER: {
@@ -777,6 +786,7 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(MS->getMemoryVT().getRawBits());
     ID.AddInteger(MS->getRawSubclassData());
     ID.AddInteger(MS->getPointerInfo().getAddrSpace());
+    ID.AddInteger(MS->getMemOperand()->getFlags());
     break;
   }
   case ISD::ATOMIC_CMP_SWAP:
@@ -799,11 +809,13 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
     ID.AddInteger(AT->getMemoryVT().getRawBits());
     ID.AddInteger(AT->getRawSubclassData());
     ID.AddInteger(AT->getPointerInfo().getAddrSpace());
+    ID.AddInteger(AT->getMemOperand()->getFlags());
     break;
   }
   case ISD::PREFETCH: {
     const MemSDNode *PF = cast<MemSDNode>(N);
     ID.AddInteger(PF->getPointerInfo().getAddrSpace());
+    ID.AddInteger(PF->getMemOperand()->getFlags());
     break;
   }
   case ISD::VECTOR_SHUFFLE: {
@@ -823,9 +835,13 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
   }
   } // end switch (N->getOpcode())
 
-  // Target specific memory nodes could also have address spaces to check.
-  if (N->isTargetMemoryOpcode())
-    ID.AddInteger(cast<MemSDNode>(N)->getPointerInfo().getAddrSpace());
+  // Target specific memory nodes could also have address spaces and flags
+  // to check.
+  if (N->isTargetMemoryOpcode()) {
+    const MemSDNode *MN = cast<MemSDNode>(N);
+    ID.AddInteger(MN->getPointerInfo().getAddrSpace());
+    ID.AddInteger(MN->getMemOperand()->getFlags());
+  }
 }
 
 /// AddNodeIDNode - Generic routine for adding a nodes info to the NodeID
@@ -7315,6 +7331,7 @@ SDValue SelectionDAG::getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT,
   ID.AddInteger(MemVT.getRawBits());
   AddNodeIDNode(ID, Opcode, VTList, Ops);
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void* IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<AtomicSDNode>(E)->refineAlignment(MMO);
@@ -7427,6 +7444,7 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl,
     ID.AddInteger(getSyntheticNodeSubclassData<MemIntrinsicSDNode>(
         Opcode, dl.getIROrder(), VTList, MemVT, MMO));
     ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+    ID.AddInteger(MMO->getFlags());
     void *IP = nullptr;
     if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
       cast<MemIntrinsicSDNode>(E)->refineAlignment(MMO);
@@ -7599,6 +7617,7 @@ SDValue SelectionDAG::getLoad(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType,
   ID.AddInteger(getSyntheticNodeSubclassData<LoadSDNode>(
       dl.getIROrder(), VTs, AM, ExtType, MemVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<LoadSDNode>(E)->refineAlignment(MMO);
@@ -7700,6 +7719,7 @@ SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
   ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>(
       dl.getIROrder(), VTs, ISD::UNINDEXED, false, VT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<StoreSDNode>(E)->refineAlignment(MMO);
@@ -7766,6 +7786,7 @@ SDValue SelectionDAG::getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val,
   ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>(
       dl.getIROrder(), VTs, ISD::UNINDEXED, true, SVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<StoreSDNode>(E)->refineAlignment(MMO);
@@ -7794,6 +7815,7 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl,
   ID.AddInteger(ST->getMemoryVT().getRawBits());
   ID.AddInteger(ST->getRawSubclassData());
   ID.AddInteger(ST->getPointerInfo().getAddrSpace());
+  ID.AddInteger(ST->getMemOperand()->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
     return SDValue(E, 0);
@@ -7851,6 +7873,7 @@ SDValue SelectionDAG::getLoadVP(ISD::MemIndexedMode AM,
   ID.AddInteger(getSyntheticNodeSubclassData<VPLoadSDNode>(
       dl.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<VPLoadSDNode>(E)->refineAlignment(MMO);
@@ -7943,6 +7966,7 @@ SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val,
   ID.AddInteger(getSyntheticNodeSubclassData<VPStoreSDNode>(
       dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<VPStoreSDNode>(E)->refineAlignment(MMO);
@@ -8013,6 +8037,7 @@ SDValue SelectionDAG::getTruncStoreVP(SDValue Chain, const SDLoc &dl,
   ID.AddInteger(getSyntheticNodeSubclassData<VPStoreSDNode>(
       dl.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<VPStoreSDNode>(E)->refineAlignment(MMO);
@@ -8043,6 +8068,7 @@ SDValue SelectionDAG::getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl,
   ID.AddInteger(ST->getMemoryVT().getRawBits());
   ID.AddInteger(ST->getRawSubclassData());
   ID.AddInteger(ST->getPointerInfo().getAddrSpace());
+  ID.AddInteger(ST->getMemOperand()->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
     return SDValue(E, 0);
@@ -8070,6 +8096,7 @@ SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl,
   ID.AddInteger(getSyntheticNodeSubclassData<VPGatherSDNode>(
       dl.getIROrder(), VTs, VT, MMO, IndexType));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<VPGatherSDNode>(E)->refineAlignment(MMO);
@@ -8113,6 +8140,7 @@ SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl,
   ID.AddInteger(getSyntheticNodeSubclassData<VPScatterSDNode>(
       dl.getIROrder(), VTs, VT, MMO, IndexType));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<VPScatterSDNode>(E)->refineAlignment(MMO);
@@ -8162,6 +8190,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
   ID.AddInteger(getSyntheticNodeSubclassData<MaskedLoadSDNode>(
       dl.getIROrder(), VTs, AM, ExtTy, isExpanding, MemVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<MaskedLoadSDNode>(E)->refineAlignment(MMO);
@@ -8209,6 +8238,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
   ID.AddInteger(getSyntheticNodeSubclassData<MaskedStoreSDNode>(
       dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<MaskedStoreSDNode>(E)->refineAlignment(MMO);
@@ -8250,6 +8280,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl,
   ID.AddInteger(getSyntheticNodeSubclassData<MaskedGatherSDNode>(
       dl.getIROrder(), VTs, MemVT, MMO, IndexType, ExtTy));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<MaskedGatherSDNode>(E)->refineAlignment(MMO);
@@ -8297,6 +8328,7 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl,
   ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>(
       dl.getIROrder(), VTs, MemVT, MMO, IndexType, IsTrunc));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);

diff  --git a/llvm/test/CodeGen/AMDGPU/mmo-target-flags-folding.ll b/llvm/test/CodeGen/AMDGPU/mmo-target-flags-folding.ll
new file mode 100644
index 0000000000000..3746810ac942e
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/mmo-target-flags-folding.ll
@@ -0,0 +1,26 @@
+; RUN: llc -march=amdgcn -mcpu=gfx900 < %s | FileCheck --check-prefix=GCN %s
+
+; This is used to crash due to mismatch of MMO target flags when folding
+; a LOAD SDNodes with 
diff erent flags.
+
+; GCN-LABEL: {{^}}test_load_folding_mmo_flags:
+; GCN: global_load_dwordx2
+define amdgpu_kernel void @test_load_folding_mmo_flags(<2 x float> addrspace(1)* %arg) {
+entry:
+  %id = tail call i32 @llvm.amdgcn.workitem.id.x()
+  %arrayidx = getelementptr inbounds <2 x float>, <2 x float> addrspace(1)* %arg, i32 %id
+  %i1 = bitcast <2 x float> addrspace(1)* %arrayidx to i64 addrspace(1)*
+  %i2 = getelementptr <2 x float>, <2 x float> addrspace(1)* %arrayidx, i64 0, i32 0
+  %i3 = load float, float addrspace(1)* %i2, align 4
+  %idx = getelementptr inbounds <2 x float>, <2 x float> addrspace(1)* %arrayidx, i64 0, i32 1
+  %i4 = load float, float addrspace(1)* %idx, align 4
+  %i5 = load i64, i64 addrspace(1)* %i1, align 4, !amdgpu.noclobber !0
+  store i64 %i5, i64 addrspace(1)* undef, align 4
+  %mul = fmul float %i3, %i4
+  store float %mul, float addrspace(1)* undef, align 4
+  unreachable
+}
+
+declare i32 @llvm.amdgcn.workitem.id.x()
+
+!0 = !{}


        


More information about the llvm-commits mailing list