[Mlir-commits] [mlir] [MLIR][Linalg] Introduce SpecializeOp (PR #70326)

lorenzo chelini llvmlistbot at llvm.org
Thu Oct 26 06:18:01 PDT 2023


https://github.com/chelini created https://github.com/llvm/llvm-project/pull/70326

Introduce an operation to specialize linalg.generics, for example, detecting a linalg.generic that is semantically equivalent to a linalg.copy and replacing the former with the latter. After code generation, it is helpful to lower named operations to vendor-optimized libraries.

>From 3b609f352f84f76d3e8b3c9478d8f9e55ae5c2b9 Mon Sep 17 00:00:00 2001
From: Lorenzo Chelini <l.chelini at icloud.com>
Date: Wed, 25 Oct 2023 15:26:42 +0200
Subject: [PATCH] [MLIR][Linalg] Introduce SpecializeOp

Introduce an operation to specialize linalg.generics, for example,
detecting a linalg.generic that is semantically equivalent to a
linalg.copy and replacing the former with the latter. After code
generation, it is helpful to lower named operations to vendor-optimized
libraries.
---
 .../Linalg/TransformOps/LinalgTransformOps.td | 36 +++++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  5 ++
 .../TransformOps/LinalgTransformOps.cpp       | 24 ++++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |  1 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 52 +++++++++++++
 .../Linalg/transform-op-specialize.mlir       | 77 +++++++++++++++++++
 6 files changed, 195 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..2d86d443a28ebbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -390,6 +390,42 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===//
+
+def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Transforms a generic operation into the equivalent named form.
+
+    #### Return modes
+
+    This operation ignores non-Linalg ops and drops them in the return.
+    If all the operations referred to by the `target` handle specialize
+    properly, the transform succeeds. Otherwise the transform silently fails.
+    The return handle points to only the subset of successfully produced
+    equivalent named operations, which can be empty or contain the original
+    ops if they were already in named form.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+  let assemblyFormat =
+      "$target attr-dict `:` "
+      "custom<SemiFunctionType>(type($target), type($transformed))";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // InterchangeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fbe2923c710aabb..122f73562852101 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -668,6 +668,11 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
 FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
                                        LinalgOp namedOp);
 
+/// Create a namedOp from the given GenericOp and replace the GenericOp.
+/// Currently we can specialize only trivial linalg copy operations.
+FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
+                                        GenericOp genericOp);
+
 /// Create a new buffer using the `allocationFn` provided. The size of this
 /// buffer is the smallest constant bounding size along each dimension that
 /// can be computed for the size of the result of `subView`. Returns the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..87be3bb85b6e788 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1018,6 +1018,30 @@ transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
   return emitDefaultSilenceableFailure(target);
 }
 
+//===----------------------------------------------------------------------===//
+// SpecializeOp
+//===----------------------------------------------------------------------===/
+
+DiagnosedSilenceableFailure
+transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
+                                    LinalgOp target,
+                                    transform::ApplyToEachResultList &results,
+                                    transform::TransformState &state) {
+  // Exit early if the operation is not a generic.
+  if (!isa<GenericOp>(target)) {
+    results.push_back(target);
+    return DiagnosedSilenceableFailure::success();
+  }
+  rewriter.setInsertionPoint(target);
+  FailureOr<LinalgOp> named =
+      specializeGenericOp(rewriter, cast<GenericOp>(target));
+  if (succeeded(named)) {
+    results.push_back(named->getOperation());
+    return DiagnosedSilenceableFailure::success();
+  }
+  return emitDefaultSilenceableFailure(target);
+}
+
 //===----------------------------------------------------------------------===//
 // InterchangeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..5ec1fd5dd7e91db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   NamedOpConversions.cpp
   Padding.cpp
   Promotion.cpp
+  Specialize.cpp
   Split.cpp
   SplitReduction.cpp
   SubsetHoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
new file mode 100644
index 000000000000000..6c7be63069dad1d
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -0,0 +1,52 @@
+//===- Specialize.cpp - linalg generic ops to named ops  ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a method to specialize generic operations to named
+// operations. Conceptually it is the opposite of generalize.cpp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-specialization"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static bool isaCopyOp(GenericOp genericOp) {
+  // Structural.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+    return false;
+
+  // Operands and maps.
+  if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+    return false;
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
+      !mapRange.back().isIdentity()) {
+    return false;
+  }
+
+  // Region.
+  Region &reg = genericOp.getRegion();
+  if (!llvm::hasSingleElement(reg))
+    return false;
+  return std::distance(reg.front().begin(), reg.front().end()) == 1;
+}
+
+FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
+                                                      GenericOp genericOp) {
+  if (isaCopyOp(genericOp)) {
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+    return namedOp;
+  }
+  return failure();
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
new file mode 100644
index 000000000000000..a125d2dc3ca29e6
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map1, #map], 
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  }
+  return
+}
+
+func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map, #map], 
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+  }
+  return
+}
+
+func.func @transpose_op_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+  // expected-note @below {{when applied to this op}}
+  linalg.generic {
+    indexing_maps = [#map, #map2], 
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{failed to apply}}
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: generalize_trivial_copy
+func.func @generalize_trivial_copy(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
+  // CHECK: linalg.copy
+  // CHECK-NOT: linalg.generic
+  linalg.generic {
+    indexing_maps = [#map, #map], 
+    iterator_types = ["parallel", "parallel"]} 
+    ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list