[llvm] [InstCombine] Fold max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 (PR #140526)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 27 09:49:50 PDT 2025


https://github.com/Charukesh827 updated https://github.com/llvm/llvm-project/pull/140526

>From 7e4e9d20707e54f756bbaa888ab7652cb66ee071 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 15:17:33 +0530
Subject: [PATCH 1/8] =?UTF-8?q?Add=20test=20for=20shifting=20binop=20=20in?=
 =?UTF-8?q?=20InstCombine=20transformation=20for=20the=20issue=20=20Missed?=
 =?UTF-8?q?=20Optimization:=20max(max(x,=20c1)=20<<=20c2,=20c3)=20?=
 =?UTF-8?q?=E2=80=94>=20max(x=20<<=20c2,=20c3)=20when=20c3=20>=3D=20c1=20*?=
 =?UTF-8?q?=202=20^=20c2=20#139786?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../Transforms/InstCombine/shift-binop.ll     | 27 +++++++++++++++++++
 1 file changed, 27 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/shift-binop.ll

diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
new file mode 100644
index 0000000000000..78e9c5ea21181
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -0,0 +1,27 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @src(i8 %arg0) {
+; CHECK-LABEL: @src(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
+  %2 = shl nuw i8 %1, 1
+  %3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
+  ret i8 %3
+}
+
+define i8 @tgt(i8 %arg0) {
+; CHECK-LABEL: @tgt(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = shl nuw i8 %arg0, 1
+  %2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
+  ret i8 %2
+}
+
+declare i8 @llvm.umax.i8(i8, i8)

>From 4e2ddfca186233798b88b7845a5f3400c0400349 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 15:55:15 +0530
Subject: [PATCH 2/8] =?UTF-8?q?[InstCombine]=20Fix=20Missed=20Optimization?=
 =?UTF-8?q?:=20max(max(x,=20c1)=20<<=20c2,=20c3)=20=E2=80=94>=20max(x=20<<?=
 =?UTF-8?q?=20c2,=20c3)=20when=20c3=20>=3D=20c1=20*=202=20^=20c2?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch fixes issue #139786
 where InstCombine where it Missed Optimization: max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2. Pre-committed test in <commit-hash>.

Alive2: https://alive2.llvm.org/ce/z/on2tJE
---
 .../InstCombine/InstCombineCalls.cpp          | 85 +++++++++++++++++++
 1 file changed, 85 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 3d35bf753c40e..ab05faa5b8f7a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -76,6 +76,8 @@
 #include <utility>
 #include <vector>
 
+#include<iostream>
+
 #define DEBUG_TYPE "instcombine"
 #include "llvm/Transforms/Utils/InstructionWorklist.h"
 
@@ -1171,6 +1173,84 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
   return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
+
+//Try canonicalize min/max(x << shamt, c<<shamt) into max(x, c) << shamt
+static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
+  Intrinsic::ID MinMaxID = II->getIntrinsicID();
+  assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin ||
+    MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) &&
+   "Expected a min or max intrinsic");
+  
+  Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
+  Value *InnerMax;
+  const APInt *C;
+  if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) || 
+      !match(Op1, m_APInt(C)))
+      return nullptr;
+  
+  auto* BinOpInst = cast<BinaryOperator>(Op0);
+  Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
+  Value *X;
+  InnerMax = BinOpInst->getOperand(0);
+  // std::cout<< InnerMax->dump() <<std::endl;
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umax>(m_Value(X),m_APInt(C))))){
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smax>(m_Value(X),m_APInt(C))))){
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umin>(m_Value(X),m_APInt(C))))){
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smin>(m_Value(X),m_APInt(C))))){
+     return nullptr;
+  }}}}
+  
+  auto *InnerMaxInst = cast<IntrinsicInst>(InnerMax);
+
+  bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin;
+  if((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+     (!IsSigned && !BinOpInst->hasNoUnsignedWrap())) 
+     return nullptr;
+
+  // Check if BinOp is a left shift
+  if (BinOp != Instruction::Shl) {
+    return nullptr;
+  }
+
+  APInt C2=llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue() ;
+  APInt C3=llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+  APInt C1=llvm::dyn_cast<llvm::ConstantInt>(InnerMaxInst->getOperand(1))->getValue();
+
+  // Compute C1 * 2^C2
+  APInt Two = APInt(C2.getBitWidth(), 2);
+  APInt Pow2C2 = Two.shl(C2); // 2^C2
+  APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+
+  // Check C3 >= C1 * 2^C2
+  if (C3.ult(C1TimesPow2C2)) {
+    return nullptr;
+  }
+
+  //Create new x binop c2
+  Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMaxInst->getOperand(0), BinOpInst->getOperand(1) );
+  
+  //Create new min/max intrinsic with new binop and c3
+  
+    if(IsSigned){
+      cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(true);
+      cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(false);
+    }else{
+      cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(true);
+      cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(false);
+    }
+  
+
+  // Get the intrinsic function for MinMaxID
+  Type *Ty = II->getType();
+  Function *MinMaxFn = Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
+
+  // Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
+  Value *Args[] = {NewBinOp, Op1};
+  Instruction *NewMax = CallInst::Create(MinMaxFn, Args, "", nullptr);
+
+  return NewMax;
+}
+
 /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
 Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
   Type *Ty = MinMax1.getType();
@@ -2035,6 +2115,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *I = moveAddAfterMinMax(II, Builder))
       return I;
 
+    // minmax(x << shamt , c << shamt) -> minmax(x, c) << shamt
+    if (Instruction *I = moveShiftAfterMinMax(II, Builder))  
+      return I;
+
+
     // minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C
     const APInt *RHSC;
     if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) &&

>From b035dcb7232cfcad9f34c1dac1e907c79669c8c0 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 16:00:01 +0530
Subject: [PATCH 3/8] removed unwanted header (iostream)

---
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index ab05faa5b8f7a..53dd5f803f97b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -76,8 +76,6 @@
 #include <utility>
 #include <vector>
 
-#include<iostream>
-
 #define DEBUG_TYPE "instcombine"
 #include "llvm/Transforms/Utils/InstructionWorklist.h"
 

>From 19f60a92a0b03b6030b8d0bff8e85572b0013c5b Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 20 May 2025 13:16:31 +0530
Subject: [PATCH 4/8] added negative test

---
 .../Transforms/InstCombine/shift-binop.ll     | 28 ++++++++++---------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
index 78e9c5ea21181..a84a30002d917 100644
--- a/llvm/test/Transforms/InstCombine/shift-binop.ll
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -1,11 +1,11 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
-define i8 @src(i8 %arg0) {
-; CHECK-LABEL: @src(
-; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
-; CHECK-NEXT:    ret i8 [[TMP2]]
+define i8 @src1(i8 %arg0) {
+; CHECK-LABEL: @src1(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT:    [[OUTMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 16)
+; CHECK-NEXT:    ret i8 [[OUTMAX]]
 ;
   %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
   %2 = shl nuw i8 %1, 1
@@ -13,15 +13,17 @@ define i8 @src(i8 %arg0) {
   ret i8 %3
 }
 
-define i8 @tgt(i8 %arg0) {
-; CHECK-LABEL: @tgt(
-; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
-; CHECK-NEXT:    ret i8 [[TMP2]]
+define i8 @src2(i8 %arg0) {
+; CHECK-LABEL: @src2(
+; CHECK-NEXT:    [[INMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[ARG0:%.*]], i8 4)
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i8 [[INMAX:%.*]], 1
+; CHECK-NEXT:    [[OUTMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 10)
+; CHECK-NEXT:    ret i8 [[OUTMAX]]
 ;
-  %1 = shl nuw i8 %arg0, 1
-  %2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
-  ret i8 %2
+  %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 4)
+  %2 = shl nuw i8 %1, 1
+  %3 = call i8 @llvm.umax.i8(i8 %2, i8 10)
+  ret i8 %3
 }
 
 declare i8 @llvm.umax.i8(i8, i8)

>From e317ee2caede9c33a4962bb970c3d5358a74a2f9 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 20 May 2025 13:20:36 +0530
Subject: [PATCH 5/8] Made suggested changes but couldn't generalize to div/rem
 as suggested

"If it is the case, you should generalize it to handle most of binops (excluding div/rem), then use simplifyBinOp and simplifyBinaryIntrinsic to check if min/max(c1 binop c2, c3) folds to c3."
---
 .../InstCombine/InstCombineCalls.cpp          | 84 ++++++++++---------
 1 file changed, 43 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 53dd5f803f97b..3d1c8a32dc3af 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1172,51 +1172,50 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
 
-//Try canonicalize min/max(x << shamt, c<<shamt) into max(x, c) << shamt
-static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
+// Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, C3)
+static Instruction *moveShiftAfterMinMax(IntrinsicInst *II,
+                                         InstCombiner::BuilderTy &Builder) {
   Intrinsic::ID MinMaxID = II->getIntrinsicID();
-  assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin ||
-    MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) &&
-   "Expected a min or max intrinsic");
-  
+  assert(isa<MinMaxIntrinsic>(II) &&
+         "Expected a min or max intrinsic");
+
   Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
   Value *InnerMax;
   const APInt *C;
-  if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) || 
+  if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
       !match(Op1, m_APInt(C)))
-      return nullptr;
-  
-  auto* BinOpInst = cast<BinaryOperator>(Op0);
+    return nullptr;
+    
+  auto *BinOpInst = cast<BinaryOperator>(Op0);
   Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
-  Value *X;
+
   InnerMax = BinOpInst->getOperand(0);
-  // std::cout<< InnerMax->dump() <<std::endl;
-  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umax>(m_Value(X),m_APInt(C))))){
-  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smax>(m_Value(X),m_APInt(C))))){
-  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umin>(m_Value(X),m_APInt(C))))){
-  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smin>(m_Value(X),m_APInt(C))))){
-     return nullptr;
-  }}}}
-  
-  auto *InnerMaxInst = cast<IntrinsicInst>(InnerMax);
 
-  bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin;
-  if((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
-     (!IsSigned && !BinOpInst->hasNoUnsignedWrap())) 
-     return nullptr;
+  auto *InnerMinMaxInst = dyn_cast<MinMaxIntrinsic>(BinOpInst->getOperand(0));
+
+  if (!InnerMinMaxInst || !InnerMinMaxInst->hasOneUse()) 
+    return nullptr;
+
+  bool IsSigned = InnerMinMaxInst->isSigned();
+  if (InnerMinMaxInst->getIntrinsicID() != MinMaxID)
+    return nullptr;
+
+  if ((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+      (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
+    return nullptr;
 
   // Check if BinOp is a left shift
   if (BinOp != Instruction::Shl) {
     return nullptr;
   }
 
-  APInt C2=llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue() ;
-  APInt C3=llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
-  APInt C1=llvm::dyn_cast<llvm::ConstantInt>(InnerMaxInst->getOperand(1))->getValue();
+  APInt C2 =llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
+  APInt C3 =llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+  APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))->getValue();
 
   // Compute C1 * 2^C2
   APInt Two = APInt(C2.getBitWidth(), 2);
-  APInt Pow2C2 = Two.shl(C2); // 2^C2
+  APInt Pow2C2 = Two.shl(C2);        // 2^C2
   APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
 
   // Check C3 >= C1 * 2^C2
@@ -1224,23 +1223,24 @@ static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::Builde
     return nullptr;
   }
 
-  //Create new x binop c2
-  Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMaxInst->getOperand(0), BinOpInst->getOperand(1) );
-  
-  //Create new min/max intrinsic with new binop and c3
-  
-    if(IsSigned){
-      cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(true);
-      cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(false);
-    }else{
-      cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(true);
-      cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(false);
+  // Create new x binop c2
+  Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMinMaxInst->getOperand(0),
+                                        BinOpInst->getOperand(1));
+
+  // Create new min/max intrinsic with new binop and c3
+  if (auto *NewBinInst = dyn_cast<Instruction>(NewBinOp))
+    if (IsSigned) {
+      cast<Instruction>(NewBinOp)->setHasNoSignedWrap(true);
+      cast<Instruction>(NewBinOp)->setHasNoUnsignedWrap(false);
+    } else {
+      cast<Instruction>(NewBinOp)->setHasNoUnsignedWrap(true);
+      cast<Instruction>(NewBinOp)->setHasNoSignedWrap(false);
     }
-  
 
   // Get the intrinsic function for MinMaxID
   Type *Ty = II->getType();
-  Function *MinMaxFn = Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
+  Function *MinMaxFn =
+      Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
 
   // Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
   Value *Args[] = {NewBinOp, Op1};
@@ -1249,6 +1249,8 @@ static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::Builde
   return NewMax;
 }
 
+
+
 /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
 Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
   Type *Ty = MinMax1.getType();

>From 6f1d83f15bf6e3057f374824e96d370addedcdb9 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 20 May 2025 13:30:04 +0530
Subject: [PATCH 6/8] added motivation to shift-binop.ll

---
 llvm/test/Transforms/InstCombine/shift-binop.ll | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
index a84a30002d917..38b6dfa43e7d6 100644
--- a/llvm/test/Transforms/InstCombine/shift-binop.ll
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -1,6 +1,19 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
+;define i32 @src(i32 %x, i32 %shamt, i32 %c) {
+;  %shl = shl nuw i32 %x, %shamt
+;  %c2 = shl nuw i32 %c, %shamt
+;  %max = call i32 @llvm.umin(i32 %shl, i32 %c2)
+;  ret i32 %max
+;}
+
+;define i32 @tgt(i32 %x, i32 %shamt, i32 %c) {
+;  %max = call i32 @llvm.umin(i32 %x, i32 %c)
+;  %shl = shl i32 %max, %shamt
+;  ret i32 %shl
+;}
+
 define i8 @src1(i8 %arg0) {
 ; CHECK-LABEL: @src1(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i8 [[ARG0:%.*]], 1

>From daabf95cb67d3793727bfe5b7c52f5232899cf70 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 27 May 2025 22:02:29 +0530
Subject: [PATCH 7/8]  made the suggested changes as follows:

1)max(X, C1) binop C2 -> max(X binop C2, C1 binop C2) is not always safe for all binops. You can reuse the helper leftDistributesOverRight.
2)The function name should be updated.
In fact, this fold can be decomposed into two steps:

max(max(X,C1) binop C2, C3) -> // Associative laws
max(max(X binop C2, C1 binop C2), C3) -> // Commutative laws
max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold
max(X binop C2, C4)
---
 .../InstCombine/InstCombineCalls.cpp          | 147 +++++++++++++-----
 1 file changed, 112 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index aa57f59b23af4..04c26bad76d4b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1175,12 +1175,45 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
 
-// Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, C3)
-static Instruction *moveShiftAfterMinMax(IntrinsicInst *II,
-                                         InstCombiner::BuilderTy &Builder) {
+
+static bool rightDistributesOverLeft(Instruction::BinaryOps ROp, bool HasNUW,
+                                     bool HasNSW, Intrinsic::ID LOp) {
+  switch (LOp) {
+  case Intrinsic::umax:
+  case Intrinsic::umin:
+    // Unsigned min/max distribute over addition and left shift if no unsigned
+    // wrap.
+    if (HasNUW && (ROp == Instruction::Add || ROp == Instruction::Shl))
+      return true;
+    // Multiplication preserves order for unsigned min/max with no unsigned
+    // wrap.
+    if (HasNUW && ROp == Instruction::Mul)
+      return true;
+    return false;
+  case Intrinsic::smax:
+  case Intrinsic::smin:
+    // Signed min/max distribute over addition if no signed wrap.
+    if (HasNSW && ROp == Instruction::Add)
+      return true;
+    // Multiplication preserves order for signed min/max with no signed wrap.
+    if (HasNSW && ROp == Instruction::Mul)
+      return true;
+    return false;
+  default:
+    return false;
+  }
+}
+
+///  Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1
+///  binop C2, C3)) -> max(X binop C2, C4) max(max(X,C1) binop C2, C3) -> //
+///  Associative laws max(max(X binop C2, C1 binop C2), C3) -> // Commutative
+///  laws max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold max(X binop
+///  C2, C4)
+
+static Instruction *reduceMinMax(IntrinsicInst *II,
+                                 InstCombiner::BuilderTy &Builder) {
   Intrinsic::ID MinMaxID = II->getIntrinsicID();
-  assert(isa<MinMaxIntrinsic>(II) &&
-         "Expected a min or max intrinsic");
+  assert(isa<MinMaxIntrinsic>(II) && "Expected a min or max intrinsic");
 
   Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
   Value *InnerMax;
@@ -1188,15 +1221,14 @@ static Instruction *moveShiftAfterMinMax(IntrinsicInst *II,
   if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
       !match(Op1, m_APInt(C)))
     return nullptr;
-    
+
   auto *BinOpInst = cast<BinaryOperator>(Op0);
   Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
 
   InnerMax = BinOpInst->getOperand(0);
 
   auto *InnerMinMaxInst = dyn_cast<MinMaxIntrinsic>(BinOpInst->getOperand(0));
-
-  if (!InnerMinMaxInst || !InnerMinMaxInst->hasOneUse()) 
+  if (!InnerMinMaxInst || !InnerMinMaxInst->hasOneUse())
     return nullptr;
 
   bool IsSigned = InnerMinMaxInst->isSigned();
@@ -1207,53 +1239,98 @@ static Instruction *moveShiftAfterMinMax(IntrinsicInst *II,
       (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
     return nullptr;
 
-  // Check if BinOp is a left shift
-  if (BinOp != Instruction::Shl) {
+  if (!rightDistributesOverLeft(BinOp, BinOpInst->hasNoUnsignedWrap(),
+                                BinOpInst->hasNoSignedWrap(),
+                                InnerMinMaxInst->getIntrinsicID()))
     return nullptr;
-  }
 
-  APInt C2 =llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
-  APInt C3 =llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
-  APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))->getValue();
-
-  // Compute C1 * 2^C2
-  APInt Two = APInt(C2.getBitWidth(), 2);
-  APInt Pow2C2 = Two.shl(C2);        // 2^C2
-  APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+  // Get constant values
+  APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))
+                 ->getValue();
+  APInt C2 =
+      llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
+  APInt C3 =
+      llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+
+  // Constant fold: Compute C1 binop C2
+  APInt C1BinOpC2, Two, Pow2C2, C1TimesPow2C2;
+  bool overflow = false;
+  switch (BinOp) {
+  case Instruction::Add:
+    C1BinOpC2 = IsSigned ? C1.sadd_ov(C2, overflow) : C1.uadd_ov(C2, overflow);
+    break;
+  case Instruction::Mul:
+    C1BinOpC2 = IsSigned ? C1.smul_ov(C2, overflow) : C1.umul_ov(C2, overflow);
+    break;
+  case Instruction::Sub:
+    C1BinOpC2 = IsSigned ? C1.ssub_ov(C2, overflow) : C1.usub_ov(C2, overflow);
+    break;
+  case Instruction::Shl:
+    // Compute C1 * 2^C2
+    Two = APInt(C2.getBitWidth(), 2);
+    Pow2C2 = Two.shl(C2);        // 2^C2
+    C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+
+    // Check C3 >= C1 * 2^C2
+    if (C3.ult(C1TimesPow2C2)) {
+      return nullptr;
+    } else {
+      C1BinOpC2 = C1.shl(C2);
+    }
+    break;
+  default:
+    return nullptr; // Unsupported binary operation
+  }
 
-  // Check C3 >= C1 * 2^C2
-  if (C3.ult(C1TimesPow2C2)) {
-    return nullptr;
+  // Constant fold: Compute MinMaxID(C1 binop C2, C3) to get C4
+  APInt C4;
+  switch (MinMaxID) {
+  case Intrinsic::umax:
+    C4 = APIntOps::umax(C1BinOpC2, C3);
+    break;
+  case Intrinsic::umin:
+    C4 = APIntOps::umin(C1BinOpC2, C3);
+    break;
+  case Intrinsic::smax:
+    C4 = APIntOps::smax(C1BinOpC2, C3);
+    break;
+  case Intrinsic::smin:
+    C4 = APIntOps::smin(C1BinOpC2, C3);
+    break;
+  default:
+    return nullptr; // Unsupported intrinsic
   }
 
-  // Create new x binop c2
+  // Create new X binop C2
   Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMinMaxInst->getOperand(0),
                                         BinOpInst->getOperand(1));
 
-  // Create new min/max intrinsic with new binop and c3
-  if (auto *NewBinInst = dyn_cast<Instruction>(NewBinOp))
+  // Set overflow flags on new binary operation
+  if (auto *NewBinInst = dyn_cast<Instruction>(NewBinOp)) {
     if (IsSigned) {
-      cast<Instruction>(NewBinOp)->setHasNoSignedWrap(true);
-      cast<Instruction>(NewBinOp)->setHasNoUnsignedWrap(false);
+      NewBinInst->setHasNoSignedWrap(true);
+      NewBinInst->setHasNoUnsignedWrap(false);
     } else {
-      cast<Instruction>(NewBinOp)->setHasNoUnsignedWrap(true);
-      cast<Instruction>(NewBinOp)->setHasNoSignedWrap(false);
+      NewBinInst->setHasNoUnsignedWrap(true);
+      NewBinInst->setHasNoSignedWrap(false);
     }
+  }
+
+  // Create constant for C4
+  Value *C4Val = ConstantInt::get(II->getType(), C4);
 
   // Get the intrinsic function for MinMaxID
   Type *Ty = II->getType();
   Function *MinMaxFn =
       Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
 
-  // Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
-  Value *Args[] = {NewBinOp, Op1};
+  // Create new min/max intrinsic: MinMaxID(NewBinOp, C4)
+  Value *Args[] = {NewBinOp, C4Val};
   Instruction *NewMax = CallInst::Create(MinMaxFn, Args, "", nullptr);
 
   return NewMax;
 }
 
-
-
 /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
 Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
   Type *Ty = MinMax1.getType();
@@ -2118,8 +2195,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *I = moveAddAfterMinMax(II, Builder))
       return I;
 
-    // minmax(x << shamt , c << shamt) -> minmax(x, c) << shamt
-    if (Instruction *I = moveShiftAfterMinMax(II, Builder))  
+    // max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1 binop C2, C3)) -> max(X binop C2, C4)
+    if (Instruction *I = reduceMinMax(II, Builder))  
       return I;
 
 

>From f92f9ce3b823081b6b02074b119fe112cea57be5 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 27 May 2025 22:19:29 +0530
Subject: [PATCH 8/8] updated the test for the changed opt

---
 .../Transforms/InstCombine/shift-binop.ll     | 49 +++++++------------
 1 file changed, 17 insertions(+), 32 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
index 38b6dfa43e7d6..d53901e3d6349 100644
--- a/llvm/test/Transforms/InstCombine/shift-binop.ll
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -1,42 +1,27 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
-;define i32 @src(i32 %x, i32 %shamt, i32 %c) {
-;  %shl = shl nuw i32 %x, %shamt
-;  %c2 = shl nuw i32 %c, %shamt
-;  %max = call i32 @llvm.umin(i32 %shl, i32 %c2)
-;  ret i32 %max
-;}
-
-;define i32 @tgt(i32 %x, i32 %shamt, i32 %c) {
-;  %max = call i32 @llvm.umin(i32 %x, i32 %c)
-;  %shl = shl i32 %max, %shamt
-;  ret i32 %shl
-;}
-
-define i8 @src1(i8 %arg0) {
+define i32 @src1(i32 %arg0) {
 ; CHECK-LABEL: @src1(
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
-; CHECK-NEXT:    [[OUTMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 16)
-; CHECK-NEXT:    ret i8 [[OUTMAX]]
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 [[ARG0:%.*]], 2
+; CHECK-NEXT:    [[OUTMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[SHL]], i32 8)
+; CHECK-NEXT:    ret i32 [[OUTMIN]]
 ;
-  %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
-  %2 = shl nuw i8 %1, 1
-  %3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
-  ret i8 %3
+  %1 = call i32 @llvm.umin.i32(i32 %arg0, i32 2)
+  %2 = shl nuw i32 %1, 2
+  %3 = call i32 @llvm.umin.i32(i32 %2, i32 16)
+  ret i32 %3
 }
 
-define i8 @src2(i8 %arg0) {
+define i32 @src2(i32 %arg0) {
 ; CHECK-LABEL: @src2(
-; CHECK-NEXT:    [[INMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[ARG0:%.*]], i8 4)
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i8 [[INMAX:%.*]], 1
-; CHECK-NEXT:    [[OUTMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 10)
-; CHECK-NEXT:    ret i8 [[OUTMAX]]
+; CHECK-NEXT:    [[INMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ARG0:%.*]], i32 2)
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[INMAX]], 18
+; CHECK-NEXT:    [[OUTMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[SHL]], i32 10)
+; CHECK-NEXT:    ret i32 [[OUTMAX]]
 ;
-  %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 4)
-  %2 = shl nuw i8 %1, 1
-  %3 = call i8 @llvm.umax.i8(i8 %2, i8 10)
-  ret i8 %3
+  %1 = call i32 @llvm.smax.i32(i32 %arg0, i32 2)
+  %2 = shl nsw i32 %1, 18
+  %3 = call i32 @llvm.smax.i32(i32 %2, i32 10)
+  ret i32 %3
 }
-
-declare i8 @llvm.umax.i8(i8, i8)



More information about the llvm-commits mailing list