[llvm] [AArch64] Improve lowering of truncating build vectors (PR #81960)

David Green via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 17 00:56:50 PST 2024


================
@@ -11369,54 +11369,105 @@ static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
   return true;
 }
 
-// Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
-// v4i32s. This is really a truncate, which we can construct out of (legal)
-// concats and truncate nodes.
-static SDValue ReconstructTruncateFromBuildVector(SDValue V, SelectionDAG &DAG) {
-  if (V.getValueType() != MVT::v16i8)
-    return SDValue();
-  assert(V.getNumOperands() == 16 && "Expected 16 operands on the BUILDVECTOR");
-
-  for (unsigned X = 0; X < 4; X++) {
-    // Check the first item in each group is an extract from lane 0 of a v4i32
-    // or v4i16.
-    SDValue BaseExt = V.getOperand(X * 4);
-    if (BaseExt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-        (BaseExt.getOperand(0).getValueType() != MVT::v4i16 &&
-         BaseExt.getOperand(0).getValueType() != MVT::v4i32) ||
-        !isa<ConstantSDNode>(BaseExt.getOperand(1)) ||
-        BaseExt.getConstantOperandVal(1) != 0)
+// Detect patterns like a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3, that
+// are truncates, which we can construct out of (legal) concats and truncate
+// nodes.
+static SDValue ReconstructTruncateFromBuildVector(SDValue V,
+                                                  SelectionDAG &DAG) {
+  EVT BVTy = V.getValueType();
+  if (BVTy != MVT::v16i8 && BVTy != MVT::v8i16 && BVTy != MVT::v8i8 &&
+      BVTy != MVT::v4i16)
+    return SDValue();
+
+  // Only handle truncating BVs.
+  if (V.getOperand(0).getValueType().getSizeInBits() ==
+      BVTy.getScalarSizeInBits())
+    return SDValue();
+
+  SmallVector<SDValue, 4> Sources;
+  uint64_t LastIdx = 0;
+  uint64_t MaxIdx = 0;
+  // Check for sequential indices e.g. i=0, i+1, ..., i=0, i+1, ...
+  for (SDValue Extr : V->ops()) {
+    SDValue SourceVec = Extr.getOperand(0);
+    EVT SourceVecTy = SourceVec.getValueType();
+
+    if (!DAG.getTargetLoweringInfo().isTypeLegal(SourceVecTy))
       return SDValue();
-    SDValue Base = BaseExt.getOperand(0);
-    // And check the other items are extracts from the same vector.
-    for (unsigned Y = 1; Y < 4; Y++) {
-      SDValue Ext = V.getOperand(X * 4 + Y);
-      if (Ext.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-          Ext.getOperand(0) != Base ||
-          !isa<ConstantSDNode>(Ext.getOperand(1)) ||
-          Ext.getConstantOperandVal(1) != Y)
+    if (!isa<ConstantSDNode>(Extr.getOperand(1)))
+      return SDValue();
+
+    uint64_t CurIdx = Extr.getConstantOperandVal(1);
+    // Allow repeat of sources.
+    if (CurIdx == 0) {
+      // Check if all lanes are used by the BV.
+      if (Sources.size() && Sources[Sources.size() - 1]
+                                    .getValueType()
+                                    .getVectorMinNumElements() != LastIdx + 1)
         return SDValue();
-    }
+      Sources.push_back(SourceVec);
+    } else if (CurIdx != LastIdx + 1)
+      return SDValue();
+
+    LastIdx = CurIdx;
+    MaxIdx = std::max(MaxIdx, CurIdx);
   }
 
-  // Turn the buildvector into a series of truncates and concates, which will
-  // become uzip1's. Any v4i32s we found get truncated to v4i16, which are
-  // concat together to produce 2 v8i16. These are both truncated and concat
-  // together.
+  // Check if all lanes are used by the BV.
+  if (Sources[Sources.size() - 1].getValueType().getVectorMinNumElements() !=
+      LastIdx + 1)
+    return SDValue();
+  if (Sources.size() % 2 != 0)
----------------
davemgreen wrote:

Is it possible for Sources to be 6, or for the given types is it always <= 4?

https://github.com/llvm/llvm-project/pull/81960


More information about the llvm-commits mailing list