[llvm] [NVPTX] Support vectors for AND combine (PR #154165)
Kevin McAfee via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 18 16:52:06 PDT 2025
https://github.com/kalxr updated https://github.com/llvm/llvm-project/pull/154165
>From 04756ad0d7649a50b09976cf063b8414cc550258 Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Mon, 18 Aug 2025 17:36:45 +0000
Subject: [PATCH 1/3] pre-commit test
---
llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 34 ++++++++++++++++++++
1 file changed, 34 insertions(+)
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 98f94bb7b3ac1..4767a7da1aadc 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -103,5 +103,39 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
%res = call <2 x i8> @test_call_2xi8(<2 x i8> %a)
ret <2 x i8> %res
}
+
+define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
+; O0-LABEL: test_uitofp_2xi8(
+; O0: {
+; O0-NEXT: .reg .b16 %rs<5>;
+; O0-NEXT: .reg .b32 %r<5>;
+; O0-EMPTY:
+; O0-NEXT: // %bb.0:
+; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O0-NEXT: mov.b32 %r1, {%rs1, %rs2};
+; O0-NEXT: and.b32 %r2, %r1, 16711935;
+; O0-NEXT: mov.b32 {%rs3, %rs4}, %r2;
+; O0-NEXT: cvt.rn.f32.u16 %r3, %rs4;
+; O0-NEXT: cvt.rn.f32.u16 %r4, %rs3;
+; O0-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O0-NEXT: ret;
+;
+; O3-LABEL: test_uitofp_2xi8(
+; O3: {
+; O3-NEXT: .reg .b16 %rs<5>;
+; O3-NEXT: .reg .b32 %r<5>;
+; O3-EMPTY:
+; O3-NEXT: // %bb.0:
+; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O3-NEXT: mov.b32 %r1, {%rs1, %rs2};
+; O3-NEXT: and.b32 %r2, %r1, 16711935;
+; O3-NEXT: mov.b32 {%rs3, %rs4}, %r2;
+; O3-NEXT: cvt.rn.f32.u16 %r3, %rs4;
+; O3-NEXT: cvt.rn.f32.u16 %r4, %rs3;
+; O3-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O3-NEXT: ret;
+ %1 = uitofp <2 x i8> %a to <2 x float>
+ ret <2 x float> %1
+}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; COMMON: {{.*}}
>From 5ace8e2c589bceed16a89d4e30738f7e0e8bc4f4 Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Mon, 18 Aug 2025 17:38:29 +0000
Subject: [PATCH 2/3] [NVPTX] Support vectors for AND combine
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 114 ++++++++++++-------
llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 25 ++--
llvm/test/CodeGen/NVPTX/shift-opt.ll | 9 +-
3 files changed, 84 insertions(+), 64 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 74e6c139c610d..8190c23407250 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,6 +5242,58 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
+// Helper function to check if an AND operation on a load can be eliminated.
+// Returns a replacement value if the load can be eliminated, else nullopt.
+static std::optional<SDValue>
+canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ SDValue Val, SDValue Mask, SDValue AExt) {
+ if (Val->getOpcode() != NVPTXISD::LoadV2 &&
+ Val->getOpcode() != NVPTXISD::LoadV4) {
+ return std::nullopt;
+ }
+
+ ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+ if (!MaskCnst) {
+ // Not an AND with a constant
+ return std::nullopt;
+ }
+
+ uint64_t MaskVal = MaskCnst->getZExtValue();
+ if (MaskVal != 0xff) {
+ // Not an AND that chops off top 8 bits
+ return std::nullopt;
+ }
+
+ MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
+ if (!Mem) {
+ // Not a MemSDNode
+ return std::nullopt;
+ }
+
+ EVT MemVT = Mem->getMemoryVT();
+ if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+ // We only handle the i8 case
+ return std::nullopt;
+ }
+
+ unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
+ if (ExtType == ISD::SEXTLOAD) {
+ // If the load is a sextload, the AND is needed to zero out the high 8 bits
+ return std::nullopt;
+ }
+
+ SDValue Result = Val;
+
+ if (AExt) {
+ // Re-insert the ext as a zext.
+ Result =
+ DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val);
+ }
+
+ // If we get here, the AND is unnecessary. Replace it with the load.
+ return Result;
+}
+
static SDValue PerformANDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// The type legalizer turns a vector load of i8 values into a zextload to i16
@@ -5252,9 +5304,8 @@ static SDValue PerformANDCombine(SDNode *N,
SDValue Val = N->getOperand(0);
SDValue Mask = N->getOperand(1);
- if (isa<ConstantSDNode>(Val)) {
+ if (isa<ConstantSDNode>(Val))
std::swap(Val, Mask);
- }
SDValue AExt;
@@ -5264,49 +5315,24 @@ static SDValue PerformANDCombine(SDNode *N,
Val = Val->getOperand(0);
}
- if (Val->getOpcode() == NVPTXISD::LoadV2 ||
- Val->getOpcode() == NVPTXISD::LoadV4) {
- ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
- if (!MaskCnst) {
- // Not an AND with a constant
- return SDValue();
- }
-
- uint64_t MaskVal = MaskCnst->getZExtValue();
- if (MaskVal != 0xff) {
- // Not an AND that chops off top 8 bits
- return SDValue();
- }
-
- MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
- if (!Mem) {
- // Not a MemSDNode?!?
- return SDValue();
- }
-
- EVT MemVT = Mem->getMemoryVT();
- if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
- // We only handle the i8 case
- return SDValue();
- }
-
- unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
- if (ExtType == ISD::SEXTLOAD) {
- // If for some reason the load is a sextload, the and is needed to zero
- // out the high 8 bits
- return SDValue();
- }
-
- bool AddTo = false;
- if (AExt.getNode() != nullptr) {
- // Re-insert the ext as a zext.
- Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
- AExt.getValueType(), Val);
- AddTo = true;
+ if (Val.getOpcode() == ISD::BUILD_VECTOR &&
+ Mask.getOpcode() == ISD::BUILD_VECTOR) {
+ assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt);
+ for (unsigned I = 0; I < Val->getNumOperands(); ++I) {
+ // We know that the AExt is null and therefore the result of this call
+ // will be the BUILD_VECTOR operand or nullopt. Rather than create a new
+ // BUILD_VECTOR with the collection of operands, we can just use the
+ // original and ignore the result.
+ if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I),
+ AExt)
+ .has_value())
+ return SDValue();
}
-
- // If we get here, the AND is unnecessary. Just replace it with the load
- DCI.CombineTo(N, Val, AddTo);
+ DCI.CombineTo(N, Val, false);
+ } else {
+ auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt);
+ if (Result.has_value())
+ DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr);
}
return SDValue();
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 4767a7da1aadc..77398b5fa41cb 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -107,32 +107,27 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
; O0-LABEL: test_uitofp_2xi8(
; O0: {
-; O0-NEXT: .reg .b16 %rs<5>;
-; O0-NEXT: .reg .b32 %r<5>;
+; O0-NEXT: .reg .b16 %rs<3>;
+; O0-NEXT: .reg .b32 %r<4>;
; O0-EMPTY:
; O0-NEXT: // %bb.0:
; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
; O0-NEXT: mov.b32 %r1, {%rs1, %rs2};
-; O0-NEXT: and.b32 %r2, %r1, 16711935;
-; O0-NEXT: mov.b32 {%rs3, %rs4}, %r2;
-; O0-NEXT: cvt.rn.f32.u16 %r3, %rs4;
-; O0-NEXT: cvt.rn.f32.u16 %r4, %rs3;
-; O0-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O0-NEXT: cvt.rn.f32.u16 %r2, %rs2;
+; O0-NEXT: cvt.rn.f32.u16 %r3, %rs1;
+; O0-NEXT: st.param.v2.b32 [func_retval0], {%r3, %r2};
; O0-NEXT: ret;
;
; O3-LABEL: test_uitofp_2xi8(
; O3: {
-; O3-NEXT: .reg .b16 %rs<5>;
-; O3-NEXT: .reg .b32 %r<5>;
+; O3-NEXT: .reg .b16 %rs<3>;
+; O3-NEXT: .reg .b32 %r<3>;
; O3-EMPTY:
; O3-NEXT: // %bb.0:
; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
-; O3-NEXT: mov.b32 %r1, {%rs1, %rs2};
-; O3-NEXT: and.b32 %r2, %r1, 16711935;
-; O3-NEXT: mov.b32 {%rs3, %rs4}, %r2;
-; O3-NEXT: cvt.rn.f32.u16 %r3, %rs4;
-; O3-NEXT: cvt.rn.f32.u16 %r4, %rs3;
-; O3-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O3-NEXT: cvt.rn.f32.u16 %r1, %rs2;
+; O3-NEXT: cvt.rn.f32.u16 %r2, %rs1;
+; O3-NEXT: st.param.v2.b32 [func_retval0], {%r2, %r1};
; O3-NEXT: ret;
%1 = uitofp <2 x i8> %a to <2 x float>
ret <2 x float> %1
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
index e7866b01064c7..e0d22c62993ba 100644
--- a/llvm/test/CodeGen/NVPTX/shift-opt.ll
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
; CHECK-LABEL: test_vec(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<7>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0];
; CHECK-NEXT: ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1];
; CHECK-NEXT: mov.b32 %r1, {%rs3, %rs4};
-; CHECK-NEXT: and.b32 %r2, %r1, 16711935;
; CHECK-NEXT: shr.u16 %rs5, %rs2, 5;
; CHECK-NEXT: shr.u16 %rs6, %rs1, 5;
-; CHECK-NEXT: mov.b32 %r3, {%rs6, %rs5};
-; CHECK-NEXT: or.b32 %r4, %r3, %r2;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: mov.b32 %r2, {%rs6, %rs5};
+; CHECK-NEXT: or.b32 %r3, %r2, %r1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%ext = zext <2 x i8> %y to <2 x i16>
%shl = shl <2 x i16> %ext, splat(i16 5)
>From c6e3981ed405bdc64bd17499a2702755629ca7ef Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Mon, 18 Aug 2025 23:51:51 +0000
Subject: [PATCH 3/3] Remove AND combine and implement computeKnownBits for
LoadV2/4
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 123 ++++----------------
1 file changed, 25 insertions(+), 98 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8190c23407250..67a5f2ae84c31 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,102 +5242,6 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
-// Helper function to check if an AND operation on a load can be eliminated.
-// Returns a replacement value if the load can be eliminated, else nullopt.
-static std::optional<SDValue>
-canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
- SDValue Val, SDValue Mask, SDValue AExt) {
- if (Val->getOpcode() != NVPTXISD::LoadV2 &&
- Val->getOpcode() != NVPTXISD::LoadV4) {
- return std::nullopt;
- }
-
- ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
- if (!MaskCnst) {
- // Not an AND with a constant
- return std::nullopt;
- }
-
- uint64_t MaskVal = MaskCnst->getZExtValue();
- if (MaskVal != 0xff) {
- // Not an AND that chops off top 8 bits
- return std::nullopt;
- }
-
- MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
- if (!Mem) {
- // Not a MemSDNode
- return std::nullopt;
- }
-
- EVT MemVT = Mem->getMemoryVT();
- if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
- // We only handle the i8 case
- return std::nullopt;
- }
-
- unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
- if (ExtType == ISD::SEXTLOAD) {
- // If the load is a sextload, the AND is needed to zero out the high 8 bits
- return std::nullopt;
- }
-
- SDValue Result = Val;
-
- if (AExt) {
- // Re-insert the ext as a zext.
- Result =
- DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val);
- }
-
- // If we get here, the AND is unnecessary. Replace it with the load.
- return Result;
-}
-
-static SDValue PerformANDCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
- // The type legalizer turns a vector load of i8 values into a zextload to i16
- // registers, optionally ANY_EXTENDs it (if target type is integer),
- // and ANDs off the high 8 bits. Since we turn this load into a
- // target-specific DAG node, the DAG combiner fails to eliminate these AND
- // nodes. Do that here.
- SDValue Val = N->getOperand(0);
- SDValue Mask = N->getOperand(1);
-
- if (isa<ConstantSDNode>(Val))
- std::swap(Val, Mask);
-
- SDValue AExt;
-
- // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
- if (Val.getOpcode() == ISD::ANY_EXTEND) {
- AExt = Val;
- Val = Val->getOperand(0);
- }
-
- if (Val.getOpcode() == ISD::BUILD_VECTOR &&
- Mask.getOpcode() == ISD::BUILD_VECTOR) {
- assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt);
- for (unsigned I = 0; I < Val->getNumOperands(); ++I) {
- // We know that the AExt is null and therefore the result of this call
- // will be the BUILD_VECTOR operand or nullopt. Rather than create a new
- // BUILD_VECTOR with the collection of operands, we can just use the
- // original and ignore the result.
- if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I),
- AExt)
- .has_value())
- return SDValue();
- }
- DCI.CombineTo(N, Val, false);
- } else {
- auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt);
- if (Result.has_value())
- DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr);
- }
-
- return SDValue();
-}
-
static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
@@ -6009,8 +5913,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformADDCombine(N, DCI, OptLevel);
case ISD::ADDRSPACECAST:
return combineADDRSPACECAST(N, DCI);
- case ISD::AND:
- return PerformANDCombine(N, DCI);
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
return combineMulWide(N, DCI, OptLevel);
@@ -6635,6 +6537,27 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
}
}
+static void computeKnownBitsFori8VLoad(const SDValue Op, KnownBits &Known) {
+ MemSDNode *Mem = dyn_cast<MemSDNode>(Op);
+ if (!Mem) {
+ return;
+ }
+
+ EVT MemVT = Mem->getMemoryVT();
+ if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+ return;
+ }
+
+ unsigned ExtType = Mem->getConstantOperandVal(Mem->getNumOperands() - 1);
+ if (ExtType == ISD::SEXTLOAD) {
+ Known = Known.sext(Known.getBitWidth());
+ return;
+ }
+ KnownBits HighZeros(Known.getBitWidth() - 8);
+ HighZeros.setAllZero();
+ Known.insertBits(HighZeros, 8);
+}
+
void NVPTXTargetLowering::computeKnownBitsForTargetNode(
const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
const SelectionDAG &DAG, unsigned Depth) const {
@@ -6644,6 +6567,10 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
case NVPTXISD::PRMT:
computeKnownBitsForPRMT(Op, Known, DAG, Depth);
break;
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4:
+ computeKnownBitsFori8VLoad(Op, Known);
+ break;
default:
break;
}
More information about the llvm-commits
mailing list