[llvm] [InstCombine] enable more factorization in SimplifyUsingDistributiveLaws (PR #69892)

Alex Cameron via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 22 19:28:55 PDT 2023


https://github.com/tetsuo-cpp created https://github.com/llvm/llvm-project/pull/69892

None

>From fb70687c6ff9c74f15067dd6eabdf413d3306047 Mon Sep 17 00:00:00 2001
From: Alex Cameron <asc at tetsuo.sh>
Date: Sun, 22 Oct 2023 18:08:54 +1100
Subject: [PATCH] [InstCombine] enable more factorization in
 SimplifyUsingDistributiveLaws

---
 .../InstCombine/InstructionCombining.cpp      |  34 ++++
 llvm/test/Transforms/InstCombine/add4.ll      | 183 ++++++++++++++++++
 llvm/test/Transforms/PGOProfile/chr.ll        |   7 +-
 3 files changed, 220 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 559eb2ef4795eb1..7a8795f12fd4fc4 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -984,6 +984,18 @@ Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) {
               tryFactorization(I, SQ, Builder, RHSOpcode, LHS, Ident, C, D))
         return V;
 
+  // The instruction has the form "(A * B) op (C op D)".  Try to factorize
+  // common term for "(A * B) op C op D".
+  if (Op0 && Op1 && LHSOpcode == Instruction::Mul && isa<Constant>(D) &&
+      LHS->hasOneUse() && RHS->hasOneUse() && TopLevelOpcode == RHSOpcode &&
+      Instruction::isCommutative(RHSOpcode))
+    if (Value *Ident = getIdentityValue(LHSOpcode, C))
+      if (Value *V =
+              tryFactorization(I, SQ, Builder, LHSOpcode, A, B, C, Ident)) {
+        Value *New = Builder.CreateBinOp(RHSOpcode, V, D);
+        return New;
+      }
+
   return nullptr;
 }
 
@@ -997,11 +1009,33 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
   BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
   BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
   Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
+  Value *A, *B, *C, *D;
+  Instruction::BinaryOps LHSOpcode, RHSOpcode;
 
   // Factorization.
   if (Value *R = tryFactorizationFolds(I))
     return R;
 
+  if (Op0)
+    LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
+  if (Op1)
+    RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
+
+  // The instruction has the form "(A op' B) op' (C * D)". See if expanding it
+  // out to "(C * D) op' (A op' B)" results in simplifications.
+  if (Op0 && Op1 && RHSOpcode == Instruction::Mul && A == C &&
+      LHSOpcode == TopLevelOpcode &&
+      rightDistributesOverLeft(TopLevelOpcode, RHSOpcode)) {
+    bool InnerCommutative = Instruction::isCommutative(TopLevelOpcode);
+    if (isa<Constant>(B) && isa<Constant>(D) && InnerCommutative) {
+      // They do! Return "RHS op' LHS".
+      ++NumExpand;
+      Value *New = Builder.CreateBinOp(TopLevelOpcode, RHS, LHS);
+      New->takeName(&I);
+      return New;
+    }
+  }
+
   // Expansion.
   if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {
     // The instruction has the form "(A op' B) op C".  See if expanding it out
diff --git a/llvm/test/Transforms/InstCombine/add4.ll b/llvm/test/Transforms/InstCombine/add4.ll
index 7fc164c8b9a7c99..3af0df57cc2e5d9 100644
--- a/llvm/test/Transforms/InstCombine/add4.ll
+++ b/llvm/test/Transforms/InstCombine/add4.ll
@@ -1,6 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
+declare i32 @use32(i32)
+
 define i64 @match_unsigned(i64 %x) {
 ; CHECK-LABEL: @match_unsigned(
 ; CHECK-NEXT:    [[UREM:%.*]] = urem i64 [[X:%.*]], 19136
@@ -127,3 +129,184 @@ define i32 @not_match_overflow(i32 %x) {
   %t4 = add i32 %t, %t3
   ret i32 %t4
 }
+
+; (x + (-1)) + (x * 5) --> (x * 6) + (-1)
+define i8 @mul_add_common_factor_0(i8 %x) {
+; CHECK-LABEL: @mul_add_common_factor_0(
+; CHECK-NEXT:    [[A1:%.*]] = mul i8 [[X:%.*]], 6
+; CHECK-NEXT:    [[TMP1:%.*]] = add i8 [[A1]], -1
+; CHECK-NEXT:    ret i8 [[TMP1]]
+;
+  %a0 = add i8 %x, -1
+  %m = mul i8 %x, 5
+  %a1 = add i8 %a0, %m ; the mul operand is the right operand, should swap the operand
+  ret i8 %a1
+}
+
+; (x * 4) + (x - 1) --> (x * 5) + (-1)
+define i16 @mul_add_common_factor_1(i16 %x) {
+; CHECK-LABEL: @mul_add_common_factor_1(
+; CHECK-NEXT:    [[A1:%.*]] = mul i16 [[X:%.*]], 5
+; CHECK-NEXT:    [[TMP1:%.*]] = add i16 [[A1]], -1
+; CHECK-NEXT:    ret i16 [[TMP1]]
+;
+  %a0 = add i16 %x, -1
+  %m = mul i16 %x, 4
+  %a1 = add i16 %m, %a0 ; the mul operand is the left operand
+  ret i16 %a1
+}
+
+; Negative test: y is not a const for (x + y) + (x * 4) --> (x * 5) + y
+define i32 @mul_add_common_factor_2(i32 %x, i32 %y) {
+; CHECK-LABEL: @mul_add_common_factor_2(
+; CHECK-NEXT:    [[A0:%.*]] = add i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[M:%.*]] = shl i32 [[X]], 2
+; CHECK-NEXT:    [[A1:%.*]] = add i32 [[A0]], [[M]]
+; CHECK-NEXT:    ret i32 [[A1]]
+;
+  %a0 = add i32 %x, %y
+  %m = mul i32 %x, 4
+  %a1 = add i32 %a0, %m
+  ret i32 %a1
+}
+
+; Negative test: y is not a const for (y + x) + (x * 4) --> (x * 5) + y
+define i32 @mul_add_common_factor_2_commute(i32 %x, i32 %y) {
+; CHECK-LABEL: @mul_add_common_factor_2_commute(
+; CHECK-NEXT:    [[A0:%.*]] = add i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[M:%.*]] = shl i32 [[X]], 2
+; CHECK-NEXT:    [[A1:%.*]] = add i32 [[A0]], [[M]]
+; CHECK-NEXT:    ret i32 [[A1]]
+;
+  %a0 = add i32 %y, %x
+  %m = mul i32 %x, 4
+  %a1 = add i32 %a0, %m
+  ret i32 %a1
+}
+
+; Negative test: t is not a const for (x + 2) + (x * t) --> (t + 1) * x + 2
+define i128 @mul_add_common_factor_3(i128 %x, i128 %t) {
+; CHECK-LABEL: @mul_add_common_factor_3(
+; CHECK-NEXT:    [[A0:%.*]] = add i128 [[X:%.*]], 2
+; CHECK-NEXT:    [[M:%.*]] = mul i128 [[X]], [[T:%.*]]
+; CHECK-NEXT:    [[A1:%.*]] = add i128 [[A0]], [[M]]
+; CHECK-NEXT:    ret i128 [[A1]]
+;
+  %a0 = add i128 %x, 2
+  %m = mul i128 %x, %t
+  %a1 = add i128 %a0, %m
+  ret i128 %a1
+}
+
+; Negative test: t is not a const for (x + 2) + (t * x) --> (t + 1) * x + 2
+define i128 @mul_add_common_factor_3_commute(i128 %x, i128 %t) {
+; CHECK-LABEL: @mul_add_common_factor_3_commute(
+; CHECK-NEXT:    [[A0:%.*]] = add i128 [[X:%.*]], 2
+; CHECK-NEXT:    [[M:%.*]] = mul i128 [[T:%.*]], [[X]]
+; CHECK-NEXT:    [[A1:%.*]] = add i128 [[A0]], [[M]]
+; CHECK-NEXT:    ret i128 [[A1]]
+;
+  %a0 = add i128 %x, 2
+  %m = mul i128 %t, %x
+  %a1 = add i128 %a0, %m
+  ret i128 %a1
+}
+
+; Negative test: The transformation should not create more instructions
+define i32 @mul_add_common_factor_4(i32 %x) {
+; CHECK-LABEL: @mul_add_common_factor_4(
+; CHECK-NEXT:    [[A0:%.*]] = add i32 [[X:%.*]], -1
+; CHECK-NEXT:    [[M:%.*]] = shl i32 [[X]], 2
+; CHECK-NEXT:    call void @use32(i32 [[A0]])
+; CHECK-NEXT:    [[A1:%.*]] = add i32 [[M]], [[A0]]
+; CHECK-NEXT:    ret i32 [[A1]]
+;
+  %a0 = add i32 %x, -1
+  %m = mul i32 %x, 4
+  call void @use32(i32 %a0) ; an extra use
+  %a1 = add i32 %m, %a0
+  ret i32 %a1
+}
+
+; (x * 4) + (x + 3) --> (x * 5) + 3
+define <2 x i8> @mul_add_common_factor_5(<2 x i8> %x) {
+; CHECK-LABEL: @mul_add_common_factor_5(
+; CHECK-NEXT:    [[A1:%.*]] = mul <2 x i8> [[X:%.*]], <i8 5, i8 5>
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i8> [[A1]], <i8 3, i8 3>
+; CHECK-NEXT:    ret <2 x i8> [[TMP1]]
+;
+  %a0 = add <2 x i8> %x, <i8 3, i8 3>
+  %m = mul <2 x i8> %x, <i8 4, i8 4> ; vector type
+  %a1 = add <2 x i8> %m, %a0
+  ret <2 x i8> %a1
+}
+
+; (x << 2) + (x - 1) --> (x * 5) + (-1)
+define i16 @shl_add_common_factor_1(i16 %x) {
+; CHECK-LABEL: @shl_add_common_factor_1(
+; CHECK-NEXT:    [[A1:%.*]] = mul i16 [[X:%.*]], 5
+; CHECK-NEXT:    [[TMP1:%.*]] = add i16 [[A1]], -1
+; CHECK-NEXT:    ret i16 [[TMP1]]
+;
+  %a0 = add i16 %x, -1
+  %s = shl i16 %x, 2
+  %a1 = add i16 %s, %a0 ; the shl operand is the left operand
+  ret i16 %a1
+}
+
+; Negative test: y is not a const for (y + x) + (x << 2)
+define i32 @shl_add_common_factor_2(i32 %x, i32 %y) {
+; CHECK-LABEL: @shl_add_common_factor_2(
+; CHECK-NEXT:    [[A0:%.*]] = add i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[S:%.*]] = shl i32 [[X]], 2
+; CHECK-NEXT:    [[A1:%.*]] = add i32 [[A0]], [[S]]
+; CHECK-NEXT:    ret i32 [[A1]]
+;
+  %a0 = add i32 %x, %y
+  %s = shl i32 %x, 2
+  %a1 = add i32 %a0, %s ; the shl operand is the right operand
+  ret i32 %a1
+}
+
+; Negative test: y is not a const for (y + x) + (x << 2)
+define i32 @shl_add_common_factor_2_commute(i32 %x, i32 %y) {
+; CHECK-LABEL: @shl_add_common_factor_2_commute(
+; CHECK-NEXT:    [[A0:%.*]] = add i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[S:%.*]] = shl i32 [[X]], 2
+; CHECK-NEXT:    [[A1:%.*]] = add i32 [[A0]], [[S]]
+; CHECK-NEXT:    ret i32 [[A1]]
+;
+  %a0 = add i32 %y, %x ; swap the operand
+  %s = shl i32 %x, 2
+  %a1 = add i32 %a0, %s
+  ret i32 %a1
+}
+
+; Negative test: The transformation should not create more instructions
+define i32 @shl_add_common_factor_4(i32 %x) {
+; CHECK-LABEL: @shl_add_common_factor_4(
+; CHECK-NEXT:    [[A0:%.*]] = add i32 [[X:%.*]], -1
+; CHECK-NEXT:    [[S:%.*]] = shl i32 [[X]], 2
+; CHECK-NEXT:    call void @use32(i32 [[S]])
+; CHECK-NEXT:    [[A1:%.*]] = add i32 [[S]], [[A0]]
+; CHECK-NEXT:    ret i32 [[A1]]
+;
+  %a0 = add i32 %x, -1
+  %s = shl i32 %x, 2
+  call void @use32(i32 %s) ; an extra use
+  %a1 = add i32 %s, %a0
+  ret i32 %a1
+}
+
+; (x << 2) + (x + 3) --> (x * 5) + 3
+define <2 x i64> @shl_add_common_factor_5(<2 x i64> %x) {
+; CHECK-LABEL: @shl_add_common_factor_5(
+; CHECK-NEXT:    [[A1:%.*]] = mul <2 x i64> [[X:%.*]], <i64 5, i64 5>
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i64> [[A1]], <i64 3, i64 3>
+; CHECK-NEXT:    ret <2 x i64> [[TMP1]]
+;
+  %a0 = add <2 x i64> %x, <i64 3, i64 3>
+  %s = shl <2 x i64> %x, <i64 2, i64 2> ; vector type
+  %a1 = add <2 x i64> %s, %a0
+  ret <2 x i64> %a1
+}
diff --git a/llvm/test/Transforms/PGOProfile/chr.ll b/llvm/test/Transforms/PGOProfile/chr.ll
index c82800ec11a12e9..0579beb18887b55 100644
--- a/llvm/test/Transforms/PGOProfile/chr.ll
+++ b/llvm/test/Transforms/PGOProfile/chr.ll
@@ -1060,10 +1060,9 @@ define i32 @test_chr_10(ptr %i, ptr %j) !prof !14 {
 ; CHECK-NEXT:    br label [[BB3]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    [[TMP8:%.*]] = phi i32 [ [[TMP3]], [[BB0]] ], [ [[TMP5]], [[BB2_NONCHR]] ], [ [[TMP5]], [[BB1_NONCHR]] ]
-; CHECK-NEXT:    [[TMP9:%.*]] = mul i32 [[TMP8]], 42
-; CHECK-NEXT:    [[TMP10:%.*]] = add i32 [[TMP8]], -99
-; CHECK-NEXT:    [[TMP11:%.*]] = add i32 [[TMP9]], [[TMP10]]
-; CHECK-NEXT:    ret i32 [[TMP11]]
+; CHECK-NEXT:    [[TMP9:%.*]] = mul i32 [[TMP8]], 43
+; CHECK-NEXT:    [[TMP10:%.*]] = add i32 [[TMP9]], -99
+; CHECK-NEXT:    ret i32 [[TMP10]]
 ;
 entry:
   %0 = load i32, ptr %i



More information about the llvm-commits mailing list