[llvm] [InstCombine] Combine or-disjoint (and->mul), (and->mul) to and->mul (PR #136013)
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 16 12:38:40 PDT 2025
https://github.com/jrbyrnes created https://github.com/llvm/llvm-project/pull/136013
The canonical pattern for bitmasked mul is currently
%val = and %x, %bitMask // where %bitMask is some constant
%cmp = icmp eq %val, 0
%sel = select %cmp, 0, %C // where %C is some constant
In certain cases, where we are combining multiple of these bitmasked muls with common factors, we are able to optimize into and->mul (see https://github.com/llvm/llvm-project/pull/135274 )
This optimization lends itself to further optimizations. This PR addresses one of such optimizations.
In cases where we have
`or-disjoint ( mul(and (X, C1), D) , mul (and (X, C2), D))`
we can combine into
`mul( and (X, (C1 + C2)), D) `
provide C1 and C2 are disjoint.
Generalized proof: https://alive2.llvm.org/ce/z/MQYMui
>From d040ee7c7f4b93a6f7d360725bbcd8bd7a54a0d1 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Wed, 16 Apr 2025 12:13:10 -0700
Subject: [PATCH] [InstCombine] Combine or-disjoint (and->mul), (and->mul) to
and->mul
Change-Id: I0bc0be96050803f6f5ce303e82e1ad758f830d7d
---
.../InstCombine/InstCombineAndOrXor.cpp | 21 +++++
llvm/test/Transforms/InstCombine/or.ll | 87 ++++++++++++++++++-
2 files changed, 104 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6cc241781d112..206131ab4a6a7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3643,6 +3643,27 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
/*NSW=*/true, /*NUW=*/true))
return R;
+
+ Value *LHSOp = nullptr, *RHSOp = nullptr;
+ const APInt *LHSConst = nullptr, *RHSConst = nullptr;
+
+ // ((X & C1) * D) + ((X & C2) * D) -> (X & (C1 + C2) * D)
+ if (match(I.getOperand(0), m_Mul(m_Value(LHSOp), m_APInt(LHSConst))) &&
+ match(I.getOperand(1), m_Mul(m_Value(RHSOp), m_APInt(RHSConst))) &&
+ LHSConst == RHSConst) {
+ Value *LHSBase = nullptr, *RHSBase = nullptr;
+ const APInt *LHSMask = nullptr, *RHSMask = nullptr;
+ if (match(LHSOp, m_And(m_Value(LHSBase), m_APInt(LHSMask))) &&
+ match(RHSOp, m_And(m_Value(RHSBase), m_APInt(RHSMask))) &&
+ LHSBase == RHSBase &&
+ ((*LHSMask & *RHSMask) == APInt::getZero(LHSMask->getBitWidth()))) {
+ auto NewAnd = Builder.CreateAnd(
+ LHSBase, ConstantInt::get(LHSOp->getType(), (*LHSMask + *RHSMask)));
+
+ return BinaryOperator::CreateMul(
+ NewAnd, ConstantInt::get(NewAnd->getType(), *LHSConst));
+ }
+ }
}
Value *X, *Y;
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 95f89e4ce11cd..777387cc662d6 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1281,10 +1281,10 @@ define <16 x i1> @test51(<16 x i1> %arg, <16 x i1> %arg1) {
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <16 x i1> [[ARG:%.*]], <16 x i1> [[ARG1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 5, i32 6, i32 23, i32 24, i32 9, i32 10, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: ret <16 x i1> [[TMP3]]
;
- %tmp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
- %tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
- %tmp3 = or <16 x i1> %tmp, %tmp2
- ret <16 x i1> %tmp3
+ %out = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
+ %out2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
+ %out3 = or <16 x i1> %out, %out2
+ ret <16 x i1> %out3
}
; This would infinite loop because it reaches a transform
@@ -2035,3 +2035,82 @@ define i32 @or_xor_and_commuted3(i32 %x, i32 %y, i32 %z) {
%or1 = or i32 %xor, %yy
ret i32 %or1
}
+
+define i32 @or_combine_mul_and1(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and1(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 6
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and2(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and2(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 10
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in, 8
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and_diff_factor(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and_diff_factor(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT: [[TMP0:%.*]] = mul nuw nsw i32 [[BITOP0]], 36
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[BITOP1]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 36
+ %bitop1 = and i32 %in, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and_diff_base(i32 %in, i32 %in1) {
+; CHECK-LABEL: @or_combine_mul_and_diff_base(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT: [[TMP0:%.*]] = mul nuw nsw i32 [[BITOP0]], 72
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN1:%.*]], 4
+; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[BITOP1]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in1, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and_decomposed(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and_decomposed(
+; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[IN:%.*]] to i1
+; CHECK-NEXT: [[OUT0:%.*]] = select i1 [[TMP2]], i32 72, i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[OUT1:%.*]] = or disjoint i32 [[OUT0]], [[OUT]]
+; CHECK-NEXT: ret i32 [[OUT1]]
+;
+ %bitop0 = and i32 %in, 1
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
More information about the llvm-commits
mailing list