[Mlir-commits] [mlir] Add category-to-named specialization to linalg-morph-ops (PR #190116)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 1 23:21:33 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Lekkala_Sravya-mcw (LekkalaSravya3)

<details>
<summary>Changes</summary>

This PR adds support for the **category-to-named** option by implementing the [TODO](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Linalg/Passes.td#L76) in the Morph-Ops Transformation pass. It introduces specialization of linalg.elementwise operations into named operations (e.g., linalg.add, linalg.exp) and handles linalg.contract using the existing generic specialization infrastructure.



---
Full diff: https://github.com/llvm/llvm-project/pull/190116.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+3-2) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Linalg/Transforms/CategoryToNamed.cpp (+160) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp (+2-3) 
- (added) mlir/test/Dialect/Linalg/linalg-morph-category-to-named.mlir (+211) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index b873f260e7d92..db7b342373b2e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -67,13 +67,14 @@ def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
            "convert category ops e.g. `linalg.elementwise` to `linalg.generic`">,
     Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false",
            "convert named ops e.g. `linalg.add` to `linalg.generic`">,
-    
+
     // Specialization path is not guaranteed.
+    Option<"categoryToNamed", "category-to-named", "bool", /*default=*/"false",
+           "convert category ops to equivalent named ops where possible">,
     Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
            "convert linalg.generic to equivalent named ops">,
     Option<"genericToCategory", "generic-to-category", "bool", /*default=*/"false",
            "convert linalg.generic to equivalent category ops"> ];
-    //  TODOs: `category-to-named`
 }
 
 def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops">,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 486ef75b76859..e86ba7fa59936 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1926,6 +1926,10 @@ void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
 /// `linalg.transform` into elementwise op map.
 void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` that convert linalg category ops to equivalent
+/// named ops where possible.
+void populateLinalgCategoryToNamedPatterns(RewritePatternSet &patterns);
+
 /// Linalg decompose convolutions patterns
 
 /// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a2149478e4c2d..38ae99cf25a96 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   AllInterfaces.cpp
   BubbleUpExtractSlice.cpp
   BufferizableOpInterfaceImpl.cpp
+  CategoryToNamed.cpp
   ConstantFold.cpp
   ConvertToDestinationStyle.cpp
   ConvertConv2DToImg2Col.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CategoryToNamed.cpp b/mlir/lib/Dialect/Linalg/Transforms/CategoryToNamed.cpp
new file mode 100644
index 0000000000000..f5c3f7374be59
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/CategoryToNamed.cpp
@@ -0,0 +1,160 @@
+//===- CategoryToNamed.cpp - convert linalg category ops into named ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg category ops that can be
+// represented by named ops, e.g. `linalg.elementwise<exp>` to `linalg.exp` or
+// `linalg.contract` to `linalg.matmul`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-category-to-named"
+
+namespace {
+
+template <typename NamedOpTy>
+static FailureOr<LinalgOp> replaceElementwiseOp(ElementwiseOp op,
+                                                PatternRewriter &rewriter) {
+  SmallVector<NamedAttribute> attrs;
+  attrs.push_back(rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+  auto namedOp = NamedOpTy::create(rewriter, op.getLoc(), op.getDpsInputs(),
+                                   op.getDpsInits(), attrs);
+
+  {
+    ScopedDiagnosticHandler handler(op.getContext(), [](Diagnostic &) {});
+    if (failed(verify(namedOp.getOperation()))) {
+      rewriter.eraseOp(namedOp);
+      return rewriter.notifyMatchFailure(
+          op, "elementwise op does not satisfy named op constraints");
+    }
+  }
+
+  rewriter.replaceOp(op, namedOp->getResults());
+  return cast<LinalgOp>(namedOp.getOperation());
+}
+
+static FailureOr<LinalgOp> specializeElementwiseOp(ElementwiseOp op,
+                                                   PatternRewriter &rewriter) {
+  switch (op.getKind()) {
+  case ElementwiseKind::select:
+    return replaceElementwiseOp<SelectOp>(op, rewriter);
+  case ElementwiseKind::add:
+    return replaceElementwiseOp<AddOp>(op, rewriter);
+  case ElementwiseKind::sub:
+    return replaceElementwiseOp<SubOp>(op, rewriter);
+  case ElementwiseKind::mul:
+    return replaceElementwiseOp<MulOp>(op, rewriter);
+  case ElementwiseKind::div:
+    return replaceElementwiseOp<DivOp>(op, rewriter);
+  case ElementwiseKind::div_unsigned:
+    return replaceElementwiseOp<DivUnsignedOp>(op, rewriter);
+  case ElementwiseKind::max_signed:
+    return replaceElementwiseOp<MaxOp>(op, rewriter);
+  case ElementwiseKind::min_signed:
+    return replaceElementwiseOp<MinOp>(op, rewriter);
+  case ElementwiseKind::max_unsigned:
+  case ElementwiseKind::min_unsigned:
+    break;
+  case ElementwiseKind::powf:
+    return replaceElementwiseOp<PowFOp>(op, rewriter);
+  case ElementwiseKind::exp:
+    return replaceElementwiseOp<ExpOp>(op, rewriter);
+  case ElementwiseKind::log:
+    return replaceElementwiseOp<LogOp>(op, rewriter);
+  case ElementwiseKind::abs:
+    return replaceElementwiseOp<AbsOp>(op, rewriter);
+  case ElementwiseKind::ceil:
+    return replaceElementwiseOp<CeilOp>(op, rewriter);
+  case ElementwiseKind::floor:
+    return replaceElementwiseOp<FloorOp>(op, rewriter);
+  case ElementwiseKind::negf:
+    return replaceElementwiseOp<NegFOp>(op, rewriter);
+  case ElementwiseKind::reciprocal:
+    return replaceElementwiseOp<ReciprocalOp>(op, rewriter);
+  case ElementwiseKind::round:
+    return replaceElementwiseOp<RoundOp>(op, rewriter);
+  case ElementwiseKind::sqrt:
+    return replaceElementwiseOp<SqrtOp>(op, rewriter);
+  case ElementwiseKind::rsqrt:
+    return replaceElementwiseOp<RsqrtOp>(op, rewriter);
+  case ElementwiseKind::square:
+    return replaceElementwiseOp<SquareOp>(op, rewriter);
+  case ElementwiseKind::tanh:
+    return replaceElementwiseOp<TanhOp>(op, rewriter);
+  case ElementwiseKind::erf:
+    return replaceElementwiseOp<ErfOp>(op, rewriter);
+  }
+
+  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+    diag << "unsupported elementwise kind for named specialization: "
+         << stringifyElementwiseKind(op.getKind());
+  });
+}
+
+struct ElementwiseToNamedPattern : public OpRewritePattern<ElementwiseOp> {
+  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ElementwiseOp op,
+                                PatternRewriter &rewriter) const override {
+    return succeeded(specializeElementwiseOp(op, rewriter)) ? success()
+                                                            : failure();
+  }
+};
+
+struct ContractToNamedPattern : public OpRewritePattern<ContractOp> {
+  using OpRewritePattern<ContractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ContractOp op,
+                                PatternRewriter &rewriter) const override {
+    // Route through a cloned generic op so we can reuse the existing
+    // contraction-to-named specialization without mutating the original op
+    // on unsuccessful matches.
+    auto *clonedOp = rewriter.clone(*op.getOperation());
+    auto clonedLinalgOp = cast<LinalgOp>(clonedOp);
+
+    FailureOr<GenericOp> genericOp =
+        generalizeNamedOp(rewriter, clonedLinalgOp);
+    if (failed(genericOp)) {
+      rewriter.eraseOp(clonedOp);
+      return failure();
+    }
+
+    GenericOpSpecializationOptions options;
+    FailureOr<LinalgOp> namedOp =
+        specializeGenericOp(rewriter, *genericOp, options);
+    if (failed(namedOp)) {
+      rewriter.eraseOp(*genericOp);
+      return failure();
+    }
+
+    if (op->getNumResults() == 0) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    rewriter.replaceOp(op, (*namedOp)->getResults());
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::linalg::populateLinalgCategoryToNamedPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ElementwiseToNamedPattern, ContractToNamedPattern>(
+      patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
index fee293647deda..2c7c320eff866 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
@@ -15,8 +15,6 @@
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -55,7 +53,8 @@ void LinalgMorphOpsPass::runOnOperation() {
     opts.emitCategoryOps = genericToCategory;
     populateLinalgGenericOpsSpecializationPatterns(patterns, opts);
   }
-
+  if (categoryToNamed)
+    populateLinalgCategoryToNamedPatterns(patterns);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-category-to-named.mlir b/mlir/test/Dialect/Linalg/linalg-morph-category-to-named.mlir
new file mode 100644
index 0000000000000..51a5e016bcde4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/linalg-morph-category-to-named.mlir
@@ -0,0 +1,211 @@
+// RUN: mlir-opt %s -split-input-file -linalg-morph-ops=category-to-named | \
+// RUN:   FileCheck %s
+
+func.func @elementwise_unary(%arg0: tensor<?x?xf32>,
+    %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.elementwise kind = #linalg.elementwise_kind<exp>
+    ins(%arg0 : tensor<?x?xf32>)
+    outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @elementwise_unary
+// CHECK-SAME: %[[IN:.+]]: tensor<?x?xf32>, %[[OUT:.+]]: tensor<?x?xf32>)
+// CHECK-NOT: linalg.elementwise
+// CHECK: linalg.exp
+// CHECK-SAME: ins(%[[IN]] : tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @elementwise_binary(%arg0: tensor<?x?xf32>,
+    %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.elementwise
+      kind = #linalg.elementwise_kind<powf>
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @elementwise_binary
+// CHECK-SAME: %[[LHS:.+]]: tensor<?x?xf32>, %[[RHS:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.elementwise
+// CHECK: linalg.powf
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#map_a = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_c = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @contract_to_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+    %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.contract indexing_maps = [#map_a, #map_b, #map_c]
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @contract_to_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.contract
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#cast_map_a = affine_map<(d0, d1, d2) -> (d0, d2)>
+#cast_map_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#cast_map_c = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @contract_to_matmul_unsigned_cast(%arg0: tensor<16x8xi16>,
+    %arg1: tensor<8x32xi64>, %arg2: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.contract indexing_maps = [#cast_map_a, #cast_map_b, #cast_map_c]
+      {cast = #linalg.type_fn<cast_unsigned>}
+      ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi64>)
+      outs(%arg2 : tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: @contract_to_matmul_unsigned_cast
+// CHECK-SAME: %[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<8x32xi64>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<16x32xi32>) -> tensor<16x32xi32>
+// CHECK-NOT: linalg.contract
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<8x32xi64>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<16x32xi32>) -> tensor<16x32xi32>
+
+// -----
+
+#map_ta = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map_tb_base = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_tc_base = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @contract_to_matmul_transpose_a(%arg0: tensor<8x16xf32>,
+    %arg1: tensor<8x32xf32>, %arg2: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.contract indexing_maps = [#map_ta, #map_tb_base, #map_tc_base]
+    ins(%arg0, %arg1 : tensor<8x16xf32>, tensor<8x32xf32>)
+    outs(%arg2 : tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0 : tensor<16x32xf32>
+}
+
+// CHECK-DAG: #[[$MAP_TA:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-DAG: #[[$MAP_TB_BASE:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP_TC_BASE:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: @contract_to_matmul_transpose_a
+// CHECK-SAME: %[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<8x32xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32>
+// CHECK-NOT: linalg.contract
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_TA]], #[[$MAP_TB_BASE]], #[[$MAP_TC_BASE]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<8x16xf32>, tensor<8x32xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<16x32xf32>) -> tensor<16x32xf32>
+
+// -----
+
+#batch_map_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#batch_map_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#batch_map_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+func.func @contract_to_batch_matmul(%arg0: tensor<2x16x8xf32>,
+    %arg1: tensor<2x8x16xf32>, %arg2: tensor<2x16x16xf32>)
+    -> tensor<2x16x16xf32> {
+  %0 = linalg.contract indexing_maps = [#batch_map_a, #batch_map_b, #batch_map_c]
+    ins(%arg0, %arg1 : tensor<2x16x8xf32>, tensor<2x8x16xf32>)
+    outs(%arg2 : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+  return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: @contract_to_batch_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CHECK-NOT: linalg.contract
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+// -----
+
+#batch_map_ta = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+#batch_map_tb = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#batch_map_tc = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+func.func @contract_to_batch_matmul_transpose_a(%arg0: tensor<2x8x16xf32>,
+    %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x16x32xf32>)
+    -> tensor<2x16x32xf32> {
+  %0 = linalg.contract indexing_maps = [#batch_map_ta, #batch_map_tb, #batch_map_tc]
+    ins(%arg0, %arg1 : tensor<2x8x16xf32>, tensor<2x8x32xf32>)
+    outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %0 : tensor<2x16x32xf32>
+}
+
+// CHECK-DAG: #[[$BATCH_MAP_TA:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$BATCH_MAP_TB:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$BATCH_MAP_TC:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @contract_to_batch_matmul_transpose_a
+// CHECK-SAME: %[[A:.+]]: tensor<2x8x16xf32>, %[[B:.+]]: tensor<2x8x32xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK-NOT: linalg.contract
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$BATCH_MAP_TA]], #[[$BATCH_MAP_TB]], #[[$BATCH_MAP_TC]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x32xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+
+// -----
+
+#non_identity_batch_a = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
+#non_identity_batch_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#non_identity_batch_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// this stays as `linalg.contract` because the batch dimension is
+// not in identity position across the operand/result maps. Named
+// `linalg.batch_matmul` does not model such non-identity batch permutations.
+func.func @contract_non_identity_batch(%arg0: tensor<4x2x8xf32>,
+    %arg1: tensor<2x8x16xf32>, %arg2: tensor<2x4x16xf32>)
+    -> tensor<2x4x16xf32> {
+  %0 = linalg.contract indexing_maps = [#non_identity_batch_a, #non_identity_batch_b, #non_identity_batch_c]
+    ins(%arg0, %arg1 : tensor<4x2x8xf32>, tensor<2x8x16xf32>)
+    outs(%arg2 : tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+  return %0 : tensor<2x4x16xf32>
+}
+
+// CHECK-DAG: #[[$NON_ID_BATCH_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
+// CHECK-DAG: #[[$NON_ID_BATCH_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$NON_ID_BATCH_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @contract_non_identity_batch
+// CHECK-SAME: %[[A:.+]]: tensor<4x2x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<2x4x16xf32>) -> tensor<2x4x16xf32>
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$NON_ID_BATCH_A]], #[[$NON_ID_BATCH_B]], #[[$NON_ID_BATCH_C]]]
+// CHECK-NOT: linalg.batch_matmul
+
+// -----
+
+#map_d = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map_e = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+#map_f = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// this stays as `linalg.contract` because it has two reduction
+// dimensions. Named matmul-like ops require exactly one M dim, one N dim, and
+// one K dim.
+func.func @contract_multi_reduction(%arg0: tensor<10x20x30xf32>,
+    %arg1: tensor<30x20x40xf32>,
+    %arg2: tensor<10x40xf32>) -> tensor<10x40xf32> {
+  %0 = linalg.contract indexing_maps = [#map_d, #map_e, #map_f]
+    ins(%arg0, %arg1 : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
+    outs(%arg2 : tensor<10x40xf32>) -> tensor<10x40xf32>
+  return %0 : tensor<10x40xf32>
+}
+
+// CHECK-LABEL: @contract_multi_reduction
+// CHECK-NOT: linalg.matmul
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#{{.+}}, #{{.+}}, #{{.+}}]

``````````

</details>


https://github.com/llvm/llvm-project/pull/190116


More information about the Mlir-commits mailing list