[llvm] fix for Issue #139786 - Missed Optimization: max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 (PR #140526)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 19 03:33:14 PDT 2025


https://github.com/Charukesh827 created https://github.com/llvm/llvm-project/pull/140526

As suggested generalize to fold max(max(x, c1) binop c2, c3) —> max(x binop c2, c3) if c3>=C1* 2 ^ c2 is done.

define i8 @src(i8 %arg0) {
  %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
  %2 = shl nuw i8 %1, 1
  %3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
  ret i8 %3
}

define i8 @tgt(i8 %arg0) {
  %1 = shl nuw i8 %arg0, 1
  %2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
  ret i8 %2
}

>From 7e4e9d20707e54f756bbaa888ab7652cb66ee071 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 15:17:33 +0530
Subject: [PATCH 1/3] =?UTF-8?q?Add=20test=20for=20shifting=20binop=20=20in?=
 =?UTF-8?q?=20InstCombine=20transformation=20for=20the=20issue=20=20Missed?=
 =?UTF-8?q?=20Optimization:=20max(max(x,=20c1)=20<<=20c2,=20c3)=20?=
 =?UTF-8?q?=E2=80=94>=20max(x=20<<=20c2,=20c3)=20when=20c3=20>=3D=20c1=20*?=
 =?UTF-8?q?=202=20^=20c2=20#139786?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../Transforms/InstCombine/shift-binop.ll     | 27 +++++++++++++++++++
 1 file changed, 27 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/shift-binop.ll

diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
new file mode 100644
index 0000000000000..78e9c5ea21181
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -0,0 +1,27 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @src(i8 %arg0) {
+; CHECK-LABEL: @src(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
+  %2 = shl nuw i8 %1, 1
+  %3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
+  ret i8 %3
+}
+
+define i8 @tgt(i8 %arg0) {
+; CHECK-LABEL: @tgt(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT:    ret i8 [[TMP2]]
+;
+  %1 = shl nuw i8 %arg0, 1
+  %2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
+  ret i8 %2
+}
+
+declare i8 @llvm.umax.i8(i8, i8)

>From 4e2ddfca186233798b88b7845a5f3400c0400349 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 15:55:15 +0530
Subject: [PATCH 2/3] =?UTF-8?q?[InstCombine]=20Fix=20Missed=20Optimization?=
 =?UTF-8?q?:=20max(max(x,=20c1)=20<<=20c2,=20c3)=20=E2=80=94>=20max(x=20<<?=
 =?UTF-8?q?=20c2,=20c3)=20when=20c3=20>=3D=20c1=20*=202=20^=20c2?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch fixes issue #139786
 where InstCombine where it Missed Optimization: max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2. Pre-committed test in <commit-hash>.

Alive2: https://alive2.llvm.org/ce/z/on2tJE
---
 .../InstCombine/InstCombineCalls.cpp          | 85 +++++++++++++++++++
 1 file changed, 85 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 3d35bf753c40e..ab05faa5b8f7a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -76,6 +76,8 @@
 #include <utility>
 #include <vector>
 
+#include<iostream>
+
 #define DEBUG_TYPE "instcombine"
 #include "llvm/Transforms/Utils/InstructionWorklist.h"
 
@@ -1171,6 +1173,84 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
   return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
+
+//Try canonicalize min/max(x << shamt, c<<shamt) into max(x, c) << shamt
+static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
+  Intrinsic::ID MinMaxID = II->getIntrinsicID();
+  assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin ||
+    MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) &&
+   "Expected a min or max intrinsic");
+  
+  Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
+  Value *InnerMax;
+  const APInt *C;
+  if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) || 
+      !match(Op1, m_APInt(C)))
+      return nullptr;
+  
+  auto* BinOpInst = cast<BinaryOperator>(Op0);
+  Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
+  Value *X;
+  InnerMax = BinOpInst->getOperand(0);
+  // std::cout<< InnerMax->dump() <<std::endl;
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umax>(m_Value(X),m_APInt(C))))){
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smax>(m_Value(X),m_APInt(C))))){
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umin>(m_Value(X),m_APInt(C))))){
+  if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smin>(m_Value(X),m_APInt(C))))){
+     return nullptr;
+  }}}}
+  
+  auto *InnerMaxInst = cast<IntrinsicInst>(InnerMax);
+
+  bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin;
+  if((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+     (!IsSigned && !BinOpInst->hasNoUnsignedWrap())) 
+     return nullptr;
+
+  // Check if BinOp is a left shift
+  if (BinOp != Instruction::Shl) {
+    return nullptr;
+  }
+
+  APInt C2=llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue() ;
+  APInt C3=llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+  APInt C1=llvm::dyn_cast<llvm::ConstantInt>(InnerMaxInst->getOperand(1))->getValue();
+
+  // Compute C1 * 2^C2
+  APInt Two = APInt(C2.getBitWidth(), 2);
+  APInt Pow2C2 = Two.shl(C2); // 2^C2
+  APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+
+  // Check C3 >= C1 * 2^C2
+  if (C3.ult(C1TimesPow2C2)) {
+    return nullptr;
+  }
+
+  //Create new x binop c2
+  Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMaxInst->getOperand(0), BinOpInst->getOperand(1) );
+  
+  //Create new min/max intrinsic with new binop and c3
+  
+    if(IsSigned){
+      cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(true);
+      cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(false);
+    }else{
+      cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(true);
+      cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(false);
+    }
+  
+
+  // Get the intrinsic function for MinMaxID
+  Type *Ty = II->getType();
+  Function *MinMaxFn = Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
+
+  // Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
+  Value *Args[] = {NewBinOp, Op1};
+  Instruction *NewMax = CallInst::Create(MinMaxFn, Args, "", nullptr);
+
+  return NewMax;
+}
+
 /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
 Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
   Type *Ty = MinMax1.getType();
@@ -2035,6 +2115,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *I = moveAddAfterMinMax(II, Builder))
       return I;
 
+    // minmax(x << shamt , c << shamt) -> minmax(x, c) << shamt
+    if (Instruction *I = moveShiftAfterMinMax(II, Builder))  
+      return I;
+
+
     // minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C
     const APInt *RHSC;
     if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) &&

>From b035dcb7232cfcad9f34c1dac1e907c79669c8c0 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 16:00:01 +0530
Subject: [PATCH 3/3] removed unwanted header (iostream)

---
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index ab05faa5b8f7a..53dd5f803f97b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -76,8 +76,6 @@
 #include <utility>
 #include <vector>
 
-#include<iostream>
-
 #define DEBUG_TYPE "instcombine"
 #include "llvm/Transforms/Utils/InstructionWorklist.h"
 



More information about the llvm-commits mailing list