[llvm] [AArch64] Improve lowering of truncating build vectors (PR #81960)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Sat Feb 17 00:56:50 PST 2024
================
@@ -11369,54 +11369,105 @@ static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
return true;
}
-// Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
-// v4i32s. This is really a truncate, which we can construct out of (legal)
-// concats and truncate nodes.
-static SDValue ReconstructTruncateFromBuildVector(SDValue V, SelectionDAG &DAG) {
- if (V.getValueType() != MVT::v16i8)
- return SDValue();
- assert(V.getNumOperands() == 16 && "Expected 16 operands on the BUILDVECTOR");
-
- for (unsigned X = 0; X < 4; X++) {
- // Check the first item in each group is an extract from lane 0 of a v4i32
- // or v4i16.
- SDValue BaseExt = V.getOperand(X * 4);
- if (BaseExt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- (BaseExt.getOperand(0).getValueType() != MVT::v4i16 &&
- BaseExt.getOperand(0).getValueType() != MVT::v4i32) ||
- !isa<ConstantSDNode>(BaseExt.getOperand(1)) ||
- BaseExt.getConstantOperandVal(1) != 0)
+// Detect patterns like a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3, that
+// are truncates, which we can construct out of (legal) concats and truncate
+// nodes.
+static SDValue ReconstructTruncateFromBuildVector(SDValue V,
+ SelectionDAG &DAG) {
+ EVT BVTy = V.getValueType();
+ if (BVTy != MVT::v16i8 && BVTy != MVT::v8i16 && BVTy != MVT::v8i8 &&
+ BVTy != MVT::v4i16)
+ return SDValue();
+
+ // Only handle truncating BVs.
+ if (V.getOperand(0).getValueType().getSizeInBits() ==
+ BVTy.getScalarSizeInBits())
+ return SDValue();
+
+ SmallVector<SDValue, 4> Sources;
+ uint64_t LastIdx = 0;
+ uint64_t MaxIdx = 0;
+ // Check for sequential indices e.g. i=0, i+1, ..., i=0, i+1, ...
+ for (SDValue Extr : V->ops()) {
+ SDValue SourceVec = Extr.getOperand(0);
+ EVT SourceVecTy = SourceVec.getValueType();
+
+ if (!DAG.getTargetLoweringInfo().isTypeLegal(SourceVecTy))
return SDValue();
- SDValue Base = BaseExt.getOperand(0);
- // And check the other items are extracts from the same vector.
- for (unsigned Y = 1; Y < 4; Y++) {
- SDValue Ext = V.getOperand(X * 4 + Y);
- if (Ext.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- Ext.getOperand(0) != Base ||
- !isa<ConstantSDNode>(Ext.getOperand(1)) ||
- Ext.getConstantOperandVal(1) != Y)
+ if (!isa<ConstantSDNode>(Extr.getOperand(1)))
+ return SDValue();
+
+ uint64_t CurIdx = Extr.getConstantOperandVal(1);
+ // Allow repeat of sources.
+ if (CurIdx == 0) {
+ // Check if all lanes are used by the BV.
+ if (Sources.size() && Sources[Sources.size() - 1]
+ .getValueType()
+ .getVectorMinNumElements() != LastIdx + 1)
return SDValue();
- }
+ Sources.push_back(SourceVec);
+ } else if (CurIdx != LastIdx + 1)
+ return SDValue();
+
+ LastIdx = CurIdx;
+ MaxIdx = std::max(MaxIdx, CurIdx);
----------------
davemgreen wrote:
MaxIdx doesn't seem to be needed.
https://github.com/llvm/llvm-project/pull/81960
More information about the llvm-commits
mailing list