[Mlir-commits] [mlir] [MLIR][Linalg] Introduce SpecializeOp (PR #70326)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 26 06:19:06 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: lorenzo chelini (chelini)
<details>
<summary>Changes</summary>
Introduce an operation to specialize linalg.generics, for example, detecting a linalg.generic that is semantically equivalent to a linalg.copy and replacing the former with the latter. After code generation, it is helpful to lower named operations to vendor-optimized libraries.
---
Full diff: https://github.com/llvm/llvm-project/pull/70326.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+36)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+24)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+52)
- (added) mlir/test/Dialect/Linalg/transform-op-specialize.mlir (+77)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..2d86d443a28ebbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -390,6 +390,42 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
}];
}
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===//
+
+def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Transforms a generic operation into the equivalent named form.
+
+ #### Return modes
+
+ This operation ignores non-Linalg ops and drops them in the return.
+ If all the operations referred to by the `target` handle specialize
+ properly, the transform succeeds. Otherwise the transform silently fails.
+ The return handle points to only the subset of successfully produced
+ equivalent named operations, which can be empty or contain the original
+ ops if they were already in named form.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+ let assemblyFormat =
+ "$target attr-dict `:` "
+ "custom<SemiFunctionType>(type($target), type($transformed))";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fbe2923c710aabb..122f73562852101 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -668,6 +668,11 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
LinalgOp namedOp);
+/// Create a namedOp from the given GenericOp and replace the GenericOp.
+/// Currently we can specialize only trivial linalg copy operations.
+FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp);
+
/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that
/// can be computed for the size of the result of `subView`. Returns the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..87be3bb85b6e788 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1018,6 +1018,30 @@ transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultSilenceableFailure(target);
}
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===/
+
+DiagnosedSilenceableFailure
+transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
+ LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ // Exit early if the operation is not a generic.
+ if (!isa<GenericOp>(target)) {
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+ }
+ rewriter.setInsertionPoint(target);
+ FailureOr<LinalgOp> named =
+ specializeGenericOp(rewriter, cast<GenericOp>(target));
+ if (succeeded(named)) {
+ results.push_back(named->getOperation());
+ return DiagnosedSilenceableFailure::success();
+ }
+ return emitDefaultSilenceableFailure(target);
+}
+
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..5ec1fd5dd7e91db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
+ Specialize.cpp
Split.cpp
SplitReduction.cpp
SubsetHoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
new file mode 100644
index 000000000000000..6c7be63069dad1d
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -0,0 +1,52 @@
+//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a method to specialize generic operations to named
+// operations. Conceptually it is the opposite of generalize.cpp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-specialization"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static bool isaCopyOp(GenericOp genericOp) {
+ // Structural.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+ return false;
+
+ // Operands and maps.
+ if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ return false;
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
+ !mapRange.back().isIdentity()) {
+ return false;
+ }
+
+ // Region.
+ Region ® = genericOp.getRegion();
+ if (!llvm::hasSingleElement(reg))
+ return false;
+ return std::distance(reg.front().begin(), reg.front().end()) == 1;
+}
+
+FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (isaCopyOp(genericOp)) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+ return namedOp;
+ }
+ return failure();
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
new file mode 100644
index 000000000000000..a125d2dc3ca29e6
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map1, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+func.func @transpose_op_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map2],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{failed to apply}}
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: generalize_trivial_copy
+func.func @generalize_trivial_copy(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // CHECK: linalg.copy
+ // CHECK-NOT: linalg.generic
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/70326
More information about the Mlir-commits
mailing list