[Mlir-commits] [mlir] [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (PR #95656)
Javed Absar
llvmlistbot at llvm.org
Fri Jun 28 16:03:48 PDT 2024
https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/95656
>From a06814c17556ca3af1135652441bf5ee0f9c36f9 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Wed, 12 Jun 2024 19:14:38 -0400
Subject: [PATCH 1/5] [MLIR][Linalg] Add pass to convert linalg.generic back to
named ops
Existing `-linalg-generalize-named-ops` lowers named ops to
linalg.generic. This patch adds `--linalg-specialize-generic-ops`
which converts, where possible, linalg.generic back to named ops.
Also, it adds patterns to recognize contractions which can be
specialized from linalg.generic to named op:
`linalg.{batch_}?matmul{_transpose_(a|b)}?`
Patterns to recognize elementwise unary/binary fills/copy
were added previously and already exist.
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 5 +
.../Dialect/Linalg/Transforms/Transforms.h | 23 ++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 2 +-
.../Dialect/Linalg/Transforms/Specialize.cpp | 205 ++++++++++++++++++
.../Linalg/roundtrip-linalg-named-ops.mlir | 49 +++++
.../Linalg/specialize-generic-ops.mlir | 37 ++++
.../transform-op-specialize_matmul.mlir | 148 +++++++++++++
7 files changed, 468 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
create mode 100644 mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0a4ce8953136d..6a60f7f3ea9f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -104,6 +104,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
+ let summary = "Convert generic ops back to named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..912f9778a40e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1395,6 +1395,24 @@ struct LinalgGeneralizationPattern
}
};
+struct LinalgSpecializationPattern
+ : public OpInterfaceRewritePattern<LinalgOp> {
+ using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+ FailureOr<LinalgOp>
+ returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
+ auto genericOp = dyn_cast<GenericOp>(op.getOperation());
+ if (!genericOp)
+ return failure();
+ return specializeGenericOp(rewriter, genericOp);
+ }
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
+};
+
/// Vectorization pattern for memref::CopyOp.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1546,6 +1564,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to convert linalg.generic ops to named
+/// ops where possible.
+void populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..8ca76ec43193d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -107,7 +107,7 @@ isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel, and have only tensor semantics.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+ genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bc4d7fbfadcc..7fac3feba98c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,12 +11,22 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
+namespace mlir {
+#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
#define DEBUG_TYPE "linalg-specialization"
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
@@ -58,6 +68,175 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+ Match = 0, // identity map.
+ Transposed, // transposed map.
+ Mismatch // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
+static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
+ unsigned i, unsigned j) {
+ auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
+ auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
+ if (!expr_ei || !expr_ej)
+ return IndexMatchResult::Mismatch;
+
+ auto ei = expr_ei.getPosition();
+ auto ej = expr_ej.getPosition();
+
+ if (ei == i && ej == j)
+ return IndexMatchResult::Match;
+
+ if (ei == j && ej == i)
+ return IndexMatchResult::Transposed;
+
+ return IndexMatchResult::Mismatch;
+}
+
+// All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+ op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+ ValueRange{op.getDpsInits()[0]});
+ return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+ return failure();
+
+ // Linalg generic contraction can be across multiple axis but for matmul
+ // variants it must be one.
+ if (genericOp.getNumReductionLoops() != 1)
+ return failure();
+
+ // Must be projected permutations.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(mapRange,
+ [](AffineMap m) { return !m.isProjectedPermutation(); }))
+ return failure();
+
+ // matmul contractions are of the form:
+ // %0 = <elemwise>(permutation-of(cu(block-argument-0),
+ // cu(block-argument-1)))
+ // %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+ //
+ // where <elemwise> and <reduce> are binary operations constituting a
+ // contraction (in the canonical case, <elemwise> is a multiplication and
+ // <reduce> is an addition). All operands of all operations may be supplied
+ // through a chain of side effect-free unary operations, such as casts,
+ // which is denoted as `cu` above.
+ if (!mlir::linalg::detail::isContractionBody(
+ *genericOp.getBlock(), [](Operation *first, Operation *second) {
+ if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+ (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+ (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+ return true;
+ return false;
+ }))
+ return failure();
+
+ // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
+ // form a matmul subcomputation. These dimensions are such that:
+ // 1. The m dimension is involved in an outer-product along LHS
+ // (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+ // 2. The n dimension is involved in an outer-product along RHS
+ // (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+ // 3. The k dimension appears as a permutation on LHS and RHS.
+ // 4. m, n and k appear only once in any given indexing.
+ // 5. Optional batch dimensions that appear in all operands are captured.
+ auto res = inferContractionDims(genericOp);
+ assert(succeeded(res) && "unexpected failure to infer contraction dims");
+ auto dims = *res;
+
+ // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+ if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+ return failure();
+
+ // Check rank of operands
+ auto indexingMaps = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+ return m.getResults().size() !=
+ dims.batch.size() + 2 /*two from {m,n,k}*/;
+ }))
+ return failure();
+
+ auto batchSize = dims.batch.size();
+ if (indexingMaps[0].getNumDims() != batchSize + 3) {
+ }
+ if (batchSize) {
+ // Each operand in a linalg generic contraction could express different
+ // permutations for its batch dimension. But for named op it must be
+ // identity since separate maps are not specified.
+ if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
+ for (unsigned i = 0; i < batchSize; ++i) {
+ auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
+ if (!expr || expr.getPosition() != i)
+ return true;
+ }
+ return false;
+ }))
+ return failure();
+ }
+
+ auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
+ auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
+ auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+
+ if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+ return r == IndexMatchResult::Mismatch;
+ }))
+ return failure();
+
+ if (c != IndexMatchResult::Match ||
+ (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+ return failure();
+
+ /// Codegen the different matmul variants.
+ if (batchSize) {
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+ genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+ genericOp);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+ }
+
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Categorize linalg generic to named op where possible.
+//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +279,31 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
return namedOp;
}
}
+
+ if (isaContractionOpInterface(genericOp)) {
+ return specializeLinalgContractions(rewriter, genericOp);
+ }
return failure();
}
+
+namespace {
+struct LinalgSpecializeGenericOpsPass
+ : public impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass> {
+
+ using impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
+ void runOnOperation() override;
+};
+} // namespace
+
+void LinalgSpecializeGenericOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateLinalgGenericOpsSpecializationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
new file mode 100644
index 0000000000000..d258d9f518534
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @roundtrip_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: roundtrip_add
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @roundtrip_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK-LABEL: roundtrip_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @roundtrip_gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @roundtrip_gemm
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
new file mode 100644
index 0000000000000..0ec2dc3a92ec7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
new file mode 100644
index 0000000000000..f64953bceefe1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @specialize_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @specialize_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32
+ }
+ return
+}
+
+// CHECK-LABEL: @matmul_transpose_a
+// CHECK-SAME: %[[ARG0:.+]]: memref<5x3xf32>, %[[ARG1:.+]]: memref<5x7xf32>, %[[ARG2:.+]]: memref<3x7xf32>) {
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul_transpose_a ins(%[[ARG0]], %[[ARG1]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[ARG2]] : memref<3x7xf32>)
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_b(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @matmul_transpose_b
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x16xf32>, %arg2: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%arg2 : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: @batch_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x16x8xf32>, %[[ARG1:.+]]: tensor<2x8x16xf32>, %[[ARG2:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[ARG0]], %[[ARG1]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[ARG2]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_transpose_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: @batch_matmul_transpose_b
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
>From 8e7987495ca78b4925e1d4a01a7f3df5232c3569 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Tue, 18 Jun 2024 05:11:40 -0400
Subject: [PATCH 2/5] [MLIR][Linalg] Address review comments
---
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 2 +-
.../lib/Dialect/Linalg/Transforms/Specialize.cpp | 5 +++--
.../Linalg/roundtrip-linalg-named-ops.mlir | 16 ++++++++--------
3 files changed, 12 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 8ca76ec43193d..6ee1810c2ff2b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -105,7 +105,7 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
- // Check all loops are parallel, and have only tensor semantics.
+ // Check all loops are parallel.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumLoops() < 1)
return false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 7fac3feba98c9..035a31050a674 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -184,8 +184,9 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return failure();
auto batchSize = dims.batch.size();
- if (indexingMaps[0].getNumDims() != batchSize + 3) {
- }
+ if (indexingMaps[0].getNumDims() != batchSize + 3)
+ return failure();
+
if (batchSize) {
// Each operand in a linalg generic contraction could express different
// permutations for its batch dimension. But for named op it must be
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
index d258d9f518534..c38c39617204d 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -1,48 +1,48 @@
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
-func.func @roundtrip_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: @roundtrip_matmul
+// CHECK-LABEL: @matmul
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
-func.func @roundtrip_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: roundtrip_add
+// CHECK-LABEL: add
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
-func.func @roundtrip_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+func.func @exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
return
}
-// CHECK-LABEL: roundtrip_exp
+// CHECK-LABEL: exp
// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
// CHECK-NOT: linalg.generic
// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
// -----
-func.func @roundtrip_gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: @roundtrip_gemm
+// CHECK-LABEL: @gemm
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
>From faf725ec3c8633312f22e22bccac1532365eb4b2 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 22 Jun 2024 19:27:34 -0400
Subject: [PATCH 3/5] [MLIR][Linalg] More review comments changes.
---
.../Dialect/Linalg/Transforms/Transforms.h | 21 +++---
.../Dialect/Linalg/Transforms/Specialize.cpp | 69 ++++++++++++-------
.../Linalg/roundtrip-linalg-named-ops.mlir | 41 ++++++-----
.../Linalg/specialize-generic-ops.mlir | 46 ++++++-------
4 files changed, 101 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 912f9778a40e4..166e029c00939 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1396,18 +1396,15 @@ struct LinalgGeneralizationPattern
};
struct LinalgSpecializationPattern
- : public OpInterfaceRewritePattern<LinalgOp> {
- using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+ : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
- FailureOr<LinalgOp>
- returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
- auto genericOp = dyn_cast<GenericOp>(op.getOperation());
- if (!genericOp)
- return failure();
- return specializeGenericOp(rewriter, genericOp);
+ FailureOr<GenericOp>
+ returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
+ return specializeGenericOp(rewriter, op);
}
- LogicalResult matchAndRewrite(LinalgOp op,
+ LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}
@@ -1565,7 +1562,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns to convert linalg.generic ops to named
-/// ops where possible.
+/// ops where possible. A linalg.generic can represent wide range and complex
+/// computations for which equivalent linalg named op may not exist e.g.
+/// linalg.generic that takes a tensor and computes a polynomial such as:
+/// p(x) = an*x^n + ... + a1x + a0
+/// There is no equivalent named op to convert to. Many such cases exist.
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 035a31050a674..2206dbd260452 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -74,14 +74,25 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
/// Identifies linalg.generic that is essentially named op of the form:
// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
//
-// It is possible that a linalg.generic may be implementing one of matmul
-// variants but not in a straight-forward way, or the linalg.generic's
-// affine map per operand capture more semantics than is possible with
-// named op (which has implicit map interpreted via name).
-//
-// But a named linalg matmul variant that was 'generalized' should be
-// convertible back to named op here.
-//
+// It is possible that a linalg.generic may be implementing a matmul but not
+// in a straight-forward way e.g. below is matrix multiply over some slice
+// ```
+// %0 = linalg.generic {
+// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
+// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
+// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
+// iterator_types = ["parallel", "parallel", "parallel"]}
+// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
+// outs(%C : tensor<20x20x20xf32>) {
+// ^bb0(%a: f32, %b: f32, %c : f32):
+// %mul = arith.mulf %a, %b : f32
+// %add = arith.addf %mul, %c : f32
+// linalg.yield %add : f32
+// } -> tensor<20x20x20xf32>
+// ```
+// It is not possible to represent above as named op.
+// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
+// not the same as linalg.generic above.
namespace {
enum class IndexMatchResult {
Match = 0, // identity map.
@@ -89,23 +100,30 @@ enum class IndexMatchResult {
Mismatch // none of the above.
};
-// Looks at the affine map of an operand and works out if generic accesses
-// the element as identity-map, transposed, or 'cant work out'.
-// This check skips the `offset` batch indices and focuses on the matmul part.
-static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
- unsigned i, unsigned j) {
- auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
- auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
- if (!expr_ei || !expr_ej)
+// Consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
+// check whether the index map of A is identity (match), transposed, or
+// something completely different (mis-match).
+// The naming and explanation is in terms of A, but the function checks
+// effectively maps for all A, B, C i.e. <M,N>, <M, K>, <K,N>.
+static IndexMatchResult matchOperandMap(AffineMap map, unsigned batchSize,
+ unsigned expectedPosOfM,
+ unsigned expectedPosOfK) {
+ // Get the matrix multiply indices. They are past the batch indices.
+ auto exprOfM = map.getResults()[batchSize];
+ auto exprOfK = map.getResults()[batchSize + 1];
+
+ // They should be pure dim ids.
+ if (exprOfM.getKind() != AffineExprKind::DimId ||
+ exprOfK.getKind() != AffineExprKind::DimId)
return IndexMatchResult::Mismatch;
- auto ei = expr_ei.getPosition();
- auto ej = expr_ej.getPosition();
+ auto posM = cast<AffineDimExpr>(exprOfM).getPosition();
+ auto posK = cast<AffineDimExpr>(exprOfK).getPosition();
- if (ei == i && ej == j)
+ if (expectedPosOfM == posM && expectedPosOfK == posK)
return IndexMatchResult::Match;
- if (ei == j && ej == i)
+ if (expectedPosOfM == posK && expectedPosOfK == posM)
return IndexMatchResult::Transposed;
return IndexMatchResult::Mismatch;
@@ -179,7 +197,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
auto indexingMaps = genericOp.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
return m.getResults().size() !=
- dims.batch.size() + 2 /*two from {m,n,k}*/;
+ dims.batch.size() + 2 /* any two of {m,n,k} */;
}))
return failure();
@@ -193,8 +211,9 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
// identity since separate maps are not specified.
if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
for (unsigned i = 0; i < batchSize; ++i) {
- auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
- if (!expr || expr.getPosition() != i)
+ auto expr = m.getResults()[i];
+ if (expr.getKind() != AffineExprKind::DimId ||
+ cast<AffineDimExpr>(expr).getPosition() != i)
return true;
}
return false;
@@ -301,7 +320,9 @@ struct LinalgSpecializeGenericOpsPass
void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
}
void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
index c38c39617204d..1fb520c5982e6 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -1,48 +1,51 @@
+// The following test examples of linalg named ops lowered to linalg.generic and then
+// lifted back up to named op.
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
-func.func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
+func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
+ linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ return
}
-// CHECK-LABEL: @matmul
-// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-LABEL: unary_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
// -----
-func.func @add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: add
+// CHECK-LABEL: binary_add
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
-func.func @exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
- linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
- return
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: exp
-// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-LABEL: @matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
-func.func @gemm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.add ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+func.func @mixed_named_ops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: @gemm
+// CHECK-LABEL: @mixed_named_ops
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 0ec2dc3a92ec7..fbb24acbd8d43 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1,37 +1,37 @@
// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
-#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.generic
- {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.divf %in, %in_0 : f32
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%A : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
linalg.yield %1 : f32
- } -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: specialize_div
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-LABEL: unary_op_exp
+// CHECK-SAME: %[[A:.+]]: tensor<?x?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// -----
-#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic
- {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
- ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
- ^bb0(%in: f32, %out: f32):
- %1 = math.exp %in : f32
+ {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
linalg.yield %1 : f32
- } -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: specialize_exp
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-LABEL: binary_op_div
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: linalg.div ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
>From 2aaa615dee265610c7e7f379d3a1928c4893615f Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Wed, 26 Jun 2024 14:15:53 -0400
Subject: [PATCH 4/5] [MLIR][Linalg] More changs based on more comments
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 40 ++++++-----------
.../Linalg/specialize-generic-ops.mlir | 44 +++++++++++++++++++
2 files changed, 57 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2206dbd260452..9d6b9985dfdea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -100,11 +100,13 @@ enum class IndexMatchResult {
Mismatch // none of the above.
};
-// Consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
-// check whether the index map of A is identity (match), transposed, or
-// something completely different (mis-match).
+// Matches position of indices appearing the affine map of operand
+// with what is expected in non-transposed case. e.g.
+// consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
+// check whether the index map of A is identity (match), transposed, or
+// something completely different (mis-match).
// The naming and explanation is in terms of A, but the function checks
-// effectively maps for all A, B, C i.e. <M,N>, <M, K>, <K,N>.
+// effectively maps for all A, B, C i.e. C<M,N>, A<M, K>, B<K,N>.
static IndexMatchResult matchOperandMap(AffineMap map, unsigned batchSize,
unsigned expectedPosOfM,
unsigned expectedPosOfK) {
@@ -129,11 +131,13 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned batchSize,
return IndexMatchResult::Mismatch;
}
-// All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
-// have same number of input/output.
-template <typename Variant>
+// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
+// All the variants expressed as pseudo regular expression:
+// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of ins/out, so its easy to stamp different versions.
+template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
- LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]});
return namedOp;
@@ -156,16 +160,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
[](AffineMap m) { return !m.isProjectedPermutation(); }))
return failure();
- // matmul contractions are of the form:
- // %0 = <elemwise>(permutation-of(cu(block-argument-0),
- // cu(block-argument-1)))
- // %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
- //
- // where <elemwise> and <reduce> are binary operations constituting a
- // contraction (in the canonical case, <elemwise> is a multiplication and
- // <reduce> is an addition). All operands of all operations may be supplied
- // through a chain of side effect-free unary operations, such as casts,
- // which is denoted as `cu` above.
if (!mlir::linalg::detail::isContractionBody(
*genericOp.getBlock(), [](Operation *first, Operation *second) {
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
@@ -176,20 +170,12 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
}))
return failure();
- // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
- // form a matmul subcomputation. These dimensions are such that:
- // 1. The m dimension is involved in an outer-product along LHS
- // (i.e. it is a permutation on RES and LHS and does not appear in RHS).
- // 2. The n dimension is involved in an outer-product along RHS
- // (i.e. it is a permutation on RES and RHS and does not appear in LHS).
- // 3. The k dimension appears as a permutation on LHS and RHS.
- // 4. m, n and k appear only once in any given indexing.
- // 5. Optional batch dimensions that appear in all operands are captured.
auto res = inferContractionDims(genericOp);
assert(succeeded(res) && "unexpected failure to infer contraction dims");
auto dims = *res;
// Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+ // Note that linalg contraction can have more than one contraction dimension.
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
return failure();
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index fbb24acbd8d43..f95cd8b754c54 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -35,3 +35,47 @@ func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.div ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @op_batch_matmul(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>, %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: op_batch_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
>From bc4215bed059c10edf0d3823a941cfe84c3421d1 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Fri, 28 Jun 2024 18:45:53 -0400
Subject: [PATCH 5/5] [MLIR][Linalg] changes based on review comments -
renaming batchSize as numOfBatchDims - example of multiple contraction dims
corresponding to K as inferred by inferControctionDims - add linalg.matvec
and mark it as TODO - use hyphen (-) rather than underscore in filenames (_).
- implement banach suggestion to replace M K in explanation with
expectedPosOfRowDim etc
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 102 ++++++++++--------
.../Linalg/specialize-generic-ops.mlir | 44 ++++++++
...nsform-op-specialize-elemwise-binary.mlir} | 0
...ansform-op-specialize-elemwise-unary.mlir} | 0
...ir => transform-op-specialize-matmul.mlir} | 0
5 files changed, 103 insertions(+), 43 deletions(-)
rename mlir/test/Dialect/Linalg/{transform-op-specialize_elemwise_binary.mlir => transform-op-specialize-elemwise-binary.mlir} (100%)
rename mlir/test/Dialect/Linalg/{transform-op-specialize_elemwise_unary.mlir => transform-op-specialize-elemwise-unary.mlir} (100%)
rename mlir/test/Dialect/Linalg/{transform-op-specialize_matmul.mlir => transform-op-specialize-matmul.mlir} (100%)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 9d6b9985dfdea..78bfa383d25a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -100,32 +100,33 @@ enum class IndexMatchResult {
Mismatch // none of the above.
};
-// Matches position of indices appearing the affine map of operand
-// with what is expected in non-transposed case. e.g.
-// consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
-// check whether the index map of A is identity (match), transposed, or
-// something completely different (mis-match).
-// The naming and explanation is in terms of A, but the function checks
-// effectively maps for all A, B, C i.e. C<M,N>, A<M, K>, B<K,N>.
-static IndexMatchResult matchOperandMap(AffineMap map, unsigned batchSize,
- unsigned expectedPosOfM,
- unsigned expectedPosOfK) {
+// Checks whether the input Affine `map` contains two consecutive dims that
+// can be interpreted as accessing a 2D matrix. It is assumed that the row
+// column dimension are adjacent axis (in this order) and start at
+// `rowDimIdx` in the input map.
+//
+// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
+// whether the map of A is identity (match), transposed, or something
+// completely different (mis-match). Similar for B and C.
+static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
+ unsigned expectedPosOfRowDim,
+ unsigned expectedPosOfColDim) {
// Get the matrix multiply indices. They are past the batch indices.
- auto exprOfM = map.getResults()[batchSize];
- auto exprOfK = map.getResults()[batchSize + 1];
+ auto exprOfRowDim = map.getResults()[rowDimIdx];
+ auto exprOfColDim = map.getResults()[rowDimIdx + 1];
- // They should be pure dim ids.
- if (exprOfM.getKind() != AffineExprKind::DimId ||
- exprOfK.getKind() != AffineExprKind::DimId)
+ // They should be pure dimension ids.
+ if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
+ exprOfColDim.getKind() != AffineExprKind::DimId)
return IndexMatchResult::Mismatch;
- auto posM = cast<AffineDimExpr>(exprOfM).getPosition();
- auto posK = cast<AffineDimExpr>(exprOfK).getPosition();
+ auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
+ auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
- if (expectedPosOfM == posM && expectedPosOfK == posK)
+ if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
return IndexMatchResult::Match;
- if (expectedPosOfM == posK && expectedPosOfK == posM)
+ if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
return IndexMatchResult::Transposed;
return IndexMatchResult::Mismatch;
@@ -149,17 +150,38 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
return failure();
- // Linalg generic contraction can be across multiple axis but for matmul
- // variants it must be one.
- if (genericOp.getNumReductionLoops() != 1)
- return failure();
-
- // Must be projected permutations.
+ // Early exit if not projected permutations.
auto mapRange = genericOp.getIndexingMapsArray();
if (llvm::any_of(mapRange,
[](AffineMap m) { return !m.isProjectedPermutation(); }))
return failure();
+ // Linalg generic contraction can be across multiple axis e.g.
+ // ```
+ // linalg.generic
+ // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
+ // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
+ // affine_map<(m, n, k1, k2) -> (m, n)>],
+ // iterator_types = ["parallel", "parallel",
+ // "reduction", "reduction"]}
+ // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
+ // outs(%C : tensor<10x40xf32>) {
+ // ^bb0(%a: f32, %b: f32, %c: f32):
+ // %1 = arith.mulf %a, %b : f32
+ // %2 = arith.addf %c, %1 : f32
+ // linalg.yield %2 : f32
+ // } -> tensor<10x40xf32>
+ // ```
+ // In above contraction, there are two reduction dimensions {k1, k2}
+ // and although a valid linalg contraction, it is not a named-op
+ // matrix multiply kind. Therefore, reject multi-dim reduction.
+ auto res = inferContractionDims(genericOp);
+ if (!succeeded(res))
+ return failure();
+ auto dims = *res;
+ if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+ return failure();
+
if (!mlir::linalg::detail::isContractionBody(
*genericOp.getBlock(), [](Operation *first, Operation *second) {
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
@@ -170,15 +192,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
}))
return failure();
- auto res = inferContractionDims(genericOp);
- assert(succeeded(res) && "unexpected failure to infer contraction dims");
- auto dims = *res;
-
- // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
- // Note that linalg contraction can have more than one contraction dimension.
- if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
- return failure();
-
// Check rank of operands
auto indexingMaps = genericOp.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
@@ -187,16 +200,16 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
}))
return failure();
- auto batchSize = dims.batch.size();
- if (indexingMaps[0].getNumDims() != batchSize + 3)
+ auto numOfBatchDims = dims.batch.size();
+ if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
return failure();
- if (batchSize) {
+ if (numOfBatchDims) {
// Each operand in a linalg generic contraction could express different
// permutations for its batch dimension. But for named op it must be
// identity since separate maps are not specified.
- if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
- for (unsigned i = 0; i < batchSize; ++i) {
+ if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
+ for (unsigned i = 0; i < numOfBatchDims; ++i) {
auto expr = m.getResults()[i];
if (expr.getKind() != AffineExprKind::DimId ||
cast<AffineDimExpr>(expr).getPosition() != i)
@@ -207,9 +220,12 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return failure();
}
- auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
- auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
- auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+ auto a =
+ matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
+ auto b =
+ matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
+ auto c =
+ matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
return r == IndexMatchResult::Mismatch;
@@ -221,7 +237,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return failure();
/// Codegen the different matmul variants.
- if (batchSize) {
+ if (numOfBatchDims) {
if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
genericOp);
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index f95cd8b754c54..03336a4d0c74b 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -79,3 +79,47 @@ func.func @op_batch_matmul(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>, %Out:
// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+// -----
+
+// This is a multi-reduction linalg.generic and cannot be lifted to matrix multiply
+#mapA = affine_map<(m, n, k1, k2) -> (m, k1, k2)>
+#mapB = affine_map<(m, n, k1, k2) -> (k2, k1, n)>
+#mapC = affine_map<(m, n, k1, k2) -> (m, n)>
+func.func @op_multi_reduction(%A: tensor<10x20x30xf32>, %B: tensor<30x20x40xf32>,
+ %C: tensor<10x40xf32>) -> tensor<10x40xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapA, #mapB, #mapC],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
+ outs(%C : tensor<10x40xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %1 = arith.mulf %a, %b : f32
+ %2 = arith.addf %c, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<10x40xf32>
+ return %0 : tensor<10x40xf32>
+}
+
+// CHECK-LABEL: op_multi_reduction
+// CHECK: linalg.generic
+
+// -----
+
+// TODO: matvec
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?xf32>) outs(%Out : tensor<?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+// CHECK-LABEL: op_matvec
+// CHECK: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
rename to mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
rename to mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/transform-op-specialize_matmul.mlir
rename to mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
More information about the Mlir-commits
mailing list