[llvm] 19e5235 - [AArch64][InstCombine] Fold MLOAD and zero extensions into MLOAD

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 6 05:54:16 PDT 2022


Author: zhongyunde
Date: 2022-04-06T20:50:42+08:00
New Revision: 19e523514714688e4eb6588925eab079997e21a6

URL: https://github.com/llvm/llvm-project/commit/19e523514714688e4eb6588925eab079997e21a6
DIFF: https://github.com/llvm/llvm-project/commit/19e523514714688e4eb6588925eab079997e21a6.diff

LOG: [AArch64][InstCombine] Fold MLOAD and zero extensions into MLOAD

Accord the discussion in D122281, we missing an ISD::AND combine for MLOAD
because it relies on BuildVectorSDNode is fails for scalable vectors.
This patch is intend to handle that, so we can circle back the type MVT::nxv2i32

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D122703

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bd38e1669144a..7bf89ca1580d4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -6067,27 +6067,25 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
       return N0;
 
-    // fold (and (masked_load) (build_vec (x, ...))) to zext_masked_load
+    // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
-    auto *BVec = dyn_cast<BuildVectorSDNode>(N1);
-    if (MLoad && BVec && MLoad->getExtensionType() == ISD::EXTLOAD &&
-        N0.hasOneUse() && N1.hasOneUse()) {
+    ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
+    if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && N0.hasOneUse() &&
+        Splat && N1.hasOneUse()) {
       EVT LoadVT = MLoad->getMemoryVT();
       EVT ExtVT = VT;
       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
         // For this AND to be a zero extension of the masked load the elements
         // of the BuildVec must mask the bottom bits of the extended element
         // type
-        if (ConstantSDNode *Splat = BVec->getConstantSplatNode()) {
-          uint64_t ElementSize =
-              LoadVT.getVectorElementType().getScalarSizeInBits();
-          if (Splat->getAPIntValue().isMask(ElementSize)) {
-            return DAG.getMaskedLoad(
-                ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
-                MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
-                LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
-                ISD::ZEXTLOAD, MLoad->isExpandingLoad());
-          }
+        uint64_t ElementSize =
+            LoadVT.getVectorElementType().getScalarSizeInBits();
+        if (Splat->getAPIntValue().isMask(ElementSize)) {
+          return DAG.getMaskedLoad(
+              ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
+              MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
+              LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
+              ISD::ZEXTLOAD, MLoad->isExpandingLoad());
         }
       }
     }

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 68c8d73fbcb6f..a00d2a5661ce9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1230,7 +1230,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i16, Legal);
       setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i32, Legal);
       setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i8, Legal);
-      setLoadExtAction(Op, MVT::nxv2i32, MVT::nxv2i16, Legal);
       setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i16, Legal);
       setLoadExtAction(Op, MVT::nxv8i16, MVT::nxv8i8, Legal);
     }

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
index 66c4798f3059a..1ddcd16f1c121 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
@@ -97,7 +97,7 @@ define <vscale x 2 x double> @masked_zload_2i16_2f64(<vscale x 2 x i16>* noalias
 ; CHECK-LABEL: masked_zload_2i16_2f64:
 ; CHECK:       ld1h { z0.d }, p0/z, [x0]
 ; CHECK-NEXT:  ptrue p0.d
-; CHECK-NEXT:  ucvtf z0.d, p0/m, z0.s
+; CHECK-NEXT:  ucvtf z0.d, p0/m, z0.d
 ; CHECK-NEXT:  ret
   %wide.load = call <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>* %in, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef)
   %zext = zext <vscale x 2 x i16> %wide.load to <vscale x 2 x i32>


        


More information about the llvm-commits mailing list