[llvm] Reduce shl64 to shl32 if shift range is [63-32] (PR #125574)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 3 12:34:27 PST 2025


https://github.com/LU-JOHN created https://github.com/llvm/llvm-project/pull/125574

Reduce:

   DST = shl i64 X, Y

where Y is in the range [63-32] to:

   DST = [shl i32 X, (Y - 32), 0]



>From 8c63a11a568bd54b12e241abd214988dbd27c0da Mon Sep 17 00:00:00 2001
From: John Lu <John.Lu at amd.com>
Date: Mon, 3 Feb 2025 09:33:06 -0600
Subject: [PATCH 1/3] Reduce shl64 to shl32 if shift range is [63-32]

Signed-off-by: John Lu <John.Lu at amd.com>
---
 .../Transforms/InstCombine/InstCombiner.h     |  2 ++
 .../InstCombine/InstCombineShifts.cpp         | 29 +++++++++++++++++
 .../InstCombine/InstructionCombining.cpp      |  7 +++++
 .../amdgpu-simplify-libcall-pow-codegen.ll    |  4 +--
 .../AMDGPU/amdgpu-simplify-libcall-pown.ll    | 17 +++++-----
 .../Transforms/InstCombine/shl64-reduce.ll    | 31 +++++++++++++++++++
 6 files changed, 80 insertions(+), 10 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/shl64-reduce.ll

diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index fa6b60cba15aaf..dfd275b020ed75 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -521,6 +521,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
                              bool AllowMultipleUsers = false) = 0;
 
   bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
+
+  bool shouldReduceShl64ToShl32();
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 7ef95800975dba..31f57b307c8fa4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1032,6 +1032,31 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
   return Changed;
 }
 
+static Instruction *transformClampedShift64(BinaryOperator &I,
+					    const SimplifyQuery &Q,
+					    InstCombiner::BuilderTy &Builder) {
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  Type *I32Type = Type::getInt32Ty(I.getContext());
+  Type *I64Type = Type::getInt64Ty(I.getContext());
+
+  if (I.getType() == I64Type) {
+    KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
+    if (KnownAmt.getMinValue().uge(32)) {
+      Value *TruncVal         = Builder.CreateTrunc(Op0, I32Type);
+      Value *TruncShiftAmt    = Builder.CreateTrunc(Op1, I32Type);
+      Value *AdjustedShiftAmt = Builder.CreateSub  (TruncShiftAmt,
+                                                   ConstantInt::get(I32Type, 32));
+      Value *Shl32   = Builder.CreateShl(TruncVal, AdjustedShiftAmt);
+      Value *VResult = Builder.CreateVectorSplat(2, ConstantInt::get(I32Type, 0));
+
+      VResult = Builder.CreateInsertElement(VResult, Shl32,
+                                           ConstantInt::get(I32Type, 1));
+      return CastInst::Create(Instruction::BitCast, VResult, I64Type);
+    }
+  }
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
 
@@ -1266,6 +1291,10 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
     }
   }
 
+  if (this->shouldReduceShl64ToShl32())
+    if (Instruction *V = transformClampedShift64(I, Q, Builder))
+      return V;
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5621511570b581..2c19261152ae04 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -194,6 +194,13 @@ bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
   return TTIForTargetIntrinsicsOnly.isValidAddrSpaceCast(FromAS, ToAS);
 }
 
+bool InstCombiner::shouldReduceShl64ToShl32() {
+  InstructionCost costShl32 = TTIForTargetIntrinsicsOnly.getArithmeticInstrCost(Instruction::Shl, Builder.getInt32Ty(), TTI::TCK_Latency);
+  InstructionCost costShl64 = TTIForTargetIntrinsicsOnly.getArithmeticInstrCost(Instruction::Shl, Builder.getInt64Ty(), TTI::TCK_Latency);
+
+  return costShl32<costShl64;
+}
+
 Value *InstCombinerImpl::EmitGEPOffset(GEPOperator *GEP, bool RewriteGEP) {
   if (!RewriteGEP)
     return llvm::emitGEPOffset(&Builder, DL, GEP);
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pow-codegen.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pow-codegen.ll
index ab2363860af9de..84ac4af6584677 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pow-codegen.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pow-codegen.ll
@@ -174,7 +174,7 @@ define double @test_pow_fast_f64__integral_y(double %x, i32 %y.i) {
 ; CHECK-NEXT:    s_waitcnt lgkmcnt(0)
 ; CHECK-NEXT:    s_swappc_b64 s[30:31], s[16:17]
 ; CHECK-NEXT:    v_lshlrev_b32_e32 v2, 31, v41
-; CHECK-NEXT:    v_and_b32_e32 v2, v2, v42
+; CHECK-NEXT:    v_and_b32_e32 v2, v42, v2
 ; CHECK-NEXT:    buffer_load_dword v42, off, s[0:3], s33 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword v41, off, s[0:3], s33 offset:4 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword v40, off, s[0:3], s33 offset:8 ; 4-byte Folded Reload
@@ -458,7 +458,7 @@ define double @test_pown_fast_f64(double %x, i32 %y) {
 ; CHECK-NEXT:    s_waitcnt lgkmcnt(0)
 ; CHECK-NEXT:    s_swappc_b64 s[30:31], s[16:17]
 ; CHECK-NEXT:    v_lshlrev_b32_e32 v2, 31, v41
-; CHECK-NEXT:    v_and_b32_e32 v2, v2, v42
+; CHECK-NEXT:    v_and_b32_e32 v2, v42, v2
 ; CHECK-NEXT:    buffer_load_dword v42, off, s[0:3], s33 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword v41, off, s[0:3], s33 offset:4 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword v40, off, s[0:3], s33 offset:8 ; 4-byte Folded Reload
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pown.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pown.ll
index f9c359bc114ed3..5155e42fef3cbf 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pown.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-pown.ll
@@ -720,14 +720,15 @@ define double @test_pown_afn_nnan_ninf_f64(double %x, i32 %y) {
 ; CHECK-NEXT:    [[POWNI2F:%.*]] = sitofp i32 [[Y]] to double
 ; CHECK-NEXT:    [[__YLOGX:%.*]] = fmul nnan ninf afn double [[__LOG2]], [[POWNI2F]]
 ; CHECK-NEXT:    [[__EXP2:%.*]] = call nnan ninf afn double @_Z4exp2d(double [[__YLOGX]])
-; CHECK-NEXT:    [[__YTOU:%.*]] = zext i32 [[Y]] to i64
-; CHECK-NEXT:    [[__YEVEN:%.*]] = shl i64 [[__YTOU]], 63
-; CHECK-NEXT:    [[TMP0:%.*]] = bitcast double [[X]] to i64
-; CHECK-NEXT:    [[__POW_SIGN:%.*]] = and i64 [[__YEVEN]], [[TMP0]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[__EXP2]] to i64
-; CHECK-NEXT:    [[TMP2:%.*]] = or i64 [[__POW_SIGN]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast i64 [[TMP2]] to double
-; CHECK-NEXT:    ret double [[TMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = shl i32 [[Y]], 31
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x i32> <i32 0, i32 poison>, i32 [[TMP0]], i64 1
+; CHECK-NEXT:    [[__YEVEN:%.*]] = bitcast <2 x i32> [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast double [[X]] to i64
+; CHECK-NEXT:    [[__POW_SIGN:%.*]] = and i64 [[TMP2]], [[__YEVEN]]
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast double [[__EXP2]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = or i64 [[__POW_SIGN]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast i64 [[TMP4]] to double
+; CHECK-NEXT:    ret double [[TMP5]]
 ;
 entry:
   %call = tail call nnan ninf afn double @_Z4powndi(double %x, i32 %y)
diff --git a/llvm/test/Transforms/InstCombine/shl64-reduce.ll b/llvm/test/Transforms/InstCombine/shl64-reduce.ll
new file mode 100644
index 00000000000000..4ab827757242a9
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shl64-reduce.ll
@@ -0,0 +1,31 @@
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+
+target triple = "amdgcn-amd-amdhsa"
+
+define i64 @func_range(i64 noundef %arg0, ptr %arg1.ptr) {
+  %shift.amt = load i64, ptr %arg1.ptr, !range !0
+  %shl = shl i64 %arg0, %shift.amt
+  ret i64 %shl
+
+; CHECK:  define i64 @func_range(i64 noundef %arg0, ptr %arg1.ptr) {
+; CHECK:  %shift.amt = load i64, ptr %arg1.ptr, align 8, !range !0
+; CHECK:  %1 = trunc i64 %arg0 to i32
+; CHECK:  %2 = trunc nuw nsw i64 %shift.amt to i32
+; CHECK:  %3 = add nsw i32 %2, -32
+; CHECK:  %4 = shl i32 %1, %3
+; CHECK:  %5 = insertelement <2 x i32> <i32 0, i32 poison>, i32 %4, i64 1
+; CHECK:  %shl = bitcast <2 x i32> %5 to i64
+; CHECK:  ret i64 %shl
+
+}
+!0 = !{i64 32, i64 64}
+
+define i64 @func_max(i64 noundef %arg0, i64 noundef %arg1) {
+  %max = call i64 @llvm.umax.i64(i64 %arg1, i64 32)
+  %min = call i64 @llvm.umin.i64(i64 %max,  i64 63)  
+  %shl = shl i64 %arg0, %min
+  ret i64 %shl
+}
+  
+

>From 736c1e16c348cd17be1fa9e112662d7170ddb4a3 Mon Sep 17 00:00:00 2001
From: John Lu <John.Lu at amd.com>
Date: Mon, 3 Feb 2025 10:02:28 -0600
Subject: [PATCH 2/3] Add checks for case where range comes from min/max calls

Signed-off-by: John Lu <John.Lu at amd.com>
---
 .../test/Transforms/InstCombine/shl64-reduce.ll | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/shl64-reduce.ll b/llvm/test/Transforms/InstCombine/shl64-reduce.ll
index 4ab827757242a9..00f6c82fae9ad0 100644
--- a/llvm/test/Transforms/InstCombine/shl64-reduce.ll
+++ b/llvm/test/Transforms/InstCombine/shl64-reduce.ll
@@ -1,8 +1,17 @@
+;; Test reduction of:
+;;
+;;   DST = shl i64 X, Y
+;;
+;; where Y is in the range [63-32] to:
+;;
+;;   DST = [shl i32 X, (Y - 32), 0]
+
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
 
 target triple = "amdgcn-amd-amdhsa"
 
+; Test reduction where range information comes from meta-data
 define i64 @func_range(i64 noundef %arg0, ptr %arg1.ptr) {
   %shift.amt = load i64, ptr %arg1.ptr, !range !0
   %shl = shl i64 %arg0, %shift.amt
@@ -21,11 +30,19 @@ define i64 @func_range(i64 noundef %arg0, ptr %arg1.ptr) {
 }
 !0 = !{i64 32, i64 64}
 
+; FIXME: This case should be reduced too, but computeKnownBits() cannot
+;        determine the range.  Match current results for now.
 define i64 @func_max(i64 noundef %arg0, i64 noundef %arg1) {
   %max = call i64 @llvm.umax.i64(i64 %arg1, i64 32)
   %min = call i64 @llvm.umin.i64(i64 %max,  i64 63)  
   %shl = shl i64 %arg0, %min
   ret i64 %shl
+
+; CHECK:  define i64 @func_max(i64 noundef %arg0, i64 noundef %arg1) {
+; CHECK:    %max = call i64 @llvm.umax.i64(i64 %arg1, i64 32)
+; CHECK:    %min = call i64 @llvm.umin.i64(i64 %max,  i64 63)
+; CHECK:    %shl = shl i64 %arg0, %min
+; CHECK:    ret i64 %shl
 }
   
 

>From 9ad122ffdc5c41654a12cc72e160cb237d066505 Mon Sep 17 00:00:00 2001
From: John Lu <John.Lu at amd.com>
Date: Mon, 3 Feb 2025 10:03:27 -0600
Subject: [PATCH 3/3] Apply clang-format

Signed-off-by: John Lu <John.Lu at amd.com>
---
 .../InstCombine/InstCombineShifts.cpp         | 19 ++++++++++---------
 .../InstCombine/InstructionCombining.cpp      |  8 +++++---
 2 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 31f57b307c8fa4..3ced23671f11a8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1033,8 +1033,8 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
 }
 
 static Instruction *transformClampedShift64(BinaryOperator &I,
-					    const SimplifyQuery &Q,
-					    InstCombiner::BuilderTy &Builder) {
+                                            const SimplifyQuery &Q,
+                                            InstCombiner::BuilderTy &Builder) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   Type *I32Type = Type::getInt32Ty(I.getContext());
   Type *I64Type = Type::getInt64Ty(I.getContext());
@@ -1042,15 +1042,16 @@ static Instruction *transformClampedShift64(BinaryOperator &I,
   if (I.getType() == I64Type) {
     KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
     if (KnownAmt.getMinValue().uge(32)) {
-      Value *TruncVal         = Builder.CreateTrunc(Op0, I32Type);
-      Value *TruncShiftAmt    = Builder.CreateTrunc(Op1, I32Type);
-      Value *AdjustedShiftAmt = Builder.CreateSub  (TruncShiftAmt,
-                                                   ConstantInt::get(I32Type, 32));
-      Value *Shl32   = Builder.CreateShl(TruncVal, AdjustedShiftAmt);
-      Value *VResult = Builder.CreateVectorSplat(2, ConstantInt::get(I32Type, 0));
+      Value *TruncVal = Builder.CreateTrunc(Op0, I32Type);
+      Value *TruncShiftAmt = Builder.CreateTrunc(Op1, I32Type);
+      Value *AdjustedShiftAmt =
+          Builder.CreateSub(TruncShiftAmt, ConstantInt::get(I32Type, 32));
+      Value *Shl32 = Builder.CreateShl(TruncVal, AdjustedShiftAmt);
+      Value *VResult =
+          Builder.CreateVectorSplat(2, ConstantInt::get(I32Type, 0));
 
       VResult = Builder.CreateInsertElement(VResult, Shl32,
-                                           ConstantInt::get(I32Type, 1));
+                                            ConstantInt::get(I32Type, 1));
       return CastInst::Create(Instruction::BitCast, VResult, I64Type);
     }
   }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 2c19261152ae04..d356741fcdf21e 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -195,10 +195,12 @@ bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const {
 }
 
 bool InstCombiner::shouldReduceShl64ToShl32() {
-  InstructionCost costShl32 = TTIForTargetIntrinsicsOnly.getArithmeticInstrCost(Instruction::Shl, Builder.getInt32Ty(), TTI::TCK_Latency);
-  InstructionCost costShl64 = TTIForTargetIntrinsicsOnly.getArithmeticInstrCost(Instruction::Shl, Builder.getInt64Ty(), TTI::TCK_Latency);
+  InstructionCost costShl32 = TTIForTargetIntrinsicsOnly.getArithmeticInstrCost(
+      Instruction::Shl, Builder.getInt32Ty(), TTI::TCK_Latency);
+  InstructionCost costShl64 = TTIForTargetIntrinsicsOnly.getArithmeticInstrCost(
+      Instruction::Shl, Builder.getInt64Ty(), TTI::TCK_Latency);
 
-  return costShl32<costShl64;
+  return costShl32 < costShl64;
 }
 
 Value *InstCombinerImpl::EmitGEPOffset(GEPOperator *GEP, bool RewriteGEP) {



More information about the llvm-commits mailing list