[llvm] 124c93c - [RISCV] When matching SROIW, check all 64 bits of the OR mask

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 16 10:08:52 PST 2020


Author: Craig Topper
Date: 2020-11-16T10:08:15-08:00
New Revision: 124c93c528758071fccfce68f6b633081a19c226

URL: https://github.com/llvm/llvm-project/commit/124c93c528758071fccfce68f6b633081a19c226
DIFF: https://github.com/llvm/llvm-project/commit/124c93c528758071fccfce68f6b633081a19c226.diff

LOG: [RISCV] When matching SROIW, check all 64 bits of the OR mask

We need to make sure the upper 32 bits are all ones to ensure the result is properly sign extended. Previously we only checked the lower 32 bits of the mask. I've also added a check that the shift amount is less than 32. Without that the original code asserts inside maskLeadingOnes if the SROI check is removed or the SROIW pattern is checked first. I've refactored the code to use early outs to reduce nesting.

I've also updated SLOIW matching with the same changes, but I couldn't find a broken test case with the existing code.

Differential Revision: https://reviews.llvm.org/D90961

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/test/CodeGen/RISCV/rv64Zbb.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 54219e902d55..765775c03587 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -314,62 +314,66 @@ bool RISCVDAGToDAGISel::SelectSLLIUW(SDValue N, SDValue &RS1, SDValue &Shamt) {
 // and then we check that VC1, the mask used to fill with ones, is compatible
 // with VC2, the shamt:
 //
-//  VC1 == maskTrailingOnes<uint32_t>(VC2)
+//  VC2 < 32
+//  VC1 == maskTrailingOnes<uint64_t>(VC2)
 
 bool RISCVDAGToDAGISel::SelectSLOIW(SDValue N, SDValue &RS1, SDValue &Shamt) {
-  if (Subtarget->getXLenVT() == MVT::i64 &&
-      N.getOpcode() == ISD::SIGN_EXTEND_INREG &&
-      cast<VTSDNode>(N.getOperand(1))->getVT() == MVT::i32) {
-    if (N.getOperand(0).getOpcode() == ISD::OR) {
-      SDValue Or = N.getOperand(0);
-      if (Or.getOperand(0).getOpcode() == ISD::SHL) {
-        SDValue Shl = Or.getOperand(0);
-        if (isa<ConstantSDNode>(Shl.getOperand(1)) &&
-            isa<ConstantSDNode>(Or.getOperand(1))) {
-          uint32_t VC1 = Or.getConstantOperandVal(1);
-          uint32_t VC2 = Shl.getConstantOperandVal(1);
-          if (VC1 == maskTrailingOnes<uint32_t>(VC2)) {
-            RS1 = Shl.getOperand(0);
-            Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
-                                              Shl.getOperand(1).getValueType());
-            return true;
-          }
-        }
-      }
-    }
-  }
-  return false;
+  assert(Subtarget->is64Bit() && "SLOIW should only be matched on RV64");
+  if (N.getOpcode() != ISD::SIGN_EXTEND_INREG ||
+      cast<VTSDNode>(N.getOperand(1))->getVT() != MVT::i32)
+    return false;
+
+   SDValue Or = N.getOperand(0);
+
+   if (Or.getOpcode() != ISD::OR || !isa<ConstantSDNode>(Or.getOperand(1)))
+     return false;
+
+   SDValue Shl = Or.getOperand(0);
+   if (Shl.getOpcode() != ISD::SHL || !isa<ConstantSDNode>(Shl.getOperand(1)))
+     return false;
+
+   uint64_t VC1 = Or.getConstantOperandVal(1);
+   uint64_t VC2 = Shl.getConstantOperandVal(1);
+
+   if (VC2 >= 32 || VC1 != maskTrailingOnes<uint64_t>(VC2))
+     return false;
+
+  RS1 = Shl.getOperand(0);
+  Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
+                                    Shl.getOperand(1).getValueType());
+  return true;
 }
 
 // Check that it is a SROIW (Shift Right Ones Immediate i32 on RV64).
 // We first check that it is the right node tree:
 //
-//  (OR (SHL RS1, VC2), VC1)
+//  (OR (SRL RS1, VC2), VC1)
 //
 // and then we check that VC1, the mask used to fill with ones, is compatible
 // with VC2, the shamt:
 //
-//  VC1 == maskLeadingOnes<uint32_t>(VC2)
-
+//  VC2 < 32
+//  VC1 == maskTrailingZeros<uint64_t>(32 - VC2)
+//
 bool RISCVDAGToDAGISel::SelectSROIW(SDValue N, SDValue &RS1, SDValue &Shamt) {
-  if (N.getOpcode() == ISD::OR && Subtarget->getXLenVT() == MVT::i64) {
-    SDValue Or = N;
-    if (Or.getOperand(0).getOpcode() == ISD::SRL) {
-      SDValue Srl = Or.getOperand(0);
-      if (isa<ConstantSDNode>(Srl.getOperand(1)) &&
-          isa<ConstantSDNode>(Or.getOperand(1))) {
-        uint32_t VC1 = Or.getConstantOperandVal(1);
-        uint32_t VC2 = Srl.getConstantOperandVal(1);
-        if (VC1 == maskLeadingOnes<uint32_t>(VC2)) {
-          RS1 = Srl.getOperand(0);
-          Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
-                                            Srl.getOperand(1).getValueType());
-          return true;
-        }
-      }
-    }
-  }
-  return false;
+  assert(Subtarget->is64Bit() && "SROIW should only be matched on RV64");
+  if (N.getOpcode() != ISD::OR || !isa<ConstantSDNode>(N.getOperand(1)))
+    return false;
+
+  SDValue Srl = N.getOperand(0);
+  if (Srl.getOpcode() != ISD::SRL || !isa<ConstantSDNode>(Srl.getOperand(1)))
+    return false;
+
+  uint64_t VC1 = N.getConstantOperandVal(1);
+  uint64_t VC2 = Srl.getConstantOperandVal(1);
+
+  if (VC2 >= 32 || VC1 != maskTrailingZeros<uint64_t>(32 - VC2))
+    return false;
+
+  RS1 = Srl.getOperand(0);
+  Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N),
+                                    Srl.getOperand(1).getValueType());
+  return true;
 }
 
 // Check that it is a RORIW (i32 Right Rotate Immediate on RV64).

diff  --git a/llvm/test/CodeGen/RISCV/rv64Zbb.ll b/llvm/test/CodeGen/RISCV/rv64Zbb.ll
index a1d0b8a74b26..66985c565370 100644
--- a/llvm/test/CodeGen/RISCV/rv64Zbb.ll
+++ b/llvm/test/CodeGen/RISCV/rv64Zbb.ll
@@ -166,7 +166,6 @@ define signext i32 @sroi_i32(i32 signext %a) nounwind {
 ; This is similar to the type legalized version of sroiw but the mask is 0 in
 ; the upper bits instead of 1 so the result is not sign extended. Make sure we
 ; don't match it to sroiw.
-; FIXME: We're matching it to sroiw.
 define i64 @sroiw_bug(i64 %a) nounwind {
 ; RV64I-LABEL: sroiw_bug:
 ; RV64I:       # %bb.0:
@@ -178,12 +177,18 @@ define i64 @sroiw_bug(i64 %a) nounwind {
 ;
 ; RV64IB-LABEL: sroiw_bug:
 ; RV64IB:       # %bb.0:
-; RV64IB-NEXT:    sroiw a0, a0, 1
+; RV64IB-NEXT:    srli a0, a0, 1
+; RV64IB-NEXT:    addi a1, zero, 1
+; RV64IB-NEXT:    slli a1, a1, 31
+; RV64IB-NEXT:    or a0, a0, a1
 ; RV64IB-NEXT:    ret
 ;
 ; RV64IBB-LABEL: sroiw_bug:
 ; RV64IBB:       # %bb.0:
-; RV64IBB-NEXT:    sroiw a0, a0, 1
+; RV64IBB-NEXT:    srli a0, a0, 1
+; RV64IBB-NEXT:    addi a1, zero, 1
+; RV64IBB-NEXT:    slli a1, a1, 31
+; RV64IBB-NEXT:    or a0, a0, a1
 ; RV64IBB-NEXT:    ret
   %neg = lshr i64 %a, 1
   %neg12 = or i64 %neg, 2147483648


        


More information about the llvm-commits mailing list