[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Fri Jul 5 07:59:13 PDT 2024
================
@@ -430,29 +430,27 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
// Decompose the target operation if it implements the AggregatedOpInterface.
// Push the decomposed operations (the ones that replaces the values produced by
-// \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
- if (!decomposableOp) {
- failed(rewriter.notifyMatchFailure(target,
- "payload is not a decomposable op"));
- return emitDefaultSilenceableFailure(target);
- }
+// \p target) in the `results`. Decompositions for all targets bind to the same
+// single ouptut value, thus the information about the original targets is lost.
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ SmallVector<Operation *> allDecomposedOps;
+ for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+ auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+ if (!decomposableOp)
+ continue;
- FailureOr<SmallVector<Value>> maybeNewResults =
- decomposableOp.decomposeOperation(rewriter);
- if (failed(maybeNewResults))
- return emitDefaultSilenceableFailure(target);
+ FailureOr<DecompositionResult> maybeNewResults =
+ decomposableOp.decomposeOperation(rewriter);
+ if (failed(maybeNewResults))
+ return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(decomposableOp, *maybeNewResults);
- for (Value val : *maybeNewResults) {
- Operation *definition = val.getDefiningOp();
- if (definition)
- results.push_back(definition);
+ rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+ allDecomposedOps.append(maybeNewResults->decomposedOps);
----------------
ftynse wrote:
I'm not a fan of mixing all results together. It basically makes the result handle useless for further composition outside of the single-payload operand case. The most recent attempt at "fixing" this is to introduce a `UnitAttr:$flatten` the presence of which explicitly requests the results to be flattened into a single list. Its absence enables a check for operand being associated with at most one payload. This makes the behavior visible to the user and thus less surprising. When needed, they can wrap the logic in `transform.foreach.
https://github.com/llvm/llvm-project/pull/97582
More information about the Mlir-commits
mailing list