[llvm] b7c2e57 - [RISCV] Add custom type legalization to form MULHSU when possible.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 1 10:16:18 PDT 2021


Author: Craig Topper
Date: 2021-04-01T10:15:55-07:00
New Revision: b7c2e577cc8f9f92b7ce206ea7d6cba3eaa3f98c

URL: https://github.com/llvm/llvm-project/commit/b7c2e577cc8f9f92b7ce206ea7d6cba3eaa3f98c
DIFF: https://github.com/llvm/llvm-project/commit/b7c2e577cc8f9f92b7ce206ea7d6cba3eaa3f98c.diff

LOG: [RISCV] Add custom type legalization to form MULHSU when possible.

There's no target independent ISD opcode for MULHSU, so custom
legalize 2*XLen multiplies ourselves. We have to be a little
careful to prefer MULHU or MULHSU.

I thought about doing this in isel by pattern matching the
(add (mul X, (srai Y, XLen-1)), (mulhu X, Y)) pattern. I decided
against this because the add might become part of a chain of adds.
I don't trust DAG combine not to reassociate with other adds making
it difficult to find both pieces again.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D99479

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoM.td
    llvm/test/CodeGen/RISCV/mul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 16a781751017..6cafa2791ed6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -219,20 +219,23 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::UDIV, XLenVT, Expand);
     setOperationAction(ISD::SREM, XLenVT, Expand);
     setOperationAction(ISD::UREM, XLenVT, Expand);
-  }
-
-  if (Subtarget.is64Bit() && Subtarget.hasStdExtM()) {
-    setOperationAction(ISD::MUL, MVT::i32, Custom);
-
-    setOperationAction(ISD::SDIV, MVT::i8, Custom);
-    setOperationAction(ISD::UDIV, MVT::i8, Custom);
-    setOperationAction(ISD::UREM, MVT::i8, Custom);
-    setOperationAction(ISD::SDIV, MVT::i16, Custom);
-    setOperationAction(ISD::UDIV, MVT::i16, Custom);
-    setOperationAction(ISD::UREM, MVT::i16, Custom);
-    setOperationAction(ISD::SDIV, MVT::i32, Custom);
-    setOperationAction(ISD::UDIV, MVT::i32, Custom);
-    setOperationAction(ISD::UREM, MVT::i32, Custom);
+  } else {
+    if (Subtarget.is64Bit()) {
+      setOperationAction(ISD::MUL, MVT::i32, Custom);
+      setOperationAction(ISD::MUL, MVT::i128, Custom);
+
+      setOperationAction(ISD::SDIV, MVT::i8, Custom);
+      setOperationAction(ISD::UDIV, MVT::i8, Custom);
+      setOperationAction(ISD::UREM, MVT::i8, Custom);
+      setOperationAction(ISD::SDIV, MVT::i16, Custom);
+      setOperationAction(ISD::UDIV, MVT::i16, Custom);
+      setOperationAction(ISD::UREM, MVT::i16, Custom);
+      setOperationAction(ISD::SDIV, MVT::i32, Custom);
+      setOperationAction(ISD::UDIV, MVT::i32, Custom);
+      setOperationAction(ISD::UREM, MVT::i32, Custom);
+    } else {
+      setOperationAction(ISD::MUL, MVT::i64, Custom);
+    }
   }
 
   setOperationAction(ISD::SDIVREM, XLenVT, Expand);
@@ -3868,9 +3871,47 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     Results.push_back(RCW.getValue(2));
     break;
   }
+  case ISD::MUL: {
+    unsigned Size = N->getSimpleValueType(0).getSizeInBits();
+    unsigned XLen = Subtarget.getXLen();
+    // This multiply needs to be expanded, try to use MULHSU+MUL if possible.
+    if (Size > XLen) {
+      assert(Size == (XLen * 2) && "Unexpected custom legalisation");
+      SDValue LHS = N->getOperand(0);
+      SDValue RHS = N->getOperand(1);
+      APInt HighMask = APInt::getHighBitsSet(Size, XLen);
+
+      bool LHSIsU = DAG.MaskedValueIsZero(LHS, HighMask);
+      bool RHSIsU = DAG.MaskedValueIsZero(RHS, HighMask);
+      // We need exactly one side to be unsigned.
+      if (LHSIsU == RHSIsU)
+        return;
+
+      auto MakeMULPair = [&](SDValue S, SDValue U) {
+        MVT XLenVT = Subtarget.getXLenVT();
+        S = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, S);
+        U = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, U);
+        SDValue Lo = DAG.getNode(ISD::MUL, DL, XLenVT, S, U);
+        SDValue Hi = DAG.getNode(RISCVISD::MULHSU, DL, XLenVT, S, U);
+        return DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), Lo, Hi);
+      };
+
+      bool LHSIsS = DAG.ComputeNumSignBits(LHS) > XLen;
+      bool RHSIsS = DAG.ComputeNumSignBits(RHS) > XLen;
+
+      // The other operand should be signed, but still prefer MULH when
+      // possible.
+      if (RHSIsU && LHSIsS && !RHSIsS)
+        Results.push_back(MakeMULPair(LHS, RHS));
+      else if (LHSIsU && RHSIsS && !LHSIsS)
+        Results.push_back(MakeMULPair(RHS, LHS));
+
+      return;
+    }
+    LLVM_FALLTHROUGH;
+  }
   case ISD::ADD:
   case ISD::SUB:
-  case ISD::MUL:
     assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
            "Unexpected custom legalisation");
     if (N->getOperand(1).getOpcode() == ISD::Constant)
@@ -6784,6 +6825,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(BuildPairF64)
   NODE_NAME_CASE(SplitF64)
   NODE_NAME_CASE(TAIL)
+  NODE_NAME_CASE(MULHSU)
   NODE_NAME_CASE(SLLW)
   NODE_NAME_CASE(SRAW)
   NODE_NAME_CASE(SRLW)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 20e96c625339..b17aa1527b79 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -40,6 +40,8 @@ enum NodeType : unsigned {
   BuildPairF64,
   SplitF64,
   TAIL,
+  // Multiply high for signedxunsigned.
+  MULHSU,
   // RV64I shifts, directly matching the semantics of the named RISC-V
   // instructions.
   SLLW,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td
index e841d7fdea0b..f654ed1949a4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td
@@ -15,6 +15,7 @@
 // RISC-V specific DAG Nodes.
 //===----------------------------------------------------------------------===//
 
+def riscv_mulhsu : SDNode<"RISCVISD::MULHSU", SDTIntBinOp>;
 def riscv_divw  : SDNode<"RISCVISD::DIVW",  SDT_RISCVIntBinOpW>;
 def riscv_divuw : SDNode<"RISCVISD::DIVUW", SDT_RISCVIntBinOpW>;
 def riscv_remuw : SDNode<"RISCVISD::REMUW", SDT_RISCVIntBinOpW>;
@@ -63,7 +64,7 @@ let Predicates = [HasStdExtM] in {
 def : PatGprGpr<mul, MUL>;
 def : PatGprGpr<mulhs, MULH>;
 def : PatGprGpr<mulhu, MULHU>;
-// No ISDOpcode for mulhsu
+def : PatGprGpr<riscv_mulhsu, MULHSU>;
 def : PatGprGpr<sdiv, DIV>;
 def : PatGprGpr<udiv, DIVU>;
 def : PatGprGpr<srem, REM>;

diff  --git a/llvm/test/CodeGen/RISCV/mul.ll b/llvm/test/CodeGen/RISCV/mul.ll
index 00df918d6f63..2260233a4559 100644
--- a/llvm/test/CodeGen/RISCV/mul.ll
+++ b/llvm/test/CodeGen/RISCV/mul.ll
@@ -398,10 +398,7 @@ define i32 @mulhsu(i32 %a, i32 %b) nounwind {
 ;
 ; RV32IM-LABEL: mulhsu:
 ; RV32IM:       # %bb.0:
-; RV32IM-NEXT:    srai a2, a1, 31
-; RV32IM-NEXT:    mulhu a1, a0, a1
-; RV32IM-NEXT:    mul a0, a0, a2
-; RV32IM-NEXT:    add a0, a1, a0
+; RV32IM-NEXT:    mulhsu a0, a1, a0
 ; RV32IM-NEXT:    ret
 ;
 ; RV64I-LABEL: mulhsu:
@@ -1423,10 +1420,7 @@ define i64 @mulhsu_i64(i64 %a, i64 %b) nounwind {
 ;
 ; RV64IM-LABEL: mulhsu_i64:
 ; RV64IM:       # %bb.0:
-; RV64IM-NEXT:    srai a2, a1, 63
-; RV64IM-NEXT:    mulhu a1, a0, a1
-; RV64IM-NEXT:    mul a0, a0, a2
-; RV64IM-NEXT:    add a0, a1, a0
+; RV64IM-NEXT:    mulhsu a0, a1, a0
 ; RV64IM-NEXT:    ret
   %1 = zext i64 %a to i128
   %2 = sext i64 %b to i128


        


More information about the llvm-commits mailing list