[PATCH] D141693: [AArch64] turn extended vecreduce bigger than v16i8 into udot/sdot

Dave Green via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 16 07:32:29 PST 2023


dmgreen added a comment.

Do you have any tests for cases that are a multiple of 8 but not of 16, like `<24 x ...`? And can you make sure we have the `load <4 x i8>` test case?



================
Comment at: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp:15257
+    SDValue Op0 =
+        DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, A.getOperand(0),
+                    DAG.getConstant(I * 16, DL, MVT::i64));
----------------
Some of these types/constants might be incorrect for multiples of 8?


================
Comment at: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp:15226
+  bool IsValidSize = Op0VT.getScalarSizeInBits() == 8;
+  if (Op0VT != MVT::v8i8 && !IsValidElementCount && !IsValidSize)
     return SDValue();
----------------
dmgreen wrote:
> zjaffal wrote:
> > dmgreen wrote:
> > > I think this should be something like !(IsValidElementCount && IsValidSize).
> > > It is worth adding a v4i8 test if one doesn't exist already:
> > > ```
> > > define i32 @src(ptr %p, i32 %b) {
> > > entry:
> > >   %a64 = load <4 x i8>, ptr %p
> > >   %a65 = sext <4 x i8> %a64 to <4 x i32>
> > >   %a66 = mul nsw <4 x i32> %a65, %a65
> > >   %a67 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a66)
> > >   %a = add i32 %a67, %b
> > >   ret i32 %a
> > > }
> > > ```
> > >I think this should be something like !(IsValidElementCount && IsValidSize).
> > but then we won't cover the case for v8i8 or shall we change is validElementCount to be
> > `Op0VT.getVectorNumElements() % 16 == 0; || Op0VT.getVectorNumElements() % 8 == 0;`
> Sorry - I meant with the `Op0VT != MVT::v8i8` too. The condition as written here will bail out if both !IsValidElementCount and !IsValidSize, but it seems like it should be bailing if one of them is false. So:
> ```
> if (Op0VT != MVT::v8i8 && (!IsValidElementCount || !IsValidSize))
> ```
> It could also do `bool IsValidElementCount = Op0VT == MVT::v8i8 || Op0VT.getVectorNumElements() % 16 == 0;` and then check that `if (!IsValidElementCount || !IsValidSize)`, if you think that is cleaner.
Using `Op0VT.getVectorNumElements() % 16 == 0 || Op0VT.getVectorNumElements() == 8;` would be simpler if you did not care about `<24 x ` types (which might be better as a 16+8, not 3*8).


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D141693/new/

https://reviews.llvm.org/D141693



More information about the llvm-commits mailing list