[llvm] [AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot (PR #107566)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 07:05:09 PDT 2024


================
@@ -21824,37 +21824,59 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
   auto ExtA = MulOp->getOperand(0);
   auto ExtB = MulOp->getOperand(1);
-  bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
-  if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
-    return SDValue();
-
   auto A = ExtA->getOperand(0);
   auto B = ExtB->getOperand(0);
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  unsigned Opcode = 0;
-
-  if (IsSExt)
-    Opcode = AArch64ISD::SDOT;
-  else if (IsZExt)
-    Opcode = AArch64ISD::UDOT;
-
-  assert(Opcode != 0 && "Unexpected dot product case encountered.");
-
   EVT ReducedType = N->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
-      (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
-      (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
-      (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
-    return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+  if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+      !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+      !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
+      !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+    return SDValue();
 
-  return SDValue();
+  bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+  bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+  bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
+  if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
+    return SDValue();
+
+  // If the extensions are mixed, we should lower it to a usdot instead
+  if (AIsZExt != BIsZExt) {
+    if (!Subtarget->hasMatMulInt8())
+      return SDValue();
+    bool Scalable = N->getValueType(0).isScalableVT();
+
+    // There's no nxv2i64 version of usdot
+    if (Scalable && ReducedType != MVT::nxv4i32)
+      return SDValue();
+
+    unsigned IntrinsicID =
+        Scalable ? Intrinsic::aarch64_sve_usdot : Intrinsic::aarch64_neon_usdot;
+    // USDOT expects the first operand to be unsigned, so swap the operands if
----------------
SamTebbs33 wrote:

Yeah I agree this can be done separately.

https://github.com/llvm/llvm-project/pull/107566


More information about the llvm-commits mailing list