[llvm] [RISCV] Check that the stack adjust immediate for cm.push/pop* has the correct sign and is divisible by 16. (PR #85295)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 10:40:03 PDT 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/85295

>From 6ff27c1f1e49b81f7b7d0d3d25544e8ac5999364 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 14 Mar 2024 11:47:51 -0700
Subject: [PATCH 1/3] [RISCV] Check that the stack adjust immediate for
 cm.push/pop* has the correct sign and is divisible by 16.

To do this I've added a new AsmOperand for cm.push to expect a
negative value. We also use that to customize the print function
so that we don't need to detect cm.push opcode to add the negative
sign.

I've renamed some places that used Spimm to be StackAdj since that's
what is being parsed. I'm still not about where we should use Spimm
or StackAdj.

I've removed the printSpimm helper function which in one usage printed
the sp[5:4]<<4 value and the other usage printed the full stack
adjustment. There wasn't anything interesting about how it was printed
it just passed the value to the raw_stream. If there was something
special needed, it's unclear whether it would be the same for the two
different usages so I inlined it.

I realize this patch may have snowballed a bit with the renames. I
can split it up if reviewers prefer.

One open question is whether we need to support stack adjustments
expressed as an expression rather than a literal integer.
---
 .../Target/RISCV/AsmParser/RISCVAsmParser.cpp | 19 ++++---
 .../RISCV/MCTargetDesc/RISCVBaseInfo.cpp      |  2 -
 .../Target/RISCV/MCTargetDesc/RISCVBaseInfo.h |  8 +--
 .../RISCV/MCTargetDesc/RISCVInstPrinter.cpp   | 18 +++----
 .../RISCV/MCTargetDesc/RISCVInstPrinter.h     |  8 ++-
 llvm/lib/Target/RISCV/RISCVInstrInfoZc.td     | 51 ++++++++++++++-----
 llvm/test/MC/RISCV/rv32zcmp-invalid.s         | 12 +++++
 llvm/test/MC/RISCV/rv64zcmp-invalid.s         | 12 +++++
 8 files changed, 94 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index 1779959324da2a..c73789d8a4c7d0 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -213,7 +213,11 @@ class RISCVAsmParser : public MCTargetAsmParser {
   ParseStatus parseReglist(OperandVector &Operands);
   ParseStatus parseRegReg(OperandVector &Operands);
   ParseStatus parseRetval(OperandVector &Operands);
-  ParseStatus parseZcmpSpimm(OperandVector &Operands);
+  ParseStatus parseZcmpStackAdj(OperandVector &Operands,
+                                bool ExpectNegative = false);
+  ParseStatus parseZcmpNegStackAdj(OperandVector &Operands) {
+    return parseZcmpStackAdj(Operands, /*ExpectNegative*/ true);
+  }
 
   bool parseOperand(OperandVector &Operands, StringRef Mnemonic);
 
@@ -1062,7 +1066,7 @@ struct RISCVOperand final : public MCParsedAsmOperand {
       break;
     case KindTy::Spimm:
       OS << "<Spimm: ";
-      RISCVZC::printSpimm(Spimm.Val, OS);
+      OS << Spimm.Val;
       OS << '>';
       break;
     case KindTy::RegReg:
@@ -1608,7 +1612,7 @@ bool RISCVAsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
         ErrorLoc,
         "operand must be {ra [, s0[-sN]]} or {x1 [, x8[-x9][, x18[-xN]]]}");
   }
-  case Match_InvalidSpimm: {
+  case Match_InvalidStackAdj: {
     SMLoc ErrorLoc = ((RISCVOperand &)*Operands[ErrorInfo]).getStartLoc();
     return Error(
         ErrorLoc,
@@ -2583,8 +2587,9 @@ ParseStatus RISCVAsmParser::parseReglist(OperandVector &Operands) {
   return ParseStatus::Success;
 }
 
-ParseStatus RISCVAsmParser::parseZcmpSpimm(OperandVector &Operands) {
-  (void)parseOptionalToken(AsmToken::Minus);
+ParseStatus RISCVAsmParser::parseZcmpStackAdj(OperandVector &Operands,
+                                              bool ExpectNegative) {
+  bool Negative = parseOptionalToken(AsmToken::Minus);
 
   SMLoc S = getLoc();
   int64_t StackAdjustment = getLexer().getTok().getIntVal();
@@ -2592,8 +2597,10 @@ ParseStatus RISCVAsmParser::parseZcmpSpimm(OperandVector &Operands) {
   unsigned RlistVal = static_cast<RISCVOperand *>(Operands[1].get())->Rlist.Val;
 
   bool IsEABI = isRVE();
-  if (!RISCVZC::getSpimm(RlistVal, Spimm, StackAdjustment, isRV64(), IsEABI))
+  if (Negative != ExpectNegative ||
+      !RISCVZC::getSpimm(RlistVal, Spimm, StackAdjustment, isRV64(), IsEABI)) {
     return ParseStatus::NoMatch;
+  }
   Operands.push_back(RISCVOperand::createSpimm(Spimm << 4, S));
   getLexer().Lex();
   return ParseStatus::Success;
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
index 61f8e71710377e..5d9a58babe606c 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
@@ -236,6 +236,4 @@ void RISCVZC::printRlist(unsigned SlistEncode, raw_ostream &OS) {
   OS << "}";
 }
 
-void RISCVZC::printSpimm(int64_t Spimm, raw_ostream &OS) { OS << Spimm; }
-
 } // namespace llvm
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 6d0381c30d3e86..c65b5121877254 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -580,15 +580,17 @@ inline static bool getSpimm(unsigned RlistVal, unsigned &SpimmVal,
                             int64_t StackAdjustment, bool IsRV64, bool IsEABI) {
   if (RlistVal == RLISTENCODE::INVALID_RLIST)
     return false;
-  unsigned stackAdj = getStackAdjBase(RlistVal, IsRV64, IsEABI);
-  SpimmVal = (StackAdjustment - stackAdj) / 16;
+  unsigned StackAdjBase = getStackAdjBase(RlistVal, IsRV64, IsEABI);
+  StackAdjustment -= StackAdjBase;
+  if (StackAdjustment % 16 != 0)
+    return false;
+  SpimmVal = StackAdjustment / 16;
   if (SpimmVal > 3)
     return false;
   return true;
 }
 
 void printRlist(unsigned SlistEncode, raw_ostream &OS);
-void printSpimm(int64_t Spimm, raw_ostream &OS);
 } // namespace RISCVZC
 
 } // namespace llvm
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
index bd899495812f44..db7760f12f9367 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
@@ -292,24 +292,24 @@ void RISCVInstPrinter::printRegReg(const MCInst *MI, unsigned OpNo,
   O << ")";
 }
 
-void RISCVInstPrinter::printSpimm(const MCInst *MI, unsigned OpNo,
-                                  const MCSubtargetInfo &STI, raw_ostream &O) {
+void RISCVInstPrinter::printStackAdj(const MCInst *MI, unsigned OpNo,
+                                     const MCSubtargetInfo &STI, raw_ostream &O,
+                                     bool Negate) {
   int64_t Imm = MI->getOperand(OpNo).getImm();
-  unsigned Opcode = MI->getOpcode();
   bool IsRV64 = STI.hasFeature(RISCV::Feature64Bit);
   bool IsEABI = STI.hasFeature(RISCV::FeatureRVE);
-  int64_t Spimm = 0;
+  int64_t StackAdj = 0;
   auto RlistVal = MI->getOperand(0).getImm();
   assert(RlistVal != 16 && "Incorrect rlist.");
   auto Base = RISCVZC::getStackAdjBase(RlistVal, IsRV64, IsEABI);
-  Spimm = Imm + Base;
-  assert((Spimm >= Base && Spimm <= Base + 48) && "Incorrect spimm");
-  if (Opcode == RISCV::CM_PUSH)
-    Spimm = -Spimm;
+  StackAdj = Imm + Base;
+  assert((StackAdj >= Base && StackAdj <= Base + 48) && "Incorrect stack adjust");
+  if (Negate)
+    StackAdj = -StackAdj;
 
   // RAII guard for ANSI color escape sequences
   WithMarkup ScopedMarkup = markup(O, Markup::Immediate);
-  RISCVZC::printSpimm(Spimm, O);
+  O << StackAdj;
 }
 
 void RISCVInstPrinter::printVMaskReg(const MCInst *MI, unsigned OpNo,
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h
index 4512bd5f4c4b7a..c444e51b5bc40c 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h
@@ -52,8 +52,12 @@ class RISCVInstPrinter : public MCInstPrinter {
                      const MCSubtargetInfo &STI, raw_ostream &O);
   void printRlist(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
                   raw_ostream &O);
-  void printSpimm(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
-                  raw_ostream &O);
+  void printStackAdj(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
+                     raw_ostream &O, bool Negate = false);
+  void printNegStackAdj(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
+                        raw_ostream &O) {
+    return printStackAdj(MI, OpNo, STI, O, /*Negate*/ true);
+  }
   void printRegReg(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
                    raw_ostream &O);
   // Autogenerated by tblgen.
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZc.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZc.td
index 2c8451c5c4ceb2..a327bd3d0c28ae 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZc.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZc.td
@@ -41,10 +41,20 @@ def RlistAsmOperand : AsmOperandClass {
   let DiagnosticType = "InvalidRlist";
 }
 
-def SpimmAsmOperand : AsmOperandClass {
-  let Name = "Spimm";
-  let ParserMethod = "parseZcmpSpimm";
-  let DiagnosticType = "InvalidSpimm";
+def StackAdjAsmOperand : AsmOperandClass {
+  let Name = "StackAdj";
+  let ParserMethod = "parseZcmpStackAdj";
+  let DiagnosticType = "InvalidStackAdj";
+  let PredicateMethod = "isSpimm";
+  let RenderMethod = "addSpimmOperands";
+}
+
+def NegStackAdjAsmOperand : AsmOperandClass {
+  let Name = "NegStackAdj";
+  let ParserMethod = "parseZcmpNegStackAdj";
+  let DiagnosticType = "InvalidStackAdj";
+  let PredicateMethod = "isSpimm";
+  let RenderMethod = "addSpimmOperands";
 }
 
 def rlist : Operand<OtherVT> {
@@ -59,11 +69,23 @@ def rlist : Operand<OtherVT> {
     // 0~3 Reserved for EABI
     return isUInt<4>(Imm) && Imm >= 4;
   }];
- }
+}
+
+def stackadj : Operand<OtherVT> {
+  let ParserMatchClass = StackAdjAsmOperand;
+  let PrintMethod = "printStackAdj";
+  let DecoderMethod = "decodeZcmpSpimm";
+  let MCOperandPredicate = [{
+    int64_t Imm;
+    if (!MCOp.evaluateAsConstantImm(Imm))
+      return false;
+    return isShiftedUInt<2, 4>(Imm);
+  }];
+}
 
-def spimm : Operand<OtherVT> {
-  let ParserMatchClass = SpimmAsmOperand;
-  let PrintMethod = "printSpimm";
+def negstackadj : Operand<OtherVT> {
+  let ParserMatchClass = NegStackAdjAsmOperand;
+  let PrintMethod = "printNegStackAdj";
   let DecoderMethod = "decodeZcmpSpimm";
   let MCOperandPredicate = [{
     int64_t Imm;
@@ -124,14 +146,15 @@ class RVZcArith_r<bits<5> funct5, string OpcodeStr> :
   let Constraints = "$rd = $rd_wb";
 }
 
-class RVInstZcCPPP<bits<5> funct5, string opcodestr>
-    : RVInst16<(outs), (ins rlist:$rlist, spimm:$spimm),
-               opcodestr, "$rlist, $spimm", [], InstFormatOther> {
+class RVInstZcCPPP<bits<5> funct5, string opcodestr,
+                   DAGOperand immtype = stackadj>
+    : RVInst16<(outs), (ins rlist:$rlist, immtype:$stackadj),
+               opcodestr, "$rlist, $stackadj", [], InstFormatOther> {
   bits<4> rlist;
-  bits<16> spimm;
+  bits<16> stackadj;
 
   let Inst{1-0} = 0b10;
-  let Inst{3-2} = spimm{5-4};
+  let Inst{3-2} = stackadj{5-4};
   let Inst{7-4} = rlist;
   let Inst{12-8} = funct5;
   let Inst{15-13} = 0b101;
@@ -195,7 +218,7 @@ def CM_MVSA01 : RVInst16CA<0b101011, 0b01, 0b10, (outs SR07:$rs1, SR07:$rs2),
 
 let DecoderNamespace = "RVZcmp", Predicates = [HasStdExtZcmp] in {
 let hasSideEffects = 0, mayLoad = 0, mayStore = 1, Uses = [X2], Defs = [X2] in
-def CM_PUSH : RVInstZcCPPP<0b11000, "cm.push">,
+def CM_PUSH : RVInstZcCPPP<0b11000, "cm.push", negstackadj>,
               Sched<[WriteIALU, ReadIALU, ReadStoreData, ReadStoreData,
                      ReadStoreData, ReadStoreData, ReadStoreData, ReadStoreData,
                      ReadStoreData, ReadStoreData, ReadStoreData, ReadStoreData,
diff --git a/llvm/test/MC/RISCV/rv32zcmp-invalid.s b/llvm/test/MC/RISCV/rv32zcmp-invalid.s
index cb99bba0aaa1e8..1acea187585f88 100644
--- a/llvm/test/MC/RISCV/rv32zcmp-invalid.s
+++ b/llvm/test/MC/RISCV/rv32zcmp-invalid.s
@@ -15,3 +15,15 @@ cm.popretz {ra, s0-s10}, 112
 
 # CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
 cm.popretz {ra, s0-s1}, 112
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.push {ra}, 16
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.pop {ra, s0-s1}, -32
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.push {ra}, -8
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.pop {ra, s0-s1}, -40
diff --git a/llvm/test/MC/RISCV/rv64zcmp-invalid.s b/llvm/test/MC/RISCV/rv64zcmp-invalid.s
index 103934583495f3..bf34554095ea5b 100644
--- a/llvm/test/MC/RISCV/rv64zcmp-invalid.s
+++ b/llvm/test/MC/RISCV/rv64zcmp-invalid.s
@@ -15,3 +15,15 @@ cm.popretz {ra, s0-s10}, 112
 
 # CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
 cm.popretz {ra, s0-s1}, 112
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.push {ra}, 16
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.pop {ra, s0-s1}, -32
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.push {ra}, -15
+
+# CHECK-ERROR: error: stack adjustment is invalid for this instruction and register list; refer to Zc spec for a detailed range of stack adjustment
+cm.pop {ra, s0-s1}, -33

>From a922056b43de2e0df27c80c91497c89594c8a3a3 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 14 Mar 2024 12:01:21 -0700
Subject: [PATCH 2/3] fixup! clang-format

---
 llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp | 3 ++-
 llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h   | 9 +++++----
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
index db7760f12f9367..04e02e9fa0ab73 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
@@ -303,7 +303,8 @@ void RISCVInstPrinter::printStackAdj(const MCInst *MI, unsigned OpNo,
   assert(RlistVal != 16 && "Incorrect rlist.");
   auto Base = RISCVZC::getStackAdjBase(RlistVal, IsRV64, IsEABI);
   StackAdj = Imm + Base;
-  assert((StackAdj >= Base && StackAdj <= Base + 48) && "Incorrect stack adjust");
+  assert((StackAdj >= Base && StackAdj <= Base + 48) &&
+         "Incorrect stack adjust");
   if (Negate)
     StackAdj = -StackAdj;
 
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h
index c444e51b5bc40c..77cc7a67e88920 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.h
@@ -52,10 +52,11 @@ class RISCVInstPrinter : public MCInstPrinter {
                      const MCSubtargetInfo &STI, raw_ostream &O);
   void printRlist(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
                   raw_ostream &O);
-  void printStackAdj(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
-                     raw_ostream &O, bool Negate = false);
-  void printNegStackAdj(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,
-                        raw_ostream &O) {
+  void printStackAdj(const MCInst *MI, unsigned OpNo,
+                     const MCSubtargetInfo &STI, raw_ostream &O,
+                     bool Negate = false);
+  void printNegStackAdj(const MCInst *MI, unsigned OpNo,
+                        const MCSubtargetInfo &STI, raw_ostream &O) {
     return printStackAdj(MI, OpNo, STI, O, /*Negate*/ true);
   }
   void printRegReg(const MCInst *MI, unsigned OpNo, const MCSubtargetInfo &STI,

>From 92998b3f6e48a09fdeac952903f9d71c2f95f74d Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 26 Mar 2024 10:39:47 -0700
Subject: [PATCH 3/3] fixup! remove curly braces.

---
 llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index c73789d8a4c7d0..e5203c23c31e89 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -2598,9 +2598,8 @@ ParseStatus RISCVAsmParser::parseZcmpStackAdj(OperandVector &Operands,
 
   bool IsEABI = isRVE();
   if (Negative != ExpectNegative ||
-      !RISCVZC::getSpimm(RlistVal, Spimm, StackAdjustment, isRV64(), IsEABI)) {
+      !RISCVZC::getSpimm(RlistVal, Spimm, StackAdjustment, isRV64(), IsEABI))
     return ParseStatus::NoMatch;
-  }
   Operands.push_back(RISCVOperand::createSpimm(Spimm << 4, S));
   getLexer().Lex();
   return ParseStatus::Success;



More information about the llvm-commits mailing list