[llvm] b7c2e57 - [RISCV] Add custom type legalization to form MULHSU when possible.
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 1 10:16:18 PDT 2021
Author: Craig Topper
Date: 2021-04-01T10:15:55-07:00
New Revision: b7c2e577cc8f9f92b7ce206ea7d6cba3eaa3f98c
URL: https://github.com/llvm/llvm-project/commit/b7c2e577cc8f9f92b7ce206ea7d6cba3eaa3f98c
DIFF: https://github.com/llvm/llvm-project/commit/b7c2e577cc8f9f92b7ce206ea7d6cba3eaa3f98c.diff
LOG: [RISCV] Add custom type legalization to form MULHSU when possible.
There's no target independent ISD opcode for MULHSU, so custom
legalize 2*XLen multiplies ourselves. We have to be a little
careful to prefer MULHU or MULHSU.
I thought about doing this in isel by pattern matching the
(add (mul X, (srai Y, XLen-1)), (mulhu X, Y)) pattern. I decided
against this because the add might become part of a chain of adds.
I don't trust DAG combine not to reassociate with other adds making
it difficult to find both pieces again.
Reviewed By: asb
Differential Revision: https://reviews.llvm.org/D99479
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoM.td
llvm/test/CodeGen/RISCV/mul.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 16a781751017..6cafa2791ed6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -219,20 +219,23 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::UDIV, XLenVT, Expand);
setOperationAction(ISD::SREM, XLenVT, Expand);
setOperationAction(ISD::UREM, XLenVT, Expand);
- }
-
- if (Subtarget.is64Bit() && Subtarget.hasStdExtM()) {
- setOperationAction(ISD::MUL, MVT::i32, Custom);
-
- setOperationAction(ISD::SDIV, MVT::i8, Custom);
- setOperationAction(ISD::UDIV, MVT::i8, Custom);
- setOperationAction(ISD::UREM, MVT::i8, Custom);
- setOperationAction(ISD::SDIV, MVT::i16, Custom);
- setOperationAction(ISD::UDIV, MVT::i16, Custom);
- setOperationAction(ISD::UREM, MVT::i16, Custom);
- setOperationAction(ISD::SDIV, MVT::i32, Custom);
- setOperationAction(ISD::UDIV, MVT::i32, Custom);
- setOperationAction(ISD::UREM, MVT::i32, Custom);
+ } else {
+ if (Subtarget.is64Bit()) {
+ setOperationAction(ISD::MUL, MVT::i32, Custom);
+ setOperationAction(ISD::MUL, MVT::i128, Custom);
+
+ setOperationAction(ISD::SDIV, MVT::i8, Custom);
+ setOperationAction(ISD::UDIV, MVT::i8, Custom);
+ setOperationAction(ISD::UREM, MVT::i8, Custom);
+ setOperationAction(ISD::SDIV, MVT::i16, Custom);
+ setOperationAction(ISD::UDIV, MVT::i16, Custom);
+ setOperationAction(ISD::UREM, MVT::i16, Custom);
+ setOperationAction(ISD::SDIV, MVT::i32, Custom);
+ setOperationAction(ISD::UDIV, MVT::i32, Custom);
+ setOperationAction(ISD::UREM, MVT::i32, Custom);
+ } else {
+ setOperationAction(ISD::MUL, MVT::i64, Custom);
+ }
}
setOperationAction(ISD::SDIVREM, XLenVT, Expand);
@@ -3868,9 +3871,47 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(RCW.getValue(2));
break;
}
+ case ISD::MUL: {
+ unsigned Size = N->getSimpleValueType(0).getSizeInBits();
+ unsigned XLen = Subtarget.getXLen();
+ // This multiply needs to be expanded, try to use MULHSU+MUL if possible.
+ if (Size > XLen) {
+ assert(Size == (XLen * 2) && "Unexpected custom legalisation");
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ APInt HighMask = APInt::getHighBitsSet(Size, XLen);
+
+ bool LHSIsU = DAG.MaskedValueIsZero(LHS, HighMask);
+ bool RHSIsU = DAG.MaskedValueIsZero(RHS, HighMask);
+ // We need exactly one side to be unsigned.
+ if (LHSIsU == RHSIsU)
+ return;
+
+ auto MakeMULPair = [&](SDValue S, SDValue U) {
+ MVT XLenVT = Subtarget.getXLenVT();
+ S = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, S);
+ U = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, U);
+ SDValue Lo = DAG.getNode(ISD::MUL, DL, XLenVT, S, U);
+ SDValue Hi = DAG.getNode(RISCVISD::MULHSU, DL, XLenVT, S, U);
+ return DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), Lo, Hi);
+ };
+
+ bool LHSIsS = DAG.ComputeNumSignBits(LHS) > XLen;
+ bool RHSIsS = DAG.ComputeNumSignBits(RHS) > XLen;
+
+ // The other operand should be signed, but still prefer MULH when
+ // possible.
+ if (RHSIsU && LHSIsS && !RHSIsS)
+ Results.push_back(MakeMULPair(LHS, RHS));
+ else if (LHSIsU && RHSIsS && !LHSIsS)
+ Results.push_back(MakeMULPair(RHS, LHS));
+
+ return;
+ }
+ LLVM_FALLTHROUGH;
+ }
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL:
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
if (N->getOperand(1).getOpcode() == ISD::Constant)
@@ -6784,6 +6825,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(BuildPairF64)
NODE_NAME_CASE(SplitF64)
NODE_NAME_CASE(TAIL)
+ NODE_NAME_CASE(MULHSU)
NODE_NAME_CASE(SLLW)
NODE_NAME_CASE(SRAW)
NODE_NAME_CASE(SRLW)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 20e96c625339..b17aa1527b79 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -40,6 +40,8 @@ enum NodeType : unsigned {
BuildPairF64,
SplitF64,
TAIL,
+ // Multiply high for signedxunsigned.
+ MULHSU,
// RV64I shifts, directly matching the semantics of the named RISC-V
// instructions.
SLLW,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td
index e841d7fdea0b..f654ed1949a4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td
@@ -15,6 +15,7 @@
// RISC-V specific DAG Nodes.
//===----------------------------------------------------------------------===//
+def riscv_mulhsu : SDNode<"RISCVISD::MULHSU", SDTIntBinOp>;
def riscv_divw : SDNode<"RISCVISD::DIVW", SDT_RISCVIntBinOpW>;
def riscv_divuw : SDNode<"RISCVISD::DIVUW", SDT_RISCVIntBinOpW>;
def riscv_remuw : SDNode<"RISCVISD::REMUW", SDT_RISCVIntBinOpW>;
@@ -63,7 +64,7 @@ let Predicates = [HasStdExtM] in {
def : PatGprGpr<mul, MUL>;
def : PatGprGpr<mulhs, MULH>;
def : PatGprGpr<mulhu, MULHU>;
-// No ISDOpcode for mulhsu
+def : PatGprGpr<riscv_mulhsu, MULHSU>;
def : PatGprGpr<sdiv, DIV>;
def : PatGprGpr<udiv, DIVU>;
def : PatGprGpr<srem, REM>;
diff --git a/llvm/test/CodeGen/RISCV/mul.ll b/llvm/test/CodeGen/RISCV/mul.ll
index 00df918d6f63..2260233a4559 100644
--- a/llvm/test/CodeGen/RISCV/mul.ll
+++ b/llvm/test/CodeGen/RISCV/mul.ll
@@ -398,10 +398,7 @@ define i32 @mulhsu(i32 %a, i32 %b) nounwind {
;
; RV32IM-LABEL: mulhsu:
; RV32IM: # %bb.0:
-; RV32IM-NEXT: srai a2, a1, 31
-; RV32IM-NEXT: mulhu a1, a0, a1
-; RV32IM-NEXT: mul a0, a0, a2
-; RV32IM-NEXT: add a0, a1, a0
+; RV32IM-NEXT: mulhsu a0, a1, a0
; RV32IM-NEXT: ret
;
; RV64I-LABEL: mulhsu:
@@ -1423,10 +1420,7 @@ define i64 @mulhsu_i64(i64 %a, i64 %b) nounwind {
;
; RV64IM-LABEL: mulhsu_i64:
; RV64IM: # %bb.0:
-; RV64IM-NEXT: srai a2, a1, 63
-; RV64IM-NEXT: mulhu a1, a0, a1
-; RV64IM-NEXT: mul a0, a0, a2
-; RV64IM-NEXT: add a0, a1, a0
+; RV64IM-NEXT: mulhsu a0, a1, a0
; RV64IM-NEXT: ret
%1 = zext i64 %a to i128
%2 = sext i64 %b to i128
More information about the llvm-commits
mailing list