[llvm] e9d4e34 - [AArch64][SVE] Add legalization support for i32/i64 vector srem/urem

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 23 16:28:09 PDT 2020


Author: Eli Friedman
Date: 2020-06-23T16:27:52-07:00
New Revision: e9d4e34ab8a4223de41fbf1881fd6a531880dda9

URL: https://github.com/llvm/llvm-project/commit/e9d4e34ab8a4223de41fbf1881fd6a531880dda9
DIFF: https://github.com/llvm/llvm-project/commit/e9d4e34ab8a4223de41fbf1881fd6a531880dda9.diff

LOG: [AArch64][SVE] Add legalization support for i32/i64 vector srem/urem

Implement them on top of sdiv/udiv, similar to what we do for integer
types.

Potential future work: implementing i8/i16 srem/urem, optimizations for
constant divisors, optimizing the mul+sub to mls.

Differential Revision: https://reviews.llvm.org/D81511

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 4a3aeae7f08a..f920d3016a95 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4421,6 +4421,10 @@ class TargetLowering : public TargetLoweringBase {
   /// only the first Count elements of the vector are used.
   SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const;
 
+  /// Expand an SREM or UREM using SDIV/UDIV or SDIVREM/UDIVREM, if legal.
+  /// Returns true if the expansion was successful.
+  bool expandREM(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const;
+
   //===--------------------------------------------------------------------===//
   // Instruction Emitting Hooks
   //

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 9969786d8d43..fadac7bb41c4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3343,26 +3343,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     break;
   }
   case ISD::UREM:
-  case ISD::SREM: {
-    EVT VT = Node->getValueType(0);
-    bool isSigned = Node->getOpcode() == ISD::SREM;
-    unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV;
-    unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
-    Tmp2 = Node->getOperand(0);
-    Tmp3 = Node->getOperand(1);
-    if (TLI.isOperationLegalOrCustom(DivRemOpc, VT)) {
-      SDVTList VTs = DAG.getVTList(VT, VT);
-      Tmp1 = DAG.getNode(DivRemOpc, dl, VTs, Tmp2, Tmp3).getValue(1);
-      Results.push_back(Tmp1);
-    } else if (TLI.isOperationLegalOrCustom(DivOpc, VT)) {
-      // X % Y -> X-X/Y*Y
-      Tmp1 = DAG.getNode(DivOpc, dl, VT, Tmp2, Tmp3);
-      Tmp1 = DAG.getNode(ISD::MUL, dl, VT, Tmp1, Tmp3);
-      Tmp1 = DAG.getNode(ISD::SUB, dl, VT, Tmp2, Tmp1);
+  case ISD::SREM:
+    if (TLI.expandREM(Node, Tmp1, DAG))
       Results.push_back(Tmp1);
-    }
     break;
-  }
   case ISD::UDIV:
   case ISD::SDIV: {
     bool isSigned = Node->getOpcode() == ISD::SDIV;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 93ce338ff232..6409f924920d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -145,6 +145,7 @@ class VectorLegalizer {
   void ExpandFixedPointDiv(SDNode *Node, SmallVectorImpl<SDValue> &Results);
   SDValue ExpandStrictFPOp(SDNode *Node);
   void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+  void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
@@ -867,6 +868,10 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VECREDUCE_FMIN:
     Results.push_back(TLI.expandVecReduce(Node, DAG));
     return;
+  case ISD::SREM:
+  case ISD::UREM:
+    ExpandREM(Node, Results);
+    return;
   }
 
   Results.push_back(DAG.UnrollVectorOp(Node));
@@ -1353,6 +1358,17 @@ void VectorLegalizer::ExpandStrictFPOp(SDNode *Node,
   UnrollStrictFPOp(Node, Results);
 }
 
+void VectorLegalizer::ExpandREM(SDNode *Node,
+                                SmallVectorImpl<SDValue> &Results) {
+  assert((Node->getOpcode() == ISD::SREM || Node->getOpcode() == ISD::UREM) &&
+         "Expected REM node");
+
+  SDValue Result;
+  if (!TLI.expandREM(Node, Result, DAG))
+    Result = DAG.UnrollVectorOp(Node);
+  Results.push_back(Result);
+}
+
 void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
                                        SmallVectorImpl<SDValue> &Results) {
   EVT VT = Node->getValueType(0);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 88faa9b13c9e..fab77fb93687 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -7823,3 +7823,26 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
     Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res);
   return Res;
 }
+
+bool TargetLowering::expandREM(SDNode *Node, SDValue &Result,
+                               SelectionDAG &DAG) const {
+  EVT VT = Node->getValueType(0);
+  SDLoc dl(Node);
+  bool isSigned = Node->getOpcode() == ISD::SREM;
+  unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV;
+  unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
+  SDValue Dividend = Node->getOperand(0);
+  SDValue Divisor = Node->getOperand(1);
+  if (isOperationLegalOrCustom(DivRemOpc, VT)) {
+    SDVTList VTs = DAG.getVTList(VT, VT);
+    Result = DAG.getNode(DivRemOpc, dl, VTs, Dividend, Divisor).getValue(1);
+    return true;
+  } else if (isOperationLegalOrCustom(DivOpc, VT)) {
+    // X % Y -> X-X/Y*Y
+    SDValue Divide = DAG.getNode(DivOpc, dl, VT, Dividend, Divisor);
+    SDValue Mul = DAG.getNode(ISD::MUL, dl, VT, Divide, Divisor);
+    Result = DAG.getNode(ISD::SUB, dl, VT, Dividend, Mul);
+    return true;
+  }
+  return false;
+}

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 94df3de50dbd..26a100e4886a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -199,6 +199,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::UADDSAT, VT, Legal);
       setOperationAction(ISD::SSUBSAT, VT, Legal);
       setOperationAction(ISD::USUBSAT, VT, Legal);
+      setOperationAction(ISD::UREM, VT, Expand);
+      setOperationAction(ISD::SREM, VT, Expand);
+      setOperationAction(ISD::SDIVREM, VT, Expand);
+      setOperationAction(ISD::UDIVREM, VT, Expand);
     }
 
     for (auto VT :

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index d856649092b5..8762b71cbf1a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -443,7 +443,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::UREM, VT, Expand);
     setOperationAction(ISD::SMUL_LOHI, VT, Expand);
     setOperationAction(ISD::UMUL_LOHI, VT, Expand);
-    setOperationAction(ISD::SDIVREM, VT, Custom);
+    setOperationAction(ISD::SDIVREM, VT, Expand);
     setOperationAction(ISD::UDIVREM, VT, Expand);
     setOperationAction(ISD::SELECT, VT, Expand);
     setOperationAction(ISD::VSELECT, VT, Expand);

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index a646f63e0839..65bd90cb5bac 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -210,6 +210,8 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT,
   setOperationAction(ISD::SREM, VT, Expand);
   setOperationAction(ISD::UREM, VT, Expand);
   setOperationAction(ISD::FREM, VT, Expand);
+  setOperationAction(ISD::SDIVREM, VT, Expand);
+  setOperationAction(ISD::UDIVREM, VT, Expand);
 
   if (!VT.isFloatingPoint() &&
       VT != MVT::v2i64 && VT != MVT::v1i64)
@@ -284,6 +286,8 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
     setOperationAction(ISD::SDIV, VT, Expand);
     setOperationAction(ISD::UREM, VT, Expand);
     setOperationAction(ISD::SREM, VT, Expand);
+    setOperationAction(ISD::UDIVREM, VT, Expand);
+    setOperationAction(ISD::SDIVREM, VT, Expand);
     setOperationAction(ISD::CTPOP, VT, Expand);
 
     // Vector reductions

diff  --git a/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll b/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll
index d1d940734f55..bc4778d66004 100644
--- a/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll
+++ b/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll
@@ -59,6 +59,36 @@ define <vscale x 4 x i64> @sdiv_split_i64(<vscale x 4 x i64> %a, <vscale x 4 x i
   ret <vscale x 4 x i64> %div
 }
 
+;
+; SREM
+;
+
+define <vscale x 4 x i32> @srem_i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
+; CHECK-LABEL: srem_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    mov z2.d, z0.d
+; CHECK-NEXT:    sdiv z2.s, p0/m, z2.s, z1.s
+; CHECK-NEXT:    mul z2.s, p0/m, z2.s, z1.s
+; CHECK-NEXT:    sub z0.s, z0.s, z2.s
+; CHECK-NEXT:    ret
+  %div = srem <vscale x 4 x i32> %a, %b
+  ret <vscale x 4 x i32> %div
+}
+
+define <vscale x 2 x i64> @srem_i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b) {
+; CHECK-LABEL: srem_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    mov z2.d, z0.d
+; CHECK-NEXT:    sdiv z2.d, p0/m, z2.d, z1.d
+; CHECK-NEXT:    mul z2.d, p0/m, z2.d, z1.d
+; CHECK-NEXT:    sub z0.d, z0.d, z2.d
+; CHECK-NEXT:    ret
+  %div = srem <vscale x 2 x i64> %a, %b
+  ret <vscale x 2 x i64> %div
+}
+
 ;
 ; UDIV
 ;
@@ -117,6 +147,37 @@ define <vscale x 4 x i64> @udiv_split_i64(<vscale x 4 x i64> %a, <vscale x 4 x i
   ret <vscale x 4 x i64> %div
 }
 
+
+;
+; UREM
+;
+
+define <vscale x 4 x i32> @urem_i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
+; CHECK-LABEL: urem_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    mov z2.d, z0.d
+; CHECK-NEXT:    udiv z2.s, p0/m, z2.s, z1.s
+; CHECK-NEXT:    mul z2.s, p0/m, z2.s, z1.s
+; CHECK-NEXT:    sub z0.s, z0.s, z2.s
+; CHECK-NEXT:    ret
+  %div = urem <vscale x 4 x i32> %a, %b
+  ret <vscale x 4 x i32> %div
+}
+
+define <vscale x 2 x i64> @urem_i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b) {
+; CHECK-LABEL: urem_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    mov z2.d, z0.d
+; CHECK-NEXT:    udiv z2.d, p0/m, z2.d, z1.d
+; CHECK-NEXT:    mul z2.d, p0/m, z2.d, z1.d
+; CHECK-NEXT:    sub z0.d, z0.d, z2.d
+; CHECK-NEXT:    ret
+  %div = urem <vscale x 2 x i64> %a, %b
+  ret <vscale x 2 x i64> %div
+}
+
 ;
 ; SMIN
 ;


        


More information about the llvm-commits mailing list