[llvm] [DAG] Extend input types if needed in combineShiftToAVG. (PR #76791)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 3 00:50:05 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

This atempts to fix #<!-- -->76734 which is a crash in invalid TRUNC nodes types from unoptimized input code in combineShiftToAVG. The NVT can be VT if the larger type was legal and the adds will not overflow, in which case the inputs should be extended.

>From what I can tell this appears to be valid (if not optimal for this case): https://alive2.llvm.org/ce/z/fRieHR

The result has also been changed to getExtOrTrunc in case that VT==NVT, which is not handled by SEXT/ZEXT.

---
Full diff: https://github.com/llvm/llvm-project/pull/76791.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+3-4) 
- (modified) llvm/test/CodeGen/AArch64/arm64-vhadd.ll (+18) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 4581bb19e97ec3..66cdd752590875 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1064,10 +1064,9 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
 
   SDLoc DL(Op);
   SDValue ResultAVG =
-      DAG.getNode(AVGOpc, DL, NVT, DAG.getNode(ISD::TRUNCATE, DL, NVT, ExtOpA),
-                  DAG.getNode(ISD::TRUNCATE, DL, NVT, ExtOpB));
-  return DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT,
-                     ResultAVG);
+      DAG.getNode(AVGOpc, DL, NVT, DAG.getExtOrTrunc(IsSigned, ExtOpA, DL, NVT),
+                  DAG.getExtOrTrunc(IsSigned, ExtOpB, DL, NVT));
+  return DAG.getExtOrTrunc(IsSigned, ResultAVG, DL, VT);
 }
 
 /// Look at Op. At this point, we know that only the OriginalDemandedBits of the
diff --git a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll
index e287eff5abb946..2224d46705253b 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll
@@ -1392,6 +1392,24 @@ define <8 x i8> @sextmask3v8i8(<8 x i16> %src1, <8 x i8> %src2) {
   ret <8 x i8> %result
 }
 
+define <4 x i16> @ext_via_i19(i64 %0, <4 x i16> %1) {
+; CHECK-LABEL: ext_via_i19:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi.4s v1, #1
+; CHECK-NEXT:    uaddw.4s v0, v1, v0
+; CHECK-NEXT:    uhadd.4s v0, v0, v1
+; CHECK-NEXT:    xtn.4h v0, v0
+; CHECK-NEXT:    ret
+  %3 = zext <4 x i16> %1 to <4 x i32>
+  %4 = add <4 x i32> %3, <i32 1, i32 1, i32 1, i32 1>
+  %5 = trunc <4 x i32> %4 to <4 x i19>
+  %new0 = add <4 x i19> %5, <i19 1, i19 1, i19 1, i19 1>
+  %new1 = lshr <4 x i19> %new0, <i19 1, i19 1, i19 1, i19 1>
+  %last = zext <4 x i19> %new1 to <4 x i32>
+  %6 = trunc <4 x i32> %last to <4 x i16>
+  ret <4 x i16> %6
+}
+
 declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>)
 declare <4 x i16> @llvm.aarch64.neon.srhadd.v4i16(<4 x i16>, <4 x i16>)
 declare <2 x i32> @llvm.aarch64.neon.srhadd.v2i32(<2 x i32>, <2 x i32>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/76791


More information about the llvm-commits mailing list