[llvm] [DAG] Replace DAGCombiner::ConstantFoldBITCASTofBUILD_VECTOR with SelectionDAG::FoldConstantBuildVector (PR #147037)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 7 00:01:57 PDT 2025


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/147037

>From 14bfeb6d8d1287f1b3c727c5cc0c399cf5b59964 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 4 Jul 2025 11:18:39 +0100
Subject: [PATCH] [DAG] Replace DAGCombiner::ConstantFoldBITCASTofBUILD_VECTOR
 with SelectionDAG::FoldConstantBuildVector

DAGCombiner can already constant fold build vectors of constants/undefs to a new vector type, but it has to be incredibly careful after legalization to not affect a target's canonicalized constants.

This patch proposes we move the implementation inside SelectionDAG to make it easier for targets to manually use the constant folding whenever it deems it safe to do so.
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  5 ++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 82 +------------------
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 71 ++++++++++++++++
 ...llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll | 10 +--
 4 files changed, 83 insertions(+), 85 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index a98e46c587273..7d8a0c4ce8e45 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2030,6 +2030,11 @@ class SelectionDAG {
   LLVM_ABI SDValue foldConstantFPMath(unsigned Opcode, const SDLoc &DL, EVT VT,
                                       ArrayRef<SDValue> Ops);
 
+  /// Fold BUILD_VECTOR of constants/undefs to the destination type
+  /// BUILD_VECTOR of constants/undefs elements.
+  LLVM_ABI SDValue FoldConstantBuildVector(BuildVectorSDNode *BV,
+                                           const SDLoc &DL, EVT DstEltVT);
+
   /// Constant fold a setcc to true or false.
   LLVM_ABI SDValue FoldSetCC(EVT VT, SDValue N1, SDValue N2, ISD::CondCode Cond,
                              const SDLoc &dl);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 586eb2f3cf45e..cc258b0983e3c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -638,7 +638,6 @@ namespace {
     SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
     SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
-    SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
     SDValue BuildSDIV(SDNode *N);
     SDValue BuildSDIVPow2(SDNode *N);
     SDValue BuildUDIV(SDNode *N);
@@ -16424,8 +16423,8 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
         TLI.isTypeLegal(VT.getVectorElementType()))) &&
       N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
       cast<BuildVectorSDNode>(N0)->isConstant())
-    return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
-                                             VT.getVectorElementType());
+    return DAG.FoldConstantBuildVector(cast<BuildVectorSDNode>(N0), SDLoc(N),
+                                       VT.getVectorElementType());
 
   // If the input is a constant, let getNode fold it.
   if (isIntOrFPConstant(N0)) {
@@ -16818,83 +16817,6 @@ SDValue DAGCombiner::visitFREEZE(SDNode *N) {
   return DAG.getNode(N0.getOpcode(), DL, N0->getVTList(), Ops, SafeFlags);
 }
 
-/// We know that BV is a build_vector node with Constant, ConstantFP or Undef
-/// operands. DstEltVT indicates the destination element value type.
-SDValue DAGCombiner::
-ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
-  EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
-
-  // If this is already the right type, we're done.
-  if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
-
-  unsigned SrcBitSize = SrcEltVT.getSizeInBits();
-  unsigned DstBitSize = DstEltVT.getSizeInBits();
-
-  // If this is a conversion of N elements of one type to N elements of another
-  // type, convert each element.  This handles FP<->INT cases.
-  if (SrcBitSize == DstBitSize) {
-    SmallVector<SDValue, 8> Ops;
-    for (SDValue Op : BV->op_values()) {
-      // If the vector element type is not legal, the BUILD_VECTOR operands
-      // are promoted and implicitly truncated.  Make that explicit here.
-      if (Op.getValueType() != SrcEltVT)
-        Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
-      Ops.push_back(DAG.getBitcast(DstEltVT, Op));
-      AddToWorklist(Ops.back().getNode());
-    }
-    EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
-                              BV->getValueType(0).getVectorNumElements());
-    return DAG.getBuildVector(VT, SDLoc(BV), Ops);
-  }
-
-  // Otherwise, we're growing or shrinking the elements.  To avoid having to
-  // handle annoying details of growing/shrinking FP values, we convert them to
-  // int first.
-  if (SrcEltVT.isFloatingPoint()) {
-    // Convert the input float vector to a int vector where the elements are the
-    // same sizes.
-    EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
-    BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
-    SrcEltVT = IntVT;
-  }
-
-  // Now we know the input is an integer vector.  If the output is a FP type,
-  // convert to integer first, then to FP of the right size.
-  if (DstEltVT.isFloatingPoint()) {
-    EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
-    SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
-
-    // Next, convert to FP elements of the same size.
-    return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
-  }
-
-  // Okay, we know the src/dst types are both integers of differing types.
-  assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
-
-  // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
-  // BuildVectorSDNode?
-  auto *BVN = cast<BuildVectorSDNode>(BV);
-
-  // Extract the constant raw bit data.
-  BitVector UndefElements;
-  SmallVector<APInt> RawBits;
-  bool IsLE = DAG.getDataLayout().isLittleEndian();
-  if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
-    return SDValue();
-
-  SDLoc DL(BV);
-  SmallVector<SDValue, 8> Ops;
-  for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
-    if (UndefElements[I])
-      Ops.push_back(DAG.getUNDEF(DstEltVT));
-    else
-      Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
-  }
-
-  EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
-  return DAG.getBuildVector(VT, DL, Ops);
-}
-
 // Returns true if floating point contraction is allowed on the FMUL-SDValue
 // `N`
 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..388487c598e3e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7280,6 +7280,77 @@ SDValue SelectionDAG::foldConstantFPMath(unsigned Opcode, const SDLoc &DL,
   return SDValue();
 }
 
+SDValue SelectionDAG::FoldConstantBuildVector(BuildVectorSDNode *BV,
+                                              const SDLoc &DL, EVT DstEltVT) {
+  EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
+
+  // If this is already the right type, we're done.
+  if (SrcEltVT == DstEltVT)
+    return SDValue(BV, 0);
+
+  unsigned SrcBitSize = SrcEltVT.getSizeInBits();
+  unsigned DstBitSize = DstEltVT.getSizeInBits();
+
+  // If this is a conversion of N elements of one type to N elements of another
+  // type, convert each element.  This handles FP<->INT cases.
+  if (SrcBitSize == DstBitSize) {
+    SmallVector<SDValue, 8> Ops;
+    for (SDValue Op : BV->op_values()) {
+      // If the vector element type is not legal, the BUILD_VECTOR operands
+      // are promoted and implicitly truncated.  Make that explicit here.
+      if (Op.getValueType() != SrcEltVT)
+        Op = getNode(ISD::TRUNCATE, DL, SrcEltVT, Op);
+      Ops.push_back(getBitcast(DstEltVT, Op));
+    }
+    EVT VT = EVT::getVectorVT(*getContext(), DstEltVT,
+                              BV->getValueType(0).getVectorNumElements());
+    return getBuildVector(VT, DL, Ops);
+  }
+
+  // Otherwise, we're growing or shrinking the elements.  To avoid having to
+  // handle annoying details of growing/shrinking FP values, we convert them to
+  // int first.
+  if (SrcEltVT.isFloatingPoint()) {
+    // Convert the input float vector to a int vector where the elements are the
+    // same sizes.
+    EVT IntEltVT = EVT::getIntegerVT(*getContext(), SrcEltVT.getSizeInBits());
+    if (SDValue Tmp = FoldConstantBuildVector(BV, DL, IntEltVT))
+      return FoldConstantBuildVector(cast<BuildVectorSDNode>(Tmp), DL,
+                                     DstEltVT);
+    return SDValue();
+  }
+
+  // Now we know the input is an integer vector.  If the output is a FP type,
+  // convert to integer first, then to FP of the right size.
+  if (DstEltVT.isFloatingPoint()) {
+    EVT IntEltVT = EVT::getIntegerVT(*getContext(), DstEltVT.getSizeInBits());
+    if (SDValue Tmp = FoldConstantBuildVector(BV, DL, IntEltVT))
+      return FoldConstantBuildVector(cast<BuildVectorSDNode>(Tmp), DL,
+                                     DstEltVT);
+  }
+
+  // Okay, we know the src/dst types are both integers of differing types.
+  assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
+
+  // Extract the constant raw bit data.
+  BitVector UndefElements;
+  SmallVector<APInt> RawBits;
+  bool IsLE = getDataLayout().isLittleEndian();
+  if (!BV->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
+    return SDValue();
+
+  SmallVector<SDValue, 8> Ops;
+  for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
+    if (UndefElements[I])
+      Ops.push_back(getUNDEF(DstEltVT));
+    else
+      Ops.push_back(getConstant(RawBits[I], DL, DstEltVT));
+  }
+
+  EVT VT = EVT::getVectorVT(*getContext(), DstEltVT, Ops.size());
+  return getBuildVector(VT, DL, Ops);
+}
+
 SDValue SelectionDAG::getAssertAlign(const SDLoc &DL, SDValue Val, Align A) {
   assert(Val.getValueType().isInteger() && "Invalid AssertAlign!");
 
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll
index a319f1260d870..85dd275207902 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll
@@ -5,18 +5,19 @@ define amdgpu_kernel void @test_iglp_opt_rev_mfma_gemm(<1 x i64> %L1) {
 ; GCN-LABEL: test_iglp_opt_rev_mfma_gemm:
 ; GCN:       ; %bb.0: ; %entry
 ; GCN-NEXT:    v_mov_b32_e32 v32, 0
+; GCN-NEXT:    ds_read_b128 v[0:3], v32
 ; GCN-NEXT:    s_load_dwordx2 s[0:1], s[8:9], 0x0
 ; GCN-NEXT:    ds_read_b128 v[28:31], v32 offset:112
 ; GCN-NEXT:    ds_read_b128 v[24:27], v32 offset:96
 ; GCN-NEXT:    ds_read_b128 v[20:23], v32 offset:80
 ; GCN-NEXT:    ds_read_b128 v[16:19], v32 offset:64
-; GCN-NEXT:    ds_read_b128 v[0:3], v32
 ; GCN-NEXT:    ds_read_b128 v[4:7], v32 offset:16
 ; GCN-NEXT:    ds_read_b128 v[8:11], v32 offset:32
 ; GCN-NEXT:    ds_read_b128 v[12:15], v32 offset:48
-; GCN-NEXT:    v_mov_b32_e32 v34, 0
-; GCN-NEXT:    v_mov_b32_e32 v35, v34
 ; GCN-NEXT:    s_waitcnt lgkmcnt(0)
+; GCN-NEXT:    ds_write_b128 v32, v[0:3]
+; GCN-NEXT:    v_mov_b32_e32 v0, 0
+; GCN-NEXT:    v_mov_b32_e32 v1, v0
 ; GCN-NEXT:    s_cmp_lg_u64 s[0:1], 0
 ; GCN-NEXT:    ; iglp_opt mask(0x00000001)
 ; GCN-NEXT:    ds_write_b128 v32, v[28:31] offset:112
@@ -26,8 +27,7 @@ define amdgpu_kernel void @test_iglp_opt_rev_mfma_gemm(<1 x i64> %L1) {
 ; GCN-NEXT:    ds_write_b128 v32, v[12:15] offset:48
 ; GCN-NEXT:    ds_write_b128 v32, v[8:11] offset:32
 ; GCN-NEXT:    ds_write_b128 v32, v[4:7] offset:16
-; GCN-NEXT:    ds_write_b128 v32, v[0:3]
-; GCN-NEXT:    ds_write_b64 v32, v[34:35]
+; GCN-NEXT:    ds_write_b64 v32, v[0:1]
 ; GCN-NEXT:    s_endpgm
 entry:
   call void @llvm.amdgcn.iglp.opt(i32 1)



More information about the llvm-commits mailing list