[Mlir-commits] [mlir] implement canonicalizer for batched linalg operations (PR #95710)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 16 09:16:57 PDT 2024
https://github.com/srcarroll created https://github.com/llvm/llvm-project/pull/95710
None
>From 0418e51cf33bc59cc6f19ed00edc8c2d62e4d9df Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 15 Jun 2024 10:46:44 -0500
Subject: [PATCH] implement canonicalizer for batched linalg operations
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 49 ++-----
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 121 ++++++++++++++++++
.../linalg/opdsl/ops/core_named_ops.py | 5 +
3 files changed, 138 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index fad234a9dcae9..3cbfb58ed8506 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,41 +304,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: reciprocal
- cpp_class_name: ReciprocalOp
- doc: |-
- Applies reciprocal(x) elementwise.
-
- No numeric casting is performed on the input operand.
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: I
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: O
- kind: output_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<() -> ()>
- - affine_map<() -> ()>
- iterator_types: []
- assignments:
- - !ScalarAssign
- arg: O
- value: !ScalarExpression
- scalar_fn:
- kind: unary
- fn_name: reciprocal
- operands:
- - !ScalarExpression
- scalar_arg: I
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: round
cpp_class_name: RoundOp
@@ -516,7 +481,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: erf
- cpp_class_name: erfOp
+ cpp_class_name: ErfOp
doc: |-
Applies erf(x) elementwise.
@@ -959,7 +924,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: powf
- cpp_class_name: PowFOp
+ cpp_class_name: PowfOp
doc: |-
Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
@@ -1622,6 +1587,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1692,6 +1659,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1762,6 +1731,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2140,6 +2111,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2208,6 +2181,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..ecd669165efc7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -42,6 +43,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include <numeric>
#include <optional>
using namespace mlir;
@@ -578,6 +580,125 @@ class RegionBuilderHelper {
} // namespace
+//===----------------------------------------------------------------------===//
+// BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename BatchOpTy, typename OpTy>
+struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
+ using OpRewritePattern<BatchOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+ PatternRewriter &rewriter) const override {
+
+ auto loc = batchMatmulOp.getLoc();
+ auto inputs = batchMatmulOp.getDpsInputs();
+ auto inits = batchMatmulOp.getDpsInits();
+ if (inputs.size() != 2 || inits.size() != 1)
+ return rewriter.notifyMatchFailure(batchMatmulOp,
+ "expected 2 inputs and 1 init");
+ auto lhs = inputs[0];
+ auto rhs = inputs[1];
+ auto init = inits[0];
+
+ auto lhsType = cast<ShapedType>(lhs.getType());
+ auto rhsType = cast<ShapedType>(rhs.getType());
+ auto initType = cast<ShapedType>(init.getType());
+ if (ShapedType::isDynamic(lhsType.getShape()[0]) ||
+ lhsType.getShape()[0] != rhsType.getShape()[0] ||
+ rhsType.getShape()[0] != initType.getShape()[0])
+ return rewriter.notifyMatchFailure(
+ batchMatmulOp, "expected batch sizes of all operands to be same");
+
+ auto results = batchMatmulOp.getResults();
+ if (results.size() > 1)
+ return rewriter.notifyMatchFailure(batchMatmulOp,
+ "expected at most one result");
+
+ SmallVector<Type, 1> resultType;
+ if (results.size() == 1) {
+ auto oldResultType = cast<RankedTensorType>(results[0].getType());
+ resultType.push_back(
+ RankedTensorType::get(oldResultType.getShape().drop_front(1),
+ oldResultType.getElementType()));
+ }
+
+ auto collapseSingletonDim = [&](Value val) -> Value {
+ SmallVector<ReassociationIndices> reassociation({{0, 1}});
+ auto valType = cast<ShapedType>(val.getType());
+ for (auto i = 2; i < valType.getRank(); i++)
+ reassociation.push_back({i});
+ if (isa<RankedTensorType>(valType)) {
+ RankedTensorType collapsedType = RankedTensorType::get(
+ valType.getShape().drop_front(1), valType.getElementType());
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
+ reassociation);
+ }
+ MemRefType collapsedType = MemRefType::get(
+ valType.getShape().drop_front(1), valType.getElementType());
+ return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
+ reassociation);
+ };
+
+ auto collapsedLhs = collapseSingletonDim(lhs);
+ auto collapsedRhs = collapseSingletonDim(rhs);
+ auto collapsedInit = collapseSingletonDim(init);
+
+ auto collapsedOp = rewriter.create<OpTy>(
+ loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+ ValueRange{collapsedInit});
+ for (auto attr : batchMatmulOp->getAttrs()) {
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ continue;
+ collapsedOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ if (results.size() < 1) {
+ rewriter.replaceOp(batchMatmulOp, collapsedOp);
+ } else {
+ SmallVector<ReassociationIndices> reassociation({{0, 1}});
+ auto resultType = cast<ShapedType>(results[0].getType());
+ for (auto i = 2; i < resultType.getRank(); i++)
+ reassociation.push_back({i});
+ Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
+ loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
+ rewriter.replaceOp(batchMatmulOp, expandedResult);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
+}
+
+void BatchMatmulTransposeAOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+ context);
+}
+
+void BatchMatmulTransposeBOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+ context);
+}
+
+void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
+}
+
+void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 43410aaa6af1b..b4b36ba0bfe51 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -518,6 +518,7 @@ def batch_matmul(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -537,6 +538,7 @@ def batch_matmul_transpose_a(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
@@ -556,6 +558,7 @@ def batch_matmul_transpose_b(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -642,6 +645,7 @@ def batch_matvec(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.k)
implements(ContractionOpInterface)
C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -660,6 +664,7 @@ def batch_vecmat(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
More information about the Mlir-commits
mailing list