[llvm] caa124b - [InstCombine] Zero-extend shift amounts in narrow funnel shift ops

Antonio Frighetto via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 7 05:16:20 PST 2023


Author: Antonio Frighetto
Date: 2023-11-07T14:15:32+01:00
New Revision: caa124b58d9f59b36ec751c98ca1bd1d1ceb0a73

URL: https://github.com/llvm/llvm-project/commit/caa124b58d9f59b36ec751c98ca1bd1d1ceb0a73
DIFF: https://github.com/llvm/llvm-project/commit/caa124b58d9f59b36ec751c98ca1bd1d1ceb0a73.diff

LOG: [InstCombine] Zero-extend shift amounts in narrow funnel shift ops

An issue arose when handling shift amounts while performing
narrowed funnel shifts simplification. Specifically, shift
amounts were incorrectly truncated when their type was
narrower than the target bit width. This has been addressed
by zero-extending `ShAmt` in such cases.

Fixes: https://github.com/llvm/llvm-project/issues/71463.

Proof: https://alive2.llvm.org/ce/z/5draKz.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/test/Transforms/InstCombine/rotate.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index efd18b44657e5da..46ef17d0e628276 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -502,11 +502,20 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
   if (!MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc))
     return nullptr;
 
-  // We have an unnecessarily wide rotate!
-  // trunc (or (shl ShVal0, ShAmt), (lshr ShVal1, BitWidth - ShAmt))
-  // Narrow the inputs and convert to funnel shift intrinsic:
-  // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt))
-  Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
+  // Adjust the width of ShAmt for narrowed funnel shift operation:
+  // - Zero-extend if ShAmt is narrower than the destination type.
+  // - Truncate if ShAmt is wider, discarding non-significant high-order bits.
+  // This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal),
+  // zext/trunc(ShAmt)).
+  Value *NarrowShAmt;
+  if (ShAmt->getType()->getScalarSizeInBits() < NarrowWidth) {
+    // If ShAmt is narrower than the destination type, zero-extend it.
+    NarrowShAmt = Builder.CreateZExt(ShAmt, DestTy, "shamt.zext");
+  } else {
+    // If ShAmt is wider than the destination type, truncate it.
+    NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy, "shamt.trunc");
+  }
+
   Value *X, *Y;
   X = Y = Builder.CreateTrunc(ShVal0, DestTy);
   if (ShVal0 != ShVal1)

diff  --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll
index ed5145255b2f072..1668da9d20c12a7 100644
--- a/llvm/test/Transforms/InstCombine/rotate.ll
+++ b/llvm/test/Transforms/InstCombine/rotate.ll
@@ -421,8 +421,8 @@ define <2 x i16> @rotate_left_commute_16bit_vec(<2 x i16> %v, <2 x i32> %shift)
 
 define i8 @rotate_right_8bit(i8 %v, i3 %shift) {
 ; CHECK-LABEL: @rotate_right_8bit(
-; CHECK-NEXT:    [[TMP1:%.*]] = zext i3 [[SHIFT:%.*]] to i8
-; CHECK-NEXT:    [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = zext i3 [[SHIFT:%.*]] to i8
+; CHECK-NEXT:    [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i8 [[CONV2]]
 ;
   %and = zext i3 %shift to i32
@@ -441,10 +441,10 @@ define i8 @rotate_right_8bit(i8 %v, i3 %shift) {
 define i8 @rotate_right_commute_8bit_unmasked_shl(i32 %v, i32 %shift) {
 ; CHECK-LABEL: @rotate_right_commute_8bit_unmasked_shl(
 ; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i8
-; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 3
-; CHECK-NEXT:    [[TMP3:%.*]] = trunc i32 [[V:%.*]] to i8
-; CHECK-NEXT:    [[TMP4:%.*]] = trunc i32 [[V]] to i8
-; CHECK-NEXT:    [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP3]], i8 [[TMP4]], i8 [[TMP2]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = and i8 [[TMP1]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[V:%.*]] to i8
+; CHECK-NEXT:    [[TMP3:%.*]] = trunc i32 [[V]] to i8
+; CHECK-NEXT:    [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP2]], i8 [[TMP3]], i8 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i8 [[CONV2]]
 ;
   %and = and i32 %shift, 3
@@ -462,10 +462,10 @@ define i8 @rotate_right_commute_8bit_unmasked_shl(i32 %v, i32 %shift) {
 define i8 @rotate_right_commute_8bit(i32 %v, i32 %shift) {
 ; CHECK-LABEL: @rotate_right_commute_8bit(
 ; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i8
-; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 3
-; CHECK-NEXT:    [[TMP3:%.*]] = trunc i32 [[V:%.*]] to i8
-; CHECK-NEXT:    [[TMP4:%.*]] = trunc i32 [[V]] to i8
-; CHECK-NEXT:    [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP3]], i8 [[TMP4]], i8 [[TMP2]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = and i8 [[TMP1]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[V:%.*]] to i8
+; CHECK-NEXT:    [[TMP3:%.*]] = trunc i32 [[V]] to i8
+; CHECK-NEXT:    [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP2]], i8 [[TMP3]], i8 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i8 [[CONV2]]
 ;
   %and = and i32 %shift, 3
@@ -483,8 +483,8 @@ define i8 @rotate_right_commute_8bit(i32 %v, i32 %shift) {
 
 define i8 @rotate8_not_safe(i8 %v, i32 %shamt) {
 ; CHECK-LABEL: @rotate8_not_safe(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
-; CHECK-NEXT:    [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
+; CHECK-NEXT:    [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %conv = zext i8 %v to i32
@@ -597,8 +597,8 @@ define i8 @rotateright_8_neg_mask_commute(i8 %v, i8 %shamt) {
 
 define i16 @rotateright_16_neg_mask_wide_amount(i16 %v, i32 %shamt) {
 ; CHECK-LABEL: @rotateright_16_neg_mask_wide_amount(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
-; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[TMP1]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i16 [[RET]]
 ;
   %neg = sub i32 0, %shamt
@@ -614,8 +614,8 @@ define i16 @rotateright_16_neg_mask_wide_amount(i16 %v, i32 %shamt) {
 
 define i16 @rotateright_16_neg_mask_wide_amount_commute(i16 %v, i32 %shamt) {
 ; CHECK-LABEL: @rotateright_16_neg_mask_wide_amount_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
-; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[TMP1]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i16
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.fshr.i16(i16 [[V:%.*]], i16 [[V]], i16 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i16 [[RET]]
 ;
   %neg = sub i32 0, %shamt
@@ -648,8 +648,8 @@ define i64 @rotateright_64_zext_neg_mask_amount(i64 %0, i32 %1) {
 
 define i8 @rotateleft_8_neg_mask_wide_amount(i8 %v, i32 %shamt) {
 ; CHECK-LABEL: @rotateleft_8_neg_mask_wide_amount(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
-; CHECK-NEXT:    [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
+; CHECK-NEXT:    [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %neg = sub i32 0, %shamt
@@ -665,8 +665,8 @@ define i8 @rotateleft_8_neg_mask_wide_amount(i8 %v, i32 %shamt) {
 
 define i8 @rotateleft_8_neg_mask_wide_amount_commute(i8 %v, i32 %shamt) {
 ; CHECK-LABEL: @rotateleft_8_neg_mask_wide_amount_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
-; CHECK-NEXT:    [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[TMP1]])
+; CHECK-NEXT:    [[SHAMT_TRUNC:%.*]] = trunc i32 [[SHAMT:%.*]] to i8
+; CHECK-NEXT:    [[RET:%.*]] = call i8 @llvm.fshl.i8(i8 [[V:%.*]], i8 [[V]], i8 [[SHAMT_TRUNC]])
 ; CHECK-NEXT:    ret i8 [[RET]]
 ;
   %neg = sub i32 0, %shamt
@@ -957,3 +957,24 @@ define i8 @unmasked_shlop_unmasked_shift_amount(i32 %x, i32 %shamt) {
   %t8 = trunc i32 %t7 to i8
   ret i8 %t8
 }
+
+define i1 @check_rotate_masked_16bit(i8 %0, i32 %1) {
+; CHECK-LABEL: @check_rotate_masked_16bit(
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP1:%.*]], 1
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i32 [[TMP3]], 0
+; CHECK-NEXT:    ret i1 [[TMP4]]
+;
+  %3 = and i32 %1, 1
+  %4 = and i8 %0, 15
+  %5 = zext i8 %4 to i32
+  %6 = lshr i32 %3, %5
+  %7 = sub i8 0, %0
+  %8 = and i8 %7, 15
+  %9 = zext i8 %8 to i32
+  %10 = shl nuw nsw i32 %3, %9
+  %11 = or i32 %6, %10
+  %12 = trunc i32 %11 to i16
+  %13 = sext i16 %12 to i64
+  %14 = icmp uge i64 0, %13
+  ret i1 %14
+}


        


More information about the llvm-commits mailing list