[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
Petr Kurapov
llvmlistbot at llvm.org
Thu Jul 4 03:22:34 PDT 2024
================
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- TODO
+ Decomposes high-level named ops into a sequence of non-aggregate named ops
+ via `AggregatedOpInterface`.
+
+ The operation ignores non-decomposable ops. The return handles point to
+ a sequence of named ops produced by the decomposition.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs TransformHandleTypeInterface:$transformed);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
----------------
kurapov-peter wrote:
I went with binding all the payload with a single output value (reverting the op change). This was easiest, it doesn't modify the original interface, and has less verbose usage:
```mlir
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%out = tensor.empty() : tensor<2x16x32xf32>
%1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%out: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
%2 = linalg.softmax dimension(1) ins(%1: tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
return %2 : tensor<2x16x32xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> (!transform.any_op)
%5 = transform.structured.generalize %3: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
```
https://github.com/llvm/llvm-project/pull/97582
More information about the Mlir-commits
mailing list