[PATCH] D158079: [InstCombine] Contracting x^2 + 2*x*y + y^2 to (x + y)^2 (float)

Christoph Stiller via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 19 07:24:02 PDT 2023


rainerzufalldererste marked an inline comment as done.
rainerzufalldererste added inline comments.


================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp:1059
+                                       m_FMul(m_Deferred(B), m_Deferred(B))))));
+  }
+
----------------
goldstein.w.n wrote:
> rainerzufalldererste wrote:
> > rainerzufalldererste wrote:
> > > goldstein.w.n wrote:
> > > > rainerzufalldererste wrote:
> > > > > rainerzufalldererste wrote:
> > > > > > goldstein.w.n wrote:
> > > > > > > This match code is basically identical to `foldSquareSumInts`. The only difference other than `FMul` vs `Mul` is you do match `FMul(A, 2)` for floats and `m_Shl(A, 1)` for ints.
> > > > > > > Can you make the match code a helper that takes either fmul/2x matcher (or just lambda wrapping) so it can be used for SumFloat / SumInt?
> > > > > > Does that imply that `m_c_FAdd` can simply be replaced with `m_c_Add` and will continue to match properly for floating point values as well?
> > > > > > I presume that would entail partially matching another pattern and then deferring the actual check for the mul2 match, as `BinaryOp_match<RHS, LHS, OpCode>` would have different `OpCode`s for `FMul` and `Shl`, which sounds like a huge mess to me; or is there a cleaner way to do that?
> > > > > > 
> > > > > > Something like this sadly doesn't compile (as the lambda return type is ambiguous):
> > > > > > ```
> > > > > >   const auto FpMul2Matcher = [](auto &value) {
> > > > > >     return m_FMul(value, m_SpecificFP(2.0));
> > > > > >   };
> > > > > >   const auto IntMul2Matcher = [](auto &value) {
> > > > > >     return m_Shl(value, m_SpecificInt(1));
> > > > > >   };
> > > > > >   const auto Mul2Matcher = FP ? FpMul2Matcher : IntMul2Matcher;
> > > > > > ```
> > > > > Even something like this shouldn't work.
> > > > > 
> > > > > ```
> > > > > template <typename TMul2, typename TCAdd, typename TMul>
> > > > > static bool MatchesSquareSum(BinaryOperator &I, Value *&A, Value *&B,
> > > > >                              const TMul2 &Mul2, const TCAdd &CAdd,
> > > > >                              const TMul &Mul) {
> > > > > 
> > > > >   // (a * a) + (((a * 2) + b) * b)
> > > > >   bool Matches =
> > > > >       match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
> > > > >                      m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
> > > > >                                   m_Deferred(B)))));
> > > > > 
> > > > >   // ((a * b) * 2)  or ((a * 2) * b)
> > > > >   // +
> > > > >   // (a * a + b * b) or (b * b + a * a)
> > > > >   if (!Matches) {
> > > > >     Matches =
> > > > >         match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
> > > > >                                    m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
> > > > >                        m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
> > > > >                                      Mul(m_Deferred(B), m_Deferred(B))))));
> > > > >   }
> > > > > 
> > > > >   return Matches;
> > > > > }
> > > > > ```
> > > > > 
> > > > > I agree that it's messy to have duplicate code, but with the way op-codes are used as template parameters I don't see a way without template specialization to do this nicely; and with template specialization it's even more of a beast.
> > > > > Am I missing some obvious way built into llvm/InstCombine to do this nicely?
> > > > Why doesn't that code work?
> > > Assuming `TMul2` etc. to be a lambda, the return type couln't be consistent, as for both `m_FMul` and `m_Shl` it'd be `BinaryOp_match<RHS, LHS, OpCode>`, with the same `OpCode` for each invocation, but different `RHS` and `LHS`. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (`Mul2`, `Mul`, `CAdd`) that simply map to the correct functions for `FAdd`/`Add` etc.
> > > However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.
> > Have you been able to come up with some better ideas? Maybe it's not _that_ terrible to go down the template specialization route, as many of the integer optimizations may have similar counterparts in FP with `nsz` and `reassoc`. Not sure how many of them are already handled twice, but there's a chance one could simplify this process by providing template specialized `m_XAdd<IsFP>(LHS, RHS)` etc. However, I'm not sure if I'm the right person to pass judgement on something that large, as I'm still very new to both LLVM and InstCombine.
> > Assuming `TMul2` etc. to be a lambda, the return type couln't be consistent, as for both `m_FMul` and `m_Shl` it'd be `BinaryOp_match<RHS, LHS, OpCode>`, with the same `OpCode` for each invocation, but different `RHS` and `LHS`. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (`Mul2`, `Mul`, `CAdd`) that simply map to the correct functions for `FAdd`/`Add` etc.
> 
> For the TMul2 don't you only need a single Value?
> Instead of passing a BinaryOperator, you could just pass a lambda i.e:
> 
> ```
> auto FPMul2 = [](Value *& A) {
>    return match(m_FMul(m_Value(A), m_SpecificFP(2));
> };
> 
> ...
> auto IntMul2 = [](Value *&A) {
>   return match(m_Shl(m_Value(A), m_SpecificInt(1));
> };
> ```
> 
> Don't see why the same isn't true for mul/add (although two values then).
> > However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.
> 
> 
Regarding the LHS and RHS, you are correct, I misspoke. The `OpCode` and `RHS` are consistent, but `LHS` isn't. There are multiple cases where `TMul2` is used:

`Mul2(m_Deferred(A)`
`Mul2(Mul(m_Value(A), m_Value(B))`
`Mul2(m_Value(A))`

All of these parameters have different types, therefore the return type of this lambda would also be different in every case. So if the parameter were `Value *&`, this wouldn't be a problem at all, but that's simply not the case. Is there a way to cast these types to `Value *&` somehow (without capturing them separately and then matching things again against the sub-match-lambda)?

`mDeferred` returns `deferredval_ty<Value>`.
`Mul(m_Value(), m_Value()` returns either `BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::FMul>` or `BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::Mul>`.
`m_Value` returns `bind_ty<Value>`.

These types aren't compatible, so the template can't deduce a consistent type even from `auto`-parameter lambdas. Same with `Mul` & `CAdd`.

Apart from that, I'm a bit confused about the `match` in your comment, as that's not quite applicable, unless we're previously matching parts of the match and then checking them against this follow-up matcher lambda, which - even if we were to do that - would end up in a large mess, as that's not only the case with `Mul2`, but also `CAdd` & `Mul` then, turning these two large matches into a ton of tiny matches.

Otherwise, I'm not quite sure why I'm explaining compilation errors here, unless I'm missing something very obvious or am completely missing the point.

This, however, isn't valid C++ code:
```
template <typename TMul2, typename TCAdd, typename TMul>
static std::tuple<bool, Value *, Value *>
MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
                 const TMul &Mul) {
  Value *A, *B;

  // (a * a) + (((a * 2) + b) * b)
  if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
                     m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
                                  m_Deferred(B))))))
    return std::make_tuple(true, A, B);

  // ((a * b) * 2)  or ((a * 2) * b)
  // +
  // (a * a + b * b) or (b * b + a * a)
  return std::make_tuple(
      match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
                                 m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
                     m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
                                   Mul(m_Deferred(B), m_Deferred(B)))))),
      A, B);
}

// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
// if `FP`: requires `nsz` and `reassoc`.
Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
  if (FP) {
    assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
           "Assumption mismatch");
  }

  std::tuple<bool, Value *, Value *> Match;

  if (FP) {
    Match = MatchesSquareSum(
        I, [](auto &V) { return m_FMul(V, m_SpecificFP(2.0)); },
        [](auto &L, auto &R) { return m_c_FAdd(L, R); },
        [](auto &L, auto &R) { return m_FMul(L, R); });
  } else {
    Match = MatchesSquareSum(
        I, [](auto &V) { return m_Shl(V, m_SpecificInt(1)); },
        [](auto &L, auto &R) { return m_c_Add(L, R); },
        [](auto &L, auto &R) { return m_Mul(L, R); });
  }

  // if one of them matches: -> (a + b)^2
  if (std::get<0>(Match)) {
    Value *AB =
        Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
    return BinaryOperator::CreateFMulFMF(AB, AB, &I);
  }

  return nullptr;
}
```

This _is_ valid C++ code, but uses template specialization to get around the previous type-ambiguity issues:
```
template <bool IsFP> struct XMul;

template <> struct XMul<false> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_Mul(L, R);
  }
};

template <> struct XMul<true> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_FMul(L, R);
  }
};

template <bool IsFP> struct XCAdd;

template <> struct XCAdd<false> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_c_Add(L, R);
  }
};

template <> struct XCAdd<true> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_c_FAdd(L, R);
  }
};

template <bool IsFP> struct XMul2;

template <> struct XMul2<false> {
  template <typename LHS> inline auto operator()(const LHS &L) const {
    return m_Shl(L, m_SpecificInt(1));
  }
};

template <> struct XMul2<true> {
  template <typename LHS> inline auto operator()(const LHS &L) const {
    return m_FMul(L, m_SpecificFP(2.0));
  }
};

template <typename TMul2, typename TCAdd, typename TMul>
static std::tuple<bool, Value *, Value *>
MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
                 const TMul &Mul) {
  Value *A, *B;

  // (a * a) + (((a * 2) + b) * b)
  if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
                     m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
                                  m_Deferred(B))))))
    return std::make_tuple(true, A, B);

  // ((a * b) * 2)  or ((a * 2) * b)
  // +
  // (a * a + b * b) or (b * b + a * a)
  return std::make_tuple(
      match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
                                 m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
                     m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
                                   Mul(m_Deferred(B), m_Deferred(B)))))),
      A, B);
}

// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
// if `FP`: requires `nsz` and `reassoc`.
Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
  if (FP) {
    assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
           "Assumption mismatch");
  }

  const std::tuple<bool, Value *, Value *> Match =
      FP ? MatchesSquareSum(I, XMul2<true>(), XCAdd<true>(), XMul<true>())
         : MatchesSquareSum(I, XMul2<false>(), XCAdd<false>(), XMul<false>());

  // if one of them matches: -> (a + b)^2
  if (std::get<0>(Match)) {
    Value *AB =
        Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
    return BinaryOperator::CreateFMulFMF(AB, AB, &I);
  }

  return nullptr;
}
```


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D158079/new/

https://reviews.llvm.org/D158079



More information about the llvm-commits mailing list