[Mlir-commits] [mlir] 431213c - [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (#95710)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 24 11:06:35 PDT 2024
Author: srcarroll
Date: 2024-06-24T13:06:31-05:00
New Revision: 431213c99d7707114d8e7956073a057cf1607160
URL: https://github.com/llvm/llvm-project/commit/431213c99d7707114d8e7956073a057cf1607160
DIFF: https://github.com/llvm/llvm-project/commit/431213c99d7707114d8e7956073a057cf1607160.diff
LOG: [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (#95710)
This patch introduces pattern rewrites for reducing the rank of named
linalg contraction ops with unit spatial dim(s) to other named
contraction ops. For example `linalg.batch_matmul` with batch size 1 ->
`linalg.matmul` and `linalg.matmul` with unit LHS spatial dim ->
`linalg.vecmat`, etc. These patterns don't support reducing the rank
along reduction dimension as those don't convert to other named
contraction ops.
Added:
mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8424207ea47e5..b0871a5dff5da 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1713,6 +1713,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
+/// Adds patterns that reduce the rank of named contraction ops that have
+/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
+/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
+/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c0829397f1f85..36f8696bf1b27 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -833,4 +833,265 @@ struct LinalgFoldUnitExtentDimsPass
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
+
+} // namespace
+
+namespace {
+
+/// Returns reassociation indices for collapsing/expanding a
+/// tensor of rank `rank` at position `pos`.
+static SmallVector<ReassociationIndices>
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
+ SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+ bool lastDim = pos == rank - 1;
+ if (rank > 2) {
+ for (int64_t i = 0; i < rank - 1; i++) {
+ if (i == pos || (lastDim && i == pos - 1))
+ reassociation[i] = ReassociationIndices{i, i + 1};
+ else if (i < pos)
+ reassociation[i] = ReassociationIndices{i};
+ else
+ reassociation[i] = ReassociationIndices{i + 1};
+ }
+ }
+ return reassociation;
+}
+
+/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
+/// If `pos < 0`, then don't collapse.
+static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
+ int64_t pos) {
+ if (pos < 0)
+ return val;
+ auto valType = cast<ShapedType>(val.getType());
+ SmallVector<int64_t> collapsedShape(valType.getShape());
+ collapsedShape.erase(collapsedShape.begin() + pos);
+ return collapseValue(
+ rewriter, val.getLoc(), val, collapsedShape,
+ getReassociationForReshapeAtDim(valType.getRank(), pos),
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+/// Base class for all rank reduction patterns for contraction ops
+/// with unit dimensions. All patterns should convert one named op
+/// to another named op. Intended to reduce only one iteration space dim
+/// at a time.
+/// Reducing multiple dims will happen with recusive application of
+/// pattern rewrites.
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
+ using OpRewritePattern<FromOpTy>::OpRewritePattern;
+
+ /// Collapse all collapsable operands.
+ SmallVector<Value>
+ collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
+ ArrayRef<int64_t> operandCollapseDims) const {
+ assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
+ "expected 3 operands and dims");
+ return llvm::map_to_vector(
+ llvm::zip(operands, operandCollapseDims), [&](auto pair) {
+ return collapseSingletonDimAt(rewriter, std::get<0>(pair),
+ std::get<1>(pair));
+ });
+ }
+
+ /// Expand result tensor.
+ Value expandResult(PatternRewriter &rewriter, Value result,
+ RankedTensorType expandedType, int64_t dim) const {
+ return rewriter.create<tensor::ExpandShapeOp>(
+ result.getLoc(), expandedType, result,
+ getReassociationForReshapeAtDim(expandedType.getRank(), dim));
+ }
+
+ LogicalResult matchAndRewrite(FromOpTy contractionOp,
+ PatternRewriter &rewriter) const override {
+
+ auto loc = contractionOp.getLoc();
+ auto inputs = contractionOp.getDpsInputs();
+ auto inits = contractionOp.getDpsInits();
+ if (inputs.size() != 2 || inits.size() != 1)
+ return rewriter.notifyMatchFailure(contractionOp,
+ "expected 2 inputs and 1 init");
+ auto lhs = inputs[0];
+ auto rhs = inputs[1];
+ auto init = inits[0];
+ SmallVector<Value> operands{lhs, rhs, init};
+
+ SmallVector<int64_t> operandUnitDims;
+ if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
+ return rewriter.notifyMatchFailure(contractionOp,
+ "no reducable dims found");
+
+ SmallVector<Value> collapsedOperands =
+ collapseOperands(rewriter, operands, operandUnitDims);
+ Value collapsedLhs = collapsedOperands[0];
+ Value collapsedRhs = collapsedOperands[1];
+ Value collapsedInit = collapsedOperands[2];
+ SmallVector<Type, 1> collapsedResultTy;
+ if (isa<RankedTensorType>(collapsedInit.getType()))
+ collapsedResultTy.push_back(collapsedInit.getType());
+ auto collapsedOp = rewriter.create<ToOpTy>(
+ loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
+ ValueRange{collapsedInit});
+ for (auto attr : contractionOp->getAttrs()) {
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ continue;
+ collapsedOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ auto results = contractionOp.getResults();
+ assert(results.size() < 2 && "expected at most one result");
+ if (results.empty()) {
+ rewriter.replaceOp(contractionOp, collapsedOp);
+ } else {
+ rewriter.replaceOp(
+ contractionOp,
+ expandResult(rewriter, collapsedOp.getResultTensors()[0],
+ cast<RankedTensorType>(results[0].getType()),
+ operandUnitDims[2]));
+ }
+
+ return success();
+ }
+
+ /// Populate `operandUnitDims` with 3 indices indicating the unit dim
+ /// for each operand that should be collapsed in this pattern. If an
+ /// operand shouldn't be collapsed, the index should be negative.
+ virtual LogicalResult
+ getOperandUnitDims(LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
+};
+
+/// Patterns for unbatching batched contraction ops
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+ /// Look for unit batch dims to collapse.
+ LogicalResult
+ getOperandUnitDims(LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDims) const override {
+ FailureOr<ContractionDimensions> maybeContractionDims =
+ inferContractionDims(op);
+ if (failed(maybeContractionDims)) {
+ LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
+ return failure();
+ }
+ ContractionDimensions contractionDims = maybeContractionDims.value();
+
+ if (contractionDims.batch.size() != 1)
+ return failure();
+ auto batchDim = contractionDims.batch[0];
+ SmallVector<std::pair<Value, unsigned>, 3> bOperands;
+ op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
+ if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
+ return cast<ShapedType>(std::get<0>(pair).getType())
+ .getShape()[std::get<1>(pair)] != 1;
+ })) {
+ LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
+ return failure();
+ }
+
+ operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
+ std::get<1>(bOperands[1]),
+ std::get<1>(bOperands[2])};
+ return success();
+ }
+};
+
+/// Patterns for reducing non-batch dimensions
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+ /// Helper for determining whether the lhs/init or rhs/init are reduced.
+ static bool constexpr reduceLeft =
+ (std::is_same_v<FromOpTy, BatchMatmulOp> &&
+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+ (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
+ (std::is_same_v<FromOpTy, MatmulOp> &&
+ std::is_same_v<ToOpTy, VecmatOp>) ||
+ (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
+ std::is_same_v<ToOpTy, VecmatOp>) ||
+ (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
+
+ /// Look for non-batch spatial dims to collapse.
+ LogicalResult
+ getOperandUnitDims(LinalgOp op,
+ SmallVectorImpl<int64_t> &operandUnitDims) const override {
+ FailureOr<ContractionDimensions> maybeContractionDims =
+ inferContractionDims(op);
+ if (failed(maybeContractionDims)) {
+ LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
+ return failure();
+ }
+ ContractionDimensions contractionDims = maybeContractionDims.value();
+
+ if constexpr (reduceLeft) {
+ auto m = contractionDims.m[0];
+ SmallVector<std::pair<Value, unsigned>, 2> mOperands;
+ op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
+ if (mOperands.size() != 2)
+ return failure();
+ if (llvm::all_of(mOperands, [](auto pair) {
+ return cast<ShapedType>(std::get<0>(pair).getType())
+ .getShape()[std::get<1>(pair)] == 1;
+ })) {
+ operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
+ std::get<1>(mOperands[1])};
+ return success();
+ }
+ } else {
+ auto n = contractionDims.n[0];
+ SmallVector<std::pair<Value, unsigned>, 2> nOperands;
+ op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
+ if (nOperands.size() != 2)
+ return failure();
+ if (llvm::all_of(nOperands, [](auto pair) {
+ return cast<ShapedType>(std::get<0>(pair).getType())
+ .getShape()[std::get<1>(pair)] == 1;
+ })) {
+ operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
+ std::get<1>(nOperands[1])};
+ return success();
+ }
+ }
+ LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
+ return failure();
+ }
+};
+
} // namespace
+
+void mlir::linalg::populateContractionOpRankReducingPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ // Unbatching patterns for unit batch size
+ patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
+ patterns
+ .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
+ context);
+ patterns
+ .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
+ context);
+ patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
+ patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
+
+ // Non-batch rank 1 reducing patterns
+ patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
+ patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
+ patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
+ patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
+ // Batch rank 1 reducing patterns
+ patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
+ patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
+ patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
+ context);
+ patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
+ context);
+
+ // Non-batch rank 0 reducing patterns
+ patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
+ patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
+}
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
new file mode 100644
index 0000000000000..c086d0fd7e633
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -0,0 +1,267 @@
+//RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s
+
+func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
+ // CHECK-LABEL: @singleton_batch_matmul_tensor
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512x256xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128x256xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512x256xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128x256xf32>)
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, 128, 256]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512x256xf32>)
+ outs(%arg2 : tensor<1x128x256xf32>) -> tensor<1x128x256xf32>
+ return %1 : tensor<1x128x256xf32>
+}
+
+// -----
+
+func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+ // CHECK-LABEL: @singleton_batch_matmul_memref
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[COLLAPSED_INIT]] : memref<?x?xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+ outs(%arg2 : memref<1x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> {
+ // CHECK-LABEL: @singleton_batch_matvec
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>)
+ outs(%arg2 : tensor<1x128xf32>) -> tensor<1x128xf32>
+ return %1 : tensor<1x128xf32>
+}
+
+// -----
+
+func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK-LABEL: @singleton_batch_vecmat
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: return %[[RES]]
+ %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
+ outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
+ // CHECK-LABEL: @singleton_batchmatmul_transpose_a
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
+ // CHECK-LABEL: @singleton_batchmatmul_transpose_b
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
+ // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
+ // CHECK-NEXT: return
+ linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
+ // CHECK-LABEL: @matmul_to_matvec_tensor
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<?x1xf32>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0
+ // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
+ // CHECK-NEXT: return %[[RES]]
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
+ return %0 : tensor<?x1xf32>
+}
+
+// -----
+
+func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+ // CHECK-LABEL: @matmul_to_matvec
+ // CHECK: linalg.matvec
+ linalg.matmul ins(%arg0, %arg1: memref<?x?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
+ return
+}
+
+// -----
+
+func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
+ // CHECK-LABEL: @matmul_to_vecmat
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
+ // CHECK-NEXT: return %[[RES]]
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>
+ return %0 : tensor<1x?xf32>
+}
+
+// -----
+
+func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) {
+ // CHECK-LABEL: @batch_matmul_to_vecmat
+ // CHECK: linalg.vecmat
+ linalg.batch_matmul ins(%arg0, %arg1: memref<1x1x?xf32>, memref<1x?x?xf32>) outs(%arg2: memref<1x1x?xf32>)
+ return
+}
+
+// -----
+
+func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref<?xf32>, %arg2: memref<1xf32>) {
+ // CHECK-LABEL: @matvec_to_dot
+ // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?xf32>
+ // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<?xf32>
+ // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1xf32>
+ // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] []
+ // CHECK-NEXT: linalg.dot ins(%[[COLLAPSED_LHS]], %[[RHS]] : memref<?xf32>, memref<?xf32>) outs(%[[COLLAPSED_INIT]] : memref<f32>)
+ linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref<?xf32>) outs(%arg2: memref<1xf32>)
+ return
+}
+
+// -----
+
+func.func @vecmat_to_dot(%arg0: memref<?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<1xf32>) {
+ // CHECK-LABEL: @vecmat_to_dot
+ // CHECK: linalg.dot
+ linalg.vecmat ins(%arg0, %arg1: memref<?xf32>, memref<?x1xf32>) outs(%arg2: memref<1xf32>)
+ return
+}
+
+// -----
+
+func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK-LABEL: @matvec_to_dot_tensor
+ // CHECK: linalg.dot
+ %0 = linalg.matvec ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?xf32>) outs(%arg2: tensor<1xf32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> {
+ // CHECK-LABEL: @matmul_transpose_a_to_vecmat
+ // CHECK: collapse_shape {{.*}} into tensor<256xf32>
+ // CHECK: collapse_shape {{.*}} into tensor<512xf32>
+ // CHECK: linalg.vecmat
+ // CHECK: expand_shape {{.*}} into tensor<1x512xf32>
+ %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32>
+ return %0 : tensor<1x512xf32>
+}
+
+// -----
+
+func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> {
+ // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat
+ // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+ // CHECK: collapse_shape {{.*}} into tensor<64x512xf32>
+ // CHECK: linalg.batch_vecmat
+ // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32>
+ %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32>
+ return %0 : tensor<64x1x512xf32>
+}
+
+// -----
+
+func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
+ // CHECK-LABEL: @matmul_transpose_b_to_matvec
+ // CHECK: linalg.matvec
+ linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<?x1xf32>)
+ return
+}
+
+// -----
+
+func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> {
+ // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
+ // CHECK: collapse_shape {{.*}} into tensor<64x128xf32>
+ // CHECK: linalg.batch_matvec
+ // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32>
+ %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32>
+ return %0 : tensor<64x128x1xf32>
+}
+
+// -----
+
+func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
+ // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
+ // CHECK: linalg.dot
+ %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ return %0 : tensor<1x1x1xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
+ // CHECK-LABEL: @nonsingleton_batch_matmul
+ // CHECK-NOT: collapse_shape
+ // CHECK: linalg.batch_matmul
+ // CHECK-NOT: expand_shape
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+ outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+ return %1 : tensor<2x?x?xf32>
+}
+
+// -----
+
+func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic
+ // CHECK-NOT: collapse_shape
+ // CHECK: linalg.batch_matmul
+ // CHECK-NOT: expand_shape
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index b28f2b3564662..283e426b4e594 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgDropUnitDims.cpp
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
+ TestLinalgRankReduceContractionOps.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
new file mode 100644
index 0000000000000..8b455d7d68c30
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp
@@ -0,0 +1,67 @@
+//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing rank reduing patterns for named
+// contraction ops with unit dims.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLinalgRankReduceContractionOps
+ : public PassWrapper<TestLinalgRankReduceContractionOps,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestLinalgRankReduceContractionOps)
+
+ TestLinalgRankReduceContractionOps() = default;
+ TestLinalgRankReduceContractionOps(
+ const TestLinalgRankReduceContractionOps &pass)
+ : PassWrapper(pass) {}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<affine::AffineDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, tensor::TensorDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-linalg-rank-reduce-contraction-ops";
+ }
+ StringRef getDescription() const final {
+ return "Test Linalg rank reduce contraction ops with unit dims";
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &this->getContext();
+ func::FuncOp funcOp = this->getOperation();
+
+ RewritePatternSet patterns(context);
+ linalg::populateContractionOpRankReducingPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(patterns))))
+ return signalPassFailure();
+ return;
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgRankReduceContractionOps() {
+ PassRegistration<TestLinalgRankReduceContractionOps>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e4fbd03d8c678..8cafb0afac9ae 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -108,6 +108,7 @@ void registerTestLinalgDecomposeOps();
void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
+void registerTestLinalgRankReduceContractionOps();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
@@ -239,6 +240,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgDropUnitDims();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
+ mlir::test::registerTestLinalgRankReduceContractionOps();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
More information about the Mlir-commits
mailing list