[Mlir-commits] [mlir] [mlir][linalg] Specialize transform op - emit category ops (PR #187506)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 19 07:12:39 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Adds optional attribute to allow specialization into category Linalg ops.

The default behavior of the transform op remains unchanged.

---
Full diff: https://github.com/llvm/llvm-project/pull/187506.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+8-6) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-1) 
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir (+37) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index cd842fb1c5392..cb61177bc7533 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -588,20 +588,22 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
      TransformOpInterface, TransformEachOpTrait,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Transforms a generic operation into the equivalent named form.
+    Transforms a generic operation into the equivalent named or category form.
+
+    By default, operations are specialized into named forms.
 
     #### Return modes
 
     This operation ignores non-Linalg ops and drops them in the return. If all
     the operations referred to by the `target` handle specialize, the transform
-    succeeds; otherwise, the operation produces a silenceable failure.  The return
-    handle points to only the subset of successfully produced equivalent named
+    succeeds; otherwise, the operation produces a silenceable failure. The return
+    handle points to only the subset of successfully produced equivalent specialized
     operations, which can be empty or contain the original ops if they were already
-    in named form. The supported specialization to named Linalg operations are:
-    - linalg.copy of any rank.
+    in the target form.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       DefaultValuedAttr<BoolAttr, "false">:$emit_category);
   let results = (outs TransformHandleTypeInterface:$transformed);
   let assemblyFormat = [{
       $target attr-dict `:` 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d751488d186ad..5f530a585ddb9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1393,8 +1393,10 @@ transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
     return DiagnosedSilenceableFailure::success();
   }
   rewriter.setInsertionPoint(target);
+  GenericOpSpecializationOptions opts;
+  opts.emitCategoryOps = getEmitCategory();
   FailureOr<LinalgOp> named =
-      specializeGenericOp(rewriter, cast<GenericOp>(target));
+      specializeGenericOp(rewriter, cast<GenericOp>(target), opts);
   if (succeeded(named)) {
     results.push_back(named->getOperation());
     return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
index bd4c65512dfa6..a7764da4b8a66 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
@@ -57,3 +57,40 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @category_contract(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+          ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %0 = arith.mulf %in, %in_0 : f32
+      %1 = arith.addf %out, %0 : f32
+      linalg.yield %1 : f32
+    } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: @category_contract
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 {emit_category = true} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/187506


More information about the Mlir-commits mailing list