[llvm] [DAGCombiner] Require same type of splat & element for build_vector (PR #88284)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 10 08:37:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Feng Zou (fzou1)

<details>
<summary>Changes</summary>

Only allow to change build_vector to concat_vector when the splat type and vector element type is same. It's to fix assertion of failing to bitcast types of different sizes.

---
Full diff: https://github.com/llvm/llvm-project/pull/88284.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+16-11) 
- (added) llvm/test/CodeGen/X86/buildvec-bitcast.ll (+24) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8fe074666a3dc9..ee90ca8eaa7d7f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23429,17 +23429,22 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
   // TODO: Maybe this is useful for non-splat too?
   if (!LegalOperations) {
     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
-      Splat = peekThroughBitcasts(Splat);
-      EVT SrcVT = Splat.getValueType();
-      if (SrcVT.isVector()) {
-        unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
-        EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
-                                     SrcVT.getVectorElementType(), NumElts);
-        if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
-          SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
-          SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
-                                       NewVT, Ops);
-          return DAG.getBitcast(VT, Concat);
+      EVT SplatVT = Splat.getValueType();
+      // Only change build_vector to a concat_vector if the splat value type is
+      // same as the vector element type.
+      if (SplatVT == VT.getVectorElementType()) {
+        Splat = peekThroughBitcasts(Splat);
+        EVT SrcVT = Splat.getValueType();
+        if (SrcVT.isVector()) {
+          unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
+          EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
+                                       SrcVT.getVectorElementType(), NumElts);
+          if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
+            SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
+            SDValue Concat =
+                DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), NewVT, Ops);
+            return DAG.getBitcast(VT, Concat);
+          }
         }
       }
     }
diff --git a/llvm/test/CodeGen/X86/buildvec-bitcast.ll b/llvm/test/CodeGen/X86/buildvec-bitcast.ll
new file mode 100644
index 00000000000000..9f7167f2946bf9
--- /dev/null
+++ b/llvm/test/CodeGen/X86/buildvec-bitcast.ll
@@ -0,0 +1,24 @@
+; RUN: llc < %s -mtriple=x86_64 -mattr=avx512bw | FileCheck %s
+
+; Verify that the DAGCombiner doesn't change build_vector to concat_vectors if
+; the vector element type is different than splat type. The example here:
+;   v8i1 = build_vector (i8 (bitcast (v8i1 X))), ..., (i8 (bitcast (v8i1 X))))
+
+; CHECK:      foo:
+; CHECK:      # %bb.0: # %entry
+; CHECK-NEXT: retq
+
+define void @foo(<8 x i1> %mask.i1) {
+entry:
+  %0 = and <8 x i1> %mask.i1, <i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>
+  %1 = bitcast <8 x i1> %0 to i8
+  %2 = icmp ne i8 %1, 0
+  %insert54 = insertelement <8 x i1> zeroinitializer, i1 %2, i64 0
+  %splat55 = shufflevector <8 x i1> %insert54, <8 x i1> zeroinitializer, <8 x i32> zeroinitializer
+  %3 = and <8 x i1> %0, %splat55
+  br label %end
+
+end:                           ; preds = %entry
+  %4 = select <8 x i1> %3, <8 x i1> zeroinitializer, <8 x i1> zeroinitializer
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/88284


More information about the llvm-commits mailing list