[Mlir-commits] [mlir] [mlir][linalg] Generic to category specialization (PR #184624)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Mar 5 01:55:33 PST 2026


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/184624

>From fe91256fb0796cd6e8052c410c7d998b61a259c6 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 4 Mar 2026 13:40:40 +0100
Subject: [PATCH 1/6] [mlir][linalg] Generic to category specialization

Adds initial support for generic to category linalg morphism.
Only conversion to contraction op is supported for now.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |   7 +
 mlir/include/mlir/Dialect/Linalg/Passes.td    |   6 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  34 ++--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  10 +-
 .../Dialect/Linalg/Transforms/MorphOps.cpp    |  12 +-
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  85 ++++++----
 .../Linalg/specialize-generic-ops.mlir        | 150 +++++++++++++++++-
 7 files changed, 246 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 4948bfffad5e0..5998f736ced34 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -889,6 +889,13 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
 
   let skipDefaultBuilders = 1;
   let builders = [
+    OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, regionBuilder);
+      }]>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index f48ea9849e237..26638b2a644c4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -70,8 +70,10 @@ def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
     
     // Specialization path is not guaranteed.
     Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
-           "convert linalg.generic to equivalent named ops"> ];
-    //  TODOs: `generic-to-category`, `category-to-named`
+           "convert linalg.generic to equivalent named ops">,
+    Option<"genericToCategory", "generic-to-category", "bool", /*default=*/"false",
+           "convert linalg.generic to equivalent category ops"> ];
+    //  TODOs: `category-to-named`
 }
 
 def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops">,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fb9cede670801..1e63455fae096 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -923,10 +923,15 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
 FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
                                        LinalgOp linalgOp);
 
-/// Create a namedOp from the given GenericOp and replace the GenericOp.
-/// Currently we can specialize only trivial linalg copy operations.
-FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
-                                        GenericOp genericOp);
+struct SpecializationOptions {
+  // Specialize generics to category ops.
+  bool emitCategoryOps = false;
+};
+
+/// Replace the given GenericOp with a namedOp or categoryOp.
+FailureOr<LinalgOp>
+specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+                    const SpecializationOptions options = {});
 
 /// Create a new buffer using the `allocationFn` provided. The size of this
 /// buffer is either the original subview size when 'useOriginalSubviewSize' is
@@ -1718,17 +1723,24 @@ struct LinalgGeneralizationPattern
 };
 
 struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LinalgSpecializationPattern(MLIRContext *context,
+                              const SpecializationOptions &options = {},
+                              PatternBenefit benefit = 1)
+      : OpRewritePattern<GenericOp>(context, benefit), options(options) {}
 
   FailureOr<GenericOp>
   returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
-    return specializeGenericOp(rewriter, op);
+    return specializeGenericOp(rewriter, op, options);
   }
 
   LogicalResult matchAndRewrite(GenericOp op,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(op, rewriter);
   }
+
+private:
+  SpecializationOptions options;
 };
 
 /// Vectorization pattern for memref::CopyOp.
@@ -1938,13 +1950,13 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns to convert linalg.generic ops to named
-/// ops where possible. A linalg.generic can represent wide range and complex
-/// computations for which equivalent linalg named op may not exist e.g.
-/// linalg.generic that takes a tensor and computes a polynomial such as:
+/// or category ops where possible. A linalg.generic can represent wide range
+/// and complex computations for which equivalent linalg named op may not exist
+/// e.g. linalg.generic that takes a tensor and computes a polynomial such as:
 ///     p(x) = an*x^n + ... + a1x + a0
-/// There is no equivalent named op to convert to. Many such cases exist.
+/// There is no equivalent ops to convert to. Many such cases exist.
 void populateLinalgGenericOpsSpecializationPatterns(
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns, const SpecializationOptions &options = {});
 
 /// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
 /// to equivalent `linalg.elementwise`.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bfc03cc7436df..67d7406987569 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -200,7 +200,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
       llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
         return AffineMapAttr::get(map);
       });
-  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  if (none_of(attributes, [](NamedAttribute attr) {
+        return attr.getName() == "indexing_maps";
+      }))
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
   return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
                            attributes, regionBuilder);
 }
@@ -217,7 +220,10 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
       llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
         return AffineMapAttr::get(map);
       });
-  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  if (none_of(attributes, [](NamedAttribute attr) {
+        return attr.getName() == "indexing_maps";
+      }))
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
   return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
                            attributes, regionBuilder);
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
index f261ccb1415fe..17416b42c47ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
@@ -44,16 +44,16 @@ void LinalgMorphOpsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
 
   // Lowering paths (named -> category -> generic)
-  if (namedToCategory) {
+  if (namedToCategory)
     populateLinalgNamedToElementwisePatterns(patterns);
-  }
-  if (namedToGeneric || categoryToGeneric) {
+  if (namedToGeneric || categoryToGeneric)
     populateLinalgNamedOpsGeneralizationPatterns(patterns);
-  }
 
   // Lifting paths (named <- category <- generic)
-  if (genericToNamed) {
-    populateLinalgGenericOpsSpecializationPatterns(patterns);
+  if (genericToNamed || genericToCategory) {
+    SpecializationOptions opts;
+    opts.emitCategoryOps = genericToCategory;
+    populateLinalgGenericOpsSpecializationPatterns(patterns, opts);
   }
 
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index d74335e3c08c9..24a02b48427ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -140,21 +140,23 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 template <typename NamedOpTy>
 static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
                                          std::optional<TypeFn> castTy) {
-  SmallVector<NamedAttribute> castAttrVec;
+  SmallVector<NamedAttribute> attributes;
   // Only explicitly specify the cast attribute for unsigned cast; signed is
   // the default for linalg.matmul/linalg.batch_matmul.
-  if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
-    castAttrVec = {rewriter.getNamedAttr(
-        "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+  if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
+    auto castAttr = rewriter.getNamedAttr(
+        "cast", TypeFnAttr::get(rewriter.getContext(), *castTy));
+    attributes.push_back(castAttr);
+  }
 
-  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // Set the original generic's maps to preserve transposed operand semantics.
+  auto indexingMapsAttr =
+      rewriter.getNamedAttr("indexing_maps", op.getIndexingMapsAttr());
+  attributes.push_back(indexingMapsAttr);
 
   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
-      ValueRange{op.getDpsInits()[0]}, castAttrVec);
-
-  // Set the original generic's maps to preserve transposed operand semantics.
-  namedOp->setAttr("indexing_maps", indexingMaps);
+      ValueRange{op.getDpsInits()[0]}, attributes);
 
   return namedOp;
 }
@@ -210,7 +212,8 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
 
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
-                                                        GenericOp genericOp) {
+                                                        GenericOp genericOp,
+                                                        bool emitCategoryOp) {
   if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
     return failure();
 
@@ -220,6 +223,28 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
     return failure();
 
+  // Only mul+add contraction is supported.
+  if (!mlir::linalg::detail::isContractionBody(
+          *genericOp.getBlock(), [](Operation *first, Operation *second) {
+            return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+                   (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+                   (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
+          }))
+    return failure();
+
+  // Determine the cast type for the named matmul op, or bail out if casts
+  // cannot be represented by the named op.
+  std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
+  if (!castTy)
+    return rewriter.notifyMatchFailure(
+        genericOp, "contains invalid cast ops for the named matmul op");
+
+  // In case of category op, wider range of representation is supported.
+  if (emitCategoryOp)
+    return replaceWithMatmulVariant<ContractOp>(rewriter, genericOp, castTy);
+
+  // Further checks for named variants.
+  //
   // Linalg generic contraction can be across multiple axis e.g.
   // ```
   //      linalg.generic
@@ -246,14 +271,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
     return failure();
 
-  if (!mlir::linalg::detail::isContractionBody(
-          *genericOp.getBlock(), [](Operation *first, Operation *second) {
-            return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
-                   (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
-                   (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
-          }))
-    return failure();
-
   // Check rank of operands
   auto indexingMaps = genericOp.getIndexingMapsArray();
   if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
@@ -292,13 +309,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
     return failure();
 
-  // Determine the cast type for the named matmul op, or bail out if casts
-  // cannot be represented by the named op.
-  std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
-  if (!castTy)
-    return rewriter.notifyMatchFailure(
-        genericOp, "contains invalid cast ops for the named matmul op");
-
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
@@ -406,8 +416,20 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 // Categorize linalg generic to named op where possible.
 //===----------------------------------------------------------------------===//
-FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
-                                                      GenericOp genericOp) {
+FailureOr<LinalgOp>
+mlir::linalg::specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+                                  const SpecializationOptions options) {
+  // Contraction - e.g. matmul
+  if (isaContractionOpInterface(genericOp)) {
+    return specializeLinalgContractions(rewriter, genericOp,
+                                        options.emitCategoryOps);
+  }
+
+  // Early exit in case of category specialization.
+  // TODO: Remove when all variants account for both named and category.
+  if (options.emitCategoryOps)
+    return failure();
+
   // Copy
   if (isaCopyOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
@@ -476,11 +498,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
-  // Contraction - e.g. matmul
-  if (isaContractionOpInterface(genericOp)) {
-    return specializeLinalgContractions(rewriter, genericOp);
-  }
-
   // Convolution - e.g. *conv/pooling*
   if (isaConvolutionOpInterface(genericOp)) {
     return specializeLinalgConvolutions(rewriter, genericOp);
@@ -509,6 +526,6 @@ void LinalgSpecializeGenericOpsPass::runOnOperation() {
 }
 
 void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+    RewritePatternSet &patterns, const SpecializationOptions &options) {
+  patterns.add<LinalgSpecializationPattern>(patterns.getContext(), options);
 }
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 87218844c5c39..3f62f27d33dba 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-morph-ops=generic-to-named | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-morph-ops=generic-to-category | FileCheck %s --check-prefix=CATEGORY
 
 #umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
@@ -17,6 +18,10 @@ func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tenso
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 
+// Not supported yet.
+// CATEGORY-LABEL: unary_op_exp
+// CATEGORY: linalg.generic
+
 // -----
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
@@ -36,6 +41,10 @@ func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.div ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// Not supported yet.
+// CATEGORY-LABEL: binary_op_div
+// CATEGORY: linalg.generic
+
 // -----
 
 ///----------------------------------------------------------------------------------------
@@ -62,6 +71,17 @@ func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?x
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// CATEGORY-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CATEGORY-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CATEGORY-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CATEGORY-LABEL: op_matmul
+// CATEGORY-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,  %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
 // Cast-auditing tests: ensure we only specialize when the cast semantics can
 // be expressed by linalg.matmul, and use the cast attribute when needed.
 
@@ -84,6 +104,11 @@ func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi32>,
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
 
+// CATEGORY-LABEL: op_matmul_unsigned_cast
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: {cast = #linalg.type_fn<cast_unsigned>}
+
 // Ensures truncation rounding is tolerated with unsigned cases.
 // Note: We only consider casts as conflicting if they have different
 // signedness behaviours, and then we do not specialize if they do
@@ -110,6 +135,11 @@ func.func @op_matmul_unsigned_cast_and_truncate(%A: tensor<16x8xi16>, %B: tensor
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
 
+// CATEGORY-LABEL: op_matmul_unsigned_cast_and_truncate
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: {cast = #linalg.type_fn<cast_unsigned>}
+
 // Signed casts are the default, no cast attribute is required.
 func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
                                  %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
@@ -131,6 +161,11 @@ func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
 // CHECK-NOT: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
 // CHECK: linalg.matmul
 
+// CATEGORY-LABEL: op_matmul_signed_cast
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-NOT: {cast =
+
 // Mixed signed/unsigned inputs cannot be encoded with a single cast attribute.
 func.func @negative_op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
@@ -151,6 +186,10 @@ func.func @negative_op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi
 // CHECK: linalg.generic
 // CHECK-NOT: linalg.matmul
 
+// CATEGORY-LABEL: negative_op_matmul_mixed_cast
+// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.contract
+
 // Output-side casts are not representable by the named matmul ops.
 func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32xi32>,
                                  %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
@@ -171,6 +210,10 @@ func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32x
 // CHECK: linalg.generic
 // CHECK-NOT: linalg.matmul
 
+// CATEGORY-LABEL: negative_op_matmul_output_cast
+// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.contract
+
 // Bitcasts are not modeled by the cast attribute, but should not block
 // specialization.
 // NOTE: Bitcasts are not preserved by the matmul named op during
@@ -196,6 +239,10 @@ func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul
 
+// CATEGORY-LABEL: op_matmul_bitcast_int_to_float
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
 // Signed float casts only use sitofp, which defaults to signed semantics.
 func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
                                        %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
@@ -217,6 +264,11 @@ func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16
 // CHECK-NOT: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
 // CHECK: linalg.matmul
 
+// CATEGORY-LABEL: op_matmul_signed_cast_float
+// CATEGORY-NOT: linalg.generic
+// CATEGORY-NOT: linalg.contract{{.*}}{cast =
+// CATEGORY: linalg.contract
+
 // Unsigned float casts are expressed via uitofp and use the unsigned cast attr.
 func.func @op_matmul_unsigned_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
                                          %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
@@ -237,6 +289,10 @@ func.func @op_matmul_unsigned_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
 
+// CATEGORY-LABEL: op_matmul_unsigned_cast_float
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract{{.*}}{cast = #linalg.type_fn<cast_unsigned>}
+
 // -----
 
 ///----------------------------------------------------------------------------------------
@@ -263,6 +319,16 @@ func.func @op_batch_matmul(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>, %Out:
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
 
+// CATEGORY-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CATEGORY-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CATEGORY-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CATEGORY-LABEL: op_batch_matmul
+// CATEGORY-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>,  %[[Out:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
 // Ensure that the unsigned cast path for cast detection is exercised for
 // batch_matmul as well.
 func.func @op_batch_matmul_unsigned_cast(%A: tensor<2x16x8xi16>,
@@ -287,9 +353,14 @@ func.func @op_batch_matmul_unsigned_cast(%A: tensor<2x16x8xi16>,
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.batch_matmul {cast = #linalg.type_fn<cast_unsigned>}
 
+// CATEGORY-LABEL: op_batch_matmul_unsigned_cast
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CATEGORY-SAME: {cast = #linalg.type_fn<cast_unsigned>}
+
 // -----
 
-// This is a multi-reduction linalg.generic and cannot be lifted to matrix multiply
+// A multi-reduction contraction.
 #mapA = affine_map<(m, n, k1, k2) -> (m, k1, k2)>
 #mapB = affine_map<(m, n, k1, k2) -> (k2, k1, n)>
 #mapC = affine_map<(m, n, k1, k2) -> (m, n)>
@@ -309,9 +380,14 @@ func.func @negative_op_multi_reduction(%A: tensor<10x20x30xf32>,
   return %0 : tensor<10x40xf32>
 }
 
+// Cannot be lifted to named matrix multiply.
 // CHECK-LABEL: negative_op_multi_reduction
 // CHECK: linalg.generic
 
+// CATEGORY-LABEL: negative_op_multi_reduction
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
 // -----
 
 // Batch dim not in identity position: batch dim d0 appears at result
@@ -332,12 +408,17 @@ func.func @negative_batch_matmul_non_identity_batch(%A: tensor<4x2x8xf32>, %B: t
   return %0 : tensor<2x4x16xf32>
 }
 
+// Cannot be lifted to named matrix multiply.
 // CHECK-LABEL: negative_batch_matmul_non_identity_batch
 // CHECK: linalg.generic
 
+// CATEGORY-LABEL: negative_batch_matmul_non_identity_batch
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
 // -----
 
-// TODO: matvec
+// TODO: named matvec
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d1)>
 #map2 = affine_map<(d0, d1) -> (d0)>
@@ -355,6 +436,10 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
 // CHECK-LABEL: op_matvec
 // CHECK: linalg.generic
 
+// CATEGORY-LABEL: op_matvec
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
 // -----
 
 // Matmul transpose A: A is accessed as (k, m) instead of (m, k)
@@ -384,6 +469,10 @@ func.func @op_matmul_transpose_a(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out:
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
 // CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// CATEGORY-LABEL: op_matmul_transpose_a
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
 // -----
 
 // Matmul transpose B: B is accessed as (n, k) instead of (k, n)
@@ -413,6 +502,17 @@ func.func @op_matmul_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out:
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
 // CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// CATEGORY-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CATEGORY-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CATEGORY-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CATEGORY-LABEL: op_matmul_transpose_b
+// CATEGORY-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+// CATEGORY-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
 // -----
 
 // Batch matmul transpose A: A is accessed as (b, k, m) instead of (b, m, k)
@@ -442,6 +542,17 @@ func.func @op_batch_matmul_transpose_a(%A: tensor<2x8x4xf32>, %B: tensor<2x8x16x
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>)
 // CHECK-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
 
+// CATEGORY-DAG: #[[$MAP_TA:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CATEGORY-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CATEGORY-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CATEGORY-LABEL: op_batch_matmul_transpose_a
+// CATEGORY-SAME: %[[A:.+]]: tensor<2x8x4xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x4x16xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+// CATEGORY-SAME: indexing_maps = [#[[$MAP_TA]], #[[$MAP_B]], #[[$MAP_C]]]
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<2x8x4xf32>, tensor<2x8x16xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+
 // -----
 
 // Batch matmul transpose B: B is accessed as (b, n, k) instead of (b, k, n)
@@ -471,6 +582,17 @@ func.func @op_batch_matmul_transpose_b(%A: tensor<2x4x8xf32>, %B: tensor<2x16x8x
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>)
 // CHECK-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
 
+// CATEGORY-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CATEGORY-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CATEGORY-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CATEGORY-LABEL: op_batch_matmul_transpose_b
+// CATEGORY-SAME: %[[A:.+]]: tensor<2x4x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x4x16xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+// CATEGORY-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x16x8xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+
 // -----
 
 // Both A and B transposed.
@@ -501,6 +623,17 @@ func.func @op_matmul_transpose_a_and_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
 // CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// CATEGORY-DAG: #[[$MAP_TA:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CATEGORY-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CATEGORY-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CATEGORY-LABEL: op_matmul_transpose_a_and_b
+// CATEGORY-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+// CATEGORY-SAME: indexing_maps = [#[[$MAP_TA]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
 // -----
 
 // Output transposed: C is accessed as (n, m) instead of (m, n).
@@ -530,3 +663,14 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
 // CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_TC]]]
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
 // CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// CATEGORY-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CATEGORY-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CATEGORY-DAG: #[[$MAP_TC:.+]] = affine_map<(d0, d1, d2) -> (d1, d0)>
+// CATEGORY-LABEL: op_matmul_transposed_output
+// CATEGORY-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+// CATEGORY-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_TC]]]
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CATEGORY-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>

>From b5e6e58060cd94aca20573468f0abe4e4ef8cb20 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 4 Mar 2026 15:48:04 +0100
Subject: [PATCH 2/6] Pass opts by reference

---
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 +-
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp        | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1e63455fae096..53d6f1b1be0c1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -931,7 +931,7 @@ struct SpecializationOptions {
 /// Replace the given GenericOp with a namedOp or categoryOp.
 FailureOr<LinalgOp>
 specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
-                    const SpecializationOptions options = {});
+                    const SpecializationOptions &options = {});
 
 /// Create a new buffer using the `allocationFn` provided. The size of this
 /// buffer is either the original subview size when 'useOriginalSubviewSize' is
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 24a02b48427ea..4f12745b6e57b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -418,7 +418,7 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 FailureOr<LinalgOp>
 mlir::linalg::specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
-                                  const SpecializationOptions options) {
+                                  const SpecializationOptions &options) {
   // Contraction - e.g. matmul
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp,

>From f53356eaf7de2174fabe3cf3639c466dae7c32c8 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 4 Mar 2026 16:00:51 +0100
Subject: [PATCH 3/6] Improve docs phrasing

---
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 +-
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp        | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 53d6f1b1be0c1..4a0e0d4eb50b8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1954,7 +1954,7 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 /// and complex computations for which equivalent linalg named op may not exist
 /// e.g. linalg.generic that takes a tensor and computes a polynomial such as:
 ///     p(x) = an*x^n + ... + a1x + a0
-/// There is no equivalent ops to convert to. Many such cases exist.
+/// There is no equivalent named op to convert to. Many such cases exist.
 void populateLinalgGenericOpsSpecializationPatterns(
     RewritePatternSet &patterns, const SpecializationOptions &options = {});
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4f12745b6e57b..dd3f5c060edfe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -149,7 +149,8 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
     attributes.push_back(castAttr);
   }
 
-  // Set the original generic's maps to preserve transposed operand semantics.
+  // Set the original generic's maps to preserve operand indexing semantics like
+  // transposition.
   auto indexingMapsAttr =
       rewriter.getNamedAttr("indexing_maps", op.getIndexingMapsAttr());
   attributes.push_back(indexingMapsAttr);
@@ -239,7 +240,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
     return rewriter.notifyMatchFailure(
         genericOp, "contains invalid cast ops for the named matmul op");
 
-  // In case of category op, wider range of representation is supported.
+  // In case of category op, wider range of variants is supported.
   if (emitCategoryOp)
     return replaceWithMatmulVariant<ContractOp>(rewriter, genericOp, castTy);
 

>From db05edf0978471985706662ec112479df81aaa96 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 4 Mar 2026 16:03:16 +0100
Subject: [PATCH 4/6] Mention default opt

---
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4a0e0d4eb50b8..2fc083bf7b871 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -924,7 +924,7 @@ FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
                                        LinalgOp linalgOp);
 
 struct SpecializationOptions {
-  // Specialize generics to category ops.
+  // Specialize generics to category ops (default: named ops).
   bool emitCategoryOps = false;
 };
 

>From bdc79802ba8730dffcfdc738b670a28cc262bc35 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 5 Mar 2026 10:32:19 +0100
Subject: [PATCH 5/6] Rename matching test cases

---
 .../Dialect/Linalg/specialize-generic-ops.mlir | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 3f62f27d33dba..92756ea90316d 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -364,9 +364,9 @@ func.func @op_batch_matmul_unsigned_cast(%A: tensor<2x16x8xi16>,
 #mapA = affine_map<(m, n, k1, k2) -> (m, k1, k2)>
 #mapB = affine_map<(m, n, k1, k2) -> (k2, k1, n)>
 #mapC = affine_map<(m, n, k1, k2) -> (m, n)>
-func.func @negative_op_multi_reduction(%A: tensor<10x20x30xf32>,
-                                       %B: tensor<30x20x40xf32>,
-                                       %C: tensor<10x40xf32>) -> tensor<10x40xf32> {
+func.func @op_multi_reduction(%A: tensor<10x20x30xf32>,
+                              %B: tensor<30x20x40xf32>,
+                              %C: tensor<10x40xf32>) -> tensor<10x40xf32> {
   %0 = linalg.generic
            {indexing_maps = [#mapA, #mapB, #mapC],
             iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
@@ -381,10 +381,10 @@ func.func @negative_op_multi_reduction(%A: tensor<10x20x30xf32>,
 }
 
 // Cannot be lifted to named matrix multiply.
-// CHECK-LABEL: negative_op_multi_reduction
+// CHECK-LABEL: op_multi_reduction
 // CHECK: linalg.generic
 
-// CATEGORY-LABEL: negative_op_multi_reduction
+// CATEGORY-LABEL: op_multi_reduction
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
 
@@ -395,8 +395,8 @@ func.func @negative_op_multi_reduction(%A: tensor<10x20x30xf32>,
 #mapBni0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
 #mapBni1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #mapBni2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-func.func @negative_batch_matmul_non_identity_batch(%A: tensor<4x2x8xf32>, %B: tensor<2x8x16xf32>,
-                                                     %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
+func.func @batch_matmul_non_identity_batch(%A: tensor<4x2x8xf32>, %B: tensor<2x8x16xf32>,
+                                           %Out: tensor<2x4x16xf32>) -> tensor<2x4x16xf32> {
   %0 = linalg.generic
            {indexing_maps = [#mapBni0, #mapBni1, #mapBni2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
            ins(%A, %B : tensor<4x2x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x4x16xf32>) {
@@ -409,10 +409,10 @@ func.func @negative_batch_matmul_non_identity_batch(%A: tensor<4x2x8xf32>, %B: t
 }
 
 // Cannot be lifted to named matrix multiply.
-// CHECK-LABEL: negative_batch_matmul_non_identity_batch
+// CHECK-LABEL: batch_matmul_non_identity_batch
 // CHECK: linalg.generic
 
-// CATEGORY-LABEL: negative_batch_matmul_non_identity_batch
+// CATEGORY-LABEL: batch_matmul_non_identity_batch
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
 

>From 793ea77fb19a21f087fcceab6be3b76035a73e61 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 5 Mar 2026 10:55:15 +0100
Subject: [PATCH 6/6] Roundtrip test

---
 .../Linalg/roundtrip-linalg-category-ops.mlir | 101 ++++++++++++++++++
 1 file changed, 101 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/roundtrip-linalg-category-ops.mlir

diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-category-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-category-ops.mlir
new file mode 100644
index 0000000000000..bfecf28a33c70
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-category-ops.mlir
@@ -0,0 +1,101 @@
+// The following test examples of linalg category ops lowered to linalg.generic
+// and then lifted back up to category op.
+// RUN: mlir-opt %s -split-input-file -linalg-morph-ops=category-to-generic \
+// RUN: | mlir-opt -split-input-file -linalg-morph-ops=generic-to-category \
+// RUN: | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @contract_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+    %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.contract indexing_maps = [#map, #map1, #map2]
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+func.func @contract_matmul_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+    %arg2: memref<?x?xf32>) {
+  linalg.contract indexing_maps = [#map, #map1, #map2]
+    ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+    outs(%arg2 : memref<?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: contract_matmul_memref
+// CHECK-SAME: %[[A:.+]]: memref<?x?xf32>, %[[B:.+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[Out:.+]]: memref<?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<?x?xf32>, memref<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : memref<?x?xf32>)
+
+func.func @contract_matmul_bitcast_int_to_float(%arg0: tensor<16x8xi32>,
+    %arg1: tensor<8x32xi32>, %arg2: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.contract indexing_maps = [#map, #map1, #map2]
+    ins(%arg0, %arg1 : tensor<16x8xi32>, tensor<8x32xi32>)
+    outs(%arg2 : tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: contract_matmul_bitcast_int_to_float
+// CHECK-SAME: %[[A:.+]]: tensor<16x8xi32>, %[[B:.+]]: tensor<8x32xi32>,
+// CHECK-SAME: %[[Out:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CHECK-NOT: cast =
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi32>, tensor<8x32xi32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<16x32xf32>) -> tensor<16x32xf32>
+
+func.func @contract_matmul_unsigned_cast_float(%arg0: tensor<16x8xi16>,
+    %arg1: tensor<8x32xi16>, %arg2: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.contract indexing_maps = [#map, #map1, #map2]
+    {cast = #linalg.type_fn<cast_unsigned>}
+    ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi16>)
+    outs(%arg2 : tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: contract_matmul_unsigned_cast_float
+// CHECK-SAME: %[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<8x32xi16>,
+// CHECK-SAME: %[[Out:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}
+// CHECK-SAME: cast = #linalg.type_fn<cast_unsigned>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<8x32xi16>)
+// CHECK-SAME: outs(%[[Out]] : tensor<16x32xf32>) -> tensor<16x32xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @contract_multi_reduction(%arg0: tensor<10x20x30xf32>,
+    %arg1: tensor<30x20x40xf32>, %arg2: tensor<10x40xf32>) -> tensor<10x40xf32> {
+  %0 = linalg.contract indexing_maps = [#map, #map1, #map2]
+    ins(%arg0, %arg1 : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
+    outs(%arg2 : tensor<10x40xf32>) -> tensor<10x40xf32>
+  return %0 : tensor<10x40xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// CHECK-LABEL: contract_multi_reduction
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.contract indexing_maps = {{\[}}#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]{{\]}}



More information about the Mlir-commits mailing list