[Mlir-commits] [mlir] [MLIR][Linalg] Introduce SpecializeOp (PR #70326)

lorenzo chelini llvmlistbot at llvm.org
Tue Oct 31 00:47:19 PDT 2023


https://github.com/chelini updated https://github.com/llvm/llvm-project/pull/70326

>From eb537ffc34fde8be26d666eaff2f8792f3e0eefa Mon Sep 17 00:00:00 2001
From: Lorenzo Chelini <l.chelini at icloud.com>
Date: Wed, 25 Oct 2023 15:26:42 +0200
Subject: [PATCH] [MLIR][Linalg] Introduce SpecializeOp

Introduce an operation to specialize linalg.generics, for example,
detecting a linalg.generic that is semantically equivalent to a
linalg.copy and replacing the former with the latter. After code
generation, it is helpful to lower named operations to vendor-optimized
libraries.
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |   3 +
 .../Linalg/TransformOps/LinalgTransformOps.td |  37 +++++
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 +
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  22 +++
 .../TransformOps/LinalgTransformOps.cpp       |  24 +++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  32 ++++
 .../Linalg/transform-op-specialize.mlir       | 143 ++++++++++++++++++
 8 files changed, 267 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f6ba6586a81a244..6c8240267e7d050 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -110,6 +110,9 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
 // TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
 bool isaConvolutionOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
+bool isaCopyOpInterface(LinalgOp linalgOp);
+
 namespace detail {
 
 /// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..9e3f79e64bb1d79 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -390,6 +390,43 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===//
+
+def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Transforms a generic operation into the equivalent named form.
+
+    #### Return modes
+
+    This operation ignores non-Linalg ops and drops them in the return. If all
+    the operations referred to by the `target` handle specialize, the transform
+    succeeds; otherwise, the operation produces a silenceable failure.  The return
+    handle points to only the subset of successfully produced equivalent named
+    operations, which can be empty or contain the original ops if they were already
+    in named form. The supported specialization to named Linalg operations are:
+    - linalg.copy of any rank.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+  let assemblyFormat =
+      "$target attr-dict `:` "
+      "custom<SemiFunctionType>(type($target), type($transformed))";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // InterchangeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fbe2923c710aabb..122f73562852101 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -668,6 +668,11 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
 FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
                                        LinalgOp namedOp);
 
+/// Create a namedOp from the given GenericOp and replace the GenericOp.
+/// Currently we can specialize only trivial linalg copy operations.
+FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
+                                        GenericOp genericOp);
+
 /// Create a new buffer using the `allocationFn` provided. The size of this
 /// buffer is the smallest constant bounding size along each dimension that
 /// can be computed for the size of the result of `subView`. Returns the
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 5fde8d71cac3e75..dfd6b991e7da159 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -32,6 +32,7 @@ using namespace mlir::linalg;
 //===----------------------------------------------------------------------===//
 // Interface utility functions
 //===----------------------------------------------------------------------===//
+
 bool linalg::detail::canOpOperandsBeDroppedImpl(
     linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
   SmallVector<AffineMap> indexingMaps;
@@ -48,6 +49,27 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
   return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
 }
 
+//===----------------------------------------------------------------------===//
+// CopyOpInterface implementation
+//===----------------------------------------------------------------------===//
+
+bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
+  // Structural.
+  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+    return false;
+
+  // Operands and maps.
+  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+    return false;
+  auto mapRange = linalgOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
+      !mapRange.back().isIdentity()) {
+    return false;
+  }
+  // Region.
+  return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..87be3bb85b6e788 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1018,6 +1018,30 @@ transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
   return emitDefaultSilenceableFailure(target);
 }
 
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===/
+
+DiagnosedSilenceableFailure
+transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
+                                    LinalgOp target,
+                                    transform::ApplyToEachResultList &results,
+                                    transform::TransformState &state) {
+  // Exit early if the operation is not a generic.
+  if (!isa<GenericOp>(target)) {
+    results.push_back(target);
+    return DiagnosedSilenceableFailure::success();
+  }
+  rewriter.setInsertionPoint(target);
+  FailureOr<LinalgOp> named =
+      specializeGenericOp(rewriter, cast<GenericOp>(target));
+  if (succeeded(named)) {
+    results.push_back(named->getOperation());
+    return DiagnosedSilenceableFailure::success();
+  }
+  return emitDefaultSilenceableFailure(target);
+}
+
 //===----------------------------------------------------------------------===//
 // InterchangeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bad246c262979b7..e0a43a29c32d88b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   NamedOpConversions.cpp
   Padding.cpp
   Promotion.cpp
+  Specialize.cpp
   Split.cpp
   SplitReduction.cpp
   SubsetHoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
new file mode 100644
index 000000000000000..4c437b5db2c7b08
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -0,0 +1,32 @@
+//===- Specialize.cpp - linalg generic ops to named ops  ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a method to specialize generic operations to named
+// operations. Conceptually it is the opposite of generalize.cpp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-specialization"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
+                                                      GenericOp genericOp) {
+  if (isaCopyOpInterface(genericOp)) {
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+    return namedOp;
+  }
+  return failure();
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
new file mode 100644
index 000000000000000..8a22c115f31170f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -0,0 +1,143 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map1, #map], 
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  }
+  return
+}
+
+func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map, #map], 
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+  }
+  return
+}
+
+func.func @transpose_op_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map, #map2], 
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  }
+  return
+}
+
+func.func @copy_with_up_cast(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map, #map], 
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : memref<?x?xf16>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f16, %out: f32):
+      %0 = arith.extf %in : f16 to f32
+      linalg.yield %0 : f32
+  }
+  return
+}
+
+func.func @copy_with_down_cast(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf16>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map, #map], 
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf16>) {
+    ^bb0(%in: f32, %out: f16):
+      %0 = arith.truncf %in : f32 to f16
+      linalg.yield %0 : f16
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{failed to apply}}
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @specialize_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+  linalg.generic {
+    indexing_maps = [#map, #map], 
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  }
+  return
+}
+
+// CHECK-LABEL: specialize_trivial_copy_memref
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+func.func @specialize_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>, 
+    %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [#map1, #map1], 
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: specialize_trivial_copy_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)
+
+func.func @already_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+  linalg.copy ins(%arg0: memref<?x?xf32>) outs(%arg1: memref<?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: already_trivial_copy_memref
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
+// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)
+
+func.func @already_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>,
+    %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.copy ins(%arg0: tensor<?x?x?xf32>) outs(%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: already_trivial_copy_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
+// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list