[Mlir-commits] [mlir] [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (PR #95710)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 21 10:30:33 PDT 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/95710
>From 0418e51cf33bc59cc6f19ed00edc8c2d62e4d9df Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 15 Jun 2024 10:46:44 -0500
Subject: [PATCH 01/14] implement canonicalizer for batched linalg operations
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 49 ++-----
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 121 ++++++++++++++++++
.../linalg/opdsl/ops/core_named_ops.py | 5 +
3 files changed, 138 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index fad234a9dcae9..3cbfb58ed8506 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,41 +304,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: reciprocal
- cpp_class_name: ReciprocalOp
- doc: |-
- Applies reciprocal(x) elementwise.
-
- No numeric casting is performed on the input operand.
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: I
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: O
- kind: output_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<() -> ()>
- - affine_map<() -> ()>
- iterator_types: []
- assignments:
- - !ScalarAssign
- arg: O
- value: !ScalarExpression
- scalar_fn:
- kind: unary
- fn_name: reciprocal
- operands:
- - !ScalarExpression
- scalar_arg: I
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: round
cpp_class_name: RoundOp
@@ -516,7 +481,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: erf
- cpp_class_name: erfOp
+ cpp_class_name: ErfOp
doc: |-
Applies erf(x) elementwise.
@@ -959,7 +924,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: powf
- cpp_class_name: PowFOp
+ cpp_class_name: PowfOp
doc: |-
Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
@@ -1622,6 +1587,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1692,6 +1659,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1762,6 +1731,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2140,6 +2111,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2208,6 +2181,8 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..ecd669165efc7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -42,6 +43,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include <numeric>
#include <optional>
using namespace mlir;
@@ -578,6 +580,125 @@ class RegionBuilderHelper {
} // namespace
+//===----------------------------------------------------------------------===//
+// BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename BatchOpTy, typename OpTy>
+struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
+ using OpRewritePattern<BatchOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+ PatternRewriter &rewriter) const override {
+
+ auto loc = batchMatmulOp.getLoc();
+ auto inputs = batchMatmulOp.getDpsInputs();
+ auto inits = batchMatmulOp.getDpsInits();
+ if (inputs.size() != 2 || inits.size() != 1)
+ return rewriter.notifyMatchFailure(batchMatmulOp,
+ "expected 2 inputs and 1 init");
+ auto lhs = inputs[0];
+ auto rhs = inputs[1];
+ auto init = inits[0];
+
+ auto lhsType = cast<ShapedType>(lhs.getType());
+ auto rhsType = cast<ShapedType>(rhs.getType());
+ auto initType = cast<ShapedType>(init.getType());
+ if (ShapedType::isDynamic(lhsType.getShape()[0]) ||
+ lhsType.getShape()[0] != rhsType.getShape()[0] ||
+ rhsType.getShape()[0] != initType.getShape()[0])
+ return rewriter.notifyMatchFailure(
+ batchMatmulOp, "expected batch sizes of all operands to be same");
+
+ auto results = batchMatmulOp.getResults();
+ if (results.size() > 1)
+ return rewriter.notifyMatchFailure(batchMatmulOp,
+ "expected at most one result");
+
+ SmallVector<Type, 1> resultType;
+ if (results.size() == 1) {
+ auto oldResultType = cast<RankedTensorType>(results[0].getType());
+ resultType.push_back(
+ RankedTensorType::get(oldResultType.getShape().drop_front(1),
+ oldResultType.getElementType()));
+ }
+
+ auto collapseSingletonDim = [&](Value val) -> Value {
+ SmallVector<ReassociationIndices> reassociation({{0, 1}});
+ auto valType = cast<ShapedType>(val.getType());
+ for (auto i = 2; i < valType.getRank(); i++)
+ reassociation.push_back({i});
+ if (isa<RankedTensorType>(valType)) {
+ RankedTensorType collapsedType = RankedTensorType::get(
+ valType.getShape().drop_front(1), valType.getElementType());
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
+ reassociation);
+ }
+ MemRefType collapsedType = MemRefType::get(
+ valType.getShape().drop_front(1), valType.getElementType());
+ return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
+ reassociation);
+ };
+
+ auto collapsedLhs = collapseSingletonDim(lhs);
+ auto collapsedRhs = collapseSingletonDim(rhs);
+ auto collapsedInit = collapseSingletonDim(init);
+
+ auto collapsedOp = rewriter.create<OpTy>(
+ loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+ ValueRange{collapsedInit});
+ for (auto attr : batchMatmulOp->getAttrs()) {
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ continue;
+ collapsedOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ if (results.size() < 1) {
+ rewriter.replaceOp(batchMatmulOp, collapsedOp);
+ } else {
+ SmallVector<ReassociationIndices> reassociation({{0, 1}});
+ auto resultType = cast<ShapedType>(results[0].getType());
+ for (auto i = 2; i < resultType.getRank(); i++)
+ reassociation.push_back({i});
+ Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
+ loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
+ rewriter.replaceOp(batchMatmulOp, expandedResult);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
+}
+
+void BatchMatmulTransposeAOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+ context);
+}
+
+void BatchMatmulTransposeBOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+ context);
+}
+
+void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
+}
+
+void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 43410aaa6af1b..b4b36ba0bfe51 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -518,6 +518,7 @@ def batch_matmul(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -537,6 +538,7 @@ def batch_matmul_transpose_a(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
@@ -556,6 +558,7 @@ def batch_matmul_transpose_b(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -642,6 +645,7 @@ def batch_matvec(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.m, D.k)
implements(ContractionOpInterface)
C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -660,6 +664,7 @@ def batch_vecmat(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ defines(Canonicalizer)
domain(D.b, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
>From 02b2ca083d145fc88a9498480e4a831affdebf10 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 16 Jun 2024 12:01:33 -0500
Subject: [PATCH 02/14] add tests
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 34 +++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 12 +-
mlir/test/Dialect/Linalg/canonicalize.mlir | 137 +++++++++++++++++-
3 files changed, 174 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3cbfb58ed8506..41f90483c93b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,40 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: reciprocal
+ cpp_class_name: ReciprocalOp
+ doc: |-
+ Applies reciprocal(x) elementwise.
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: reciprocal
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: round
cpp_class_name: RoundOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ecd669165efc7..4e47b6018c445 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -605,16 +605,12 @@ struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
auto lhsType = cast<ShapedType>(lhs.getType());
auto rhsType = cast<ShapedType>(rhs.getType());
auto initType = cast<ShapedType>(init.getType());
- if (ShapedType::isDynamic(lhsType.getShape()[0]) ||
- lhsType.getShape()[0] != rhsType.getShape()[0] ||
- rhsType.getShape()[0] != initType.getShape()[0])
- return rewriter.notifyMatchFailure(
- batchMatmulOp, "expected batch sizes of all operands to be same");
+ if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
+ initType.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
auto results = batchMatmulOp.getResults();
- if (results.size() > 1)
- return rewriter.notifyMatchFailure(batchMatmulOp,
- "expected at most one result");
+ assert(results.size() < 2 && "expected at most one result");
SmallVector<Type, 1> resultType;
if (results.size() == 1) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..8514bcb089891 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
return %0 : tensor<2x3xf32>
}
-// ----
+// -----
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,3 +1096,138 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
func.return %transpose2 : tensor<3x4x5xf32>
}
+// -----
+
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+ // CHECK-LABEL: @singleton_batch_matmul_tensor
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+ outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
+ return %1 : tensor<1x?x?xf32>
+}
+
+// -----
+
+func.func @singletone_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+ // CHECK-LABEL: @singletone_batch_matmul_memref
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+ outs(%arg2 : memref<1x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @singletone_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK-LABEL: @singletone_batch_matvec
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
+ outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singletone_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK-LABEL: @singletone_batch_vecmat
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
+ outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singletone_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
+ // CHECK-LABEL: @singletone_batchmatmul_transpose_a
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @singletone_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
+ // CHECK-LABEL: @singletone_batchmatmul_transpose_b
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
+ // CHECK-LABEL: @nonsingleton_batch_matmul
+ // CHECK-NOT: collapse_shape
+ // CHECK: linalg.batch_matmul
+ // CHECK-NOT: expand_shape
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+ outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+ return %1 : tensor<2x?x?xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
+ // CHECK-NOT: collapse_shape
+ // CHECK: linalg.batch_matmul
+ // CHECK-NOT: expand_shape
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+
>From 543b0d643506d12c658e2984d943873ed4c8b78b Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 16 Jun 2024 12:12:35 -0500
Subject: [PATCH 03/14] remove unecessary changes
---
.../mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml | 1 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 --
2 files changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 41f90483c93b3..3f0aa33767a75 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -309,6 +309,7 @@ metadata: !LinalgOpMetadata
cpp_class_name: ReciprocalOp
doc: |-
Applies reciprocal(x) elementwise.
+
No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4e47b6018c445..8df33a107c2cb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -43,7 +42,6 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
-#include <numeric>
#include <optional>
using namespace mlir;
>From 5732d87375c942306ddf5c7a6661b8123f423b1c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 18 Jun 2024 20:28:31 -0500
Subject: [PATCH 04/14] Move patterns to a populate function and implement test
pass
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 14 +-
.../Dialect/Linalg/Transforms/Transforms.h | 7 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 115 ---------------
.../Linalg/Transforms/DropUnitDims.cpp | 98 +++++++++++++
.../linalg/opdsl/ops/core_named_ops.py | 5 -
mlir/test/Dialect/Linalg/canonicalize.mlir | 137 +-----------------
mlir/test/lib/Dialect/Linalg/CMakeLists.txt | 1 +
.../TestLinalgRankReduceContractionOps.cpp | 68 +++++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
9 files changed, 179 insertions(+), 268 deletions(-)
create mode 100644 mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3f0aa33767a75..fad234a9dcae9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -516,7 +516,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: erf
- cpp_class_name: ErfOp
+ cpp_class_name: erfOp
doc: |-
Applies erf(x) elementwise.
@@ -959,7 +959,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: powf
- cpp_class_name: PowfOp
+ cpp_class_name: PowFOp
doc: |-
Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
@@ -1622,8 +1622,6 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
- defines:
- - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1694,8 +1692,6 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
- defines:
- - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1766,8 +1762,6 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
- defines:
- - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2146,8 +2140,6 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
- defines:
- - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2216,8 +2208,6 @@ metadata: !LinalgOpMetadata
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
- defines:
- - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..c49383c600a57 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
+/// Adds patterns that that reduce the rank of named contraction ops that have
+/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
+/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
+/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8df33a107c2cb..b79afebfa8158 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -578,121 +578,6 @@ class RegionBuilderHelper {
} // namespace
-//===----------------------------------------------------------------------===//
-// BatchMatmulOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-template <typename BatchOpTy, typename OpTy>
-struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
- using OpRewritePattern<BatchOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
- PatternRewriter &rewriter) const override {
-
- auto loc = batchMatmulOp.getLoc();
- auto inputs = batchMatmulOp.getDpsInputs();
- auto inits = batchMatmulOp.getDpsInits();
- if (inputs.size() != 2 || inits.size() != 1)
- return rewriter.notifyMatchFailure(batchMatmulOp,
- "expected 2 inputs and 1 init");
- auto lhs = inputs[0];
- auto rhs = inputs[1];
- auto init = inits[0];
-
- auto lhsType = cast<ShapedType>(lhs.getType());
- auto rhsType = cast<ShapedType>(rhs.getType());
- auto initType = cast<ShapedType>(init.getType());
- if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
- initType.getShape()[0] != 1)
- return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
-
- auto results = batchMatmulOp.getResults();
- assert(results.size() < 2 && "expected at most one result");
-
- SmallVector<Type, 1> resultType;
- if (results.size() == 1) {
- auto oldResultType = cast<RankedTensorType>(results[0].getType());
- resultType.push_back(
- RankedTensorType::get(oldResultType.getShape().drop_front(1),
- oldResultType.getElementType()));
- }
-
- auto collapseSingletonDim = [&](Value val) -> Value {
- SmallVector<ReassociationIndices> reassociation({{0, 1}});
- auto valType = cast<ShapedType>(val.getType());
- for (auto i = 2; i < valType.getRank(); i++)
- reassociation.push_back({i});
- if (isa<RankedTensorType>(valType)) {
- RankedTensorType collapsedType = RankedTensorType::get(
- valType.getShape().drop_front(1), valType.getElementType());
- return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
- reassociation);
- }
- MemRefType collapsedType = MemRefType::get(
- valType.getShape().drop_front(1), valType.getElementType());
- return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
- reassociation);
- };
-
- auto collapsedLhs = collapseSingletonDim(lhs);
- auto collapsedRhs = collapseSingletonDim(rhs);
- auto collapsedInit = collapseSingletonDim(init);
-
- auto collapsedOp = rewriter.create<OpTy>(
- loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
- ValueRange{collapsedInit});
- for (auto attr : batchMatmulOp->getAttrs()) {
- if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
- continue;
- collapsedOp->setAttr(attr.getName(), attr.getValue());
- }
-
- if (results.size() < 1) {
- rewriter.replaceOp(batchMatmulOp, collapsedOp);
- } else {
- SmallVector<ReassociationIndices> reassociation({{0, 1}});
- auto resultType = cast<ShapedType>(results[0].getType());
- for (auto i = 2; i < resultType.getRank(); i++)
- reassociation.push_back({i});
- Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
- loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
- rewriter.replaceOp(batchMatmulOp, expandedResult);
- }
-
- return success();
- }
-};
-
-} // namespace
-
-void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
-}
-
-void BatchMatmulTransposeAOp::getCanonicalizationPatterns(
- RewritePatternSet &results, MLIRContext *context) {
- results.add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
-}
-
-void BatchMatmulTransposeBOp::getCanonicalizationPatterns(
- RewritePatternSet &results, MLIRContext *context) {
- results.add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
-}
-
-void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
-}
-
-void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
-}
-
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c0829397f1f85..9248710d5afc9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -812,6 +812,103 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
patterns.add<MoveInitOperandsToInput>(patterns.getContext());
}
+namespace {
+
+template <typename BatchOpTy, typename OpTy>
+struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
+ using OpRewritePattern<BatchOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+ PatternRewriter &rewriter) const override {
+
+ auto loc = batchMatmulOp.getLoc();
+ auto inputs = batchMatmulOp.getDpsInputs();
+ auto inits = batchMatmulOp.getDpsInits();
+ if (inputs.size() != 2 || inits.size() != 1)
+ return rewriter.notifyMatchFailure(batchMatmulOp,
+ "expected 2 inputs and 1 init");
+ auto lhs = inputs[0];
+ auto rhs = inputs[1];
+ auto init = inits[0];
+
+ auto lhsType = cast<ShapedType>(lhs.getType());
+ auto rhsType = cast<ShapedType>(rhs.getType());
+ auto initType = cast<ShapedType>(init.getType());
+ if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
+ initType.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
+
+ auto results = batchMatmulOp.getResults();
+ assert(results.size() < 2 && "expected at most one result");
+
+ SmallVector<Type, 1> resultType;
+ if (results.size() == 1) {
+ auto oldResultType = cast<RankedTensorType>(results[0].getType());
+ resultType.push_back(
+ RankedTensorType::get(oldResultType.getShape().drop_front(1),
+ oldResultType.getElementType()));
+ }
+
+ auto collapseSingletonDim = [&](Value val) -> Value {
+ SmallVector<ReassociationIndices> reassociation({{0, 1}});
+ auto valType = cast<ShapedType>(val.getType());
+ for (auto i = 2; i < valType.getRank(); i++)
+ reassociation.push_back({i});
+ if (isa<RankedTensorType>(valType)) {
+ RankedTensorType collapsedType = RankedTensorType::get(
+ valType.getShape().drop_front(1), valType.getElementType());
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
+ reassociation);
+ }
+ MemRefType collapsedType = MemRefType::get(
+ valType.getShape().drop_front(1), valType.getElementType());
+ return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
+ reassociation);
+ };
+
+ auto collapsedLhs = collapseSingletonDim(lhs);
+ auto collapsedRhs = collapseSingletonDim(rhs);
+ auto collapsedInit = collapseSingletonDim(init);
+
+ auto collapsedOp = rewriter.create<OpTy>(
+ loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+ ValueRange{collapsedInit});
+ for (auto attr : batchMatmulOp->getAttrs()) {
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ continue;
+ collapsedOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ if (results.size() < 1) {
+ rewriter.replaceOp(batchMatmulOp, collapsedOp);
+ } else {
+ SmallVector<ReassociationIndices> reassociation({{0, 1}});
+ auto resultType = cast<ShapedType>(results[0].getType());
+ for (auto i = 2; i < resultType.getRank(); i++)
+ reassociation.push_back({i});
+ Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
+ loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
+ rewriter.replaceOp(batchMatmulOp, expandedResult);
+ }
+
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateContractionOpRankReducingPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
+ patterns
+ .add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+ context);
+ patterns
+ .add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+ context);
+ patterns.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
+ patterns.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+}
+
namespace {
/// Pass that removes unit-extent dims within generic ops.
struct LinalgFoldUnitExtentDimsPass
@@ -833,4 +930,5 @@ struct LinalgFoldUnitExtentDimsPass
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
+
} // namespace
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index b4b36ba0bfe51..43410aaa6af1b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -518,7 +518,6 @@ def batch_matmul(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
- defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -538,7 +537,6 @@ def batch_matmul_transpose_a(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
- defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
@@ -558,7 +556,6 @@ def batch_matmul_transpose_b(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
- defines(Canonicalizer)
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -645,7 +642,6 @@ def batch_matvec(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
- defines(Canonicalizer)
domain(D.b, D.m, D.k)
implements(ContractionOpInterface)
C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
@@ -664,7 +660,6 @@ def batch_vecmat(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
- defines(Canonicalizer)
domain(D.b, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 8514bcb089891..928030a81dc02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
return %0 : tensor<2x3xf32>
}
-// -----
+// ----
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,138 +1096,3 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
func.return %transpose2 : tensor<3x4x5xf32>
}
-// -----
-
-func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
- // CHECK-LABEL: @singleton_batch_matmul_tensor
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
- // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
- // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
- // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
- // CHECK-NEXT: return %[[RES]]
- %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
- outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
- return %1 : tensor<1x?x?xf32>
-}
-
-// -----
-
-func.func @singletone_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
- // CHECK-LABEL: @singletone_batch_matmul_memref
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
- // CHECK-NEXT: return
- linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
- outs(%arg2 : memref<1x?x?xf32>)
- return
-}
-
-// -----
-
-func.func @singletone_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
- // CHECK-LABEL: @singletone_batch_matvec
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
- // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
- // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
- // CHECK-NEXT: return %[[RES]]
- %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
- outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
- return %1 : tensor<1x?xf32>
-}
-
-// -----
-
-func.func @singletone_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
- // CHECK-LABEL: @singletone_batch_vecmat
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
- // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
- // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
- // CHECK-NEXT: return %[[RES]]
- %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
- outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
- return %1 : tensor<1x?xf32>
-}
-
-// -----
-
-func.func @singletone_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
- // CHECK-LABEL: @singletone_batchmatmul_transpose_a
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
- // CHECK-NEXT: return
- linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
- return
-}
-
-// -----
-
-func.func @singletone_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
- // CHECK-LABEL: @singletone_batchmatmul_transpose_b
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
- // CHECK-NEXT: return
- linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
- return
-}
-
-// -----
-
-func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
- // CHECK-LABEL: @nonsingleton_batch_matmul
- // CHECK-NOT: collapse_shape
- // CHECK: linalg.batch_matmul
- // CHECK-NOT: expand_shape
- %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
- outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
- return %1 : tensor<2x?x?xf32>
-}
-
-// -----
-
-func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
- // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
- // CHECK-NOT: collapse_shape
- // CHECK: linalg.batch_matmul
- // CHECK-NOT: expand_shape
- %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
- outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
- return %1 : tensor<?x?x?xf32>
-}
-
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index b28f2b3564662..283e426b4e594 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgDropUnitDims.cpp
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
+ TestLinalgRankReduceContractionOps.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
new file mode 100644
index 0000000000000..5ca27be30a687
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
@@ -0,0 +1,68 @@
+//===- TestLinalgRankReduceContractionOps.cpp - Test Linalg rank reduce
+//contractions ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing rank reduing patterns for named
+// contraction ops with unit dims.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLinalgRankReduceContractionOps
+ : public PassWrapper<TestLinalgRankReduceContractionOps,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestLinalgRankReduceContractionOps)
+
+ TestLinalgRankReduceContractionOps() = default;
+ TestLinalgRankReduceContractionOps(
+ const TestLinalgRankReduceContractionOps &pass)
+ : PassWrapper(pass) {}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<affine::AffineDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, tensor::TensorDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-linalg-rank-reduce-contraction-ops";
+ }
+ StringRef getDescription() const final {
+ return "Test Linalg rank reduce contraction ops with unit dims";
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &this->getContext();
+ func::FuncOp funcOp = this->getOperation();
+
+ RewritePatternSet patterns(context);
+ linalg::populateContractionOpRankReducingPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(patterns))))
+ return signalPassFailure();
+ return;
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgRankReduceContractionOps() {
+ PassRegistration<TestLinalgRankReduceContractionOps>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0e8b161d51345..d4ea7a9cae0d2 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -106,6 +106,7 @@ void registerTestLinalgDecomposeOps();
void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
+void registerTestLinalgRankReduceContractionOps();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
@@ -235,6 +236,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgDropUnitDims();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
+ mlir::test::registerTestLinalgRankReduceContractionOps();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
>From 28078405788b799cf64bcf5a7a4059c0eb739875 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 10:24:39 -0500
Subject: [PATCH 05/14] refactor common logic into abstract base class
---
.../Linalg/Transforms/DropUnitDims.cpp | 230 +++++++++++++-----
1 file changed, 166 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9248710d5afc9..07b0cdea40c92 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -814,10 +814,66 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
namespace {
-template <typename BatchOpTy, typename OpTy>
-struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
- using OpRewritePattern<BatchOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp,
+static SmallVector<ReassociationIndices>
+getReassociationsForTrailingDims(int64_t rank) {
+ SmallVector<ReassociationIndices> reassociation(rank - 1, {});
+ if (rank > 1) {
+ reassociation[rank - 2] =
+ (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
+ for (int64_t i = 0; i < rank - 2; i++)
+ reassociation[i] = {i};
+ }
+ return reassociation;
+}
+
+static SmallVector<ReassociationIndices>
+getReassociationsForLeadingDims(int64_t rank) {
+ SmallVector<ReassociationIndices> reassociation(rank - 1, {});
+ if (rank > 1) {
+ reassociation[0] =
+ (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
+ for (int64_t i = 1; i < rank - 1; i++)
+ reassociation[i] = {i + rank - 2};
+ }
+ return reassociation;
+}
+
+static Value collapseLeadingSingletonDim(PatternRewriter &rewriter, Value val) {
+ auto valType = cast<ShapedType>(val.getType());
+ return collapseValue(
+ rewriter, val.getLoc(), val, valType.getShape().drop_front(1),
+ getReassociationsForLeadingDims(valType.getRank()),
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+static Value collapseTrailingSingletonDim(PatternRewriter &rewriter,
+ Value val) {
+ auto valType = cast<ShapedType>(val.getType());
+ return collapseValue(
+ rewriter, val.getLoc(), val, valType.getShape().drop_back(1),
+ getReassociationsForTrailingDims(valType.getRank()),
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+static Value expandLeadingSingletonDim(PatternRewriter &rewriter, Value val,
+ RankedTensorType expandedType) {
+ return rewriter.create<tensor::ExpandShapeOp>(
+ val.getLoc(), expandedType, val,
+ getReassociationsForLeadingDims(expandedType.getRank()));
+}
+
+static Value expandTrailingSingletonDim(PatternRewriter &rewriter, Value val,
+ RankedTensorType expandedType) {
+ return rewriter.create<tensor::ExpandShapeOp>(
+ val.getLoc(), expandedType, val,
+ getReassociationsForTrailingDims(expandedType.getRank()));
+}
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
+ using OpRewritePattern<FromOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FromOpTy batchMatmulOp,
PatternRewriter &rewriter) const override {
auto loc = batchMatmulOp.getLoc();
@@ -830,47 +886,19 @@ struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
auto rhs = inputs[1];
auto init = inits[0];
- auto lhsType = cast<ShapedType>(lhs.getType());
- auto rhsType = cast<ShapedType>(rhs.getType());
- auto initType = cast<ShapedType>(init.getType());
- if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 ||
- initType.getShape()[0] != 1)
- return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1");
-
- auto results = batchMatmulOp.getResults();
- assert(results.size() < 2 && "expected at most one result");
-
- SmallVector<Type, 1> resultType;
- if (results.size() == 1) {
- auto oldResultType = cast<RankedTensorType>(results[0].getType());
- resultType.push_back(
- RankedTensorType::get(oldResultType.getShape().drop_front(1),
- oldResultType.getElementType()));
- }
-
- auto collapseSingletonDim = [&](Value val) -> Value {
- SmallVector<ReassociationIndices> reassociation({{0, 1}});
- auto valType = cast<ShapedType>(val.getType());
- for (auto i = 2; i < valType.getRank(); i++)
- reassociation.push_back({i});
- if (isa<RankedTensorType>(valType)) {
- RankedTensorType collapsedType = RankedTensorType::get(
- valType.getShape().drop_front(1), valType.getElementType());
- return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, val,
- reassociation);
- }
- MemRefType collapsedType = MemRefType::get(
- valType.getShape().drop_front(1), valType.getElementType());
- return rewriter.create<memref::CollapseShapeOp>(loc, collapsedType, val,
- reassociation);
- };
-
- auto collapsedLhs = collapseSingletonDim(lhs);
- auto collapsedRhs = collapseSingletonDim(rhs);
- auto collapsedInit = collapseSingletonDim(init);
-
- auto collapsedOp = rewriter.create<OpTy>(
- loc, resultType, ValueRange{collapsedLhs, collapsedRhs},
+ if (!checkTypes(lhs, rhs, init))
+ return rewriter.notifyMatchFailure(batchMatmulOp,
+ "no reducable dims found");
+
+ auto collapsedOperands = collapseOperands(rewriter, lhs, rhs, init);
+ auto collapsedLhs = collapsedOperands[0];
+ auto collapsedRhs = collapsedOperands[1];
+ auto collapsedInit = collapsedOperands[2];
+ SmallVector<Type, 1> collapsedResultTy;
+ if (isa<RankedTensorType>(collapsedInit.getType()))
+ collapsedResultTy.push_back(collapsedInit.getType());
+ auto collapsedOp = rewriter.create<ToOpTy>(
+ loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : batchMatmulOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
@@ -878,35 +906,109 @@ struct BatchMatmulToMatmul : OpRewritePattern<BatchOpTy> {
collapsedOp->setAttr(attr.getName(), attr.getValue());
}
- if (results.size() < 1) {
+ auto results = batchMatmulOp.getResults();
+ assert(results.size() < 2 && "expected at most one result");
+ if (results.size() < 1)
rewriter.replaceOp(batchMatmulOp, collapsedOp);
- } else {
- SmallVector<ReassociationIndices> reassociation({{0, 1}});
- auto resultType = cast<ShapedType>(results[0].getType());
- for (auto i = 2; i < resultType.getRank(); i++)
- reassociation.push_back({i});
- Value expandedResult = rewriter.create<tensor::ExpandShapeOp>(
- loc, resultType, collapsedOp.getResultTensors()[0], reassociation);
- rewriter.replaceOp(batchMatmulOp, expandedResult);
- }
+ else
+ rewriter.replaceOp(
+ batchMatmulOp,
+ expandResult(rewriter, collapsedOp.getResultTensors()[0],
+ cast<RankedTensorType>(results[0].getType())));
return success();
}
+
+ virtual bool checkTypes(Value lhs, Value rhs, Value init) const = 0;
+ virtual SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter,
+ Value lhs, Value rhs,
+ Value init) const = 0;
+ virtual Value expandResult(PatternRewriter &rewriter, Value result,
+ RankedTensorType expandedType) const = 0;
+};
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+ bool checkTypes(Value lhs, Value rhs, Value init) const override {
+ auto lhsType = cast<ShapedType>(lhs.getType());
+ auto rhsType = cast<ShapedType>(rhs.getType());
+ auto initType = cast<ShapedType>(init.getType());
+ return lhsType.getShape()[0] == 1 && rhsType.getShape()[0] == 1 &&
+ initType.getShape()[0] == 1;
+ }
+
+ SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
+ Value rhs, Value init) const override {
+ auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
+ auto collapsedRhs = collapseLeadingSingletonDim(rewriter, rhs);
+ auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
+ return SmallVector<Value, 3>{collapsedLhs, collapsedRhs, collapsedInit};
+ }
+ Value expandResult(PatternRewriter &rewriter, Value result,
+ RankedTensorType expandedType) const override {
+ return expandLeadingSingletonDim(rewriter, result, expandedType);
+ }
+};
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+ static bool constexpr reduceLeading =
+ (std::is_same<FromOpTy, MatmulOp>::value &&
+ std::is_same<ToOpTy, VecmatOp>::value) ||
+ (std::is_same<FromOpTy, MatvecOp>::value &&
+ std::is_same<ToOpTy, DotOp>::value);
+
+ bool checkTypes(Value lhs, Value rhs, Value init) const override {
+ auto lhsType = cast<ShapedType>(lhs.getType());
+ auto rhsType = cast<ShapedType>(rhs.getType());
+ auto initType = cast<ShapedType>(init.getType());
+ if (reduceLeading)
+ return lhsType.getShape()[0] == 1 && initType.getShape()[0] == 1;
+ else
+ return rhsType.getShape().back() == 1 && initType.getShape().back() == 1;
+ }
+
+ SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
+ Value rhs, Value init) const override {
+ if (reduceLeading) {
+ auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
+ auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
+ return SmallVector<Value, 3>{collapsedLhs, rhs, collapsedInit};
+ } else {
+ auto collapsedRhs = collapseTrailingSingletonDim(rewriter, rhs);
+ auto collapsedInit = collapseTrailingSingletonDim(rewriter, init);
+ return SmallVector<Value, 3>{lhs, collapsedRhs, collapsedInit};
+ }
+ }
+ Value expandResult(PatternRewriter &rewriter, Value result,
+ RankedTensorType expandedType) const override {
+ if (reduceLeading)
+ return expandLeadingSingletonDim(rewriter, result, expandedType);
+ else
+ return expandTrailingSingletonDim(rewriter, result, expandedType);
+ }
};
+
} // namespace
void mlir::linalg::populateContractionOpRankReducingPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add<BatchMatmulToMatmul<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<BatchMatmulToMatmul<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<BatchMatmulToMatmul<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
- patterns.add<BatchMatmulToMatmul<BatchMatvecOp, MatvecOp>>(context);
- patterns.add<BatchMatmulToMatmul<BatchVecmatOp, VecmatOp>>(context);
+ patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
+ patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+ context);
+ patterns.add<RankReduceBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+ context);
+ patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
+ patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
+ patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
+ patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
+ patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
+ patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
}
namespace {
>From 679192b56ac231fce54824d0189a52f39ea2a63b Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 17:07:46 -0500
Subject: [PATCH 06/14] add regression test
---
.../Linalg/rank-reduce-contraction-ops.mlir | 197 ++++++++++++++++++
1 file changed, 197 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
new file mode 100644
index 0000000000000..279a1d52ae72b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -0,0 +1,197 @@
+//RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s
+
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+ // CHECK-LABEL: @singleton_batch_matmul_tensor
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+ outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
+ return %1 : tensor<1x?x?xf32>
+}
+
+// -----
+
+func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+ // CHECK-LABEL: @singleton_batch_matmul_memref
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+ outs(%arg2 : memref<1x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @singleton_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK-LABEL: @singleton_batch_matvec
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
+ outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK-LABEL: @singleton_batch_vecmat
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
+ outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
+ // CHECK-LABEL: @singleton_batchmatmul_transpose_a
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
+ // CHECK-LABEL: @singleton_batchmatmul_transpose_b
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @matmul_to_vecmat(%arg0: memref<1x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<1x?xf32>) {
+ // CHECK-LABEL: @matmul_to_vecmat
+ // CHECK: linalg.vecmat
+ linalg.matmul ins(%arg0, %arg1: memref<1x?xf32>, memref<?x?xf32>) outs(%arg2: memref<1x?xf32>)
+ return
+}
+
+// -----
+
+func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) {
+ // CHECK-LABEL: @batch_matmul_to_vecmat
+ // CHECK: linalg.vecmat
+ linalg.batch_matmul ins(%arg0, %arg1: memref<1x1x?xf32>, memref<1x?x?xf32>) outs(%arg2: memref<1x1x?xf32>)
+ return
+}
+
+// -----
+
+func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref<?xf32>, %arg2: memref<1xf32>) {
+ // CHECK-LABEL: @matvec_to_dot
+ // CHECK: linalg.dot
+ linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref<?xf32>) outs(%arg2: memref<1xf32>)
+ return
+}
+
+// -----
+
+func.func @vecmat_to_dot(%arg0: memref<?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<1xf32>) {
+ // CHECK-LABEL: @vecmat_to_dot
+ // CHECK: linalg.dot
+ linalg.vecmat ins(%arg0, %arg1: memref<?xf32>, memref<?x1xf32>) outs(%arg2: memref<1xf32>)
+ return
+}
+
+// -----
+
+func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK-LABEL: @matvec_to_dot_tensor
+ // CHECK: linalg.dot
+ %0 = linalg.matvec ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?xf32>) outs(%arg2: tensor<1xf32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
+ // CHECK-LABEL: @matmul_to_matvec_tensor
+ // CHECK: linalg.matvec
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
+ return %0 : tensor<?x1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+ // CHECK-LABEL: @matmul_to_matvec
+ // CHECK: linalg.matvec
+ linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
+ return
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
+ // CHECK-LABEL: @nonsingleton_batch_matmul
+ // CHECK-NOT: collapse_shape
+ // CHECK: linalg.batch_matmul
+ // CHECK-NOT: expand_shape
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+ outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+ return %1 : tensor<2x?x?xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
+ // CHECK-NOT: collapse_shape
+ // CHECK: linalg.batch_matmul
+ // CHECK-NOT: expand_shape
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
>From ce02e9d206cf718927854338636b6f357c62101c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 18:13:39 -0500
Subject: [PATCH 07/14] flesh out some tests
---
.../Linalg/rank-reduce-contraction-ops.mlir | 69 ++++++++++++-------
1 file changed, 46 insertions(+), 23 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 279a1d52ae72b..79003670d2726 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -111,15 +111,51 @@ func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: me
// -----
-func.func @matmul_to_vecmat(%arg0: memref<1x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<1x?xf32>) {
- // CHECK-LABEL: @matmul_to_vecmat
- // CHECK: linalg.vecmat
- linalg.matmul ins(%arg0, %arg1: memref<1x?xf32>, memref<?x?xf32>) outs(%arg2: memref<1x?xf32>)
+func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
+ // CHECK-LABEL: @matmul_to_matvec_tensor
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
+ // CHECK-NEXT: return %[[RES]]
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
+ return %0 : tensor<?x1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+ // CHECK-LABEL: @matmul_to_matvec
+ // CHECK: linalg.matvec
+ linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
return
}
// -----
+func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK-LABEL: @matmul_to_vecmat
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: return %[[RES]]
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %0 : tensor<1x?xf32>
+}
+
+// -----
+
func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) {
// CHECK-LABEL: @batch_matmul_to_vecmat
// CHECK: linalg.vecmat
@@ -131,7 +167,12 @@ func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?x
func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref<?xf32>, %arg2: memref<1xf32>) {
// CHECK-LABEL: @matvec_to_dot
- // CHECK: linalg.dot
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] []
+ // CHECK-NEXT: linalg.dot ins(%[[COLLAPSED_LHS]], %[[RHS]] : memref<?xf32>, memref<?xf32>) outs(%[[COLLAPSED_INIT]] : memref<f32>)
linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref<?xf32>) outs(%arg2: memref<1xf32>)
return
}
@@ -156,24 +197,6 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
// -----
-func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
- // CHECK-LABEL: @matmul_to_matvec_tensor
- // CHECK: linalg.matvec
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
- return %0 : tensor<?x1xf32>
-}
-
-// -----
-
-func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
- // CHECK-LABEL: @matmul_to_matvec
- // CHECK: linalg.matvec
- linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
- return
-}
-
-// -----
-
func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
// CHECK-LABEL: @nonsingleton_batch_matmul
// CHECK-NOT: collapse_shape
>From 9b98efdf911f806f3ae15f2f9e74d99af1d8775e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 20:13:19 -0500
Subject: [PATCH 08/14] support transpose matmul conversion
---
.../Linalg/Transforms/DropUnitDims.cpp | 119 ++++++++++--------
.../Linalg/rank-reduce-contraction-ops.mlir | 18 +++
2 files changed, 84 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 07b0cdea40c92..d9230c6127e00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -812,6 +812,30 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
patterns.add<MoveInitOperandsToInput>(patterns.getContext());
}
+namespace {
+/// Pass that removes unit-extent dims within generic ops.
+struct LinalgFoldUnitExtentDimsPass
+ : public impl::LinalgFoldUnitExtentDimsPassBase<
+ LinalgFoldUnitExtentDimsPass> {
+ using impl::LinalgFoldUnitExtentDimsPassBase<
+ LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *context = op->getContext();
+ RewritePatternSet patterns(context);
+ ControlDropUnitDims options;
+ if (useRankReducingSlices) {
+ options.rankReductionStrategy = linalg::ControlDropUnitDims::
+ RankReductionStrategy::ExtractInsertSlice;
+ }
+ linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
+ populateMoveInitOperandsToInputPattern(patterns);
+ (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+ }
+};
+
+} // namespace
+
namespace {
static SmallVector<ReassociationIndices>
@@ -855,20 +879,6 @@ static Value collapseTrailingSingletonDim(PatternRewriter &rewriter,
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}
-static Value expandLeadingSingletonDim(PatternRewriter &rewriter, Value val,
- RankedTensorType expandedType) {
- return rewriter.create<tensor::ExpandShapeOp>(
- val.getLoc(), expandedType, val,
- getReassociationsForLeadingDims(expandedType.getRank()));
-}
-
-static Value expandTrailingSingletonDim(PatternRewriter &rewriter, Value val,
- RankedTensorType expandedType) {
- return rewriter.create<tensor::ExpandShapeOp>(
- val.getLoc(), expandedType, val,
- getReassociationsForTrailingDims(expandedType.getRank()));
-}
-
template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
using OpRewritePattern<FromOpTy>::OpRewritePattern;
@@ -948,7 +958,9 @@ struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
}
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType) const override {
- return expandLeadingSingletonDim(rewriter, result, expandedType);
+ return rewriter.create<tensor::ExpandShapeOp>(
+ result.getLoc(), expandedType, result,
+ getReassociationsForLeadingDims(expandedType.getRank()));
}
};
@@ -956,9 +968,15 @@ template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
- static bool constexpr reduceLeading =
+ static bool constexpr isTranspose =
+ std::is_same<FromOpTy, MatmulTransposeAOp>::value ||
+ std::is_same<FromOpTy, MatmulTransposeBOp>::value;
+
+ static bool constexpr reduceLeft =
(std::is_same<FromOpTy, MatmulOp>::value &&
std::is_same<ToOpTy, VecmatOp>::value) ||
+ (std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
+ std::is_same<ToOpTy, VecmatOp>::value) ||
(std::is_same<FromOpTy, MatvecOp>::value &&
std::is_same<ToOpTy, DotOp>::value);
@@ -966,30 +984,47 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
auto lhsType = cast<ShapedType>(lhs.getType());
auto rhsType = cast<ShapedType>(rhs.getType());
auto initType = cast<ShapedType>(init.getType());
- if (reduceLeading)
- return lhsType.getShape()[0] == 1 && initType.getShape()[0] == 1;
+ int constexpr offset = (int)isTranspose;
+ if (reduceLeft)
+ return lhsType.getShape().begin()[offset] == 1 &&
+ initType.getShape().begin()[offset] == 1;
else
- return rhsType.getShape().back() == 1 && initType.getShape().back() == 1;
+ return rhsType.getShape().rbegin()[offset] == 1 &&
+ initType.getShape().rbegin()[offset] == 1;
}
SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
Value rhs, Value init) const override {
- if (reduceLeading) {
- auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
- auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
- return SmallVector<Value, 3>{collapsedLhs, rhs, collapsedInit};
+ if (reduceLeft) {
+ if (isTranspose) {
+ lhs = collapseTrailingSingletonDim(rewriter, lhs);
+ init = collapseTrailingSingletonDim(rewriter, init);
+ } else {
+ lhs = collapseLeadingSingletonDim(rewriter, lhs);
+ init = collapseLeadingSingletonDim(rewriter, init);
+ }
} else {
- auto collapsedRhs = collapseTrailingSingletonDim(rewriter, rhs);
- auto collapsedInit = collapseTrailingSingletonDim(rewriter, init);
- return SmallVector<Value, 3>{lhs, collapsedRhs, collapsedInit};
+ if (isTranspose) {
+ rhs = collapseLeadingSingletonDim(rewriter, rhs);
+ init = collapseLeadingSingletonDim(rewriter, init);
+ } else {
+ rhs = collapseTrailingSingletonDim(rewriter, rhs);
+ init = collapseTrailingSingletonDim(rewriter, init);
+ }
}
+ return SmallVector<Value, 3>{lhs, rhs, init};
}
+
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType) const override {
- if (reduceLeading)
- return expandLeadingSingletonDim(rewriter, result, expandedType);
+ if (reduceLeft)
+ return rewriter.create<tensor::ExpandShapeOp>(
+ result.getLoc(), expandedType, result,
+ getReassociationsForLeadingDims(expandedType.getRank()));
else
- return expandTrailingSingletonDim(rewriter, result, expandedType);
+ return rewriter.create<tensor::ExpandShapeOp>(
+ result.getLoc(), expandedType, result,
+ getReassociationsForTrailingDims(expandedType.getRank()));
}
};
@@ -1007,30 +1042,8 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
+ patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
+ patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
}
-
-namespace {
-/// Pass that removes unit-extent dims within generic ops.
-struct LinalgFoldUnitExtentDimsPass
- : public impl::LinalgFoldUnitExtentDimsPassBase<
- LinalgFoldUnitExtentDimsPass> {
- using impl::LinalgFoldUnitExtentDimsPassBase<
- LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
- void runOnOperation() override {
- Operation *op = getOperation();
- MLIRContext *context = op->getContext();
- RewritePatternSet patterns(context);
- ControlDropUnitDims options;
- if (useRankReducingSlices) {
- options.rankReductionStrategy = linalg::ControlDropUnitDims::
- RankReductionStrategy::ExtractInsertSlice;
- }
- linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
- populateMoveInitOperandsToInputPattern(patterns);
- (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
- }
-};
-
-} // namespace
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 79003670d2726..0548f3f860a89 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -197,6 +197,24 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
// -----
+func.func @matmul_transpose_a_to_vecmat(%arg0: memref<?x1xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x1xf32>) {
+ // CHECK-LABEL: @matmul_transpose_a_to_vecmat
+ // CHECK: linalg.vecmat
+ linalg.matmul_transpose_a ins(%arg0, %arg1: memref<?x1xf32>, memref<?x?xf32>) outs(%arg2: memref<?x1xf32>)
+ return
+}
+
+// -----
+
+func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) {
+ // CHECK-LABEL: @matmul_transpose_b_to_matvec
+ // CHECK: linalg.matvec
+ linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<1x?xf32>)
+ return
+}
+
+// -----
+
func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
// CHECK-LABEL: @nonsingleton_batch_matmul
// CHECK-NOT: collapse_shape
>From cbf8eddc377b7970cd12846a8d95ea740b5e3bf1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 19 Jun 2024 20:28:48 -0500
Subject: [PATCH 09/14] const conditionals
---
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d9230c6127e00..327c46b965c87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -985,7 +985,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
auto rhsType = cast<ShapedType>(rhs.getType());
auto initType = cast<ShapedType>(init.getType());
int constexpr offset = (int)isTranspose;
- if (reduceLeft)
+ if constexpr (reduceLeft)
return lhsType.getShape().begin()[offset] == 1 &&
initType.getShape().begin()[offset] == 1;
else
@@ -995,8 +995,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
Value rhs, Value init) const override {
- if (reduceLeft) {
- if (isTranspose) {
+ if constexpr (reduceLeft) {
+ if constexpr (isTranspose) {
lhs = collapseTrailingSingletonDim(rewriter, lhs);
init = collapseTrailingSingletonDim(rewriter, init);
} else {
@@ -1004,7 +1004,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
init = collapseLeadingSingletonDim(rewriter, init);
}
} else {
- if (isTranspose) {
+ if constexpr (isTranspose) {
rhs = collapseLeadingSingletonDim(rewriter, rhs);
init = collapseLeadingSingletonDim(rewriter, init);
} else {
@@ -1017,7 +1017,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType) const override {
- if (reduceLeft)
+ if constexpr (reduceLeft)
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
getReassociationsForLeadingDims(expandedType.getRank()));
>From 29ac128442379767bca2583409f58bd605bb06d1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 20 Jun 2024 14:19:38 -0500
Subject: [PATCH 10/14] refactor and add more patterns
---
.../Linalg/Transforms/DropUnitDims.cpp | 143 +++++++++++-------
1 file changed, 85 insertions(+), 58 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 327c46b965c87..26a85e8225b24 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -839,43 +839,31 @@ struct LinalgFoldUnitExtentDimsPass
namespace {
static SmallVector<ReassociationIndices>
-getReassociationsForTrailingDims(int64_t rank) {
- SmallVector<ReassociationIndices> reassociation(rank - 1, {});
- if (rank > 1) {
- reassociation[rank - 2] =
- (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
- for (int64_t i = 0; i < rank - 2; i++)
- reassociation[i] = {i};
- }
- return reassociation;
-}
-
-static SmallVector<ReassociationIndices>
-getReassociationsForLeadingDims(int64_t rank) {
- SmallVector<ReassociationIndices> reassociation(rank - 1, {});
- if (rank > 1) {
- reassociation[0] =
- (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
- for (int64_t i = 1; i < rank - 1; i++)
- reassociation[i] = {i + rank - 2};
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos,
+ bool fromRight = false) {
+ SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+ if (rank > 2) {
+ int64_t offsetPos = pos - (int64_t)fromRight;
+ for (int64_t i = 0; i < rank - 1; i++) {
+ if (i == offsetPos)
+ reassociation[i] = ReassociationIndices{i, i + 1};
+ else if (i < offsetPos)
+ reassociation[i] = ReassociationIndices{i};
+ else
+ reassociation[i] = ReassociationIndices{i + offsetPos + 1};
+ }
}
return reassociation;
}
-static Value collapseLeadingSingletonDim(PatternRewriter &rewriter, Value val) {
+static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
+ int64_t pos, bool fromRight = false) {
auto valType = cast<ShapedType>(val.getType());
+ SmallVector<int64_t> collapsedShape(valType.getShape());
+ collapsedShape.erase(collapsedShape.begin() + pos);
return collapseValue(
- rewriter, val.getLoc(), val, valType.getShape().drop_front(1),
- getReassociationsForLeadingDims(valType.getRank()),
- ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
-}
-
-static Value collapseTrailingSingletonDim(PatternRewriter &rewriter,
- Value val) {
- auto valType = cast<ShapedType>(val.getType());
- return collapseValue(
- rewriter, val.getLoc(), val, valType.getShape().drop_back(1),
- getReassociationsForTrailingDims(valType.getRank()),
+ rewriter, val.getLoc(), val, collapsedShape,
+ getReassociationForReshapeAtDim(valType.getRank(), pos, fromRight),
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}
@@ -951,16 +939,16 @@ struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
Value rhs, Value init) const override {
- auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
- auto collapsedRhs = collapseLeadingSingletonDim(rewriter, rhs);
- auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
+ auto collapsedLhs = collapseSingletonDimAt(rewriter, lhs, 0);
+ auto collapsedRhs = collapseSingletonDimAt(rewriter, rhs, 0);
+ auto collapsedInit = collapseSingletonDimAt(rewriter, init, 0);
return SmallVector<Value, 3>{collapsedLhs, collapsedRhs, collapsedInit};
}
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType) const override {
- return rewriter.create<tensor::ExpandShapeOp>(
- result.getLoc(), expandedType, result,
- getReassociationsForLeadingDims(expandedType.getRank()));
+ return rewriter.create<tensor::ExpandShapeOp>(
+ result.getLoc(), expandedType, result,
+ getReassociationForReshapeAtDim(expandedType.getRank(), 0));
}
};
@@ -968,11 +956,32 @@ template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+ static bool constexpr isBatched =
+ std::is_same<FromOpTy, BatchMatmulOp>::value ||
+ std::is_same<FromOpTy, BatchMatvecOp>::value ||
+ std::is_same<FromOpTy, BatchVecmatOp>::value ||
+ std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
+ std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value;
+
+ static bool constexpr isLHSTransposed =
+ std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
+ std::is_same<FromOpTy, MatmulTransposeAOp>::value;
+
+ static bool constexpr isRHSTransposed =
+ std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
+ std::is_same<FromOpTy, MatmulTransposeBOp>::value;
+
static bool constexpr isTranspose =
+ std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
+ std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
std::is_same<FromOpTy, MatmulTransposeAOp>::value ||
std::is_same<FromOpTy, MatmulTransposeBOp>::value;
static bool constexpr reduceLeft =
+ (std::is_same<FromOpTy, BatchMatmulOp>::value &&
+ std::is_same<ToOpTy, BatchVecmatOp>::value) ||
+ (std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value &&
+ std::is_same<ToOpTy, BatchVecmatOp>::value) ||
(std::is_same<FromOpTy, MatmulOp>::value &&
std::is_same<ToOpTy, VecmatOp>::value) ||
(std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
@@ -980,37 +989,41 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
(std::is_same<FromOpTy, MatvecOp>::value &&
std::is_same<ToOpTy, DotOp>::value);
+ static int constexpr lhsTransposeOffset = (int)isLHSTransposed;
+ static int constexpr rhsTransposeOffset = (int)isRHSTransposed;
+ static int constexpr batchOffset = (int)isBatched;
+
bool checkTypes(Value lhs, Value rhs, Value init) const override {
auto lhsType = cast<ShapedType>(lhs.getType());
auto rhsType = cast<ShapedType>(rhs.getType());
auto initType = cast<ShapedType>(init.getType());
- int constexpr offset = (int)isTranspose;
if constexpr (reduceLeft)
- return lhsType.getShape().begin()[offset] == 1 &&
- initType.getShape().begin()[offset] == 1;
+ return lhsType.getShape().begin()[lhsTransposeOffset + batchOffset] ==
+ 1 &&
+ initType.getShape().begin()[lhsTransposeOffset + batchOffset] == 1;
else
- return rhsType.getShape().rbegin()[offset] == 1 &&
- initType.getShape().rbegin()[offset] == 1;
+ return rhsType.getShape().rbegin()[rhsTransposeOffset] == 1 &&
+ initType.getShape().rbegin()[rhsTransposeOffset] == 1;
}
SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
Value rhs, Value init) const override {
+
if constexpr (reduceLeft) {
- if constexpr (isTranspose) {
- lhs = collapseTrailingSingletonDim(rewriter, lhs);
- init = collapseTrailingSingletonDim(rewriter, init);
- } else {
- lhs = collapseLeadingSingletonDim(rewriter, lhs);
- init = collapseLeadingSingletonDim(rewriter, init);
- }
+ lhs = collapseSingletonDimAt(rewriter, lhs,
+ lhsTransposeOffset + batchOffset,
+ /*fromRight=*/isLHSTransposed);
+ init = collapseSingletonDimAt(rewriter, init,
+ lhsTransposeOffset + batchOffset,
+ /*fromRight*/ isLHSTransposed);
} else {
- if constexpr (isTranspose) {
- rhs = collapseLeadingSingletonDim(rewriter, rhs);
- init = collapseLeadingSingletonDim(rewriter, init);
- } else {
- rhs = collapseTrailingSingletonDim(rewriter, rhs);
- init = collapseTrailingSingletonDim(rewriter, init);
- }
+ auto rhsRank = cast<ShapedType>(rhs.getType()).getRank();
+ auto initRank = cast<ShapedType>(init.getType()).getRank();
+ rhs = collapseSingletonDimAt(
+ rewriter, rhs, rhsRank - rhsTransposeOffset - 1, /*fromRight=*/true);
+ init = collapseSingletonDimAt(rewriter, init,
+ initRank - rhsTransposeOffset - 1,
+ /*fromRight=*/true);
}
return SmallVector<Value, 3>{lhs, rhs, init};
}
@@ -1020,11 +1033,13 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
if constexpr (reduceLeft)
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
- getReassociationsForLeadingDims(expandedType.getRank()));
+ getReassociationForReshapeAtDim(expandedType.getRank(), 0));
else
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
- getReassociationsForTrailingDims(expandedType.getRank()));
+ getReassociationForReshapeAtDim(expandedType.getRank(),
+ expandedType.getRank() - 1,
+ /*fromRight=*/true));
}
};
@@ -1033,6 +1048,7 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
void mlir::linalg::populateContractionOpRankReducingPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
+ // Unbatching patterns for unit batch size
patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
context);
@@ -1040,10 +1056,21 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
context);
patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
+
+ // Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
+ // Batch rank 1 reducing patterns
+ patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
+ patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
+ patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
+ context);
+ patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
+ context);
+
+ // Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
}
>From 7851ae12b78d9b4fbb947de3a6c9fa1654fe0ab6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 20 Jun 2024 19:13:36 -0500
Subject: [PATCH 11/14] more refactor
---
.../Linalg/Transforms/DropUnitDims.cpp | 229 +++++++++---------
.../Linalg/rank-reduce-contraction-ops.mlir | 55 ++---
2 files changed, 141 insertions(+), 143 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 26a85e8225b24..771b40d6a2001 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -839,31 +839,32 @@ struct LinalgFoldUnitExtentDimsPass
namespace {
static SmallVector<ReassociationIndices>
-getReassociationForReshapeAtDim(int64_t rank, int64_t pos,
- bool fromRight = false) {
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+ auto lastDim = pos == rank - 1;
if (rank > 2) {
- int64_t offsetPos = pos - (int64_t)fromRight;
for (int64_t i = 0; i < rank - 1; i++) {
- if (i == offsetPos)
+ if (i == pos || (lastDim && i == pos - 1))
reassociation[i] = ReassociationIndices{i, i + 1};
- else if (i < offsetPos)
+ else if (i < pos)
reassociation[i] = ReassociationIndices{i};
else
- reassociation[i] = ReassociationIndices{i + offsetPos + 1};
+ reassociation[i] = ReassociationIndices{i + 1};
}
}
return reassociation;
}
static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
- int64_t pos, bool fromRight = false) {
+ int64_t pos) {
+ if (pos < 0)
+ return val;
auto valType = cast<ShapedType>(val.getType());
SmallVector<int64_t> collapsedShape(valType.getShape());
collapsedShape.erase(collapsedShape.begin() + pos);
return collapseValue(
rewriter, val.getLoc(), val, collapsedShape,
- getReassociationForReshapeAtDim(valType.getRank(), pos, fromRight),
+ getReassociationForReshapeAtDim(valType.getRank(), pos),
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}
@@ -871,24 +872,52 @@ template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
using OpRewritePattern<FromOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(FromOpTy batchMatmulOp,
+ SmallVector<Value, 3>
+ collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
+ ArrayRef<int64_t> operandCollapseDims) const {
+ assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
+ "expected 3 operands and dims");
+ return llvm::to_vector(llvm::map_range(
+ llvm::zip(operands, operandCollapseDims), [&](auto pair) {
+ return collapseSingletonDimAt(rewriter, std::get<0>(pair),
+ std::get<1>(pair));
+ }));
+ }
+
+ Value expandResult(PatternRewriter &rewriter, Value result,
+ RankedTensorType expandedType, int64_t dim) const {
+ return rewriter.create<tensor::ExpandShapeOp>(
+ result.getLoc(), expandedType, result,
+ getReassociationForReshapeAtDim(expandedType.getRank(), dim));
+ }
+
+ LogicalResult matchAndRewrite(FromOpTy contractionOp,
PatternRewriter &rewriter) const override {
- auto loc = batchMatmulOp.getLoc();
- auto inputs = batchMatmulOp.getDpsInputs();
- auto inits = batchMatmulOp.getDpsInits();
+ auto loc = contractionOp.getLoc();
+ auto inputs = contractionOp.getDpsInputs();
+ auto inits = contractionOp.getDpsInits();
if (inputs.size() != 2 || inits.size() != 1)
- return rewriter.notifyMatchFailure(batchMatmulOp,
+ return rewriter.notifyMatchFailure(contractionOp,
"expected 2 inputs and 1 init");
auto lhs = inputs[0];
auto rhs = inputs[1];
auto init = inits[0];
+ SmallVector<Value> operands{lhs, rhs, init};
- if (!checkTypes(lhs, rhs, init))
- return rewriter.notifyMatchFailure(batchMatmulOp,
+ auto maybeContractionDims = inferContractionDims(contractionOp);
+ if (failed(maybeContractionDims))
+ return rewriter.notifyMatchFailure(contractionOp,
+ "could not infer contraction dims");
+
+ auto contractionDims = maybeContractionDims.value();
+ SmallVector<int64_t> operandUnitDims;
+ if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
+ return rewriter.notifyMatchFailure(contractionOp,
"no reducable dims found");
- auto collapsedOperands = collapseOperands(rewriter, lhs, rhs, init);
+ auto collapsedOperands =
+ collapseOperands(rewriter, operands, operandUnitDims);
auto collapsedLhs = collapsedOperands[0];
auto collapsedRhs = collapsedOperands[1];
auto collapsedInit = collapsedOperands[2];
@@ -898,57 +927,63 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
auto collapsedOp = rewriter.create<ToOpTy>(
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
- for (auto attr : batchMatmulOp->getAttrs()) {
+ for (auto attr : contractionOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}
- auto results = batchMatmulOp.getResults();
+ auto results = contractionOp.getResults();
assert(results.size() < 2 && "expected at most one result");
if (results.size() < 1)
- rewriter.replaceOp(batchMatmulOp, collapsedOp);
+ rewriter.replaceOp(contractionOp, collapsedOp);
else
rewriter.replaceOp(
- batchMatmulOp,
+ contractionOp,
expandResult(rewriter, collapsedOp.getResultTensors()[0],
- cast<RankedTensorType>(results[0].getType())));
+ cast<RankedTensorType>(results[0].getType()),
+ operandUnitDims[2]));
return success();
}
- virtual bool checkTypes(Value lhs, Value rhs, Value init) const = 0;
- virtual SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter,
- Value lhs, Value rhs,
- Value init) const = 0;
- virtual Value expandResult(PatternRewriter &rewriter, Value result,
- RankedTensorType expandedType) const = 0;
+ virtual LogicalResult
+ getOperandUnitDims(LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDindices) const = 0;
};
template <typename FromOpTy, typename ToOpTy>
struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
- bool checkTypes(Value lhs, Value rhs, Value init) const override {
- auto lhsType = cast<ShapedType>(lhs.getType());
- auto rhsType = cast<ShapedType>(rhs.getType());
- auto initType = cast<ShapedType>(init.getType());
- return lhsType.getShape()[0] == 1 && rhsType.getShape()[0] == 1 &&
- initType.getShape()[0] == 1;
- }
+ LogicalResult getOperandUnitDims(
+ LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDindices) const override {
+ auto inputs = op.getDpsInputs();
+ auto inits = op.getDpsInits();
+ if (inputs.size() != 2 || inits.size() != 1)
+ return failure();
- SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
- Value rhs, Value init) const override {
- auto collapsedLhs = collapseSingletonDimAt(rewriter, lhs, 0);
- auto collapsedRhs = collapseSingletonDimAt(rewriter, rhs, 0);
- auto collapsedInit = collapseSingletonDimAt(rewriter, init, 0);
- return SmallVector<Value, 3>{collapsedLhs, collapsedRhs, collapsedInit};
- }
- Value expandResult(PatternRewriter &rewriter, Value result,
- RankedTensorType expandedType) const override {
- return rewriter.create<tensor::ExpandShapeOp>(
- result.getLoc(), expandedType, result,
- getReassociationForReshapeAtDim(expandedType.getRank(), 0));
+ auto maybeContractionDims = inferContractionDims(op);
+ if (failed(maybeContractionDims))
+ return failure();
+ auto contractionDims = maybeContractionDims.value();
+
+ if (contractionDims.batch.size() != 1)
+ return failure();
+ auto batchDim = contractionDims.batch[0];
+ SmallVector<std::pair<Value, unsigned>, 2> bOperands;
+ op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
+ if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
+ return cast<ShapedType>(std::get<0>(pair).getType())
+ .getShape()[std::get<1>(pair)] != 1;
+ }))
+ return failure();
+
+ operandUnitDindices = SmallVector<int64_t>{std::get<1>(bOperands[0]),
+ std::get<1>(bOperands[1]),
+ std::get<1>(bOperands[2])};
+ return success();
}
};
@@ -956,27 +991,6 @@ template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
- static bool constexpr isBatched =
- std::is_same<FromOpTy, BatchMatmulOp>::value ||
- std::is_same<FromOpTy, BatchMatvecOp>::value ||
- std::is_same<FromOpTy, BatchVecmatOp>::value ||
- std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
- std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value;
-
- static bool constexpr isLHSTransposed =
- std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
- std::is_same<FromOpTy, MatmulTransposeAOp>::value;
-
- static bool constexpr isRHSTransposed =
- std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
- std::is_same<FromOpTy, MatmulTransposeBOp>::value;
-
- static bool constexpr isTranspose =
- std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value ||
- std::is_same<FromOpTy, BatchMatmulTransposeBOp>::value ||
- std::is_same<FromOpTy, MatmulTransposeAOp>::value ||
- std::is_same<FromOpTy, MatmulTransposeBOp>::value;
-
static bool constexpr reduceLeft =
(std::is_same<FromOpTy, BatchMatmulOp>::value &&
std::is_same<ToOpTy, BatchVecmatOp>::value) ||
@@ -989,57 +1003,44 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
(std::is_same<FromOpTy, MatvecOp>::value &&
std::is_same<ToOpTy, DotOp>::value);
- static int constexpr lhsTransposeOffset = (int)isLHSTransposed;
- static int constexpr rhsTransposeOffset = (int)isRHSTransposed;
- static int constexpr batchOffset = (int)isBatched;
-
- bool checkTypes(Value lhs, Value rhs, Value init) const override {
- auto lhsType = cast<ShapedType>(lhs.getType());
- auto rhsType = cast<ShapedType>(rhs.getType());
- auto initType = cast<ShapedType>(init.getType());
- if constexpr (reduceLeft)
- return lhsType.getShape().begin()[lhsTransposeOffset + batchOffset] ==
- 1 &&
- initType.getShape().begin()[lhsTransposeOffset + batchOffset] == 1;
- else
- return rhsType.getShape().rbegin()[rhsTransposeOffset] == 1 &&
- initType.getShape().rbegin()[rhsTransposeOffset] == 1;
- }
-
- SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
- Value rhs, Value init) const override {
+ LogicalResult getOperandUnitDims(
+ LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDindices) const override {
+ auto maybeContractionDims = inferContractionDims(op);
+ if (failed(maybeContractionDims))
+ return failure();
+ auto contractionDims = maybeContractionDims.value();
if constexpr (reduceLeft) {
- lhs = collapseSingletonDimAt(rewriter, lhs,
- lhsTransposeOffset + batchOffset,
- /*fromRight=*/isLHSTransposed);
- init = collapseSingletonDimAt(rewriter, init,
- lhsTransposeOffset + batchOffset,
- /*fromRight*/ isLHSTransposed);
+ auto m = contractionDims.m[0];
+ SmallVector<std::pair<Value, unsigned>, 2> mOperands;
+ op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
+ if (mOperands.size() != 2)
+ return failure();
+ if (llvm::all_of(mOperands, [](auto pair) {
+ return cast<ShapedType>(std::get<0>(pair).getType())
+ .getShape()[std::get<1>(pair)] == 1;
+ })) {
+ operandUnitDindices = SmallVector<int64_t>{
+ std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])};
+ return success();
+ }
} else {
- auto rhsRank = cast<ShapedType>(rhs.getType()).getRank();
- auto initRank = cast<ShapedType>(init.getType()).getRank();
- rhs = collapseSingletonDimAt(
- rewriter, rhs, rhsRank - rhsTransposeOffset - 1, /*fromRight=*/true);
- init = collapseSingletonDimAt(rewriter, init,
- initRank - rhsTransposeOffset - 1,
- /*fromRight=*/true);
+ auto n = contractionDims.n[0];
+ SmallVector<std::pair<Value, unsigned>, 2> nOperands;
+ op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
+ if (nOperands.size() != 2)
+ return failure();
+ if (llvm::all_of(nOperands, [](auto pair) {
+ return cast<ShapedType>(std::get<0>(pair).getType())
+ .getShape()[std::get<1>(pair)] == 1;
+ })) {
+ operandUnitDindices = SmallVector<int64_t>{
+ -1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])};
+ return success();
+ }
}
- return SmallVector<Value, 3>{lhs, rhs, init};
- }
-
- Value expandResult(PatternRewriter &rewriter, Value result,
- RankedTensorType expandedType) const override {
- if constexpr (reduceLeft)
- return rewriter.create<tensor::ExpandShapeOp>(
- result.getLoc(), expandedType, result,
- getReassociationForReshapeAtDim(expandedType.getRank(), 0));
- else
- return rewriter.create<tensor::ExpandShapeOp>(
- result.getLoc(), expandedType, result,
- getReassociationForReshapeAtDim(expandedType.getRank(),
- expandedType.getRank() - 1,
- /*fromRight=*/true));
+ return failure();
}
};
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 0548f3f860a89..70568be99474e 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -1,23 +1,19 @@
//RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s
-func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
// CHECK-LABEL: @singleton_batch_matmul_tensor
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512x256xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128x256xf32>
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?x?xf32>)
- // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
- // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]]
- // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512x256xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128x256xf32>)
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, 128, 256]
// CHECK-NEXT: return %[[RES]]
- %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
- outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
- return %1 : tensor<1x?x?xf32>
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512x256xf32>)
+ outs(%arg2 : tensor<1x128x256xf32>) -> tensor<1x128x256xf32>
+ return %1 : tensor<1x128x256xf32>
}
// -----
@@ -39,22 +35,20 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr
// -----
-func.func @singleton_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> {
// CHECK-LABEL: @singleton_batch_matvec
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128xf32>
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
- // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
- // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128]
// CHECK-NEXT: return %[[RES]]
- %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>)
- outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
- return %1 : tensor<1x?xf32>
+ %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>)
+ outs(%arg2 : tensor<1x128xf32>) -> tensor<1x128xf32>
+ return %1 : tensor<1x128xf32>
}
// -----
@@ -197,19 +191,22 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
// -----
-func.func @matmul_transpose_a_to_vecmat(%arg0: memref<?x1xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x1xf32>) {
+func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> {
// CHECK-LABEL: @matmul_transpose_a_to_vecmat
+ // CHECK: collapse_shape {{.*}} into tensor<256xf32>
+ // CHECK: collapse_shape {{.*}} into tensor<512xf32>
// CHECK: linalg.vecmat
- linalg.matmul_transpose_a ins(%arg0, %arg1: memref<?x1xf32>, memref<?x?xf32>) outs(%arg2: memref<?x1xf32>)
- return
+ // CHECK: expand_shape {{.*}} into tensor<1x512xf32>
+ %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32>
+ return %0 : tensor<1x512xf32>
}
// -----
-func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) {
+func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
// CHECK-LABEL: @matmul_transpose_b_to_matvec
// CHECK: linalg.matvec
- linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<1x?xf32>)
+ linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<?x1xf32>)
return
}
>From acca39bc577ec6077ef71740b4176ccf65826a05 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 21 Jun 2024 11:14:03 -0500
Subject: [PATCH 12/14] address comments and extra cleanup
---
.../Dialect/Linalg/Transforms/Transforms.h | 4 +-
.../Linalg/Transforms/DropUnitDims.cpp | 126 ++++++++++--------
.../Linalg/rank-reduce-contraction-ops.mlir | 9 ++
.../TestLinalgRankReduceContractionOps.cpp | 3 +-
4 files changed, 84 insertions(+), 58 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c49383c600a57..3682a68b0e2c8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,8 +1692,8 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
-/// Adds patterns that that reduce the rank of named contraction ops that have
-/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`,
+/// Adds patterns that reduce the rank of named contraction ops that have
+/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 771b40d6a2001..e1daeb3ad666e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -838,10 +838,12 @@ struct LinalgFoldUnitExtentDimsPass
namespace {
+/// Returns reassociation indices for collapsing/expanding a
+/// tensor of rank `rank` at position `pos`.
static SmallVector<ReassociationIndices>
getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
- auto lastDim = pos == rank - 1;
+ bool lastDim = pos == rank - 1;
if (rank > 2) {
for (int64_t i = 0; i < rank - 1; i++) {
if (i == pos || (lastDim && i == pos - 1))
@@ -855,6 +857,8 @@ getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
return reassociation;
}
+/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
+/// If `pos < 0`, then don't collapse.
static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
int64_t pos) {
if (pos < 0)
@@ -868,22 +872,30 @@ static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}
+/// Base class for all rank reduction patterns for contraction ops
+/// with unit dimensions. All patterns should convert one named op
+/// to another named op. Intended to reduce only one iteration space dim
+/// at a time.
+/// Reducing multiple dims will happen with recusive application of
+/// pattern rewrites.
template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
using OpRewritePattern<FromOpTy>::OpRewritePattern;
- SmallVector<Value, 3>
+ /// Collapse all collapsable operands.
+ SmallVector<Value>
collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
ArrayRef<int64_t> operandCollapseDims) const {
assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
"expected 3 operands and dims");
- return llvm::to_vector(llvm::map_range(
+ return llvm::map_to_vector(
llvm::zip(operands, operandCollapseDims), [&](auto pair) {
return collapseSingletonDimAt(rewriter, std::get<0>(pair),
std::get<1>(pair));
- }));
+ });
}
+ /// Expand result tensor.
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType, int64_t dim) const {
return rewriter.create<tensor::ExpandShapeOp>(
@@ -905,12 +917,6 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
auto init = inits[0];
SmallVector<Value> operands{lhs, rhs, init};
- auto maybeContractionDims = inferContractionDims(contractionOp);
- if (failed(maybeContractionDims))
- return rewriter.notifyMatchFailure(contractionOp,
- "could not infer contraction dims");
-
- auto contractionDims = maybeContractionDims.value();
SmallVector<int64_t> operandUnitDims;
if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
return rewriter.notifyMatchFailure(contractionOp,
@@ -935,80 +941,89 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
auto results = contractionOp.getResults();
assert(results.size() < 2 && "expected at most one result");
- if (results.size() < 1)
+ if (results.empty()) {
rewriter.replaceOp(contractionOp, collapsedOp);
- else
+ } else {
rewriter.replaceOp(
contractionOp,
expandResult(rewriter, collapsedOp.getResultTensors()[0],
cast<RankedTensorType>(results[0].getType()),
operandUnitDims[2]));
+ }
return success();
}
+ /// Populate `operandUnitDims` with 3 indices indicating the unit dim
+ /// for each operand that should be collapsed in this pattern. If an
+ /// operand shouldn't be collapsed, the index should be negative.
virtual LogicalResult
getOperandUnitDims(LinalgOp op,
- SmallVectorImpl<int64_t> &operandUnitDindices) const = 0;
+ SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
};
+/// Patterns for unbatching batched contraction ops
template <typename FromOpTy, typename ToOpTy>
-struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
+struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
- LogicalResult getOperandUnitDims(
- LinalgOp op,
- SmallVectorImpl<int64_t> &operandUnitDindices) const override {
- auto inputs = op.getDpsInputs();
- auto inits = op.getDpsInits();
- if (inputs.size() != 2 || inits.size() != 1)
- return failure();
-
+ /// Look for unit batch dims to collapse.
+ LogicalResult
+ getOperandUnitDims(LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDims) const override {
auto maybeContractionDims = inferContractionDims(op);
- if (failed(maybeContractionDims))
+ if (failed(maybeContractionDims)) {
+ LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
+ }
auto contractionDims = maybeContractionDims.value();
if (contractionDims.batch.size() != 1)
return failure();
auto batchDim = contractionDims.batch[0];
- SmallVector<std::pair<Value, unsigned>, 2> bOperands;
+ SmallVector<std::pair<Value, unsigned>, 3> bOperands;
op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] != 1;
- }))
+ })) {
+ LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
return failure();
+ }
- operandUnitDindices = SmallVector<int64_t>{std::get<1>(bOperands[0]),
- std::get<1>(bOperands[1]),
- std::get<1>(bOperands[2])};
+ operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
+ std::get<1>(bOperands[1]),
+ std::get<1>(bOperands[2])};
return success();
}
};
+/// Patterns for reducing non-batch dimensions
template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+ /// Helper for determining whether the lhs/init or rhs/init are reduced.
static bool constexpr reduceLeft =
- (std::is_same<FromOpTy, BatchMatmulOp>::value &&
- std::is_same<ToOpTy, BatchVecmatOp>::value) ||
- (std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value &&
- std::is_same<ToOpTy, BatchVecmatOp>::value) ||
- (std::is_same<FromOpTy, MatmulOp>::value &&
- std::is_same<ToOpTy, VecmatOp>::value) ||
- (std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
- std::is_same<ToOpTy, VecmatOp>::value) ||
- (std::is_same<FromOpTy, MatvecOp>::value &&
- std::is_same<ToOpTy, DotOp>::value);
-
- LogicalResult getOperandUnitDims(
- LinalgOp op,
- SmallVectorImpl<int64_t> &operandUnitDindices) const override {
+ (std::is_same_v<FromOpTy, BatchMatmulOp> &&
+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+ (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+ (std::is_same_v<FromOpTy, MatmulOp> &&
+ std::is_same_v<ToOpTy, VecmatOp>) ||
+ (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
+ std::is_same_v<ToOpTy, VecmatOp>) ||
+ (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
+
+ /// Look for non-batch spatial dims to collapse.
+ LogicalResult
+ getOperandUnitDims(LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDims) const override {
auto maybeContractionDims = inferContractionDims(op);
- if (failed(maybeContractionDims))
+ if (failed(maybeContractionDims)) {
+ LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
+ }
auto contractionDims = maybeContractionDims.value();
if constexpr (reduceLeft) {
@@ -1021,8 +1036,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
- operandUnitDindices = SmallVector<int64_t>{
- std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])};
+ operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
+ std::get<1>(mOperands[1])};
return success();
}
} else {
@@ -1035,11 +1050,12 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
- operandUnitDindices = SmallVector<int64_t>{
- -1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])};
+ operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
+ std::get<1>(nOperands[1])};
return success();
}
}
+ LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
return failure();
}
};
@@ -1050,13 +1066,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
- patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns.add<RankReduceBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
- patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
- patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
+ patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
+ patterns
+ .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+ context);
+ patterns
+ .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+ context);
+ patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
+ patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 70568be99474e..fd59a4a52e378 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -212,6 +212,15 @@ func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x
// -----
+func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
+ // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
+ // CHECK: linalg.dot
+ %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ return %0 : tensor<1x1x1xf32>
+}
+
+// -----
+
func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
// CHECK-LABEL: @nonsingleton_batch_matmul
// CHECK-NOT: collapse_shape
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
index 5ca27be30a687..8b455d7d68c30 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
@@ -1,5 +1,4 @@
-//===- TestLinalgRankReduceContractionOps.cpp - Test Linalg rank reduce
-//contractions ---===//
+//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From 01197ef9f661bc40348c48fd907aab7f1dc118b3 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 21 Jun 2024 11:29:46 -0500
Subject: [PATCH 13/14] add more tests
---
.../Linalg/rank-reduce-contraction-ops.mlir | 23 +++++++++++++++++++
1 file changed, 23 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index fd59a4a52e378..c086d0fd7e633 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -203,6 +203,18 @@ func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<
// -----
+func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> {
+ // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat
+ // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+ // CHECK: collapse_shape {{.*}} into tensor<64x512xf32>
+ // CHECK: linalg.batch_vecmat
+ // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32>
+ %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32>
+ return %0 : tensor<64x1x512xf32>
+}
+
+// -----
+
func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
// CHECK-LABEL: @matmul_transpose_b_to_matvec
// CHECK: linalg.matvec
@@ -212,6 +224,17 @@ func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x
// -----
+func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> {
+ // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+ // CHECK: collapse_shape {{.*}} into tensor<64x128xf32>
+ // CHECK: linalg.batch_matvec
+ // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32>
+ %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32>
+ return %0 : tensor<64x128x1xf32>
+}
+
+// -----
+
func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
// CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
// CHECK: linalg.dot
>From 52151f41149e6c80bf8e1e68eb9176ba1f3de341 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 21 Jun 2024 12:30:12 -0500
Subject: [PATCH 14/14] expand more autos
---
.../Dialect/Linalg/Transforms/DropUnitDims.cpp | 18 ++++++++++--------
1 file changed, 10 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e1daeb3ad666e..36f8696bf1b27 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -922,11 +922,11 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
return rewriter.notifyMatchFailure(contractionOp,
"no reducable dims found");
- auto collapsedOperands =
+ SmallVector<Value> collapsedOperands =
collapseOperands(rewriter, operands, operandUnitDims);
- auto collapsedLhs = collapsedOperands[0];
- auto collapsedRhs = collapsedOperands[1];
- auto collapsedInit = collapsedOperands[2];
+ Value collapsedLhs = collapsedOperands[0];
+ Value collapsedRhs = collapsedOperands[1];
+ Value collapsedInit = collapsedOperands[2];
SmallVector<Type, 1> collapsedResultTy;
if (isa<RankedTensorType>(collapsedInit.getType()))
collapsedResultTy.push_back(collapsedInit.getType());
@@ -971,12 +971,13 @@ struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const override {
- auto maybeContractionDims = inferContractionDims(op);
+ FailureOr<ContractionDimensions> maybeContractionDims =
+ inferContractionDims(op);
if (failed(maybeContractionDims)) {
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
}
- auto contractionDims = maybeContractionDims.value();
+ ContractionDimensions contractionDims = maybeContractionDims.value();
if (contractionDims.batch.size() != 1)
return failure();
@@ -1019,12 +1020,13 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const override {
- auto maybeContractionDims = inferContractionDims(op);
+ FailureOr<ContractionDimensions> maybeContractionDims =
+ inferContractionDims(op);
if (failed(maybeContractionDims)) {
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
}
- auto contractionDims = maybeContractionDims.value();
+ ContractionDimensions contractionDims = maybeContractionDims.value();
if constexpr (reduceLeft) {
auto m = contractionDims.m[0];
More information about the Mlir-commits
mailing list