[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Petr Kurapov llvmlistbot at llvm.org
Thu Jul 4 03:22:34 PDT 2024


================
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
 def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    TODO
+    Decomposes high-level named ops into a sequence of non-aggregate named ops
+    via `AggregatedOpInterface`.
+
+    The operation ignores non-decomposable ops. The return handles point to
+    a sequence of named ops produced by the decomposition.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
----------------
kurapov-peter wrote:

I went with binding all the payload with a single output value (reverting the op change). This was easiest, it doesn't modify the original interface, and has less verbose usage:

```mlir
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
  %out = tensor.empty() : tensor<2x16x32xf32>
  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%out: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
  %2 = linalg.softmax dimension(1) ins(%1: tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
  return %2 : tensor<2x16x32xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> (!transform.any_op)
    %5 = transform.structured.generalize %3: (!transform.any_op) -> !transform.any_op

    transform.yield
  }
}
```

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list