[llvm] 68cc35d - [InstCombine] Matrix multiplication negation optimisation

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 20 11:51:01 PDT 2022


Author: Zain Jaffal
Date: 2022-09-20T19:50:39+01:00
New Revision: 68cc35d52cff2d8345c6dffbed0d1b36b20f824f

URL: https://github.com/llvm/llvm-project/commit/68cc35d52cff2d8345c6dffbed0d1b36b20f824f
DIFF: https://github.com/llvm/llvm-project/commit/68cc35d52cff2d8345c6dffbed0d1b36b20f824f.diff

LOG: [InstCombine] Matrix multiplication negation optimisation

If one of the operands in a matrix multiplication is negated we can optimise the equation by moving the negation to the smallest element of the operands or the result.

Reviewed By: spatel, fhahn

Differential Revision: https://reviews.llvm.org/D133300

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index a6dd83cb199fe..f28ce858a169a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1830,6 +1830,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     break;
   }
   case Intrinsic::matrix_multiply: {
+    // Optimize negation in matrix multiplication.
+
     // -A * -B -> A * B
     Value *A, *B;
     if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) &&
@@ -1838,6 +1840,50 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       replaceOperand(*II, 1, B);
       return II;
     }
+
+    Value *Op0 = II->getOperand(0);
+    Value *Op1 = II->getOperand(1);
+    Value *OpNotNeg, *NegatedOp;
+    unsigned NegatedOpArg, OtherOpArg;
+    if (match(Op0, m_FNeg(m_Value(OpNotNeg)))) {
+      NegatedOp = Op0;
+      NegatedOpArg = 0;
+      OtherOpArg = 1;
+    } else if (match(Op1, m_FNeg(m_Value(OpNotNeg)))) {
+      NegatedOp = Op1;
+      NegatedOpArg = 1;
+      OtherOpArg = 0;
+    } else
+      // Multiplication doesn't have a negated operand.
+      break;
+
+    // Only optimize if the negated operand has only one use.
+    if (!NegatedOp->hasOneUse())
+      break;
+
+    Value *OtherOp = II->getOperand(OtherOpArg);
+    VectorType *RetTy = cast<VectorType>(II->getType());
+    VectorType *NegatedOpTy = cast<VectorType>(NegatedOp->getType());
+    VectorType *OtherOpTy = cast<VectorType>(OtherOp->getType());
+    ElementCount NegatedCount = NegatedOpTy->getElementCount();
+    ElementCount OtherCount = OtherOpTy->getElementCount();
+    ElementCount RetCount = RetTy->getElementCount();
+    // (-A) * B -> A * (-B), if it is cheaper to negate B and vice versa.
+    if (ElementCount::isKnownGT(NegatedCount, OtherCount) &&
+        ElementCount::isKnownLT(OtherCount, RetCount)) {
+      Value *InverseOtherOp = Builder.CreateFNeg(OtherOp);
+      replaceOperand(*II, NegatedOpArg, OpNotNeg);
+      replaceOperand(*II, OtherOpArg, InverseOtherOp);
+      return II;
+    }
+    // (-A) * B -> -(A * B), if it is cheaper to negate the result
+    if (ElementCount::isKnownGT(NegatedCount, RetCount)) {
+      SmallVector<Value *, 5> NewArgs(II->args());
+      NewArgs[NegatedOpArg] = OpNotNeg;
+      Instruction *NewMul =
+          Builder.CreateIntrinsic(II->getType(), IID, NewArgs, II);
+      return replaceInstUsesWith(*II, Builder.CreateFNegFMF(NewMul, II));
+    }
     break;
   }
   case Intrinsic::fmuladd: {

diff  --git a/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll
index b0eecd9d5255e..bd1050efc160f 100644
--- a/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll
+++ b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll
@@ -4,9 +4,9 @@
 ; The result has the fewest vector elements between the result and the two operands so the negation can be moved there
 define <2 x double> @test_negation_move_to_result(<6 x double> %a, <3 x double> %b) {
 ; CHECK-LABEL: @test_negation_move_to_result(
-; CHECK-NEXT:    [[A_NEG:%.*]] = fneg <6 x double> [[A:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
-; CHECK-NEXT:    ret <2 x double> [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg <2 x double> [[TMP1]]
+; CHECK-NEXT:    ret <2 x double> [[TMP2]]
 ;
   %a.neg = fneg <6 x double> %a
   %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
@@ -17,20 +17,53 @@ define <2 x double> @test_negation_move_to_result(<6 x double> %a, <3 x double>
 ; Fast flag should be preserved
 define <2 x double> @test_negation_move_to_result_with_fastflags(<6 x double> %a, <3 x double> %b) {
 ; CHECK-LABEL: @test_negation_move_to_result_with_fastflags(
-; CHECK-NEXT:    [[A_NEG:%.*]] = fneg <6 x double> [[A:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
-; CHECK-NEXT:    ret <2 x double> [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg fast <2 x double> [[TMP1]]
+; CHECK-NEXT:    ret <2 x double> [[TMP2]]
 ;
   %a.neg = fneg <6 x double> %a
   %res = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
   ret <2 x double> %res
 }
 
+define <2 x double> @test_negation_move_to_result_with_nnan_flag(<6 x double> %a, <3 x double> %b) {
+; CHECK-LABEL: @test_negation_move_to_result_with_nnan_flag(
+; CHECK-NEXT:    [[TMP1:%.*]] = call nnan <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg nnan <2 x double> [[TMP1]]
+; CHECK-NEXT:    ret <2 x double> [[TMP2]]
+;
+  %a.neg = fneg <6 x double> %a
+  %res = tail call nnan <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
+  ret <2 x double> %res
+}
+
+define <2 x double> @test_negation_move_to_result_with_nsz_flag(<6 x double> %a, <3 x double> %b) {
+; CHECK-LABEL: @test_negation_move_to_result_with_nsz_flag(
+; CHECK-NEXT:    [[TMP1:%.*]] = call nsz <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg nsz <2 x double> [[TMP1]]
+; CHECK-NEXT:    ret <2 x double> [[TMP2]]
+;
+  %a.neg = fneg <6 x double> %a
+  %res = tail call nsz <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
+  ret <2 x double> %res
+}
+
+define <2 x double> @test_negation_move_to_result_with_fastflag_on_negation(<6 x double> %a, <3 x double> %b) {
+; CHECK-LABEL: @test_negation_move_to_result_with_fastflag_on_negation(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg <2 x double> [[TMP1]]
+; CHECK-NEXT:    ret <2 x double> [[TMP2]]
+;
+  %a.neg = fneg fast<6 x double> %a
+  %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
+  ret <2 x double> %res
+}
+
 ; %b has the fewest vector elements between the result and the two operands so the negation can be moved there
 define <9 x double> @test_move_negation_to_second_operand(<27 x double> %a, <3 x double> %b) {
 ; CHECK-LABEL: @test_move_negation_to_second_operand(
-; CHECK-NEXT:    [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1)
 ; CHECK-NEXT:    ret <9 x double> [[RES]]
 ;
   %a.neg = fneg <27 x double> %a
@@ -42,8 +75,8 @@ define <9 x double> @test_move_negation_to_second_operand(<27 x double> %a, <3 x
 ; Fast flag should be preserved
 define <9 x double> @test_move_negation_to_second_operand_with_fast_flags(<27 x double> %a, <3 x double> %b) {
 ; CHECK-LABEL: @test_move_negation_to_second_operand_with_fast_flags(
-; CHECK-NEXT:    [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1)
 ; CHECK-NEXT:    ret <9 x double> [[RES]]
 ;
   %a.neg = fneg <27 x double> %a
@@ -54,9 +87,9 @@ define <9 x double> @test_move_negation_to_second_operand_with_fast_flags(<27 x
 ; The result has the fewest vector elements between the result and the two operands so the negation can be moved there
 define <2 x double> @test_negation_move_to_result_from_second_operand(<3 x double> %a, <6 x double> %b){
 ; CHECK-LABEL: @test_negation_move_to_result_from_second_operand(
-; CHECK-NEXT:    [[B_NEG:%.*]] = fneg <6 x double> [[B:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> [[A:%.*]], <6 x double> [[B_NEG]], i32 1, i32 3, i32 2)
-; CHECK-NEXT:    ret <2 x double> [[RES]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> [[A:%.*]], <6 x double> [[B:%.*]], i32 1, i32 3, i32 2)
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg <2 x double> [[TMP1]]
+; CHECK-NEXT:    ret <2 x double> [[TMP2]]
 ;
   %b.neg = fneg <6 x double> %b
   %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> %a, <6 x double> %b.neg, i32 1, i32 3, i32 2)
@@ -66,8 +99,8 @@ define <2 x double> @test_negation_move_to_result_from_second_operand(<3 x doubl
 ; %a has the fewest vector elements between the result and the two operands so the negation can be moved there
 define <9 x double> @test_move_negation_to_first_operand(<3 x double> %a, <27 x double> %b) {
 ; CHECK-LABEL: @test_move_negation_to_first_operand(
-; CHECK-NEXT:    [[B_NEG:%.*]] = fneg <27 x double> [[B:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> [[A:%.*]], <27 x double> [[B_NEG]], i32 1, i32 3, i32 9)
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> [[TMP1]], <27 x double> [[B:%.*]], i32 1, i32 3, i32 9)
 ; CHECK-NEXT:    ret <9 x double> [[RES]]
 ;
   %b.neg = fneg <27 x double> %b
@@ -234,8 +267,8 @@ define <12 x double> @fneg_with_multiple_uses_2(<15 x double> %a, <20 x double>
 ; negation should be moved to the second operand given it has the smallest operand count
 define <72 x double> @chain_of_matrix_mutliplies(<27 x double> %a, <3 x double> %b, <8 x double> %c) {
 ; CHECK-LABEL: @chain_of_matrix_mutliplies(
-; CHECK-NEXT:    [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1)
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1)
 ; CHECK-NEXT:    [[RES_2:%.*]] = tail call <72 x double> @llvm.matrix.multiply.v72f64.v9f64.v8f64(<9 x double> [[RES]], <8 x double> [[C:%.*]], i32 9, i32 1, i32 8)
 ; CHECK-NEXT:    ret <72 x double> [[RES_2]]
 ;
@@ -249,11 +282,11 @@ define <72 x double> @chain_of_matrix_mutliplies(<27 x double> %a, <3 x double>
 ; second negation should be moved to the result of the second multipication
 define <6 x double> @chain_of_matrix_mutliplies_with_two_negations(<3 x double> %a, <5 x double> %b, <10 x double> %c) {
 ; CHECK-LABEL: @chain_of_matrix_mutliplies_with_two_negations(
-; CHECK-NEXT:    [[B_NEG:%.*]] = fneg <5 x double> [[B:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[A:%.*]], <5 x double> [[B_NEG]], i32 3, i32 1, i32 5)
-; CHECK-NEXT:    [[RES_NEG:%.*]] = fneg <15 x double> [[RES]]
-; CHECK-NEXT:    [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> [[RES_NEG]], <10 x double> [[C:%.*]], i32 3, i32 5, i32 2)
-; CHECK-NEXT:    ret <6 x double> [[RES_2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[TMP1]], <5 x double> [[B:%.*]], i32 3, i32 1, i32 5)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> [[RES]], <10 x double> [[C:%.*]], i32 3, i32 5, i32 2)
+; CHECK-NEXT:    [[TMP3:%.*]] = fneg <6 x double> [[TMP2]]
+; CHECK-NEXT:    ret <6 x double> [[TMP3]]
 ;
   %b.neg = fneg <5 x double> %b
   %res = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> %a, <5 x double> %b.neg, i32 3, i32 1, i32 5)
@@ -265,10 +298,10 @@ define <6 x double> @chain_of_matrix_mutliplies_with_two_negations(<3 x double>
 ; negation should be propagated to the result of the second matrix multiplication
 define <6 x double> @chain_of_matrix_mutliplies_propagation(<15 x double> %a, <20 x double> %b, <8 x double> %c){
 ; CHECK-LABEL: @chain_of_matrix_mutliplies_propagation(
-; CHECK-NEXT:    [[A_NEG:%.*]] = fneg <15 x double> [[A:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A_NEG]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4)
-; CHECK-NEXT:    [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> [[RES]], <8 x double> [[C:%.*]], i32 3, i32 4, i32 2)
-; CHECK-NEXT:    ret <6 x double> [[RES_2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A:%.*]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> [[TMP1]], <8 x double> [[C:%.*]], i32 3, i32 4, i32 2)
+; CHECK-NEXT:    [[TMP3:%.*]] = fneg <6 x double> [[TMP2]]
+; CHECK-NEXT:    ret <6 x double> [[TMP3]]
 ;
   %a.neg = fneg <15 x double> %a
   %res = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> %a.neg, <20 x double> %b, i32 3, i32 5, i32 4)


        


More information about the llvm-commits mailing list