[llvm] [RISCV] Improve performCONCAT_VECTORCombine stride matching (PR #68726)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 16 11:18:44 PDT 2023


================
@@ -13798,60 +13798,48 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
         Ld->getValueType(0) != BaseLdVT)
       return SDValue();
 
-    Ptrs.push_back(Ld->getBasePtr());
+    Lds.push_back(Ld);
 
     // The common alignment is the most restrictive (smallest) of all the loads
     Align = std::min(Align, Ld->getAlign());
   }
 
-  auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == 0)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add LastPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()-1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
-  };
-  auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == Ptrs.size() - 1)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add NextPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()+1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
+  using PtrDiff = std::pair<SDValue, bool>;
+  auto GetPtrDiff = [&DAG, &DL](LoadSDNode *Ld1,
+                                LoadSDNode *Ld2) -> std::optional<PtrDiff> {
+    // If the load ptrs can be decomposed into a common (Base + Index) with a
+    // common constant stride, then return the constant stride.
+    BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG);
+    BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG);
+    if (BIO1.equalBaseIndex(BIO2, DAG))
+      return {{DAG.getConstant(BIO2.getOffset() - BIO1.getOffset(), DL,
----------------
michaelmaitland wrote:

I think what you propose makes sense. I thought about trying to do it using std::function to defer creation of constant or negative, but that doesn't allow us to compare `getPtrDiff() != BaseDiff`. We need to do a comparison by value on three cases:

1. SDValue, true
2. SDValue, false
3. int64_t, false

And the pair<variant, bool> structure permits us to do this check.

https://github.com/llvm/llvm-project/pull/68726


More information about the llvm-commits mailing list