[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 23 12:04:14 PDT 2024
MaheshRavishankar wrote:
There is broadly two things that we need to make progress here
1) This is two separate PRs, one that is changing the softmax decomposition and one that is adding a pass for decomposition. The latter should be easy to land.
2) The change to the decomposition of softmax IMO is an inefficient lowering of softmax and will require "something" to get the state back. This should be part of the PR that is changing the decomposition. It is moving from a more succinct representation that Linalg allows to something that is (artifically) hamstrung with current definitions of the named ops. I dont expect the issue with named ops to be fixed as a precursor (though that would be the right thing to do IMO), but for this PR, I dont see how we can land it without having an option to chose how to decompose softmax (with default being what it is today, and an option to lower to sequence of named ops). On top of that adding a generalization to convert everything to `linalg.generic`s is a non-starter IMO. You will be forcing all downstream users to either use "recognizers" heavily to retrieve back the information that is lost by generalization and not giving downstream users control on when they want to generalize.
> This seems to be a problem to @MaheshRavishankar and I want to understand it better. My guess is that there are pattern matchers that won't work with the generic version of `fill` (and why we want named ops in the first place).
>
Just to map back to what I said above, we can "recognize" that its a fill, but that seems like an unnecessary burden added to downstream users because it has been generalized too early without any control. I can go into details about why I think "fill" is special but thats a separate issue IMO.
> Regarding the `fill` issue, I think customizable decomposition would be a reasonable solution - helps preserve the downstream usage and doesn't hold back the upstream.
>
> Regarding `broadcast`, I could work on setting the semantics for implicit casting. One thing that is unclear to me though is whether having implicit cast semantics for named ops is beneficial. Wasn't the whole point of named ops to have a very explicit IR that is easy to analyze? In that regard, the absence of implicit casts is actually a good thing (I also don't see how it is ambiguous, could you please clarify?). Is there any real problem with broadcasts except for not being succinct? Wouldn't implicit casting just add unnecessary burden for analyses and transforms to handle various cases of arguments?
I want clarify this. This is NOT implicit broadcasting. This is very much unambiguous broadcast representation. For example, `linalg.generic` allows you to represent broadcast-add this way
```
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0, %1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%3 = arith.addf %b0, %b1 : f32
linalg.yield %3: f32
} -> tensor<?x?xf32>
```
There is nothing ambiguous or implicit in this broadcasting. The problem with named ops is that it forces all operands to be of the same rank, which is an unnecessary requirement at Linalg level. The fix is to allow named ops to make use of the broadcast representation that Linalg inherently allows. In the name of "explicit broadcasting" we have an artificial requirement of getting all operands to the same rank that is unnecessary IMO. Also it strictly easier to go from this representation to a representation that requires all operands to be of same rank (its essentially a lowering, you break up the operation into multiple ops). Going from a representation where all ops are "broadcasted" to the same rank to the above representation is IMO a lifting.
Actually that brings me to maybe a potential solution. You can take the existing lowering for softmax and then add a pass to explicitly split out the broadcast and then generalize. That will get you to the state you want here?
https://github.com/llvm/llvm-project/pull/97582
More information about the Mlir-commits
mailing list