[PATCH] D116039: [X86] Combine reduce (add (mul x, y)) to VNNI instruction.
LuoYuanke via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 28 19:13:36 PST 2021
LuoYuanke added inline comments.
================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41864
+ MVT::getVectorVT(MVT::i8, Ext0.getValueType().getVectorElementCount());
+ if (Ext0.getOperand(0).getValueType().getScalarType() != MVT::i8)
+ ZExt0 = DAG.getZExtOrTrunc(Ext0.getOperand(0), DL, Vi8VT);
----------------
craig.topper wrote:
> Can we just TRUNCATE the Ext nodes without assuming they are extend nodes. That way it just works when you support constants in the future?
> Can we just TRUNCATE the Ext nodes without assuming they are extend nodes. That way it just works when you support constants in the future?
Good idea. :) I'll update my patch.
================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41870
+ // Find the appropriate width for the DotProduct.
+ EVT InVT = (ZExt0 == Ext0) ? Ext0.getOperand(0).getValueType():
+ ZExt0.getValueType();
----------------
craig.topper wrote:
> Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
> Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41870
+ // Find the appropriate width for the DotProduct.
+ EVT InVT = (ZExt0 == Ext0) ? Ext0.getOperand(0).getValueType():
+ ZExt0.getValueType();
----------------
LuoYuanke wrote:
> craig.topper wrote:
> > Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
> > Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
>
>
> Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.
Yes, that's better.
================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:42164
+ // Verify the type we're extracting is i32, as the output element type of
+ // vpdpbusd and vpdpwssd is i32.
+ if (ExtractVT != MVT::i32)
----------------
craig.topper wrote:
> This code isn't handling vpdpwssd so why mention it here?
> This code isn't handling vpdpwssd so why mention it here?
My original code covers both vpdpbusd and vpdpwssd. I'll clean it.
================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:42176
+
+ if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND ||
+ Root.getOpcode() == ISD::ZERO_EXTEND ||
----------------
craig.topper wrote:
> craig.topper wrote:
> > craig.topper wrote:
> > > Is this code valid for this transform? There's a large comment of justification for why it is ok for SAD. I think I only saw a test for the SIGN_EXTEND case?
> > Oops I see the other test. I need to think about the math.
> I don't think we can do this if the multiply result is zero extended. Each of the 4 multiplies done by vpdpbusd compute a signed 16-bit product that will be sign extended before adding into the accumulator.
>
> I think we also need to verify that the multiply has at least 2x the number of bits of the input. We shouldn't match (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). Does anything prevent that right now?
> I don't think we can do this if the multiply result is zero extended. Each of the 4 multiplies done by vpdpbusd compute a signed 16-bit product that will be sign extended before adding into the accumulator.
>
> I think we also need to verify that the multiply has at least 2x the number of bits of the input. We shouldn't match (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). Does anything prevent that right now?
Really good catch. Thanks.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D116039/new/
https://reviews.llvm.org/D116039
More information about the llvm-commits
mailing list