[Mlir-commits] [mlir] [mlir][linalg] Move transpose_matmul to targeted transform op (PR #89717)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 23 01:38:58 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

<details>
<summary>Changes</summary>

More targeted than a blanket "apply everywhere" pattern. Follow up to #<!-- -->89075 to address @<!-- -->ftynse's feedback.

---
Full diff: https://github.com/llvm/llvm-project/pull/89717.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+46-17) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+8) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+28-6) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+98-77) 
- (modified) mlir/test/Dialect/Linalg/transpose-matmul-a.mlir (+2-3) 
- (modified) mlir/test/Dialect/Linalg/transpose-matmul-b.mlir (+2-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index beb4cb076f4947..d0ad4ccdf031d9 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.linalg.transpose_matmul",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Collects patterns to convert Linalg matmul ops to transposed variants.
-
-    By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
-    instead transpose RHS matrix.
-  }];
-
-  let arguments = (ins
-    DefaultValuedAttr<TransposeMatmulInput,
-                      "TransposeMatmulInput::lhs">:$inputToTranspose);
-
-  let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
-}
-
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+def TransposeMatmulOp : Op<Transform_Dialect,
+    "structured.transpose_matmul",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Convert Linalg matmul ops to transposed variants.
+
+    By default the LHS matrix is transposed. Specify `<rhs>` to instead
+    transpose RHS matrix.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported, i.e., not a
+    `linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
+    and returns a handle to the transposed matmul op.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$target,
+    DefaultValuedAttr<TransposeMatmulInput,
+                      "TransposeMatmulInput::lhs">:$inputToTranspose);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = [{
+    $target (`<` $inputToTranspose^ `>`)?
+    attr-dict `:` functional-type($target, results)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3bee911ca282ea..5ecf84fa9c7012 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
                                        linalg::Conv2DNhwcFhwcQOp op);
 
+/// Convert Linalg matmul ops to transposed variants.
+FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
+                                       linalg::MatmulOp op,
+                                       bool transposeLHS = true);
+FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
+                                            linalg::BatchMatmulOp op,
+                                            bool transposeLHS = true);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f1faa83cbb9cc..b4463c1912d518 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
 }
 
-void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
-  linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
-}
-
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -3422,6 +3416,34 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeMatmulOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::MatmulOp op) {
+            return transposeMatmul(rewriter, op, transposeLHS);
+          })
+          .Case([&](linalg::BatchMatmulOp op) {
+            return transposeBatchMatmul(rewriter, op, transposeLHS);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+  // Handle to the new Matmul operation with transposed filters
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a4a05b243ad2b4..aa0052ce47fa7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -18,7 +18,6 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-namespace {
 /// Pattern to replace
 ///
 ///   linalg.matmul(a, b)
@@ -29,44 +28,107 @@ namespace {
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
+FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
+                                                     linalg::MatmulOp matmulOp,
+                                                     bool transposeLHS) {
+  if (!bufferization::hasTensorSemantics(matmulOp))
+    return rewriter.notifyMatchFailure(
+        matmulOp, "only matmul ops with tensors are supported");
+
+  Location loc = matmulOp.getLoc();
+  Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
+  auto type = cast<ShapedType>(input.getType());
+
+  SmallVector<Value> dynamicDims;
+  if (type.isDynamicDim(1))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+  if (type.isDynamicDim(0))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+  ArrayRef<int64_t> shape = type.getShape();
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
+      dynamicDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, input, empty, ArrayRef<int64_t>{1, 0});
+  Operation *newMatmulOp;
+  if (transposeLHS) {
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
+  } else {
+    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
+        loc, matmulOp.getResultTypes(),
+        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+        matmulOp.getOutputs());
+  }
+  rewriter.replaceOp(matmulOp, newMatmulOp);
+  return newMatmulOp;
+}
+
+/// Pattern to replace
+///
+///   linalg.batch_matmul(a, b)
+///
+/// with
+///
+///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///
+/// Only the non-batch dimensions are transposed. By default the LHS is
+/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+FailureOr<Operation *>
+mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
+                                   linalg::BatchMatmulOp batchMatmulOp,
+                                   bool transposeLHS) {
+  if (!bufferization::hasTensorSemantics(batchMatmulOp))
+    return rewriter.notifyMatchFailure(
+        batchMatmulOp, "only matmul ops with tensors are supported");
+
+  Location loc = batchMatmulOp.getLoc();
+  Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
+  auto type = cast<ShapedType>(input.getType());
+
+  SmallVector<Value> dynamicDims;
+  if (type.isDynamicDim(0))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+  if (type.isDynamicDim(2))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
+  if (type.isDynamicDim(1))
+    dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+
+  ArrayRef<int64_t> shape = type.getShape();
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
+      type.getElementType(), dynamicDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
+  Operation *newMatmulOp;
+  if (transposeLHS) {
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
+        batchMatmulOp.getOutputs());
+  } else {
+    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
+        loc, batchMatmulOp.getResultTypes(),
+        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
+        batchMatmulOp.getOutputs());
+  }
+  rewriter.replaceOp(batchMatmulOp, newMatmulOp);
+  return newMatmulOp;
+}
+
+namespace {
 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
   TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
 
-  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+  LogicalResult matchAndRewrite(linalg::MatmulOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(matmulOp))
-      return rewriter.notifyMatchFailure(
-          matmulOp, "only matmul ops with tensors are supported");
-
-    Location loc = matmulOp.getLoc();
-    Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
-    auto type = cast<ShapedType>(input.getType());
-
-    SmallVector<Value> dynamicDims;
-    if (type.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
-    if (type.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
-
-    ArrayRef<int64_t> shape = type.getShape();
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
-        dynamicDims);
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, input, empty, ArrayRef<int64_t>{1, 0});
-    if (transposeLHS) {
-      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
-          matmulOp, matmulOp.getResultTypes(),
-          ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
-          matmulOp.getOutputs());
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
-          matmulOp, matmulOp.getResultTypes(),
-          ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
-          matmulOp.getOutputs());
+    if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
+      return failure();
     }
-
     return success();
   }
 
@@ -74,57 +136,16 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
   bool transposeLHS;
 };
 
-/// Pattern to replace
-///
-///   linalg.batch_matmul(a, b)
-///
-/// with
-///
-///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
-///
-/// Only the non-batch dimensions are transposed. By default the LHS is
-/// transposed. Set `transposeLHS=false` to transpose RHS instead.
 struct TransposeBatchMatmul final
     : public OpRewritePattern<linalg::BatchMatmulOp> {
   TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
 
-  LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
+  LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(batchMatmulOp))
-      return rewriter.notifyMatchFailure(
-          batchMatmulOp, "only matmul ops with tensors are supported");
-
-    Location loc = batchMatmulOp.getLoc();
-    Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
-    auto type = cast<ShapedType>(input.getType());
-
-    SmallVector<Value> dynamicDims;
-    if (type.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
-    if (type.isDynamicDim(2))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
-    if (type.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
-
-    ArrayRef<int64_t> shape = type.getShape();
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
-        type.getElementType(), dynamicDims);
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-    if (transposeLHS) {
-      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
-          batchMatmulOp, batchMatmulOp.getResultTypes(),
-          ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
-          batchMatmulOp.getOutputs());
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
-          batchMatmulOp, batchMatmulOp.getResultTypes(),
-          ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
-          batchMatmulOp.getOutputs());
+    if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
+      return failure();
     }
-
     return success();
   }
 
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
index 1d2460f5467a5d..b1f33cfa56327e 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
@@ -2,10 +2,9 @@
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %0 {
-      transform.apply_patterns.linalg.transpose_matmul
-    } : !transform.any_op
     transform.apply_cse to %0 : !transform.any_op
     transform.apply_patterns to %0 {
       transform.apply_patterns.canonicalization
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
index eecd76f1ecca7d..41e64c04dc6e59 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
@@ -2,10 +2,9 @@
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %0 {
-      transform.apply_patterns.linalg.transpose_matmul <rhs>
-    } : !transform.any_op
     transform.apply_cse to %0 : !transform.any_op
     transform.apply_patterns to %0 {
       transform.apply_patterns.canonicalization

``````````

</details>


https://github.com/llvm/llvm-project/pull/89717


More information about the Mlir-commits mailing list