[llvm] 69d4890 - [InstCombine] Fold abs(a * abs(b)) --> abs(a * b) (#78110)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 19 02:54:55 PST 2024


Author: elhewaty
Date: 2024-02-19T11:54:51+01:00
New Revision: 69d4890f801d6821e969a992d88d0bf48f9249cf

URL: https://github.com/llvm/llvm-project/commit/69d4890f801d6821e969a992d88d0bf48f9249cf
DIFF: https://github.com/llvm/llvm-project/commit/69d4890f801d6821e969a992d88d0bf48f9249cf.diff

LOG: [InstCombine] Fold abs(a * abs(b)) --> abs(a * b) (#78110)

Proof: https://alive2.llvm.org/ce/z/hfbEra
Fixes: https://github.com/llvm/llvm-project/issues/73211

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/abs-intrinsic.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 5266808c5abab4..0be1495083cebb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1569,6 +1569,17 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
       return replaceOperand(*II, 0, X);
 
+    Value *Y;
+    // abs(a * abs(b)) -> abs(a * b)
+    if (match(IIOperand,
+              m_OneUse(m_c_Mul(m_Value(X),
+                               m_Intrinsic<Intrinsic::abs>(m_Value(Y)))))) {
+      bool NSW =
+          cast<Instruction>(IIOperand)->hasNoSignedWrap() && IntMinIsPoison;
+      auto *XY = NSW ? Builder.CreateNSWMul(X, Y) : Builder.CreateMul(X, Y);
+      return replaceOperand(*II, 0, XY);
+    }
+
     if (std::optional<bool> Known =
             getKnownSignOrZero(IIOperand, II, DL, &AC, &DT)) {
       // abs(x) -> x if x >= 0 (include abs(x-y) --> x - y where x >= y)

diff  --git a/llvm/test/Transforms/InstCombine/abs-intrinsic.ll b/llvm/test/Transforms/InstCombine/abs-intrinsic.ll
index 7fe34d92376485..b6740374c07913 100644
--- a/llvm/test/Transforms/InstCombine/abs-intrinsic.ll
+++ b/llvm/test/Transforms/InstCombine/abs-intrinsic.ll
@@ -5,7 +5,136 @@ declare i8 @llvm.abs.i8(i8, i1)
 declare i32 @llvm.abs.i32(i32, i1)
 declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1)
 declare <3 x i82> @llvm.abs.v3i82(<3 x i82>, i1)
+declare <2 x i8> @llvm.abs.v2i8(<2 x i8>, i1)
 declare void @llvm.assume(i1)
+declare void @use(i32)
+
+define i8 @test_abs_abs_a_mul_b_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_abs_abs_a_mul_b_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i8 [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i8 @llvm.abs.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[ABS2]]
+;
+  %abs1 = call i8 @llvm.abs.i8(i8 %a, i1 true)
+  %mul = mul i8 %abs1, %b
+  %abs2 = call i8 @llvm.abs.i8(i8 %mul, i1 true)
+  ret i8 %abs2
+}
+
+define i8 @test_abs_a_mul_abs_b_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_abs_a_mul_abs_b_i8(
+; CHECK-NEXT:    [[A1:%.*]] = urem i8 123, [[A:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i8 [[A1]], [[B:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i8 @llvm.abs.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[ABS2]]
+;
+  %a1 = urem i8 123, %a  ; thwart complexity-based canonicalization
+  %abs1 = call i8 @llvm.abs.i8(i8 %b, i1 true)
+  %mul = mul i8 %a1, %abs1
+  %abs2 = call i8 @llvm.abs.i8(i8 %mul, i1 true)
+  ret i8 %abs2
+}
+
+define i32 @test_abs_abs_a_mul_b_i32(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_abs_abs_a_mul_b_i32(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i32 @llvm.abs.i32(i32 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i32 [[ABS2]]
+;
+  %abs1 = call i32 @llvm.abs.i32(i32 %a, i1 true)
+  %mul = mul i32 %abs1, %b
+  %abs2 = call i32 @llvm.abs.i32(i32 %mul, i1 true)
+  ret i32 %abs2
+}
+
+define i32 @test_abs_abs_a_mul_b_i32_abs_false_true(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_abs_abs_a_mul_b_i32_abs_false_true(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i32 @llvm.abs.i32(i32 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i32 [[ABS2]]
+;
+  %abs1 = call i32 @llvm.abs.i32(i32 %a, i1 false)
+  %mul = mul i32 %abs1, %b
+  %abs2 = call i32 @llvm.abs.i32(i32 %mul, i1 true)
+  ret i32 %abs2
+}
+
+define i32 @test_abs_abs_a_mul_b_i32_abs_true_false(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_abs_abs_a_mul_b_i32_abs_true_false(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i32 @llvm.abs.i32(i32 [[TMP1]], i1 false)
+; CHECK-NEXT:    ret i32 [[ABS2]]
+;
+  %abs1 = call i32 @llvm.abs.i32(i32 %a, i1 true)
+  %mul = mul i32 %abs1, %b
+  %abs2 = call i32 @llvm.abs.i32(i32 %mul, i1 false)
+  ret i32 %abs2
+}
+
+define i32 @test_abs_abs_a_mul_b_i32_abs_false_false(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_abs_abs_a_mul_b_i32_abs_false_false(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i32 @llvm.abs.i32(i32 [[TMP1]], i1 false)
+; CHECK-NEXT:    ret i32 [[ABS2]]
+;
+  %abs1 = call i32 @llvm.abs.i32(i32 %a, i1 false)
+  %mul = mul i32 %abs1, %b
+  %abs2 = call i32 @llvm.abs.i32(i32 %mul, i1 false)
+  ret i32 %abs2
+}
+
+; this should work
+define i8 @test_nsw_with_true(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_nsw_with_true(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul nsw i8 [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i8 @llvm.abs.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[ABS2]]
+;
+  %abs1 = call i8 @llvm.abs.i8(i8 %a, i1 false)
+  %mul = mul nsw i8 %abs1, %b
+  %abs2 = call i8 @llvm.abs.i8(i8 %mul, i1 true)
+  ret i8 %abs2
+}
+
+; this should't propagate the nsw
+define i8 @test_nsw_with_false(i8 %a, i8 %b) {
+; CHECK-LABEL: @test_nsw_with_false(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i8 [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i8 @llvm.abs.i8(i8 [[TMP1]], i1 false)
+; CHECK-NEXT:    ret i8 [[ABS2]]
+;
+  %abs1 = call i8 @llvm.abs.i8(i8 %a, i1 false)
+  %mul = mul nsw i8 %abs1, %b
+  %abs2 = call i8 @llvm.abs.i8(i8 %mul, i1 false)
+  ret i8 %abs2
+}
+
+define i32 @test_abs_abs_a_mul_b_more_one_use(i32 %a, i32 %b) {
+; CHECK-LABEL: @test_abs_abs_a_mul_b_more_one_use(
+; CHECK-NEXT:    [[ABS1:%.*]] = call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true)
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[ABS1]], [[B:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call i32 @llvm.abs.i32(i32 [[MUL]], i1 false)
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[ABS2]]
+;
+  %abs1 = call i32 @llvm.abs.i32(i32 %a, i1 true)
+  %mul = mul i32 %abs1, %b
+  %abs2 = call i32 @llvm.abs.i32(i32 %mul, i1 false)
+  call void @use(i32 %mul)
+  ret i32 %abs2
+}
+
+define <2 x i8> @test_abs_abs_a_mul_b_vector_i8(<2 x i8> %a, <2 x i8> %b) {
+; CHECK-LABEL: @test_abs_abs_a_mul_b_vector_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <2 x i8> [[B:%.*]], [[A:%.*]]
+; CHECK-NEXT:    [[ABS2:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[TMP1]], i1 true)
+; CHECK-NEXT:    ret <2 x i8> [[ABS2]]
+;
+  %abs = call <2 x i8> @llvm.abs.v2i8(<2 x i8> %a, i1 true)
+  %mul = mul <2 x i8> %abs, %b
+  %abs2 = call <2 x i8> @llvm.abs.v2i8(<2 x i8> %mul, i1 true)
+  ret <2 x i8> %abs2
+}
 
 ; abs preserves trailing zeros so the second and is unneeded
 define i32 @abs_trailing_zeros(i32 %x) {


        


More information about the llvm-commits mailing list