[Mlir-commits] [mlir] [mlir][linalg] Specialize transform op - emit category ops (PR #187506)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 19 07:12:39 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Adds optional attribute to allow specialization into category Linalg ops.
The default behavior of the transform op remains unchanged.
---
Full diff: https://github.com/llvm/llvm-project/pull/187506.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+8-6)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-1)
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir (+37)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index cd842fb1c5392..cb61177bc7533 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -588,20 +588,22 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Transforms a generic operation into the equivalent named form.
+ Transforms a generic operation into the equivalent named or category form.
+
+ By default, operations are specialized into named forms.
#### Return modes
This operation ignores non-Linalg ops and drops them in the return. If all
the operations referred to by the `target` handle specialize, the transform
- succeeds; otherwise, the operation produces a silenceable failure. The return
- handle points to only the subset of successfully produced equivalent named
+ succeeds; otherwise, the operation produces a silenceable failure. The return
+ handle points to only the subset of successfully produced equivalent specialized
operations, which can be empty or contain the original ops if they were already
- in named form. The supported specialization to named Linalg operations are:
- - linalg.copy of any rank.
+ in the target form.
}];
- let arguments = (ins TransformHandleTypeInterface:$target);
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ DefaultValuedAttr<BoolAttr, "false">:$emit_category);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat = [{
$target attr-dict `:`
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d751488d186ad..5f530a585ddb9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1393,8 +1393,10 @@ transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
rewriter.setInsertionPoint(target);
+ GenericOpSpecializationOptions opts;
+ opts.emitCategoryOps = getEmitCategory();
FailureOr<LinalgOp> named =
- specializeGenericOp(rewriter, cast<GenericOp>(target));
+ specializeGenericOp(rewriter, cast<GenericOp>(target), opts);
if (succeeded(named)) {
results.push_back(named->getOperation());
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
index bd4c65512dfa6..a7764da4b8a66 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
@@ -57,3 +57,40 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @category_contract(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: @category_contract
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 {emit_category = true} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/187506
More information about the Mlir-commits
mailing list