[llvm] [SelectionDAG] Ensure that we don't create `UCMP`/`SCMP` nodes with operands being scalars and result being a 1-element vector during scalarization (PR #98687)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 12 12:42:49 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Volodymyr Vasylkun (Poseydon42)

<details>
<summary>Changes</summary>

This patch fixes a problem that existed before where in some situations a `UCMP`/`SCMP` node which operated on 1-element vectors had a legal result type (i.e. `v1i64` on AArch64), but illegal operands (i.e. `v1i65`). This meant that operand scalarization was performed on the node and the operands were changed to a legal scalar type, but the result wasn't. This then led to `UCMP`/`SCMP` nodes with different vector-ness of operands and result appearing in the SDAG. This patch addresses this issue by fully scalarizing the `UCMP`/`SCMP` node and then turning its result back into a 1-element vector using a `SCALAR_TO_VECTOR` node.

It also adds several assertions to `SelectionDAG::getNode()` to avoid this or a similar issue arising in the future. I wasn't sure if these two changes are unrelated enough to warrant two small separate PRs, but I'm happy to split this PR into two if that's deemed more appropriate.

---
Full diff: https://github.com/llvm/llvm-project/pull/98687.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+4-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+11) 
- (modified) llvm/test/CodeGen/AArch64/ucmp.ll (+17) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index bbf08e862da12..4282f440abff7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1029,7 +1029,10 @@ SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N) {
 SDValue DAGTypeLegalizer::ScalarizeVecOp_CMP(SDNode *N) {
   SDValue LHS = GetScalarizedVector(N->getOperand(0));
   SDValue RHS = GetScalarizedVector(N->getOperand(1));
-  return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS);
+
+  EVT ResVT = N->getValueType(0).getVectorElementType();
+  SDValue Cmp = DAG.getNode(N->getOpcode(), SDLoc(N), ResVT, LHS, RHS);
+  return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), N->getValueType(0), Cmp);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 79f90bae1d8d6..adbe8fc9466b1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6983,6 +6983,17 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
         return getNode(ISD::AND, DL, VT, N1, getNOT(DL, N2, VT));
     }
     break;
+  case ISD::SCMP:
+  case ISD::UCMP:
+    assert(N1.getValueType() == N2.getValueType() &&
+           "Types of operands of UCMP/SCMP must match");
+    assert(N1.getValueType().isVector() == VT.isVector() &&
+           "Operands and return type of must both be scalars or vectors");
+    if (VT.isVector())
+      assert(VT.getVectorElementCount() ==
+                 N1.getValueType().getVectorElementCount() &&
+             "Result and operands must have the same number of elements");
+    break;
   case ISD::AVGFLOORS:
   case ISD::AVGFLOORU:
   case ISD::AVGCEILS:
diff --git a/llvm/test/CodeGen/AArch64/ucmp.ll b/llvm/test/CodeGen/AArch64/ucmp.ll
index 39a32194147eb..351d440243b70 100644
--- a/llvm/test/CodeGen/AArch64/ucmp.ll
+++ b/llvm/test/CodeGen/AArch64/ucmp.ll
@@ -93,3 +93,20 @@ define i64 @ucmp.64.64(i64 %x, i64 %y) nounwind {
   %1 = call i64 @llvm.ucmp(i64 %x, i64 %y)
   ret i64 %1
 }
+
+define <1 x i64> @ucmp.1.64.65(<1 x i65> %x, <1 x i65> %y) {
+; CHECK-LABEL: ucmp.1.64.65:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and x8, x1, #0x1
+; CHECK-NEXT:    and x9, x3, #0x1
+; CHECK-NEXT:    cmp x2, x0
+; CHECK-NEXT:    sbcs xzr, x9, x8
+; CHECK-NEXT:    cset x10, lo
+; CHECK-NEXT:    cmp x0, x2
+; CHECK-NEXT:    sbcs xzr, x8, x9
+; CHECK-NEXT:    csinv x8, x10, xzr, hs
+; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    ret
+  %1 = call <1 x i64> @llvm.ucmp(<1 x i65> %x, <1 x i65> %y)
+  ret <1 x i64> %1
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/98687


More information about the llvm-commits mailing list