[PATCH] D158079: [InstCombine] Contracting x^2 + 2*x*y + y^2 to (x + y)^2 (float)

Noah Goldstein via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 19 11:25:23 PDT 2023


goldstein.w.n added inline comments.


================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp:1059
+                                       m_FMul(m_Deferred(B), m_Deferred(B))))));
+  }
+
----------------
rainerzufalldererste wrote:
> goldstein.w.n wrote:
> > rainerzufalldererste wrote:
> > > rainerzufalldererste wrote:
> > > > goldstein.w.n wrote:
> > > > > rainerzufalldererste wrote:
> > > > > > rainerzufalldererste wrote:
> > > > > > > goldstein.w.n wrote:
> > > > > > > > This match code is basically identical to `foldSquareSumInts`. The only difference other than `FMul` vs `Mul` is you do match `FMul(A, 2)` for floats and `m_Shl(A, 1)` for ints.
> > > > > > > > Can you make the match code a helper that takes either fmul/2x matcher (or just lambda wrapping) so it can be used for SumFloat / SumInt?
> > > > > > > Does that imply that `m_c_FAdd` can simply be replaced with `m_c_Add` and will continue to match properly for floating point values as well?
> > > > > > > I presume that would entail partially matching another pattern and then deferring the actual check for the mul2 match, as `BinaryOp_match<RHS, LHS, OpCode>` would have different `OpCode`s for `FMul` and `Shl`, which sounds like a huge mess to me; or is there a cleaner way to do that?
> > > > > > > 
> > > > > > > Something like this sadly doesn't compile (as the lambda return type is ambiguous):
> > > > > > > ```
> > > > > > >   const auto FpMul2Matcher = [](auto &value) {
> > > > > > >     return m_FMul(value, m_SpecificFP(2.0));
> > > > > > >   };
> > > > > > >   const auto IntMul2Matcher = [](auto &value) {
> > > > > > >     return m_Shl(value, m_SpecificInt(1));
> > > > > > >   };
> > > > > > >   const auto Mul2Matcher = FP ? FpMul2Matcher : IntMul2Matcher;
> > > > > > > ```
> > > > > > Even something like this shouldn't work.
> > > > > > 
> > > > > > ```
> > > > > > template <typename TMul2, typename TCAdd, typename TMul>
> > > > > > static bool MatchesSquareSum(BinaryOperator &I, Value *&A, Value *&B,
> > > > > >                              const TMul2 &Mul2, const TCAdd &CAdd,
> > > > > >                              const TMul &Mul) {
> > > > > > 
> > > > > >   // (a * a) + (((a * 2) + b) * b)
> > > > > >   bool Matches =
> > > > > >       match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
> > > > > >                      m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
> > > > > >                                   m_Deferred(B)))));
> > > > > > 
> > > > > >   // ((a * b) * 2)  or ((a * 2) * b)
> > > > > >   // +
> > > > > >   // (a * a + b * b) or (b * b + a * a)
> > > > > >   if (!Matches) {
> > > > > >     Matches =
> > > > > >         match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
> > > > > >                                    m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
> > > > > >                        m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
> > > > > >                                      Mul(m_Deferred(B), m_Deferred(B))))));
> > > > > >   }
> > > > > > 
> > > > > >   return Matches;
> > > > > > }
> > > > > > ```
> > > > > > 
> > > > > > I agree that it's messy to have duplicate code, but with the way op-codes are used as template parameters I don't see a way without template specialization to do this nicely; and with template specialization it's even more of a beast.
> > > > > > Am I missing some obvious way built into llvm/InstCombine to do this nicely?
> > > > > Why doesn't that code work?
> > > > Assuming `TMul2` etc. to be a lambda, the return type couln't be consistent, as for both `m_FMul` and `m_Shl` it'd be `BinaryOp_match<RHS, LHS, OpCode>`, with the same `OpCode` for each invocation, but different `RHS` and `LHS`. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (`Mul2`, `Mul`, `CAdd`) that simply map to the correct functions for `FAdd`/`Add` etc.
> > > > However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.
> > > Have you been able to come up with some better ideas? Maybe it's not _that_ terrible to go down the template specialization route, as many of the integer optimizations may have similar counterparts in FP with `nsz` and `reassoc`. Not sure how many of them are already handled twice, but there's a chance one could simplify this process by providing template specialized `m_XAdd<IsFP>(LHS, RHS)` etc. However, I'm not sure if I'm the right person to pass judgement on something that large, as I'm still very new to both LLVM and InstCombine.
> > > Assuming `TMul2` etc. to be a lambda, the return type couln't be consistent, as for both `m_FMul` and `m_Shl` it'd be `BinaryOp_match<RHS, LHS, OpCode>`, with the same `OpCode` for each invocation, but different `RHS` and `LHS`. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (`Mul2`, `Mul`, `CAdd`) that simply map to the correct functions for `FAdd`/`Add` etc.
> > 
> > For the TMul2 don't you only need a single Value?
> > Instead of passing a BinaryOperator, you could just pass a lambda i.e:
> > 
> > ```
> > auto FPMul2 = [](Value *& A) {
> >    return match(m_FMul(m_Value(A), m_SpecificFP(2));
> > };
> > 
> > ...
> > auto IntMul2 = [](Value *&A) {
> >   return match(m_Shl(m_Value(A), m_SpecificInt(1));
> > };
> > ```
> > 
> > Don't see why the same isn't true for mul/add (although two values then).
> > > However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.
> > 
> > 
> Regarding the LHS and RHS, you are correct, I misspoke. The `OpCode` and `RHS` are consistent, but `LHS` isn't. There are multiple cases where `TMul2` is used:
> 
> `Mul2(m_Deferred(A)`
> `Mul2(Mul(m_Value(A), m_Value(B))`
> `Mul2(m_Value(A))`
> 
> All of these parameters have different types, therefore the return type of this lambda would also be different in every case. So if the parameter were `Value *&`, this wouldn't be a problem at all, but that's simply not the case. Is there a way to cast these types to `Value *&` somehow (without capturing them separately and then matching things again against the sub-match-lambda)?
> 
> `mDeferred` returns `deferredval_ty<Value>`.
> `Mul(m_Value(), m_Value()` returns either `BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::FMul>` or `BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::Mul>`.
> `m_Value` returns `bind_ty<Value>`.
> 
> These types aren't compatible, so the template can't deduce a consistent type even from `auto`-parameter lambdas. Same with `Mul` & `CAdd`.
> 
> Apart from that, I'm a bit confused about the `match` in your comment, as that's not quite applicable, unless we're previously matching parts of the match and then checking them against this follow-up matcher lambda, which - even if we were to do that - would end up in a large mess, as that's not only the case with `Mul2`, but also `CAdd` & `Mul` then, turning these two large matches into a ton of tiny matches.
> 
> Otherwise, I'm not quite sure why I'm explaining compilation errors here, unless I'm missing something very obvious or am completely missing the point.
> 
> This, however, isn't valid C++ code:
> ```
> template <typename TMul2, typename TCAdd, typename TMul>
> static std::tuple<bool, Value *, Value *>
> MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
>                  const TMul &Mul) {
>   Value *A, *B;
> 
>   // (a * a) + (((a * 2) + b) * b)
>   if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
>                      m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
>                                   m_Deferred(B))))))
>     return std::make_tuple(true, A, B);
> 
>   // ((a * b) * 2)  or ((a * 2) * b)
>   // +
>   // (a * a + b * b) or (b * b + a * a)
>   return std::make_tuple(
>       match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
>                                  m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
>                      m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
>                                    Mul(m_Deferred(B), m_Deferred(B)))))),
>       A, B);
> }
> 
> // Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
> // if `FP`: requires `nsz` and `reassoc`.
> Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
>   if (FP) {
>     assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
>            "Assumption mismatch");
>   }
> 
>   std::tuple<bool, Value *, Value *> Match;
> 
>   if (FP) {
>     Match = MatchesSquareSum(
>         I, [](auto &V) { return m_FMul(V, m_SpecificFP(2.0)); },
>         [](auto &L, auto &R) { return m_c_FAdd(L, R); },
>         [](auto &L, auto &R) { return m_FMul(L, R); });
>   } else {
>     Match = MatchesSquareSum(
>         I, [](auto &V) { return m_Shl(V, m_SpecificInt(1)); },
>         [](auto &L, auto &R) { return m_c_Add(L, R); },
>         [](auto &L, auto &R) { return m_Mul(L, R); });
>   }
> 
>   // if one of them matches: -> (a + b)^2
>   if (std::get<0>(Match)) {
>     Value *AB =
>         Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
>     return BinaryOperator::CreateFMulFMF(AB, AB, &I);
>   }
> 
>   return nullptr;
> }
> ```
> 
> This _is_ valid C++ code, but uses template specialization to get around the previous type-ambiguity issues:
> ```
> template <bool IsFP> struct XMul;
> 
> template <> struct XMul<false> {
>   template <typename LHS, typename RHS>
>   inline auto operator()(const LHS &L, const RHS &R) const {
>     return m_Mul(L, R);
>   }
> };
> 
> template <> struct XMul<true> {
>   template <typename LHS, typename RHS>
>   inline auto operator()(const LHS &L, const RHS &R) const {
>     return m_FMul(L, R);
>   }
> };
> 
> template <bool IsFP> struct XCAdd;
> 
> template <> struct XCAdd<false> {
>   template <typename LHS, typename RHS>
>   inline auto operator()(const LHS &L, const RHS &R) const {
>     return m_c_Add(L, R);
>   }
> };
> 
> template <> struct XCAdd<true> {
>   template <typename LHS, typename RHS>
>   inline auto operator()(const LHS &L, const RHS &R) const {
>     return m_c_FAdd(L, R);
>   }
> };
> 
> template <bool IsFP> struct XMul2;
> 
> template <> struct XMul2<false> {
>   template <typename LHS> inline auto operator()(const LHS &L) const {
>     return m_Shl(L, m_SpecificInt(1));
>   }
> };
> 
> template <> struct XMul2<true> {
>   template <typename LHS> inline auto operator()(const LHS &L) const {
>     return m_FMul(L, m_SpecificFP(2.0));
>   }
> };
> 
> template <typename TMul2, typename TCAdd, typename TMul>
> static std::tuple<bool, Value *, Value *>
> MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
>                  const TMul &Mul) {
>   Value *A, *B;
> 
>   // (a * a) + (((a * 2) + b) * b)
>   if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
>                      m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
>                                   m_Deferred(B))))))
>     return std::make_tuple(true, A, B);
> 
>   // ((a * b) * 2)  or ((a * 2) * b)
>   // +
>   // (a * a + b * b) or (b * b + a * a)
>   return std::make_tuple(
>       match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
>                                  m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
>                      m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
>                                    Mul(m_Deferred(B), m_Deferred(B)))))),
>       A, B);
> }
> 
> // Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
> // if `FP`: requires `nsz` and `reassoc`.
> Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
>   if (FP) {
>     assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
>            "Assumption mismatch");
>   }
> 
>   const std::tuple<bool, Value *, Value *> Match =
>       FP ? MatchesSquareSum(I, XMul2<true>(), XCAdd<true>(), XMul<true>())
>          : MatchesSquareSum(I, XMul2<false>(), XCAdd<false>(), XMul<false>());
> 
>   // if one of them matches: -> (a + b)^2
>   if (std::get<0>(Match)) {
>     Value *AB =
>         Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
>     return BinaryOperator::CreateFMulFMF(AB, AB, &I);
>   }
> 
>   return nullptr;
> }
> ```
How about something along the lines of:
```
template <unsigned OpcMul, unsigned OpcAdd, unsigned OpcMul2, typename Mul2Rhs>
static bool foldSquareSum(BinaryOperator &I, Mul2Rhs MRhs, Value *&AOut,
                                  Value *&BOut) {
  Value *A, *B;
  bool Matches = match(
      &I,
      m_c_BinOp(OpcAdd, m_OneUse(m_BinOp(OpcMul, m_Value(A), m_Deferred(A))),
                m_OneUse(m_BinOp(
                    OpcMul,
                    m_c_BinOp(OpcAdd, m_BinOp(OpcMul2, m_Deferred(A), MRhs),
                              m_Value(B)),
                    m_Deferred(B)))));
  if (!Matches) {
    Matches = match(
        &I,
        m_c_BinOp(
            OpcAdd,
            m_CombineOr(
                m_OneUse(m_BinOp(
                    OpcMul2, m_BinOp(OpcMul, m_Value(A), m_Value(B)), MRhs)),
                m_OneUse(m_BinOp(OpcMul, m_BinOp(OpcMul2, m_Value(A), MRhs),
                                 m_Value(B)))),
            m_OneUse(
                m_c_BinOp(OpcAdd, m_BinOp(OpcMul, m_Deferred(A), m_Deferred(A)),
                          m_BinOp(OpcMul, m_Deferred(B), m_Deferred(B))))));
  }
  AOut = A;
  BOut = B;
  return Matches;
}


// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
Instruction *InstCombinerImpl::foldSquareSumInts(BinaryOperator &I) {
  Value *A, *B;

  bool Matches =
      foldSquareSum<Instruction::Mul, Instruction::Add, Instruction::Shl>(
          I, m_SpecificInt(1), A, B);
  // if one of them matches: -> (a + b)^2
  if (Matches) {
    Value *AB = Builder.CreateAdd(A, B);
    return BinaryOperator::CreateMul(AB, AB);
  }

  return nullptr;
}


// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
// Requires `nsz` and `reassoc`.

Instruction *InstCombinerImpl::foldSquareSumFloat(BinaryOperator &I) {
  Value *A, *B;

  assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch");

  bool Matches =
      foldSquareSum<Instruction::FMul, Instruction::FAdd, Instruction::FMul>(
          I, m_SpecificFP(2.0), A, B);

  // if one of them matches: -> (a + b)^2
  if (Matches) {
    Value *AB = Builder.CreateFAddFMF(A, B, &I);
    return BinaryOperator::CreateFMulFMF(AB, AB, &I);
  }

  return nullptr;
}
```

Needs comments/whatnot but don't see why this would fallshort.
All the InstCombine tests pass with this (I assume including all the tests relevant to int/fp version of this).


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D158079/new/

https://reviews.llvm.org/D158079



More information about the llvm-commits mailing list