[llvm] [RISCV] add computeKnownBitsForTargetNode for RISCVISD::SRLW (PR #155995)
Shreeyash Pandey via llvm-commits
llvm-commits at lists.llvm.org
Sat Aug 30 10:20:47 PDT 2025
https://github.com/bojle updated https://github.com/llvm/llvm-project/pull/155995
>From 7c8e165ca068cb9508928ed76611cc0582c0ef70 Mon Sep 17 00:00:00 2001
From: Shreeyash Pandey <shreeyash335 at gmail.com>
Date: Mon, 25 Aug 2025 16:59:13 +0530
Subject: [PATCH 1/3] [RISCV] add computeKnownBitsForTargetNode for
RISCVISD::SRLW
Signed-off-by: Shreeyash Pandey <shreeyash335 at gmail.com>
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 9 ++
llvm/unittests/Target/RISCV/CMakeLists.txt | 1 +
.../Target/RISCV/RISCVSelectionDAGTest.cpp | 109 ++++++++++++++++++
3 files changed, 119 insertions(+)
create mode 100644 llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4a1db80076530..6dacd83bfe550 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21340,6 +21340,15 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
Known = Known.sext(BitWidth);
break;
}
+ case RISCVISD::SRLW: {
+ KnownBits Known2;
+ Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
+ Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
+ Known = KnownBits::lshr(Known.trunc(32), Known2.trunc(5).zext(32));
+ // Restore the original width by sign extending.
+ Known = Known.sext(BitWidth);
+ break;
+ }
case RISCVISD::CTZW: {
KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
unsigned PossibleTZ = Known2.trunc(32).countMaxTrailingZeros();
diff --git a/llvm/unittests/Target/RISCV/CMakeLists.txt b/llvm/unittests/Target/RISCV/CMakeLists.txt
index 8da8c3896faf1..701bbee55da71 100644
--- a/llvm/unittests/Target/RISCV/CMakeLists.txt
+++ b/llvm/unittests/Target/RISCV/CMakeLists.txt
@@ -19,4 +19,5 @@ set(LLVM_LINK_COMPONENTS
add_llvm_target_unittest(RISCVTests
MCInstrAnalysisTest.cpp
RISCVInstrInfoTest.cpp
+ RISCVSelectionDAGTest.cpp
)
diff --git a/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp b/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
new file mode 100644
index 0000000000000..a13f88484c00c
--- /dev/null
+++ b/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
@@ -0,0 +1,109 @@
+//===----------------------------------------------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "RISCVISelLowering.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Target/TargetMachine.h"
+#include "gtest/gtest.h"
+
+namespace llvm {
+
+class RISCVSelectionDAGTest : public testing::Test {
+
+protected:
+ static void SetUpTestCase() {
+ LLVMInitializeRISCVTargetInfo();
+ LLVMInitializeRISCVTarget();
+ LLVMInitializeRISCVTargetMC();
+ }
+
+ void SetUp() override {
+ StringRef Assembly = "define void @f() { ret void }";
+
+ Triple TargetTriple("riscv64", "unknown", "linux");
+
+ std::string Error;
+ const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
+
+ TargetOptions Options;
+ TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
+ TargetTriple, "generic", "", Options, std::nullopt, std::nullopt,
+ CodeGenOptLevel::Default));
+
+ SMDiagnostic SMError;
+ M = parseAssemblyString(Assembly, SMError, Context);
+ if (!M)
+ report_fatal_error(SMError.getMessage());
+ M->setDataLayout(TM->createDataLayout());
+
+ F = M->getFunction("f");
+ if (!F)
+ report_fatal_error("Function 'f' not found");
+
+ MachineModuleInfo MMI(TM.get());
+
+ MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
+ MMI.getContext(), /*FunctionNum*/ 0);
+
+ DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None);
+ if (!DAG)
+ report_fatal_error("SelectionDAG allocation failed");
+
+ OptimizationRemarkEmitter ORE(F);
+ DAG->init(*MF, ORE, /*LibInfo*/ nullptr, /*AA*/ nullptr,
+ /*AC*/ nullptr, /*MDT*/ nullptr, /*MSDT*/ nullptr, MMI, nullptr);
+ }
+
+ LLVMContext Context;
+ std::unique_ptr<TargetMachine> TM;
+ std::unique_ptr<Module> M;
+ Function *F = nullptr;
+ std::unique_ptr<MachineFunction> MF;
+ std::unique_ptr<SelectionDAG> DAG;
+};
+
+/// SRLW: Logical Shift Right
+TEST_F(RISCVSelectionDAGTest, computeKnownBits_SRLW) {
+ // Following DAG is created from this IR snippet:
+ //
+ // define i64 @f(i32 %x, i32 %y) {
+ // %a = and i32 %x, 2147483647 ; zeros the MSB for %x
+ // %b = lshr i32 %a, %y
+ // %c = zext i32 %b to i64 ; makes the most significant 32 bits 0
+ // ret i64 %c
+ // }
+ SDLoc Loc;
+ auto IntVT = EVT::getIntegerVT(Context, 32);
+ auto Int64VT = EVT::getIntegerVT(Context, 64);
+ auto Px = DAG->getRegister(0, IntVT);
+ auto Py = DAG->getConstant(2147483647, Loc, IntVT);
+ auto N1 = DAG->getNode(ISD::AND, Loc, IntVT, Px, Py);
+ auto Qx = DAG->getRegister(0, IntVT);
+ auto N2 = DAG->getNode(ISD::SRL, Loc, IntVT, N1, Qx);
+ auto N3 = DAG->getNode(ISD::ZERO_EXTEND, Loc, Int64VT, N2);
+ // N1 = 0???????????????????????????????
+ // N2 = 0???????????????????????????????
+ // N3 = 000000000000000000000000000000000???????????????????????????????
+ // After zero extend, we expect 33 most significant zeros to be known:
+ // 32 from sign extension and 1 from AND operation
+ KnownBits Known = DAG->computeKnownBits(N3);
+ EXPECT_EQ(Known.Zero, APInt(64, -2147483648));
+ EXPECT_EQ(Known.One, APInt(64, 0));
+}
+
+} // end namespace llvm
>From d713682b7a6a5fcc81428ae2fcc958d14f83e770 Mon Sep 17 00:00:00 2001
From: Shreeyash Pandey <shreeyash335 at gmail.com>
Date: Sat, 30 Aug 2025 17:36:17 +0530
Subject: [PATCH 2/3] [RISCV] use RISCVISD::SRLW instead of the generic ISD:SRL
Signed-off-by: Shreeyash Pandey <shreeyash335 at gmail.com>
---
llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp b/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
index a13f88484c00c..c6779e6f7b86a 100644
--- a/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
+++ b/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
@@ -6,6 +6,7 @@
//===----------------------------------------------------------------------===//
#include "RISCVISelLowering.h"
+#include "RISCVSelectionDAGInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
@@ -94,7 +95,7 @@ TEST_F(RISCVSelectionDAGTest, computeKnownBits_SRLW) {
auto Py = DAG->getConstant(2147483647, Loc, IntVT);
auto N1 = DAG->getNode(ISD::AND, Loc, IntVT, Px, Py);
auto Qx = DAG->getRegister(0, IntVT);
- auto N2 = DAG->getNode(ISD::SRL, Loc, IntVT, N1, Qx);
+ auto N2 = DAG->getNode(RISCVISD::SRLW, Loc, IntVT, N1, Qx);
auto N3 = DAG->getNode(ISD::ZERO_EXTEND, Loc, Int64VT, N2);
// N1 = 0???????????????????????????????
// N2 = 0???????????????????????????????
>From f3833cd75cc20f231c4f6dd2f32790eaad8922b1 Mon Sep 17 00:00:00 2001
From: Shreeyash Pandey <shreeyash335 at gmail.com>
Date: Sat, 30 Aug 2025 22:49:49 +0530
Subject: [PATCH 3/3] [RISCV] make SRLW io types be 64 bits
Signed-off-by: Shreeyash Pandey <shreeyash335 at gmail.com>
---
llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp b/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
index c6779e6f7b86a..55fe665dcbadd 100644
--- a/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
+++ b/llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
@@ -94,8 +94,8 @@ TEST_F(RISCVSelectionDAGTest, computeKnownBits_SRLW) {
auto Px = DAG->getRegister(0, IntVT);
auto Py = DAG->getConstant(2147483647, Loc, IntVT);
auto N1 = DAG->getNode(ISD::AND, Loc, IntVT, Px, Py);
- auto Qx = DAG->getRegister(0, IntVT);
- auto N2 = DAG->getNode(RISCVISD::SRLW, Loc, IntVT, N1, Qx);
+ auto Qx = DAG->getRegister(0, Int64VT);
+ auto N2 = DAG->getNode(RISCVISD::SRLW, Loc, Int64VT, N1, Qx);
auto N3 = DAG->getNode(ISD::ZERO_EXTEND, Loc, Int64VT, N2);
// N1 = 0???????????????????????????????
// N2 = 0???????????????????????????????
More information about the llvm-commits
mailing list