[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Jul 8 13:29:16 PDT 2024


================
@@ -430,29 +430,27 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
 
 // Decompose the target operation if it implements the AggregatedOpInterface.
 // Push the decomposed operations (the ones that replaces the values produced by
-// \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
-  if (!decomposableOp) {
-    failed(rewriter.notifyMatchFailure(target,
-                                       "payload is not a decomposable op"));
-    return emitDefaultSilenceableFailure(target);
-  }
+// \p target) in the `results`. Decompositions for all targets bind to the same
+// single ouptut value, thus the information about the original targets is lost.
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       TransformResults &transformResults,
+                                       TransformState &state) {
+  SmallVector<Operation *> allDecomposedOps;
+  for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+    auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+    if (!decomposableOp)
+      continue;
 
-  FailureOr<SmallVector<Value>> maybeNewResults =
-      decomposableOp.decomposeOperation(rewriter);
-  if (failed(maybeNewResults))
-    return emitDefaultSilenceableFailure(target);
+    FailureOr<DecompositionResult> maybeNewResults =
+        decomposableOp.decomposeOperation(rewriter);
+    if (failed(maybeNewResults))
+      return emitDefaultSilenceableFailure(target);
 
-  rewriter.replaceOp(decomposableOp, *maybeNewResults);
-  for (Value val : *maybeNewResults) {
-    Operation *definition = val.getDefiningOp();
-    if (definition)
-      results.push_back(definition);
+    rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+    allDecomposedOps.append(maybeNewResults->decomposedOps);
----------------
ftynse wrote:

https://mlir.llvm.org/docs/Dialects/Transform/#transformforeach_match-transformforeachmatchop has `flatten_results`. You can `git blame` your way back to the PR that added it. The implementation is rather long, in this case it's just a matter of having an attribute and a conditional as described above.

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list