[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
    Petr Kurapov 
    llvmlistbot at llvm.org
       
    Thu Jul  4 03:22:34 PDT 2024
    
    
  
================
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
 def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    TODO
+    Decomposes high-level named ops into a sequence of non-aggregate named ops
+    via `AggregatedOpInterface`.
+
+    The operation ignores non-decomposable ops. The return handles point to
+    a sequence of named ops produced by the decomposition.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
----------------
kurapov-peter wrote:
I went with binding all the payload with a single output value (reverting the op change). This was easiest, it doesn't modify the original interface, and has less verbose usage:
```mlir
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
  %out = tensor.empty() : tensor<2x16x32xf32>
  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%out: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
  %2 = linalg.softmax dimension(1) ins(%1: tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
  return %2 : tensor<2x16x32xf32>
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> (!transform.any_op)
    %5 = transform.structured.generalize %3: (!transform.any_op) -> !transform.any_op
    transform.yield
  }
}
```
https://github.com/llvm/llvm-project/pull/97582
    
    
More information about the Mlir-commits
mailing list