[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jul 5 07:59:13 PDT 2024


================
@@ -430,29 +430,27 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
 
 // Decompose the target operation if it implements the AggregatedOpInterface.
 // Push the decomposed operations (the ones that replaces the values produced by
-// \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
-  if (!decomposableOp) {
-    failed(rewriter.notifyMatchFailure(target,
-                                       "payload is not a decomposable op"));
-    return emitDefaultSilenceableFailure(target);
-  }
+// \p target) in the `results`. Decompositions for all targets bind to the same
+// single ouptut value, thus the information about the original targets is lost.
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       TransformResults &transformResults,
+                                       TransformState &state) {
+  SmallVector<Operation *> allDecomposedOps;
+  for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+    auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+    if (!decomposableOp)
+      continue;
 
-  FailureOr<SmallVector<Value>> maybeNewResults =
-      decomposableOp.decomposeOperation(rewriter);
-  if (failed(maybeNewResults))
-    return emitDefaultSilenceableFailure(target);
+    FailureOr<DecompositionResult> maybeNewResults =
+        decomposableOp.decomposeOperation(rewriter);
+    if (failed(maybeNewResults))
+      return emitDefaultSilenceableFailure(target);
 
-  rewriter.replaceOp(decomposableOp, *maybeNewResults);
-  for (Value val : *maybeNewResults) {
-    Operation *definition = val.getDefiningOp();
-    if (definition)
-      results.push_back(definition);
+    rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+    allDecomposedOps.append(maybeNewResults->decomposedOps);
----------------
ftynse wrote:

I'm not a fan of mixing all results together. It basically makes the result handle useless for further composition outside of the single-payload operand case. The most recent attempt at "fixing" this is to introduce a `UnitAttr:$flatten` the presence of which explicitly requests the results to be flattened into a single list. Its absence enables a check for operand being associated with at most one payload. This makes the behavior visible to the user and thus less surprising. When needed, they can wrap the logic in `transform.foreach.

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list