[Mlir-commits] [mlir] [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (PR #95656)

Javed Absar llvmlistbot at llvm.org
Sat Jun 15 05:20:08 PDT 2024


https://github.com/javedabsar1 created https://github.com/llvm/llvm-project/pull/95656

Existing `-linalg-generalize-named-ops` lowers named ops to linalg.generic. 

This patch adds `--linalg-specialize-generic-ops` which converts, where possible, linalg.generic back to named ops. Also, it adds patterns to recognize contractions which can be specialized from linalg.generic to named op:
 `linalg.{batch_}?matmul{_transpose_(a|b)}?`

Patterns to recognize elementwise unary/binary fills/copy were added previously and already exist.

>From a06814c17556ca3af1135652441bf5ee0f9c36f9 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Wed, 12 Jun 2024 19:14:38 -0400
Subject: [PATCH] [MLIR][Linalg] Add pass to convert linalg.generic back to
 named ops

Existing `-linalg-generalize-named-ops` lowers named ops to
linalg.generic. This patch adds `--linalg-specialize-generic-ops`
which converts, where possible, linalg.generic back to named ops.
Also, it adds patterns to recognize contractions which can be
specialized from linalg.generic to named op:
 `linalg.{batch_}?matmul{_transpose_(a|b)}?`
Patterns to recognize elementwise unary/binary fills/copy
were added previously and already exist.
---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |   5 +
 .../Dialect/Linalg/Transforms/Transforms.h    |  23 ++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |   2 +-
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 205 ++++++++++++++++++
 .../Linalg/roundtrip-linalg-named-ops.mlir    |  49 +++++
 .../Linalg/specialize-generic-ops.mlir        |  37 ++++
 .../transform-op-specialize_matmul.mlir       | 148 +++++++++++++
 7 files changed, 468 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
 create mode 100644 mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0a4ce8953136d..6a60f7f3ea9f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -104,6 +104,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
+  let summary = "Convert generic ops back to named ops";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..912f9778a40e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern
   }
 };
 
+struct LinalgSpecializationPattern
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+  FailureOr<LinalgOp>
+  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
+    auto genericOp = dyn_cast<GenericOp>(op.getOperation());
+    if (!genericOp)
+      return failure();
+    return specializeGenericOp(rewriter, genericOp);
+  }
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+};
+
 /// Vectorization pattern for memref::CopyOp.
 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
   using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1546,6 +1564,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 /// linalg.generic ops.
 void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns to convert linalg.generic ops to named
+/// ops where possible.
+void populateLinalgGenericOpsSpecializationPatterns(
+    RewritePatternSet &patterns);
+
 /// Linalg decompose convolutions patterns
 
 /// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..8ca76ec43193d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -107,7 +107,7 @@ isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
                                           unsigned arity) {
   // Check all loops are parallel, and have only tensor semantics.
   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+      genericOp.getNumLoops() < 1)
     return false;
 
   // Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bc4d7fbfadcc..7fac3feba98c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,12 +11,22 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 
+namespace mlir {
+#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
 #define DEBUG_TYPE "linalg-specialization"
 
 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
@@ -58,6 +68,175 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+  Match = 0,  // identity map.
+  Transposed, // transposed map.
+  Mismatch    // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
+static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
+                                        unsigned i, unsigned j) {
+  auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
+  auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
+  if (!expr_ei || !expr_ej)
+    return IndexMatchResult::Mismatch;
+
+  auto ei = expr_ei.getPosition();
+  auto ej = expr_ej.getPosition();
+
+  if (ei == i && ej == j)
+    return IndexMatchResult::Match;
+
+  if (ei == j && ej == i)
+    return IndexMatchResult::Transposed;
+
+  return IndexMatchResult::Mismatch;
+}
+
+//  All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+//  have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+  LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+      op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+      ValueRange{op.getDpsInits()[0]});
+  return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+                                                        GenericOp genericOp) {
+  if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+    return failure();
+
+  // Linalg generic contraction can be across multiple axis but for matmul
+  // variants it must be one.
+  if (genericOp.getNumReductionLoops() != 1)
+    return failure();
+
+  // Must be projected permutations.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(mapRange,
+                   [](AffineMap m) { return !m.isProjectedPermutation(); }))
+    return failure();
+
+  //  matmul contractions are of the form:
+  //  %0 = <elemwise>(permutation-of(cu(block-argument-0),
+  //                                 cu(block-argument-1)))
+  //  %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+  //
+  //  where <elemwise> and <reduce> are binary operations constituting a
+  //  contraction (in the canonical case, <elemwise> is a multiplication and
+  //  <reduce> is an addition). All operands of all operations may be supplied
+  //  through a chain of side effect-free unary operations, such as casts,
+  //  which is denoted as `cu` above.
+  if (!mlir::linalg::detail::isContractionBody(
+          *genericOp.getBlock(), [](Operation *first, Operation *second) {
+            if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+                (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+                (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+              return true;
+            return false;
+          }))
+    return failure();
+
+  // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
+  // form a matmul subcomputation. These dimensions are such that:
+  //   1. The m dimension is involved in an outer-product along LHS
+  //      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+  //   2. The n dimension is involved in an outer-product along RHS
+  //      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+  //   3. The k dimension appears as a permutation on LHS and RHS.
+  //   4. m, n and k appear only once in any given indexing.
+  //   5. Optional batch dimensions that appear in all operands are captured.
+  auto res = inferContractionDims(genericOp);
+  assert(succeeded(res) && "unexpected failure to infer contraction dims");
+  auto dims = *res;
+
+  // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+  if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+    return failure();
+
+  // Check rank of operands
+  auto indexingMaps = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+        return m.getResults().size() !=
+               dims.batch.size() + 2 /*two from {m,n,k}*/;
+      }))
+    return failure();
+
+  auto batchSize = dims.batch.size();
+  if (indexingMaps[0].getNumDims() != batchSize + 3) {
+  }
+  if (batchSize) {
+    // Each operand in a linalg generic contraction  could express different
+    // permutations for its batch dimension. But for named op it must be
+    // identity since separate maps are not specified.
+    if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
+          for (unsigned i = 0; i < batchSize; ++i) {
+            auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
+            if (!expr || expr.getPosition() != i)
+              return true;
+          }
+          return false;
+        }))
+      return failure();
+  }
+
+  auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
+  auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
+  auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+
+  if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+        return r == IndexMatchResult::Mismatch;
+      }))
+    return failure();
+
+  if (c != IndexMatchResult::Match ||
+      (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+    return failure();
+
+  /// Codegen the different matmul variants.
+  if (batchSize) {
+    if (a == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+                                                               genericOp);
+    if (b == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+                                                               genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+  }
+
+  if (a == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+  if (b == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Categorize linalg generic to named op where possible.
+//===----------------------------------------------------------------------===//
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
   if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +279,31 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
       return namedOp;
     }
   }
+
+  if (isaContractionOpInterface(genericOp)) {
+    return specializeLinalgContractions(rewriter, genericOp);
+  }
   return failure();
 }
+
+namespace {
+struct LinalgSpecializeGenericOpsPass
+    : public impl::LinalgSpecializeGenericOpsPassBase<
+          LinalgSpecializeGenericOpsPass> {
+
+  using impl::LinalgSpecializeGenericOpsPassBase<
+      LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
+  void runOnOperation() override;
+};
+} // namespace
+
+void LinalgSpecializeGenericOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  populateLinalgGenericOpsSpecializationPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
new file mode 100644
index 0000000000000..d258d9f518534
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @roundtrip_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: roundtrip_add
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,  %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK-LABEL: roundtrip_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @roundtrip_gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_gemm
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
new file mode 100644
index 0000000000000..0ec2dc3a92ec7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic 
+         {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+         ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.divf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = math.exp %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
new file mode 100644
index 0000000000000..f64953bceefe1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @specialize_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+          ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %0 = arith.mulf %in, %in_0 : f32
+      %1 = arith.addf %out, %0 : f32
+      linalg.yield %1 : f32
+    } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @specialize_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.generic
+     {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
+      ^bb0(%in: f32, %in_0: f32, %out: f32):
+       %0 = arith.mulf %in, %in_0 : f32
+       %1 = arith.addf %out, %0 : f32
+       linalg.yield %1 : f32
+  }
+  return
+}
+
+// CHECK-LABEL: @matmul_transpose_a
+// CHECK-SAME: %[[ARG0:.+]]: memref<5x3xf32>, %[[ARG1:.+]]: memref<5x7xf32>, %[[ARG2:.+]]: memref<3x7xf32>) {
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul_transpose_a ins(%[[ARG0]], %[[ARG1]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[ARG2]] : memref<3x7xf32>)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_b(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+          ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @matmul_transpose_b
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x16xf32>, %arg2: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+  %0 = linalg.generic
+           {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+           ins(%arg0, %arg1 : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%arg2 : tensor<2x16x16xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %out, %1 : f32
+      linalg.yield %2 : f32
+  } -> tensor<2x16x16xf32>
+  return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: @batch_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x16x8xf32>, %[[ARG1:.+]]: tensor<2x8x16xf32>, %[[ARG2:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[ARG0]], %[[ARG1]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[ARG2]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_transpose_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+       {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+       ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %out, %1 : f32
+      linalg.yield %2 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: @batch_matmul_transpose_b
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list