[PATCH] D116039: [X86] Combine reduce (add (mul x, y)) to VNNI instruction.

LuoYuanke via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 28 19:13:36 PST 2021


LuoYuanke added inline comments.


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41864
+      MVT::getVectorVT(MVT::i8, Ext0.getValueType().getVectorElementCount());
+  if (Ext0.getOperand(0).getValueType().getScalarType() != MVT::i8)
+    ZExt0 = DAG.getZExtOrTrunc(Ext0.getOperand(0), DL, Vi8VT);
----------------
craig.topper wrote:
> Can we just TRUNCATE the Ext nodes without assuming they are extend nodes. That way it just works when you support constants in the future?
> Can we just TRUNCATE the Ext nodes without assuming they are extend nodes. That way it just works when you support constants in the future?

Good idea. :) I'll update my patch.


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41870
+  // Find the appropriate width for the DotProduct.
+  EVT InVT = (ZExt0 == Ext0) ? Ext0.getOperand(0).getValueType():
+                               ZExt0.getValueType();
----------------
craig.topper wrote:
> Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
> Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.




================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41870
+  // Find the appropriate width for the DotProduct.
+  EVT InVT = (ZExt0 == Ext0) ? Ext0.getOperand(0).getValueType():
+                               ZExt0.getValueType();
----------------
LuoYuanke wrote:
> craig.topper wrote:
> > Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
> > Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
> 
> 
> Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.

Yes, that's better.


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:42164
+  // Verify the type we're extracting is i32, as the output element type of
+  // vpdpbusd and vpdpwssd is i32.
+  if (ExtractVT != MVT::i32)
----------------
craig.topper wrote:
> This code isn't handling vpdpwssd so why mention it here?
> This code isn't handling vpdpwssd so why mention it here?

My original code covers both vpdpbusd and vpdpwssd. I'll clean it.


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:42176
+
+  if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND ||
+               Root.getOpcode() == ISD::ZERO_EXTEND ||
----------------
craig.topper wrote:
> craig.topper wrote:
> > craig.topper wrote:
> > > Is this code valid for this transform? There's a large comment of justification for why it is ok for SAD. I think I only saw a test for the SIGN_EXTEND case?
> > Oops I see the other test. I need to think about the math.
> I don't think we can do this if the multiply result is zero extended. Each of the 4 multiplies done by vpdpbusd compute a signed 16-bit product that will be sign extended before adding into the accumulator.
> 
> I think we also need to verify that the multiply has at least 2x the number of bits of the input. We shouldn't match (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). Does anything prevent that right now?
> I don't think we can do this if the multiply result is zero extended. Each of the 4 multiplies done by vpdpbusd compute a signed 16-bit product that will be sign extended before adding into the accumulator.
> 
> I think we also need to verify that the multiply has at least 2x the number of bits of the input. We shouldn't match (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). Does anything prevent that right now?

Really good catch. Thanks.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D116039/new/

https://reviews.llvm.org/D116039



More information about the llvm-commits mailing list