[Mlir-commits] [mlir] [MLIR][Linalg] Introduce SpecializeOp (PR #70326)
lorenzo chelini
llvmlistbot at llvm.org
Tue Oct 31 00:47:19 PDT 2023
https://github.com/chelini updated https://github.com/llvm/llvm-project/pull/70326
>From eb537ffc34fde8be26d666eaff2f8792f3e0eefa Mon Sep 17 00:00:00 2001
From: Lorenzo Chelini <l.chelini at icloud.com>
Date: Wed, 25 Oct 2023 15:26:42 +0200
Subject: [PATCH] [MLIR][Linalg] Introduce SpecializeOp
Introduce an operation to specialize linalg.generics, for example,
detecting a linalg.generic that is semantically equivalent to a
linalg.copy and replacing the former with the latter. After code
generation, it is helpful to lower named operations to vendor-optimized
libraries.
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 3 +
.../Linalg/TransformOps/LinalgTransformOps.td | 37 +++++
.../Dialect/Linalg/Transforms/Transforms.h | 5 +
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 22 +++
.../TransformOps/LinalgTransformOps.cpp | 24 +++
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Dialect/Linalg/Transforms/Specialize.cpp | 32 ++++
.../Linalg/transform-op-specialize.mlir | 143 ++++++++++++++++++
8 files changed, 267 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f6ba6586a81a244..6c8240267e7d050 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -110,6 +110,9 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
// TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
bool isaConvolutionOpInterface(LinalgOp linalgOp);
+/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
+bool isaCopyOpInterface(LinalgOp linalgOp);
+
namespace detail {
/// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..9e3f79e64bb1d79 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -390,6 +390,43 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
}];
}
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===//
+
+def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Transforms a generic operation into the equivalent named form.
+
+ #### Return modes
+
+ This operation ignores non-Linalg ops and drops them in the return. If all
+ the operations referred to by the `target` handle specialize, the transform
+ succeeds; otherwise, the operation produces a silenceable failure. The return
+ handle points to only the subset of successfully produced equivalent named
+ operations, which can be empty or contain the original ops if they were already
+ in named form. The supported specialization to named Linalg operations are:
+ - linalg.copy of any rank.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+ let assemblyFormat =
+ "$target attr-dict `:` "
+ "custom<SemiFunctionType>(type($target), type($transformed))";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fbe2923c710aabb..122f73562852101 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -668,6 +668,11 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
LinalgOp namedOp);
+/// Create a namedOp from the given GenericOp and replace the GenericOp.
+/// Currently we can specialize only trivial linalg copy operations.
+FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp);
+
/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that
/// can be computed for the size of the result of `subView`. Returns the
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 5fde8d71cac3e75..dfd6b991e7da159 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -32,6 +32,7 @@ using namespace mlir::linalg;
//===----------------------------------------------------------------------===//
// Interface utility functions
//===----------------------------------------------------------------------===//
+
bool linalg::detail::canOpOperandsBeDroppedImpl(
linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
SmallVector<AffineMap> indexingMaps;
@@ -48,6 +49,27 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
+//===----------------------------------------------------------------------===//
+// CopyOpInterface implementation
+//===----------------------------------------------------------------------===//
+
+bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
+ // Structural.
+ if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+ return false;
+
+ // Operands and maps.
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+ return false;
+ auto mapRange = linalgOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
+ !mapRange.back().isIdentity()) {
+ return false;
+ }
+ // Region.
+ return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
+}
+
//===----------------------------------------------------------------------===//
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..87be3bb85b6e788 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1018,6 +1018,30 @@ transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultSilenceableFailure(target);
}
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===/
+
+DiagnosedSilenceableFailure
+transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
+ LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ // Exit early if the operation is not a generic.
+ if (!isa<GenericOp>(target)) {
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+ }
+ rewriter.setInsertionPoint(target);
+ FailureOr<LinalgOp> named =
+ specializeGenericOp(rewriter, cast<GenericOp>(target));
+ if (succeeded(named)) {
+ results.push_back(named->getOperation());
+ return DiagnosedSilenceableFailure::success();
+ }
+ return emitDefaultSilenceableFailure(target);
+}
+
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bad246c262979b7..e0a43a29c32d88b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
+ Specialize.cpp
Split.cpp
SplitReduction.cpp
SubsetHoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
new file mode 100644
index 000000000000000..4c437b5db2c7b08
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -0,0 +1,32 @@
+//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a method to specialize generic operations to named
+// operations. Conceptually it is the opposite of generalize.cpp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-specialization"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (isaCopyOpInterface(genericOp)) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+ return namedOp;
+ }
+ return failure();
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
new file mode 100644
index 000000000000000..8a22c115f31170f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -0,0 +1,143 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map1, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+func.func @transpose_op_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map2],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+func.func @copy_with_up_cast(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf16>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f16, %out: f32):
+ %0 = arith.extf %in : f16 to f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+func.func @copy_with_down_cast(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf16>) {
+ // expected-note @below {{when applied to this op}}
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf16>) {
+ ^bb0(%in: f32, %out: f16):
+ %0 = arith.truncf %in : f32 to f16
+ linalg.yield %0 : f16
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{failed to apply}}
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @specialize_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ }
+ return
+}
+
+// CHECK-LABEL: specialize_trivial_copy_memref
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+func.func @specialize_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>,
+ %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_trivial_copy_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)
+
+func.func @already_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+ linalg.copy ins(%arg0: memref<?x?xf32>) outs(%arg1: memref<?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: already_trivial_copy_memref
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
+// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)
+
+func.func @already_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>,
+ %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.copy ins(%arg0: tensor<?x?x?xf32>) outs(%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: already_trivial_copy_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
+// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list