[llvm] [InstCombine] Fold max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 (PR #140526)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 20 01:03:07 PDT 2025
https://github.com/Charukesh827 updated https://github.com/llvm/llvm-project/pull/140526
>From 7e4e9d20707e54f756bbaa888ab7652cb66ee071 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 15:17:33 +0530
Subject: [PATCH 1/6] =?UTF-8?q?Add=20test=20for=20shifting=20binop=20=20in?=
=?UTF-8?q?=20InstCombine=20transformation=20for=20the=20issue=20=20Missed?=
=?UTF-8?q?=20Optimization:=20max(max(x,=20c1)=20<<=20c2,=20c3)=20?=
=?UTF-8?q?=E2=80=94>=20max(x=20<<=20c2,=20c3)=20when=20c3=20>=3D=20c1=20*?=
=?UTF-8?q?=202=20^=20c2=20#139786?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../Transforms/InstCombine/shift-binop.ll | 27 +++++++++++++++++++
1 file changed, 27 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/shift-binop.ll
diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
new file mode 100644
index 0000000000000..78e9c5ea21181
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -0,0 +1,27 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @src(i8 %arg0) {
+; CHECK-LABEL: @src(
+; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
+ %2 = shl nuw i8 %1, 1
+ %3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
+ ret i8 %3
+}
+
+define i8 @tgt(i8 %arg0) {
+; CHECK-LABEL: @tgt(
+; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = shl nuw i8 %arg0, 1
+ %2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
+ ret i8 %2
+}
+
+declare i8 @llvm.umax.i8(i8, i8)
>From 4e2ddfca186233798b88b7845a5f3400c0400349 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 15:55:15 +0530
Subject: [PATCH 2/6] =?UTF-8?q?[InstCombine]=20Fix=20Missed=20Optimization?=
=?UTF-8?q?:=20max(max(x,=20c1)=20<<=20c2,=20c3)=20=E2=80=94>=20max(x=20<<?=
=?UTF-8?q?=20c2,=20c3)=20when=20c3=20>=3D=20c1=20*=202=20^=20c2?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch fixes issue #139786
where InstCombine where it Missed Optimization: max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2. Pre-committed test in <commit-hash>.
Alive2: https://alive2.llvm.org/ce/z/on2tJE
---
.../InstCombine/InstCombineCalls.cpp | 85 +++++++++++++++++++
1 file changed, 85 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 3d35bf753c40e..ab05faa5b8f7a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -76,6 +76,8 @@
#include <utility>
#include <vector>
+#include<iostream>
+
#define DEBUG_TYPE "instcombine"
#include "llvm/Transforms/Utils/InstructionWorklist.h"
@@ -1171,6 +1173,84 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
+
+//Try canonicalize min/max(x << shamt, c<<shamt) into max(x, c) << shamt
+static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
+ Intrinsic::ID MinMaxID = II->getIntrinsicID();
+ assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin ||
+ MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) &&
+ "Expected a min or max intrinsic");
+
+ Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
+ Value *InnerMax;
+ const APInt *C;
+ if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
+ !match(Op1, m_APInt(C)))
+ return nullptr;
+
+ auto* BinOpInst = cast<BinaryOperator>(Op0);
+ Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
+ Value *X;
+ InnerMax = BinOpInst->getOperand(0);
+ // std::cout<< InnerMax->dump() <<std::endl;
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umax>(m_Value(X),m_APInt(C))))){
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smax>(m_Value(X),m_APInt(C))))){
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umin>(m_Value(X),m_APInt(C))))){
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smin>(m_Value(X),m_APInt(C))))){
+ return nullptr;
+ }}}}
+
+ auto *InnerMaxInst = cast<IntrinsicInst>(InnerMax);
+
+ bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin;
+ if((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+ (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
+ return nullptr;
+
+ // Check if BinOp is a left shift
+ if (BinOp != Instruction::Shl) {
+ return nullptr;
+ }
+
+ APInt C2=llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue() ;
+ APInt C3=llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+ APInt C1=llvm::dyn_cast<llvm::ConstantInt>(InnerMaxInst->getOperand(1))->getValue();
+
+ // Compute C1 * 2^C2
+ APInt Two = APInt(C2.getBitWidth(), 2);
+ APInt Pow2C2 = Two.shl(C2); // 2^C2
+ APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+
+ // Check C3 >= C1 * 2^C2
+ if (C3.ult(C1TimesPow2C2)) {
+ return nullptr;
+ }
+
+ //Create new x binop c2
+ Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMaxInst->getOperand(0), BinOpInst->getOperand(1) );
+
+ //Create new min/max intrinsic with new binop and c3
+
+ if(IsSigned){
+ cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(true);
+ cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(false);
+ }else{
+ cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(true);
+ cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(false);
+ }
+
+
+ // Get the intrinsic function for MinMaxID
+ Type *Ty = II->getType();
+ Function *MinMaxFn = Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
+
+ // Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
+ Value *Args[] = {NewBinOp, Op1};
+ Instruction *NewMax = CallInst::Create(MinMaxFn, Args, "", nullptr);
+
+ return NewMax;
+}
+
/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();
@@ -2035,6 +2115,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = moveAddAfterMinMax(II, Builder))
return I;
+ // minmax(x << shamt , c << shamt) -> minmax(x, c) << shamt
+ if (Instruction *I = moveShiftAfterMinMax(II, Builder))
+ return I;
+
+
// minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C
const APInt *RHSC;
if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) &&
>From b035dcb7232cfcad9f34c1dac1e907c79669c8c0 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Mon, 19 May 2025 16:00:01 +0530
Subject: [PATCH 3/6] removed unwanted header (iostream)
---
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index ab05faa5b8f7a..53dd5f803f97b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -76,8 +76,6 @@
#include <utility>
#include <vector>
-#include<iostream>
-
#define DEBUG_TYPE "instcombine"
#include "llvm/Transforms/Utils/InstructionWorklist.h"
>From 19f60a92a0b03b6030b8d0bff8e85572b0013c5b Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 20 May 2025 13:16:31 +0530
Subject: [PATCH 4/6] added negative test
---
.../Transforms/InstCombine/shift-binop.ll | 28 ++++++++++---------
1 file changed, 15 insertions(+), 13 deletions(-)
diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
index 78e9c5ea21181..a84a30002d917 100644
--- a/llvm/test/Transforms/InstCombine/shift-binop.ll
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -1,11 +1,11 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-define i8 @src(i8 %arg0) {
-; CHECK-LABEL: @src(
-; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
-; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
-; CHECK-NEXT: ret i8 [[TMP2]]
+define i8 @src1(i8 %arg0) {
+; CHECK-LABEL: @src1(
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT: [[OUTMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 16)
+; CHECK-NEXT: ret i8 [[OUTMAX]]
;
%1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
%2 = shl nuw i8 %1, 1
@@ -13,15 +13,17 @@ define i8 @src(i8 %arg0) {
ret i8 %3
}
-define i8 @tgt(i8 %arg0) {
-; CHECK-LABEL: @tgt(
-; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
-; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
-; CHECK-NEXT: ret i8 [[TMP2]]
+define i8 @src2(i8 %arg0) {
+; CHECK-LABEL: @src2(
+; CHECK-NEXT: [[INMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[ARG0:%.*]], i8 4)
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw i8 [[INMAX:%.*]], 1
+; CHECK-NEXT: [[OUTMAX:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 10)
+; CHECK-NEXT: ret i8 [[OUTMAX]]
;
- %1 = shl nuw i8 %arg0, 1
- %2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
- ret i8 %2
+ %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 4)
+ %2 = shl nuw i8 %1, 1
+ %3 = call i8 @llvm.umax.i8(i8 %2, i8 10)
+ ret i8 %3
}
declare i8 @llvm.umax.i8(i8, i8)
>From e317ee2caede9c33a4962bb970c3d5358a74a2f9 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 20 May 2025 13:20:36 +0530
Subject: [PATCH 5/6] Made suggested changes but couldn't generalize to div/rem
as suggested
"If it is the case, you should generalize it to handle most of binops (excluding div/rem), then use simplifyBinOp and simplifyBinaryIntrinsic to check if min/max(c1 binop c2, c3) folds to c3."
---
.../InstCombine/InstCombineCalls.cpp | 84 ++++++++++---------
1 file changed, 43 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 53dd5f803f97b..3d1c8a32dc3af 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1172,51 +1172,50 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
-//Try canonicalize min/max(x << shamt, c<<shamt) into max(x, c) << shamt
-static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
+// Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, C3)
+static Instruction *moveShiftAfterMinMax(IntrinsicInst *II,
+ InstCombiner::BuilderTy &Builder) {
Intrinsic::ID MinMaxID = II->getIntrinsicID();
- assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin ||
- MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) &&
- "Expected a min or max intrinsic");
-
+ assert(isa<MinMaxIntrinsic>(II) &&
+ "Expected a min or max intrinsic");
+
Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
Value *InnerMax;
const APInt *C;
- if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
+ if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
!match(Op1, m_APInt(C)))
- return nullptr;
-
- auto* BinOpInst = cast<BinaryOperator>(Op0);
+ return nullptr;
+
+ auto *BinOpInst = cast<BinaryOperator>(Op0);
Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
- Value *X;
+
InnerMax = BinOpInst->getOperand(0);
- // std::cout<< InnerMax->dump() <<std::endl;
- if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umax>(m_Value(X),m_APInt(C))))){
- if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smax>(m_Value(X),m_APInt(C))))){
- if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umin>(m_Value(X),m_APInt(C))))){
- if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smin>(m_Value(X),m_APInt(C))))){
- return nullptr;
- }}}}
-
- auto *InnerMaxInst = cast<IntrinsicInst>(InnerMax);
- bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin;
- if((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
- (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
- return nullptr;
+ auto *InnerMinMaxInst = dyn_cast<MinMaxIntrinsic>(BinOpInst->getOperand(0));
+
+ if (!InnerMinMaxInst || !InnerMinMaxInst->hasOneUse())
+ return nullptr;
+
+ bool IsSigned = InnerMinMaxInst->isSigned();
+ if (InnerMinMaxInst->getIntrinsicID() != MinMaxID)
+ return nullptr;
+
+ if ((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+ (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
+ return nullptr;
// Check if BinOp is a left shift
if (BinOp != Instruction::Shl) {
return nullptr;
}
- APInt C2=llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue() ;
- APInt C3=llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
- APInt C1=llvm::dyn_cast<llvm::ConstantInt>(InnerMaxInst->getOperand(1))->getValue();
+ APInt C2 =llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
+ APInt C3 =llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+ APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))->getValue();
// Compute C1 * 2^C2
APInt Two = APInt(C2.getBitWidth(), 2);
- APInt Pow2C2 = Two.shl(C2); // 2^C2
+ APInt Pow2C2 = Two.shl(C2); // 2^C2
APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
// Check C3 >= C1 * 2^C2
@@ -1224,23 +1223,24 @@ static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::Builde
return nullptr;
}
- //Create new x binop c2
- Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMaxInst->getOperand(0), BinOpInst->getOperand(1) );
-
- //Create new min/max intrinsic with new binop and c3
-
- if(IsSigned){
- cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(true);
- cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(false);
- }else{
- cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(true);
- cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(false);
+ // Create new x binop c2
+ Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMinMaxInst->getOperand(0),
+ BinOpInst->getOperand(1));
+
+ // Create new min/max intrinsic with new binop and c3
+ if (auto *NewBinInst = dyn_cast<Instruction>(NewBinOp))
+ if (IsSigned) {
+ cast<Instruction>(NewBinOp)->setHasNoSignedWrap(true);
+ cast<Instruction>(NewBinOp)->setHasNoUnsignedWrap(false);
+ } else {
+ cast<Instruction>(NewBinOp)->setHasNoUnsignedWrap(true);
+ cast<Instruction>(NewBinOp)->setHasNoSignedWrap(false);
}
-
// Get the intrinsic function for MinMaxID
Type *Ty = II->getType();
- Function *MinMaxFn = Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
+ Function *MinMaxFn =
+ Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
// Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
Value *Args[] = {NewBinOp, Op1};
@@ -1249,6 +1249,8 @@ static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::Builde
return NewMax;
}
+
+
/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();
>From 6f1d83f15bf6e3057f374824e96d370addedcdb9 Mon Sep 17 00:00:00 2001
From: Charukesh827 <charue222 at gmail.com>
Date: Tue, 20 May 2025 13:30:04 +0530
Subject: [PATCH 6/6] added motivation to shift-binop.ll
---
llvm/test/Transforms/InstCombine/shift-binop.ll | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
index a84a30002d917..38b6dfa43e7d6 100644
--- a/llvm/test/Transforms/InstCombine/shift-binop.ll
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -1,6 +1,19 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+;define i32 @src(i32 %x, i32 %shamt, i32 %c) {
+; %shl = shl nuw i32 %x, %shamt
+; %c2 = shl nuw i32 %c, %shamt
+; %max = call i32 @llvm.umin(i32 %shl, i32 %c2)
+; ret i32 %max
+;}
+
+;define i32 @tgt(i32 %x, i32 %shamt, i32 %c) {
+; %max = call i32 @llvm.umin(i32 %x, i32 %c)
+; %shl = shl i32 %max, %shamt
+; ret i32 %shl
+;}
+
define i8 @src1(i8 %arg0) {
; CHECK-LABEL: @src1(
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
More information about the llvm-commits
mailing list