[llvm] [NVPTX] Optimize v2x16 BUILD_VECTORs to PRMT (PR #116675)

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 07:27:57 PST 2024


================
@@ -6176,6 +6176,57 @@ static SDValue PerformLOADCombine(SDNode *N,
       DL);
 }
 
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+  auto VT = N->getValueType(0);
+  if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT))
+    return SDValue();
+
+  auto Op0 = N->getOperand(0);
+  auto Op1 = N->getOperand(1);
+
+  // Start out by assuming we want to take the lower 2 bytes of each i32
+  // operand.
+  uint64_t Op0Bytes = 0x10;
+  uint64_t Op1Bytes = 0x54;
+
+  std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
+                                                {&Op1, &Op1Bytes}};
+
+  // Check that each operand is an i16, truncated from an i32 operand. We'll
+  // select individual bytes from those original operands. Optionally, fold in a
+  // shift right of that original operand.
+  for (auto &[Op, OpBytes] : OpData) {
+    // Eat up any bitcast
+    if (Op->getOpcode() == ISD::BITCAST)
+      *Op = Op->getOperand(0);
+
+    if (Op->getValueType() != MVT::i16 || Op->getOpcode() != ISD::TRUNCATE ||
+        Op->getOperand(0).getValueType() != MVT::i32)
+      return SDValue();
+
+    *Op = Op->getOperand(0);
+
+    // Optionally, fold in a shift-right of the original operand and permute
+    // the two higher bytes from the shifted operand
+    if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
+      if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
+        *OpBytes += 0x22;
+        *Op = Op->getOperand(0);
+      }
+    }
----------------
frasercrmck wrote:

How do you see this idea working in the case of something like `generic_2xi16` [here](https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/NVPTX/load-store.ll#L197-L217)?

In this case, it's better to extract to actual 16-bit registers from the vector using `mov.b32 {%rs1, %rs2}, %r1`, because we're doing scalarized 16-bit operations on them. We can't detect that at the point at which we do `LowerEXTRACT_VECTOR_ELEMENT`. Well we can, but it's not common to have to look at uses when lowering nodes, and it's a bit of a red flag to me. The problem with eagerly creating PRMTs when lowering the vector extracts is that it gives us these awkward 32-bit registers, which we have to undo.

We'd probably have to do what `NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT` is doing and try and find two `truncate`s from two `PRMT`s from the same original vector and replace them all with a single `NVPTX::I32ToV2I16`. This starts to sound a little strange, especially since we haven't proven that using PRMTs in this way is going to bring any tangible benefits.

What do you think?

https://github.com/llvm/llvm-project/pull/116675


More information about the llvm-commits mailing list