[llvm-branch-commits] [llvm] release/19.x: [SDAG][ISel][TableGen][LoongArch] Report error for trivial bitcasts when there are predicate calls (#116075) (PR #116797)

Tobias Hieta via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Nov 25 00:36:48 PST 2024


https://github.com/tru updated https://github.com/llvm/llvm-project/pull/116797

>From ea960400213ad6a619d40536ae8965fca70eac89 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Sat, 16 Nov 2024 20:55:33 -0800
Subject: [PATCH 1/2] [Mips] Change vsplat_imm_eq_1 to a ComplexPattern.
 (#116471)

Resolves a FIXME and avoids needing to workaround #116075.

Adding parentheses around the (vsplat_imm_eq_1) fixes the error cited in
the FIXME by changing the ComplexPattern from a leaf node to an
operator.

(cherry picked from commit 2f4572f5e7e2d7f4626e825404c11f07d191fb05)
---
 llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp   |  4 ++
 llvm/lib/Target/Mips/MipsISelDAGToDAG.h     |  3 ++
 llvm/lib/Target/Mips/MipsMSAInstrInfo.td    | 59 ++++++++-------------
 llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp | 12 +++++
 llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h   |  3 ++
 5 files changed, 43 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp b/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp
index f6f32fde3b7778..a9ffd2bedf21e0 100644
--- a/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp
+++ b/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp
@@ -220,6 +220,10 @@ bool MipsDAGToDAGISel::selectVSplatMaskR(SDValue N, SDValue &Imm) const {
   return false;
 }
 
+bool MipsDAGToDAGISel::selectVSplatImmEq1(SDValue N) const {
+  llvm_unreachable("Unimplemented function.");
+}
+
 /// Convert vector addition with vector subtraction if that allows to encode
 /// constant as an immediate and thus avoid extra 'ldi' instruction.
 /// add X, <-1, -1...> --> sub X, <1, 1...>
diff --git a/llvm/lib/Target/Mips/MipsISelDAGToDAG.h b/llvm/lib/Target/Mips/MipsISelDAGToDAG.h
index 6135f968078542..3485300a782c94 100644
--- a/llvm/lib/Target/Mips/MipsISelDAGToDAG.h
+++ b/llvm/lib/Target/Mips/MipsISelDAGToDAG.h
@@ -120,6 +120,9 @@ class MipsDAGToDAGISel : public SelectionDAGISel {
   /// starting at bit zero.
   virtual bool selectVSplatMaskR(SDValue N, SDValue &Imm) const;
 
+  /// Select constant vector splats whose value is 1.
+  virtual bool selectVSplatImmEq1(SDValue N) const;
+
   /// Convert vector addition with vector subtraction if that allows to encode
   /// constant as an immediate and thus avoid extra 'ldi' instruction.
   /// add X, <-1, -1...> --> sub X, <1, 1...>
diff --git a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
index c4abccb24c6f35..f4c32c9dcd4212 100644
--- a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
+++ b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
@@ -198,14 +198,8 @@ def vsplati32 : PatFrag<(ops node:$e0),
                         (v4i32 (build_vector node:$e0, node:$e0,
                                              node:$e0, node:$e0))>;
 
-def vsplati64_imm_eq_1 : PatLeaf<(bitconvert (v4i32 (build_vector))), [{
-  APInt Imm;
-  SDNode *BV = N->getOperand(0).getNode();
-  EVT EltTy = N->getValueType(0).getVectorElementType();
-
-  return selectVSplat(BV, Imm, EltTy.getSizeInBits()) &&
-         Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 1;
-}]>;
+// Any build_vector that is a constant splat with a value that equals 1
+def vsplat_imm_eq_1 : ComplexPattern<vAny, 0, "selectVSplatImmEq1">;
 
 def vsplati64 : PatFrag<(ops node:$e0),
                         (v2i64 (build_vector node:$e0, node:$e0))>;
@@ -217,7 +211,7 @@ def vsplati64_splat_d : PatFrag<(ops node:$e0),
                                                                 node:$e0,
                                                                 node:$e0,
                                                                 node:$e0)),
-                                           vsplati64_imm_eq_1))))>;
+                                           (vsplat_imm_eq_1)))))>;
 
 def vsplatf32 : PatFrag<(ops node:$e0),
                         (v4f32 (build_vector node:$e0, node:$e0,
@@ -352,46 +346,35 @@ def vsplat_maskr_bits_uimm6
     : SplatComplexPattern<vsplat_uimm6, vAny, 1, "selectVSplatMaskR",
                           [build_vector, bitconvert]>;
 
-// Any build_vector that is a constant splat with a value that equals 1
-// FIXME: These should be a ComplexPattern but we can't use them because the
-//        ISel generator requires the uses to have a name, but providing a name
-//        causes other errors ("used in pattern but not operand list")
-def vsplat_imm_eq_1 : PatLeaf<(build_vector), [{
-  APInt Imm;
-  EVT EltTy = N->getValueType(0).getVectorElementType();
-
-  return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
-         Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 1;
-}]>;
 
 def vbclr_b : PatFrag<(ops node:$ws, node:$wt),
-                      (and node:$ws, (vnot (shl vsplat_imm_eq_1, node:$wt)))>;
+                      (and node:$ws, (vnot (shl (vsplat_imm_eq_1), node:$wt)))>;
 def vbclr_h : PatFrag<(ops node:$ws, node:$wt),
-                      (and node:$ws, (vnot (shl vsplat_imm_eq_1, node:$wt)))>;
+                      (and node:$ws, (vnot (shl (vsplat_imm_eq_1), node:$wt)))>;
 def vbclr_w : PatFrag<(ops node:$ws, node:$wt),
-                      (and node:$ws, (vnot (shl vsplat_imm_eq_1, node:$wt)))>;
+                      (and node:$ws, (vnot (shl (vsplat_imm_eq_1), node:$wt)))>;
 def vbclr_d : PatFrag<(ops node:$ws, node:$wt),
-                      (and node:$ws, (vnot (shl (v2i64 vsplati64_imm_eq_1),
+                      (and node:$ws, (vnot (shl (v2i64 (vsplat_imm_eq_1)),
                                                node:$wt)))>;
 
 def vbneg_b : PatFrag<(ops node:$ws, node:$wt),
-                      (xor node:$ws, (shl vsplat_imm_eq_1, node:$wt))>;
+                      (xor node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>;
 def vbneg_h : PatFrag<(ops node:$ws, node:$wt),
-                      (xor node:$ws, (shl vsplat_imm_eq_1, node:$wt))>;
+                      (xor node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>;
 def vbneg_w : PatFrag<(ops node:$ws, node:$wt),
-                      (xor node:$ws, (shl vsplat_imm_eq_1, node:$wt))>;
+                      (xor node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>;
 def vbneg_d : PatFrag<(ops node:$ws, node:$wt),
-                      (xor node:$ws, (shl (v2i64 vsplati64_imm_eq_1),
+                      (xor node:$ws, (shl (v2i64 (vsplat_imm_eq_1)),
                                           node:$wt))>;
 
 def vbset_b : PatFrag<(ops node:$ws, node:$wt),
-                      (or node:$ws, (shl vsplat_imm_eq_1, node:$wt))>;
+                      (or node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>;
 def vbset_h : PatFrag<(ops node:$ws, node:$wt),
-                      (or node:$ws, (shl vsplat_imm_eq_1, node:$wt))>;
+                      (or node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>;
 def vbset_w : PatFrag<(ops node:$ws, node:$wt),
-                      (or node:$ws, (shl vsplat_imm_eq_1, node:$wt))>;
+                      (or node:$ws, (shl (vsplat_imm_eq_1), node:$wt))>;
 def vbset_d : PatFrag<(ops node:$ws, node:$wt),
-                      (or node:$ws, (shl (v2i64 vsplati64_imm_eq_1),
+                      (or node:$ws, (shl (v2i64 (vsplat_imm_eq_1)),
                                          node:$wt))>;
 
 def muladd : PatFrag<(ops node:$wd, node:$ws, node:$wt),
@@ -3842,7 +3825,7 @@ class MSAShiftPat<SDNode Node, ValueType VT, MSAInst Insn, dag Vec> :
          (VT (Insn VT:$ws, VT:$wt))>;
 
 class MSABitPat<SDNode Node, ValueType VT, MSAInst Insn, PatFrag Frag> :
-  MSAPat<(VT (Node VT:$ws, (shl vsplat_imm_eq_1, (Frag VT:$wt)))),
+  MSAPat<(VT (Node VT:$ws, (shl (vsplat_imm_eq_1), (Frag VT:$wt)))),
          (VT (Insn VT:$ws, VT:$wt))>;
 
 multiclass MSAShiftPats<SDNode Node, string Insn> {
@@ -3861,7 +3844,7 @@ multiclass MSABitPats<SDNode Node, string Insn> {
   def : MSABitPat<Node, v16i8, !cast<MSAInst>(Insn#_B), vsplati8imm7>;
   def : MSABitPat<Node, v8i16, !cast<MSAInst>(Insn#_H), vsplati16imm15>;
   def : MSABitPat<Node, v4i32, !cast<MSAInst>(Insn#_W), vsplati32imm31>;
-  def : MSAPat<(Node v2i64:$ws, (shl (v2i64 vsplati64_imm_eq_1),
+  def : MSAPat<(Node v2i64:$ws, (shl (v2i64 (vsplat_imm_eq_1)),
                                      (vsplati64imm63 v2i64:$wt))),
                (v2i64 (!cast<MSAInst>(Insn#_D) v2i64:$ws, v2i64:$wt))>;
 }
@@ -3872,16 +3855,16 @@ defm : MSAShiftPats<sra, "SRA">;
 defm : MSABitPats<xor, "BNEG">;
 defm : MSABitPats<or, "BSET">;
 
-def : MSAPat<(and v16i8:$ws, (vnot (shl vsplat_imm_eq_1,
+def : MSAPat<(and v16i8:$ws, (vnot (shl (vsplat_imm_eq_1),
                                         (vsplati8imm7 v16i8:$wt)))),
              (v16i8 (BCLR_B v16i8:$ws, v16i8:$wt))>;
-def : MSAPat<(and v8i16:$ws, (vnot (shl vsplat_imm_eq_1,
+def : MSAPat<(and v8i16:$ws, (vnot (shl (vsplat_imm_eq_1),
                                         (vsplati16imm15 v8i16:$wt)))),
              (v8i16 (BCLR_H v8i16:$ws, v8i16:$wt))>;
-def : MSAPat<(and v4i32:$ws, (vnot (shl vsplat_imm_eq_1,
+def : MSAPat<(and v4i32:$ws, (vnot (shl (vsplat_imm_eq_1),
                                         (vsplati32imm31 v4i32:$wt)))),
              (v4i32 (BCLR_W v4i32:$ws, v4i32:$wt))>;
-def : MSAPat<(and v2i64:$ws, (vnot (shl (v2i64 vsplati64_imm_eq_1),
+def : MSAPat<(and v2i64:$ws, (vnot (shl (v2i64 (vsplat_imm_eq_1)),
                                         (vsplati64imm63 v2i64:$wt)))),
              (v2i64 (BCLR_D v2i64:$ws, v2i64:$wt))>;
 
diff --git a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp
index 7ad300c6cccd45..66c034a889c600 100644
--- a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp
+++ b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp
@@ -730,6 +730,18 @@ bool MipsSEDAGToDAGISel::selectVSplatUimmInvPow2(SDValue N,
   return false;
 }
 
+// Select const vector splat of 1.
+bool MipsSEDAGToDAGISel::selectVSplatImmEq1(SDValue N) const {
+  APInt ImmValue;
+  EVT EltTy = N->getValueType(0).getVectorElementType();
+
+  if (N->getOpcode() == ISD::BITCAST)
+    N = N->getOperand(0);
+
+  return selectVSplat(N.getNode(), ImmValue, EltTy.getSizeInBits()) &&
+         ImmValue.getBitWidth() == EltTy.getSizeInBits() && ImmValue == 1;
+}
+
 bool MipsSEDAGToDAGISel::trySelect(SDNode *Node) {
   unsigned Opcode = Node->getOpcode();
   SDLoc DL(Node);
diff --git a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h
index 7b843b0e0b2552..22d8e924ac534f 100644
--- a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h
+++ b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h
@@ -124,6 +124,9 @@ class MipsSEDAGToDAGISel : public MipsDAGToDAGISel {
   /// starting at bit zero.
   bool selectVSplatMaskR(SDValue N, SDValue &Imm) const override;
 
+  /// Select constant vector splats whose value is 1.
+  bool selectVSplatImmEq1(SDValue N) const override;
+
   bool trySelect(SDNode *Node) override;
 
   // Emits proper ABI for _mcount profiling calls.

>From 3d12f45e50b68ac908ef05571e5cc52f4b966d94 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 19 Nov 2024 21:24:40 +0800
Subject: [PATCH 2/2] [SDAG][ISel][TableGen][LoongArch] Report error for
 trivial bitcasts when there are predicate calls (#116075)

On loongarch64 with lsx extension, we select `VBITREV_W` for `v4i32 (xor
X, (shl splat(1), Y))`:

https://github.com/llvm/llvm-project/blob/8e6630391699116641cf390a10476295b7d4b95c/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td#L1583-L1584

And `vsplat_imm_eq_1` is defined as:

https://github.com/llvm/llvm-project/blob/8e6630391699116641cf390a10476295b7d4b95c/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td#L77-L87

For the `(bitconvert (v4i32 (build_vector)))` case, the pattern is
expected to be:
```
PATTERN: (xor:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, (shl:{ *:[v4i32] } (bitconvert:{ *:[v4i32] } (build_vector:{ *:[v4i32] }))<<P:Predicate_vsplat_imm_eq_1>>, v4i32:{ *:[v4i32] }:$vk))
RESULT:  (VBITREV_W:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, v4i32:{ *:[v4i32] }:$vk)
```

However, `simplifyTree` drops the `bitconvert` node and its predicates:

https://github.com/llvm/llvm-project/blob/8e6630391699116641cf390a10476295b7d4b95c/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp#L3036-L3062

Then llvm will match `vsplat_imm_eq_1` for any v4i32 splats and cause a
miscompilation:
```
PATTERN: (xor:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, (shl:{ *:[v4i32] } (build_vector:{ *:[v4i32] }), v4i32:{ *:[v4i32] }:$vk))
RESULT:  (VBITREV_W:{ *:[v4i32] } v4i32:{ *:[v4i32] }:$vj, v4i32:{ *:[v4i32] }:$vk)
```

This patch adds additional checks for predicates associated with the
trivial bitconvert node. Unused patterns in the LoongArch target are
also removed.

Fixes https://github.com/llvm/llvm-project/issues/116008.

(cherry picked from commit c727b48287cc96888f9e262f23d53cf635cf3b3d)
---
 .../Target/LoongArch/LoongArchLSXInstrInfo.td   |  6 ++----
 llvm/test/CodeGen/LoongArch/lsx/pr116008.ll     | 17 +++++++++++++++++
 .../TableGen/Common/CodeGenDAGPatterns.cpp      |  8 ++++++++
 3 files changed, 27 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/CodeGen/LoongArch/lsx/pr116008.ll

diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index 0580683c3ce303..0233baecf6dd9c 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -67,8 +67,7 @@ class VecCond<SDPatternOperator OpNode, ValueType TyNode,
   let usesCustomInserter = 1;
 }
 
-def vsplat_imm_eq_1 : PatFrags<(ops), [(build_vector),
-                                       (bitconvert (v4i32 (build_vector)))], [{
+def vsplat_imm_eq_1 : PatFrags<(ops), [(build_vector)], [{
   APInt Imm;
   EVT EltTy = N->getValueType(0).getVectorElementType();
 
@@ -109,8 +108,7 @@ def vsplati32_imm_eq_31 : PatFrags<(ops), [(build_vector)], [{
   return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
          Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 31;
 }]>;
-def vsplati64_imm_eq_63 : PatFrags<(ops), [(build_vector),
-                                           (bitconvert (v4i32 (build_vector)))], [{
+def vsplati64_imm_eq_63 : PatFrags<(ops), [(build_vector)], [{
   APInt Imm;
   EVT EltTy = N->getValueType(0).getVectorElementType();
 
diff --git a/llvm/test/CodeGen/LoongArch/lsx/pr116008.ll b/llvm/test/CodeGen/LoongArch/lsx/pr116008.ll
new file mode 100644
index 00000000000000..ba8ffc34931893
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/lsx/pr116008.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc --mtriple=loongarch64 --mattr=+lsx < %s | FileCheck %s
+
+define <4 x i32> @xor_shl_splat_vec_one(i32 %x, <4 x i32> %y) nounwind {
+; CHECK-LABEL: xor_shl_splat_vec_one:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vreplgr2vr.w $vr1, $a0
+; CHECK-NEXT:    vsll.w $vr0, $vr1, $vr0
+; CHECK-NEXT:    vbitrevi.w $vr0, $vr0, 0
+; CHECK-NEXT:    ret
+entry:
+  %ins = insertelement <4 x i32> poison, i32 %x, i64 0
+  %splat = shufflevector <4 x i32> %ins, <4 x i32> poison, <4 x i32> zeroinitializer
+  %shl = shl <4 x i32> %splat, %y
+  %xor = xor <4 x i32> %shl, splat (i32 1)
+  ret <4 x i32> %xor
+}
diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
index a8cecca0d4a54f..ca71569008d5ec 100644
--- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
@@ -3042,6 +3042,14 @@ static bool SimplifyTree(TreePatternNodePtr &N) {
       !N->getExtType(0).empty() &&
       N->getExtType(0) == N->getChild(0).getExtType(0) &&
       N->getName().empty()) {
+    if (!N->getPredicateCalls().empty()) {
+      std::string Str;
+      raw_string_ostream OS(Str);
+      OS << *N
+         << "\n trivial bitconvert node should not have predicate calls\n";
+      PrintFatalError(Str);
+      return false;
+    }
     N = N->getChildShared(0);
     SimplifyTree(N);
     return true;



More information about the llvm-branch-commits mailing list