[Mlir-commits] [mlir] 3efac5c - [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (#95656)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 30 11:37:55 PDT 2024
Author: Javed Absar
Date: 2024-06-30T19:37:51+01:00
New Revision: 3efac5c68ac3117e8488a7fa247e45951e52936f
URL: https://github.com/llvm/llvm-project/commit/3efac5c68ac3117e8488a7fa247e45951e52936f
DIFF: https://github.com/llvm/llvm-project/commit/3efac5c68ac3117e8488a7fa247e45951e52936f.diff
LOG: [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (#95656)
Add a new mlir-opt pass `--linalg-specialize-generic-ops` which lifts generic,
where possible, to linalg named ops.
Much like `-linalg-generalize-named-ops` lowers named ops to linalg.generic .
Also add patterns to recognize contractions which can be specialized from
linalg.generic to named op: `linalg.{batch_}?matmul{_transpose_(a|b)}?`
Added:
mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Removed:
mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0621a9f33ba1e..d96ad919b65f0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -94,6 +94,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
+ let summary = "Convert generic ops back to named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b0871a5dff5da..3812eb50095d5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1416,6 +1416,20 @@ struct LinalgGeneralizationPattern
}
};
+struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ FailureOr<GenericOp>
+ returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
+ return specializeGenericOp(rewriter, op);
+ }
+
+ LogicalResult matchAndRewrite(GenericOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
+};
+
/// Vectorization pattern for memref::CopyOp.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -1567,6 +1581,15 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to convert linalg.generic ops to named
+/// ops where possible. A linalg.generic can represent wide range and complex
+/// computations for which equivalent linalg named op may not exist e.g.
+/// linalg.generic that takes a tensor and computes a polynomial such as:
+/// p(x) = an*x^n + ... + a1x + a0
+/// There is no equivalent named op to convert to. Many such cases exist.
+void populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f35ab3b856b4e..6ee1810c2ff2b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -105,9 +105,9 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
- // Check all loops are parallel, and have only tensor semantics.
+ // Check all loops are parallel.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+ genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bc4d7fbfadcc..78bfa383d25a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,12 +11,22 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
+namespace mlir {
+#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
#define DEBUG_TYPE "linalg-specialization"
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
@@ -58,6 +68,197 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing a matmul but not
+// in a straight-forward way e.g. below is matrix multiply over some slice
+// ```
+// %0 = linalg.generic {
+// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
+// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
+// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
+// iterator_types = ["parallel", "parallel", "parallel"]}
+// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
+// outs(%C : tensor<20x20x20xf32>) {
+// ^bb0(%a: f32, %b: f32, %c : f32):
+// %mul = arith.mulf %a, %b : f32
+// %add = arith.addf %mul, %c : f32
+// linalg.yield %add : f32
+// } -> tensor<20x20x20xf32>
+// ```
+// It is not possible to represent above as named op.
+// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
+// not the same as linalg.generic above.
+namespace {
+enum class IndexMatchResult {
+ Match = 0, // identity map.
+ Transposed, // transposed map.
+ Mismatch // none of the above.
+};
+
+// Checks whether the input Affine `map` contains two consecutive dims that
+// can be interpreted as accessing a 2D matrix. It is assumed that the row
+// column dimension are adjacent axis (in this order) and start at
+// `rowDimIdx` in the input map.
+//
+// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
+// whether the map of A is identity (match), transposed, or something
+// completely
diff erent (mis-match). Similar for B and C.
+static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
+ unsigned expectedPosOfRowDim,
+ unsigned expectedPosOfColDim) {
+ // Get the matrix multiply indices. They are past the batch indices.
+ auto exprOfRowDim = map.getResults()[rowDimIdx];
+ auto exprOfColDim = map.getResults()[rowDimIdx + 1];
+
+ // They should be pure dimension ids.
+ if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
+ exprOfColDim.getKind() != AffineExprKind::DimId)
+ return IndexMatchResult::Mismatch;
+
+ auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
+ auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
+
+ if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
+ return IndexMatchResult::Match;
+
+ if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
+ return IndexMatchResult::Transposed;
+
+ return IndexMatchResult::Mismatch;
+}
+
+// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
+// All the variants expressed as pseudo regular expression:
+// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of ins/out, so its easy to stamp
diff erent versions.
+template <typename NamedOpTy>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+ op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+ ValueRange{op.getDpsInits()[0]});
+ return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+ return failure();
+
+ // Early exit if not projected permutations.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(mapRange,
+ [](AffineMap m) { return !m.isProjectedPermutation(); }))
+ return failure();
+
+ // Linalg generic contraction can be across multiple axis e.g.
+ // ```
+ // linalg.generic
+ // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
+ // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
+ // affine_map<(m, n, k1, k2) -> (m, n)>],
+ // iterator_types = ["parallel", "parallel",
+ // "reduction", "reduction"]}
+ // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
+ // outs(%C : tensor<10x40xf32>) {
+ // ^bb0(%a: f32, %b: f32, %c: f32):
+ // %1 = arith.mulf %a, %b : f32
+ // %2 = arith.addf %c, %1 : f32
+ // linalg.yield %2 : f32
+ // } -> tensor<10x40xf32>
+ // ```
+ // In above contraction, there are two reduction dimensions {k1, k2}
+ // and although a valid linalg contraction, it is not a named-op
+ // matrix multiply kind. Therefore, reject multi-dim reduction.
+ auto res = inferContractionDims(genericOp);
+ if (!succeeded(res))
+ return failure();
+ auto dims = *res;
+ if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+ return failure();
+
+ if (!mlir::linalg::detail::isContractionBody(
+ *genericOp.getBlock(), [](Operation *first, Operation *second) {
+ if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+ (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+ (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+ return true;
+ return false;
+ }))
+ return failure();
+
+ // Check rank of operands
+ auto indexingMaps = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+ return m.getResults().size() !=
+ dims.batch.size() + 2 /* any two of {m,n,k} */;
+ }))
+ return failure();
+
+ auto numOfBatchDims = dims.batch.size();
+ if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
+ return failure();
+
+ if (numOfBatchDims) {
+ // Each operand in a linalg generic contraction could express
diff erent
+ // permutations for its batch dimension. But for named op it must be
+ // identity since separate maps are not specified.
+ if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
+ for (unsigned i = 0; i < numOfBatchDims; ++i) {
+ auto expr = m.getResults()[i];
+ if (expr.getKind() != AffineExprKind::DimId ||
+ cast<AffineDimExpr>(expr).getPosition() != i)
+ return true;
+ }
+ return false;
+ }))
+ return failure();
+ }
+
+ auto a =
+ matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
+ auto b =
+ matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
+ auto c =
+ matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
+
+ if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+ return r == IndexMatchResult::Mismatch;
+ }))
+ return failure();
+
+ if (c != IndexMatchResult::Match ||
+ (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+ return failure();
+
+ /// Codegen the
diff erent matmul variants.
+ if (numOfBatchDims) {
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+ genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+ genericOp);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+ }
+
+ if (a == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+ if (b == IndexMatchResult::Transposed)
+ return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Categorize linalg generic to named op where possible.
+//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
@@ -100,5 +301,33 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
return namedOp;
}
}
+
+ if (isaContractionOpInterface(genericOp)) {
+ return specializeLinalgContractions(rewriter, genericOp);
+ }
return failure();
}
+
+namespace {
+struct LinalgSpecializeGenericOpsPass
+ : public impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass> {
+
+ using impl::LinalgSpecializeGenericOpsPassBase<
+ LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
+ void runOnOperation() override;
+};
+} // namespace
+
+void LinalgSpecializeGenericOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateLinalgGenericOpsSpecializationPatterns(patterns);
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
+
+void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<LinalgSpecializationPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
new file mode 100644
index 0000000000000..1fb520c5982e6
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -0,0 +1,52 @@
+// The following test examples of linalg named ops lowered to linalg.generic and then
+// lifted back up to named op.
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
+ linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK-LABEL: unary_exp
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: binary_add
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @mixed_named_ops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @mixed_named_ops
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
new file mode 100644
index 0000000000000..cf495a7d29b70
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -0,0 +1,126 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%A : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: unary_op_exp
+// CHECK-SAME: %[[A:.+]]: tensor<?x?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: binary_op_div
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @op_batch_matmul(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>, %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: op_batch_matmul
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+// -----
+
+// This is a multi-reduction linalg.generic and cannot be lifted to matrix multiply
+#mapA = affine_map<(m, n, k1, k2) -> (m, k1, k2)>
+#mapB = affine_map<(m, n, k1, k2) -> (k2, k1, n)>
+#mapC = affine_map<(m, n, k1, k2) -> (m, n)>
+func.func @negative_op_multi_reduction(%A: tensor<10x20x30xf32>,
+ %B: tensor<30x20x40xf32>,
+ %C: tensor<10x40xf32>) -> tensor<10x40xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapA, #mapB, #mapC],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
+ outs(%C : tensor<10x40xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %1 = arith.mulf %a, %b : f32
+ %2 = arith.addf %c, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<10x40xf32>
+ return %0 : tensor<10x40xf32>
+}
+
+// CHECK-LABEL: negative_op_multi_reduction
+// CHECK: linalg.generic
+
+// -----
+
+// TODO: matvec
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?xf32>) outs(%Out : tensor<?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+// CHECK-LABEL: op_matvec
+// CHECK: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
rename to mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
rename to mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
new file mode 100644
index 0000000000000..f64953bceefe1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @specialize_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @specialize_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.mulf %in, %in_0 : f32
+ %1 = arith.addf %out, %0 : f32
+ linalg.yield %1 : f32
+ }
+ return
+}
+
+// CHECK-LABEL: @matmul_transpose_a
+// CHECK-SAME: %[[ARG0:.+]]: memref<5x3xf32>, %[[ARG1:.+]]: memref<5x7xf32>, %[[ARG2:.+]]: memref<3x7xf32>) {
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul_transpose_a ins(%[[ARG0]], %[[ARG1]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[ARG2]] : memref<3x7xf32>)
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_transpose_b(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: @matmul_transpose_b
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x16xf32>, %arg2: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%arg2 : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: @batch_matmul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x16x8xf32>, %[[ARG1:.+]]: tensor<2x8x16xf32>, %[[ARG2:.+]]: tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[ARG0]], %[[ARG1]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[ARG2]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_transpose_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: @batch_matmul_transpose_b
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list