[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Jul 8 13:29:16 PDT 2024
================
@@ -430,29 +430,27 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
// Decompose the target operation if it implements the AggregatedOpInterface.
// Push the decomposed operations (the ones that replaces the values produced by
-// \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
- if (!decomposableOp) {
- failed(rewriter.notifyMatchFailure(target,
- "payload is not a decomposable op"));
- return emitDefaultSilenceableFailure(target);
- }
+// \p target) in the `results`. Decompositions for all targets bind to the same
+// single ouptut value, thus the information about the original targets is lost.
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ SmallVector<Operation *> allDecomposedOps;
+ for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+ auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+ if (!decomposableOp)
+ continue;
- FailureOr<SmallVector<Value>> maybeNewResults =
- decomposableOp.decomposeOperation(rewriter);
- if (failed(maybeNewResults))
- return emitDefaultSilenceableFailure(target);
+ FailureOr<DecompositionResult> maybeNewResults =
+ decomposableOp.decomposeOperation(rewriter);
+ if (failed(maybeNewResults))
+ return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(decomposableOp, *maybeNewResults);
- for (Value val : *maybeNewResults) {
- Operation *definition = val.getDefiningOp();
- if (definition)
- results.push_back(definition);
+ rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+ allDecomposedOps.append(maybeNewResults->decomposedOps);
----------------
ftynse wrote:
https://mlir.llvm.org/docs/Dialects/Transform/#transformforeach_match-transformforeachmatchop has `flatten_results`. You can `git blame` your way back to the PR that added it. The implementation is rather long, in this case it's just a matter of having an attribute and a conditional as described above.
https://github.com/llvm/llvm-project/pull/97582
More information about the Mlir-commits
mailing list