[PATCH] D158079: [InstCombine] Contracting x^2 + 2*x*y + y^2 to (x + y)^2 (float)
Noah Goldstein via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Sat Aug 19 11:25:23 PDT 2023
goldstein.w.n added inline comments.
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp:1059
+ m_FMul(m_Deferred(B), m_Deferred(B))))));
+ }
+
----------------
rainerzufalldererste wrote:
> goldstein.w.n wrote:
> > rainerzufalldererste wrote:
> > > rainerzufalldererste wrote:
> > > > goldstein.w.n wrote:
> > > > > rainerzufalldererste wrote:
> > > > > > rainerzufalldererste wrote:
> > > > > > > goldstein.w.n wrote:
> > > > > > > > This match code is basically identical to `foldSquareSumInts`. The only difference other than `FMul` vs `Mul` is you do match `FMul(A, 2)` for floats and `m_Shl(A, 1)` for ints.
> > > > > > > > Can you make the match code a helper that takes either fmul/2x matcher (or just lambda wrapping) so it can be used for SumFloat / SumInt?
> > > > > > > Does that imply that `m_c_FAdd` can simply be replaced with `m_c_Add` and will continue to match properly for floating point values as well?
> > > > > > > I presume that would entail partially matching another pattern and then deferring the actual check for the mul2 match, as `BinaryOp_match<RHS, LHS, OpCode>` would have different `OpCode`s for `FMul` and `Shl`, which sounds like a huge mess to me; or is there a cleaner way to do that?
> > > > > > >
> > > > > > > Something like this sadly doesn't compile (as the lambda return type is ambiguous):
> > > > > > > ```
> > > > > > > const auto FpMul2Matcher = [](auto &value) {
> > > > > > > return m_FMul(value, m_SpecificFP(2.0));
> > > > > > > };
> > > > > > > const auto IntMul2Matcher = [](auto &value) {
> > > > > > > return m_Shl(value, m_SpecificInt(1));
> > > > > > > };
> > > > > > > const auto Mul2Matcher = FP ? FpMul2Matcher : IntMul2Matcher;
> > > > > > > ```
> > > > > > Even something like this shouldn't work.
> > > > > >
> > > > > > ```
> > > > > > template <typename TMul2, typename TCAdd, typename TMul>
> > > > > > static bool MatchesSquareSum(BinaryOperator &I, Value *&A, Value *&B,
> > > > > > const TMul2 &Mul2, const TCAdd &CAdd,
> > > > > > const TMul &Mul) {
> > > > > >
> > > > > > // (a * a) + (((a * 2) + b) * b)
> > > > > > bool Matches =
> > > > > > match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
> > > > > > m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
> > > > > > m_Deferred(B)))));
> > > > > >
> > > > > > // ((a * b) * 2) or ((a * 2) * b)
> > > > > > // +
> > > > > > // (a * a + b * b) or (b * b + a * a)
> > > > > > if (!Matches) {
> > > > > > Matches =
> > > > > > match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
> > > > > > m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
> > > > > > m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
> > > > > > Mul(m_Deferred(B), m_Deferred(B))))));
> > > > > > }
> > > > > >
> > > > > > return Matches;
> > > > > > }
> > > > > > ```
> > > > > >
> > > > > > I agree that it's messy to have duplicate code, but with the way op-codes are used as template parameters I don't see a way without template specialization to do this nicely; and with template specialization it's even more of a beast.
> > > > > > Am I missing some obvious way built into llvm/InstCombine to do this nicely?
> > > > > Why doesn't that code work?
> > > > Assuming `TMul2` etc. to be a lambda, the return type couln't be consistent, as for both `m_FMul` and `m_Shl` it'd be `BinaryOp_match<RHS, LHS, OpCode>`, with the same `OpCode` for each invocation, but different `RHS` and `LHS`. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (`Mul2`, `Mul`, `CAdd`) that simply map to the correct functions for `FAdd`/`Add` etc.
> > > > However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.
> > > Have you been able to come up with some better ideas? Maybe it's not _that_ terrible to go down the template specialization route, as many of the integer optimizations may have similar counterparts in FP with `nsz` and `reassoc`. Not sure how many of them are already handled twice, but there's a chance one could simplify this process by providing template specialized `m_XAdd<IsFP>(LHS, RHS)` etc. However, I'm not sure if I'm the right person to pass judgement on something that large, as I'm still very new to both LLVM and InstCombine.
> > > Assuming `TMul2` etc. to be a lambda, the return type couln't be consistent, as for both `m_FMul` and `m_Shl` it'd be `BinaryOp_match<RHS, LHS, OpCode>`, with the same `OpCode` for each invocation, but different `RHS` and `LHS`. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (`Mul2`, `Mul`, `CAdd`) that simply map to the correct functions for `FAdd`/`Add` etc.
> >
> > For the TMul2 don't you only need a single Value?
> > Instead of passing a BinaryOperator, you could just pass a lambda i.e:
> >
> > ```
> > auto FPMul2 = [](Value *& A) {
> > return match(m_FMul(m_Value(A), m_SpecificFP(2));
> > };
> >
> > ...
> > auto IntMul2 = [](Value *&A) {
> > return match(m_Shl(m_Value(A), m_SpecificInt(1));
> > };
> > ```
> >
> > Don't see why the same isn't true for mul/add (although two values then).
> > > However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.
> >
> >
> Regarding the LHS and RHS, you are correct, I misspoke. The `OpCode` and `RHS` are consistent, but `LHS` isn't. There are multiple cases where `TMul2` is used:
>
> `Mul2(m_Deferred(A)`
> `Mul2(Mul(m_Value(A), m_Value(B))`
> `Mul2(m_Value(A))`
>
> All of these parameters have different types, therefore the return type of this lambda would also be different in every case. So if the parameter were `Value *&`, this wouldn't be a problem at all, but that's simply not the case. Is there a way to cast these types to `Value *&` somehow (without capturing them separately and then matching things again against the sub-match-lambda)?
>
> `mDeferred` returns `deferredval_ty<Value>`.
> `Mul(m_Value(), m_Value()` returns either `BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::FMul>` or `BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::Mul>`.
> `m_Value` returns `bind_ty<Value>`.
>
> These types aren't compatible, so the template can't deduce a consistent type even from `auto`-parameter lambdas. Same with `Mul` & `CAdd`.
>
> Apart from that, I'm a bit confused about the `match` in your comment, as that's not quite applicable, unless we're previously matching parts of the match and then checking them against this follow-up matcher lambda, which - even if we were to do that - would end up in a large mess, as that's not only the case with `Mul2`, but also `CAdd` & `Mul` then, turning these two large matches into a ton of tiny matches.
>
> Otherwise, I'm not quite sure why I'm explaining compilation errors here, unless I'm missing something very obvious or am completely missing the point.
>
> This, however, isn't valid C++ code:
> ```
> template <typename TMul2, typename TCAdd, typename TMul>
> static std::tuple<bool, Value *, Value *>
> MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
> const TMul &Mul) {
> Value *A, *B;
>
> // (a * a) + (((a * 2) + b) * b)
> if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
> m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
> m_Deferred(B))))))
> return std::make_tuple(true, A, B);
>
> // ((a * b) * 2) or ((a * 2) * b)
> // +
> // (a * a + b * b) or (b * b + a * a)
> return std::make_tuple(
> match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
> m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
> m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
> Mul(m_Deferred(B), m_Deferred(B)))))),
> A, B);
> }
>
> // Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
> // if `FP`: requires `nsz` and `reassoc`.
> Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
> if (FP) {
> assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
> "Assumption mismatch");
> }
>
> std::tuple<bool, Value *, Value *> Match;
>
> if (FP) {
> Match = MatchesSquareSum(
> I, [](auto &V) { return m_FMul(V, m_SpecificFP(2.0)); },
> [](auto &L, auto &R) { return m_c_FAdd(L, R); },
> [](auto &L, auto &R) { return m_FMul(L, R); });
> } else {
> Match = MatchesSquareSum(
> I, [](auto &V) { return m_Shl(V, m_SpecificInt(1)); },
> [](auto &L, auto &R) { return m_c_Add(L, R); },
> [](auto &L, auto &R) { return m_Mul(L, R); });
> }
>
> // if one of them matches: -> (a + b)^2
> if (std::get<0>(Match)) {
> Value *AB =
> Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
> return BinaryOperator::CreateFMulFMF(AB, AB, &I);
> }
>
> return nullptr;
> }
> ```
>
> This _is_ valid C++ code, but uses template specialization to get around the previous type-ambiguity issues:
> ```
> template <bool IsFP> struct XMul;
>
> template <> struct XMul<false> {
> template <typename LHS, typename RHS>
> inline auto operator()(const LHS &L, const RHS &R) const {
> return m_Mul(L, R);
> }
> };
>
> template <> struct XMul<true> {
> template <typename LHS, typename RHS>
> inline auto operator()(const LHS &L, const RHS &R) const {
> return m_FMul(L, R);
> }
> };
>
> template <bool IsFP> struct XCAdd;
>
> template <> struct XCAdd<false> {
> template <typename LHS, typename RHS>
> inline auto operator()(const LHS &L, const RHS &R) const {
> return m_c_Add(L, R);
> }
> };
>
> template <> struct XCAdd<true> {
> template <typename LHS, typename RHS>
> inline auto operator()(const LHS &L, const RHS &R) const {
> return m_c_FAdd(L, R);
> }
> };
>
> template <bool IsFP> struct XMul2;
>
> template <> struct XMul2<false> {
> template <typename LHS> inline auto operator()(const LHS &L) const {
> return m_Shl(L, m_SpecificInt(1));
> }
> };
>
> template <> struct XMul2<true> {
> template <typename LHS> inline auto operator()(const LHS &L) const {
> return m_FMul(L, m_SpecificFP(2.0));
> }
> };
>
> template <typename TMul2, typename TCAdd, typename TMul>
> static std::tuple<bool, Value *, Value *>
> MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
> const TMul &Mul) {
> Value *A, *B;
>
> // (a * a) + (((a * 2) + b) * b)
> if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
> m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
> m_Deferred(B))))))
> return std::make_tuple(true, A, B);
>
> // ((a * b) * 2) or ((a * 2) * b)
> // +
> // (a * a + b * b) or (b * b + a * a)
> return std::make_tuple(
> match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
> m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
> m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
> Mul(m_Deferred(B), m_Deferred(B)))))),
> A, B);
> }
>
> // Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
> // if `FP`: requires `nsz` and `reassoc`.
> Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
> if (FP) {
> assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
> "Assumption mismatch");
> }
>
> const std::tuple<bool, Value *, Value *> Match =
> FP ? MatchesSquareSum(I, XMul2<true>(), XCAdd<true>(), XMul<true>())
> : MatchesSquareSum(I, XMul2<false>(), XCAdd<false>(), XMul<false>());
>
> // if one of them matches: -> (a + b)^2
> if (std::get<0>(Match)) {
> Value *AB =
> Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
> return BinaryOperator::CreateFMulFMF(AB, AB, &I);
> }
>
> return nullptr;
> }
> ```
How about something along the lines of:
```
template <unsigned OpcMul, unsigned OpcAdd, unsigned OpcMul2, typename Mul2Rhs>
static bool foldSquareSum(BinaryOperator &I, Mul2Rhs MRhs, Value *&AOut,
Value *&BOut) {
Value *A, *B;
bool Matches = match(
&I,
m_c_BinOp(OpcAdd, m_OneUse(m_BinOp(OpcMul, m_Value(A), m_Deferred(A))),
m_OneUse(m_BinOp(
OpcMul,
m_c_BinOp(OpcAdd, m_BinOp(OpcMul2, m_Deferred(A), MRhs),
m_Value(B)),
m_Deferred(B)))));
if (!Matches) {
Matches = match(
&I,
m_c_BinOp(
OpcAdd,
m_CombineOr(
m_OneUse(m_BinOp(
OpcMul2, m_BinOp(OpcMul, m_Value(A), m_Value(B)), MRhs)),
m_OneUse(m_BinOp(OpcMul, m_BinOp(OpcMul2, m_Value(A), MRhs),
m_Value(B)))),
m_OneUse(
m_c_BinOp(OpcAdd, m_BinOp(OpcMul, m_Deferred(A), m_Deferred(A)),
m_BinOp(OpcMul, m_Deferred(B), m_Deferred(B))))));
}
AOut = A;
BOut = B;
return Matches;
}
// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
Instruction *InstCombinerImpl::foldSquareSumInts(BinaryOperator &I) {
Value *A, *B;
bool Matches =
foldSquareSum<Instruction::Mul, Instruction::Add, Instruction::Shl>(
I, m_SpecificInt(1), A, B);
// if one of them matches: -> (a + b)^2
if (Matches) {
Value *AB = Builder.CreateAdd(A, B);
return BinaryOperator::CreateMul(AB, AB);
}
return nullptr;
}
// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
// Requires `nsz` and `reassoc`.
Instruction *InstCombinerImpl::foldSquareSumFloat(BinaryOperator &I) {
Value *A, *B;
assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch");
bool Matches =
foldSquareSum<Instruction::FMul, Instruction::FAdd, Instruction::FMul>(
I, m_SpecificFP(2.0), A, B);
// if one of them matches: -> (a + b)^2
if (Matches) {
Value *AB = Builder.CreateFAddFMF(A, B, &I);
return BinaryOperator::CreateFMulFMF(AB, AB, &I);
}
return nullptr;
}
```
Needs comments/whatnot but don't see why this would fallshort.
All the InstCombine tests pass with this (I assume including all the tests relevant to int/fp version of this).
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D158079/new/
https://reviews.llvm.org/D158079
More information about the llvm-commits
mailing list