[Mlir-commits] [mlir] [mlir][linalg] Move transpose_matmul to targeted transform op (PR #89717)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 23 01:38:58 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Cullen Rhodes (c-rhodes)
<details>
<summary>Changes</summary>
More targeted than a blanket "apply everywhere" pattern. Follow up to #<!-- -->89075 to address @<!-- -->ftynse's feedback.
---
Full diff: https://github.com/llvm/llvm-project/pull/89717.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+46-17)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+8)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+28-6)
- (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+98-77)
- (modified) mlir/test/Dialect/Linalg/transpose-matmul-a.mlir (+2-3)
- (modified) mlir/test/Dialect/Linalg/transpose-matmul-b.mlir (+2-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index beb4cb076f4947..d0ad4ccdf031d9 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
-def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
- "apply_patterns.linalg.transpose_matmul",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
- let description = [{
- Collects patterns to convert Linalg matmul ops to transposed variants.
-
- By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
- instead transpose RHS matrix.
- }];
-
- let arguments = (ins
- DefaultValuedAttr<TransposeMatmulInput,
- "TransposeMatmulInput::lhs">:$inputToTranspose);
-
- let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
-}
-
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
@@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
}];
}
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+def TransposeMatmulOp : Op<Transform_Dialect,
+ "structured.transpose_matmul",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Convert Linalg matmul ops to transposed variants.
+
+ By default the LHS matrix is transposed. Specify `<rhs>` to instead
+ transpose RHS matrix.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported, i.e., not a
+ `linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
+ and returns a handle to the transposed matmul op.
+ }];
+
+ let arguments = (ins
+ TransformHandleTypeInterface:$target,
+ DefaultValuedAttr<TransposeMatmulInput,
+ "TransposeMatmulInput::lhs">:$inputToTranspose);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = [{
+ $target (`<` $inputToTranspose^ `>`)?
+ attr-dict `:` functional-type($target, results)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3bee911ca282ea..5ecf84fa9c7012 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcQOp op);
+/// Convert Linalg matmul ops to transposed variants.
+FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
+ linalg::MatmulOp op,
+ bool transposeLHS = true);
+FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
+ linalg::BatchMatmulOp op,
+ bool transposeLHS = true);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f1faa83cbb9cc..b4463c1912d518 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}
-void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
- RewritePatternSet &patterns) {
- bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
- linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
-}
-
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
@@ -3422,6 +3416,34 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
+ transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
+ auto maybeTransformed =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case([&](linalg::MatmulOp op) {
+ return transposeMatmul(rewriter, op, transposeLHS);
+ })
+ .Case([&](linalg::BatchMatmulOp op) {
+ return transposeBatchMatmul(rewriter, op, transposeLHS);
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(op, "not supported");
+ });
+ if (failed(maybeTransformed))
+ return emitDefaultSilenceableFailure(target);
+ // Handle to the new Matmul operation with transposed filters
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a4a05b243ad2b4..aa0052ce47fa7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -18,7 +18,6 @@
using namespace mlir;
using namespace mlir::linalg;
-namespace {
/// Pattern to replace
///
/// linalg.matmul(a, b)
@@ -29,44 +28,107 @@ namespace {
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
+FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
+ linalg::MatmulOp matmulOp,
+ bool transposeLHS) {
+ if (!bufferization::hasTensorSemantics(matmulOp))
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only matmul ops with tensors are supported");
+
+ Location loc = matmulOp.getLoc();
+ Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
+ auto type = cast<ShapedType>(input.getType());
+
+ SmallVector<Value> dynamicDims;
+ if (type.isDynamicDim(1))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+ if (type.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+ ArrayRef<int64_t> shape = type.getShape();
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
+ dynamicDims);
+ auto transposeOp = rewriter.create<linalg::TransposeOp>(
+ loc, input, empty, ArrayRef<int64_t>{1, 0});
+ Operation *newMatmulOp;
+ if (transposeLHS) {
+ newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
+ loc, matmulOp.getResultTypes(),
+ ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+ matmulOp.getOutputs());
+ } else {
+ newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
+ loc, matmulOp.getResultTypes(),
+ ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+ matmulOp.getOutputs());
+ }
+ rewriter.replaceOp(matmulOp, newMatmulOp);
+ return newMatmulOp;
+}
+
+/// Pattern to replace
+///
+/// linalg.batch_matmul(a, b)
+///
+/// with
+///
+/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///
+/// Only the non-batch dimensions are transposed. By default the LHS is
+/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+FailureOr<Operation *>
+mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
+ linalg::BatchMatmulOp batchMatmulOp,
+ bool transposeLHS) {
+ if (!bufferization::hasTensorSemantics(batchMatmulOp))
+ return rewriter.notifyMatchFailure(
+ batchMatmulOp, "only matmul ops with tensors are supported");
+
+ Location loc = batchMatmulOp.getLoc();
+ Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
+ auto type = cast<ShapedType>(input.getType());
+
+ SmallVector<Value> dynamicDims;
+ if (type.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+ if (type.isDynamicDim(2))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
+ if (type.isDynamicDim(1))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+
+ ArrayRef<int64_t> shape = type.getShape();
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
+ type.getElementType(), dynamicDims);
+ auto transposeOp = rewriter.create<linalg::TransposeOp>(
+ loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
+ Operation *newMatmulOp;
+ if (transposeLHS) {
+ newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
+ loc, batchMatmulOp.getResultTypes(),
+ ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
+ batchMatmulOp.getOutputs());
+ } else {
+ newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
+ loc, batchMatmulOp.getResultTypes(),
+ ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
+ batchMatmulOp.getOutputs());
+ }
+ rewriter.replaceOp(batchMatmulOp, newMatmulOp);
+ return newMatmulOp;
+}
+
+namespace {
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
- LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+ LogicalResult matchAndRewrite(linalg::MatmulOp op,
PatternRewriter &rewriter) const override {
- if (!bufferization::hasTensorSemantics(matmulOp))
- return rewriter.notifyMatchFailure(
- matmulOp, "only matmul ops with tensors are supported");
-
- Location loc = matmulOp.getLoc();
- Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
- auto type = cast<ShapedType>(input.getType());
-
- SmallVector<Value> dynamicDims;
- if (type.isDynamicDim(1))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
- if (type.isDynamicDim(0))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
-
- ArrayRef<int64_t> shape = type.getShape();
- Value empty = rewriter.create<tensor::EmptyOp>(
- loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
- dynamicDims);
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, input, empty, ArrayRef<int64_t>{1, 0});
- if (transposeLHS) {
- rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
- matmulOp, matmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
- matmulOp.getOutputs());
- } else {
- rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
- matmulOp, matmulOp.getResultTypes(),
- ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
- matmulOp.getOutputs());
+ if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
+ return failure();
}
-
return success();
}
@@ -74,57 +136,16 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
bool transposeLHS;
};
-/// Pattern to replace
-///
-/// linalg.batch_matmul(a, b)
-///
-/// with
-///
-/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
-///
-/// Only the non-batch dimensions are transposed. By default the LHS is
-/// transposed. Set `transposeLHS=false` to transpose RHS instead.
struct TransposeBatchMatmul final
: public OpRewritePattern<linalg::BatchMatmulOp> {
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
- LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
+ LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
PatternRewriter &rewriter) const override {
- if (!bufferization::hasTensorSemantics(batchMatmulOp))
- return rewriter.notifyMatchFailure(
- batchMatmulOp, "only matmul ops with tensors are supported");
-
- Location loc = batchMatmulOp.getLoc();
- Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
- auto type = cast<ShapedType>(input.getType());
-
- SmallVector<Value> dynamicDims;
- if (type.isDynamicDim(0))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
- if (type.isDynamicDim(2))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
- if (type.isDynamicDim(1))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
-
- ArrayRef<int64_t> shape = type.getShape();
- Value empty = rewriter.create<tensor::EmptyOp>(
- loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
- type.getElementType(), dynamicDims);
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
- if (transposeLHS) {
- rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
- batchMatmulOp, batchMatmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
- batchMatmulOp.getOutputs());
- } else {
- rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
- batchMatmulOp, batchMatmulOp.getResultTypes(),
- ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
- batchMatmulOp.getOutputs());
+ if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
+ return failure();
}
-
return success();
}
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
index 1d2460f5467a5d..b1f33cfa56327e 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
@@ -2,10 +2,9 @@
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %0 {
- transform.apply_patterns.linalg.transpose_matmul
- } : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
index eecd76f1ecca7d..41e64c04dc6e59 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
@@ -2,10 +2,9 @@
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %0 {
- transform.apply_patterns.linalg.transpose_matmul <rhs>
- } : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization
``````````
</details>
https://github.com/llvm/llvm-project/pull/89717
More information about the Mlir-commits
mailing list