[PATCH] D116039: [X86] Combine reduce (add (mul x, y)) to VNNI instruction.

Craig Topper via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 27 10:34:19 PST 2021


craig.topper added inline comments.


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41818
+             Op.getOpcode() == ISD::SIGN_EXTEND) &&
+            Op.getOperand(0).getValueType().getScalarSizeInBits() <= 8)
+          return true;
----------------
You can use `Op.getOperand(0).getScalarValueSizeInBits()` to simplify this


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41828
+  // value, so we just check the signed bits.
+  if ((IsFreeTruncation(Op0) && DAG.ComputeMinSignedBits(Op0) <= 9) &&
+      (IsFreeTruncation(Op1) && DAG.ComputeMinSignedBits(Op1) <= 8))
----------------
Can we use `DAG.computeKnownBits(Op0).countMaxActiveBits() <= 8` to make this more readable?


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41856
 
+static SDValue createVPDPBUSD(SelectionDAG &DAG, const SDValue &Ext0,
+                              const SDValue &Ext1, unsigned &LogBias,
----------------
Why are Ext0 and Ext1 passed by const reference? SDValue should be passed by value.


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41864
+      MVT::getVectorVT(MVT::i8, Ext0.getValueType().getVectorElementCount());
+  if (Ext0.getOperand(0).getValueType().getScalarType() != MVT::i8)
+    ZExt0 = DAG.getZExtOrTrunc(Ext0.getOperand(0), DL, Vi8VT);
----------------
Can we just TRUNCATE the Ext nodes without assuming they are extend nodes. That way it just works when you support constants in the future?


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41870
+  // Find the appropriate width for the DotProduct.
+  EVT InVT = (ZExt0 == Ext0) ? Ext0.getOperand(0).getValueType():
+                               ZExt0.getValueType();
----------------
Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:42164
+  // Verify the type we're extracting is i32, as the output element type of
+  // vpdpbusd and vpdpwssd is i32.
+  if (ExtractVT != MVT::i32)
----------------
This code isn't handling vpdpwssd so why mention it here?


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:42176
+
+  if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND ||
+               Root.getOpcode() == ISD::ZERO_EXTEND ||
----------------
Is this code valid for this transform? There's a large comment of justification for why it is ok for SAD. I think I only saw a test for the SIGN_EXTEND case?


================
Comment at: llvm/lib/Target/X86/X86ISelLowering.cpp:41813-41814
+      [](SDValue &Op) -> bool {
+        if ((Op.getOpcode() == ISD::ZERO_EXTEND ||
+             Op.getOpcode() == ISD::SIGN_EXTEND) &&
+            Op.getOperand(0).getValueType().getScalarSizeInBits() <= 8)
----------------
LuoYuanke wrote:
> pengfei wrote:
> > How about `ANY_EXTEND` ? The same below.
> This check the opcode, so we need check both zero extend and sign extend. I'm not sure if any extend also works, because the upper bits is undefined. What's the signed bit for any extend?
It is undefined and ComputeMinSignedBits will return BitWidth - 1 for it.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D116039/new/

https://reviews.llvm.org/D116039



More information about the llvm-commits mailing list