[llvm] [LegalizeTypes][RISCV] Use signed promotion for UADDSAT if that's what the target prefers. (PR #102842)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 11 20:00:56 PDT 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/102842

As noted in #102781 we can promote UADDSAT if we use sign extend instead of zero extend.

The custom handler for RISC-V was using SIGN_EXTEND when the Zbb extension was enabled. With this change we no longer need the custom code.

>From ac39f401c44e4080e8980bdb7f91c29ea64cae56 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Sun, 11 Aug 2024 19:12:05 -0700
Subject: [PATCH] [LegalizeTypes][RISCV] Use signed promotion for UADDSAT if
 that's what the target prefers.

---
 .../SelectionDAG/LegalizeIntegerTypes.cpp     | 30 ++++++++++++-------
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 23 ++++----------
 llvm/test/CodeGen/RISCV/uadd_sat_plus.ll      |  2 +-
 3 files changed, 25 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 3635bc7a965804..174b3820693f5f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -1056,6 +1056,25 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
     return matcher.getNode(ISD::USUBSAT, dl, Op1.getValueType(), Op1, Op2);
   }
 
+  if (Opcode == ISD::UADDSAT) {
+    EVT OVT = Op1.getValueType();
+    EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
+    // We can promote if we use sign-extend. Do this if the target prefers.
+    if (TLI.isSExtCheaperThanZExt(OVT, NVT)) {
+      Op1 = SExtPromotedInteger(Op1);
+      Op2 = SExtPromotedInteger(Op2);
+      return matcher.getNode(ISD::UADDSAT, dl, NVT, Op1, Op2);
+    }
+
+    Op1 = ZExtPromotedInteger(Op1);
+    Op2 = ZExtPromotedInteger(Op2);
+    unsigned NewBits = NVT.getScalarSizeInBits();
+    APInt MaxVal = APInt::getLowBitsSet(NewBits, OldBits);
+    SDValue SatMax = DAG.getConstant(MaxVal, dl, NVT);
+    SDValue Add = matcher.getNode(ISD::ADD, dl, NVT, Op1, Op2);
+    return matcher.getNode(ISD::UMIN, dl, NVT, Add, SatMax);
+  }
+
   bool IsShift = Opcode == ISD::USHLSAT || Opcode == ISD::SSHLSAT;
 
   // FIXME: We need vp-aware PromotedInteger functions.
@@ -1063,9 +1082,6 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   if (IsShift) {
     Op1Promoted = GetPromotedInteger(Op1);
     Op2Promoted = ZExtPromotedInteger(Op2);
-  } else if (Opcode == ISD::UADDSAT) {
-    Op1Promoted = ZExtPromotedInteger(Op1);
-    Op2Promoted = ZExtPromotedInteger(Op2);
   } else {
     Op1Promoted = SExtPromotedInteger(Op1);
     Op2Promoted = SExtPromotedInteger(Op2);
@@ -1073,14 +1089,6 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   EVT PromotedType = Op1Promoted.getValueType();
   unsigned NewBits = PromotedType.getScalarSizeInBits();
 
-  if (Opcode == ISD::UADDSAT) {
-    APInt MaxVal = APInt::getLowBitsSet(NewBits, OldBits);
-    SDValue SatMax = DAG.getConstant(MaxVal, dl, PromotedType);
-    SDValue Add =
-        matcher.getNode(ISD::ADD, dl, PromotedType, Op1Promoted, Op2Promoted);
-    return matcher.getNode(ISD::UMIN, dl, PromotedType, Add, SatMax);
-  }
-
   // Shift cannot use a min/max expansion, we can't detect overflow if all of
   // the bits have been shifted out.
   if (IsShift || matcher.isOperationLegal(Opcode, PromotedType)) {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 872c288f5ad8cd..732d4fed036d8a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -268,11 +268,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::LOAD, MVT::i32, Custom);
     setOperationAction({ISD::ADD, ISD::SUB, ISD::SHL, ISD::SRA, ISD::SRL},
                        MVT::i32, Custom);
-    setOperationAction({ISD::UADDO, ISD::USUBO, ISD::UADDSAT}, MVT::i32,
-                       Custom);
+    setOperationAction({ISD::UADDO, ISD::USUBO}, MVT::i32, Custom);
     if (!Subtarget.hasStdExtZbb())
-      setOperationAction({ISD::SADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, MVT::i32,
-                         Custom);
+      setOperationAction(
+          {ISD::SADDSAT, ISD::SSUBSAT, ISD::UADDSAT, ISD::USUBSAT}, MVT::i32,
+          Custom);
     setOperationAction(ISD::SADDO, MVT::i32, Custom);
   }
   if (!Subtarget.hasStdExtZmmul()) {
@@ -12173,20 +12173,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   case ISD::UADDSAT:
   case ISD::USUBSAT: {
     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
-           "Unexpected custom legalisation");
-    if (Subtarget.hasStdExtZbb()) {
-      // With Zbb we can sign extend and let LegalizeDAG use minu/maxu. Using
-      // sign extend allows overflow of the lower 32 bits to be detected on
-      // the promoted size.
-      SDValue LHS =
-          DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(0));
-      SDValue RHS =
-          DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, N->getOperand(1));
-      SDValue Res = DAG.getNode(N->getOpcode(), DL, MVT::i64, LHS, RHS);
-      Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
-      return;
-    }
-
+           !Subtarget.hasStdExtZbb() && "Unexpected custom legalisation");
     // Without Zbb, expand to UADDO/USUBO+select which will trigger our custom
     // promotion for UADDO/USUBO.
     Results.push_back(expandAddSubSat(N, DAG));
diff --git a/llvm/test/CodeGen/RISCV/uadd_sat_plus.ll b/llvm/test/CodeGen/RISCV/uadd_sat_plus.ll
index 219f0daf270bc1..23875a7ec56211 100644
--- a/llvm/test/CodeGen/RISCV/uadd_sat_plus.ll
+++ b/llvm/test/CodeGen/RISCV/uadd_sat_plus.ll
@@ -40,9 +40,9 @@ define i32 @func32(i32 %x, i32 %y, i32 %z) nounwind {
 ;
 ; RV64IZbb-LABEL: func32:
 ; RV64IZbb:       # %bb.0:
+; RV64IZbb-NEXT:    sext.w a0, a0
 ; RV64IZbb-NEXT:    mulw a1, a1, a2
 ; RV64IZbb-NEXT:    not a2, a1
-; RV64IZbb-NEXT:    sext.w a0, a0
 ; RV64IZbb-NEXT:    minu a0, a0, a2
 ; RV64IZbb-NEXT:    add a0, a0, a1
 ; RV64IZbb-NEXT:    ret



More information about the llvm-commits mailing list