[Mlir-commits] [mlir] [MLIR][LINALG] Add more specialize patterns (PR #91153)

Javed Absar llvmlistbot at llvm.org
Sun May 5 16:39:06 PDT 2024


https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/91153

>From 468bdd0a08fa3afb03eb388ba938d5fac9f9f591 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 4 May 2024 17:31:31 -0400
Subject: [PATCH 1/2] [MLIR][LINALG] Add more specialize patterns

Currently only linalg.copy is recognized when trying to specialize
linalg.generics back to named op. This diff enables recognition
of more generic to named op e.g. linalg.fill, elemwise unary/binary.
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 12 +++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 99 +++++++++++++++++++
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 47 +++++++++
 .../Linalg/transform-op-specialize.mlir       | 26 ++++-
 ...ansform-op-specialize_elemwise_binary.mlir | 63 ++++++++++++
 ...ransform-op-specialize_elemwise_unary.mlir | 25 +++++
 6 files changed, 271 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f92843a1dcb987..7a67525c1ba674 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class GenericOp;
 
 namespace detail {
 /// Implementation of the method that check if given operands
@@ -115,6 +116,17 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise unary op e.g. linalg.exp.
+bool isaElementwiseUnaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise binary op e.g. linalg.sub.
+bool isaElementwiseBinaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+bool isaFillOpInterface(GenericOp genericOp);
+
 namespace detail {
 
 /// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..e6611e496a4a2e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -70,6 +70,105 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
   return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
 }
 
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+bool linalg::isaFillOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+     return false;
+
+  if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+    return false;
+
+  // Input should be referenced and init should not.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
+       genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return false;
+
+  OpOperand *value = genericOp.getDpsInputOperand(0);
+  if (!genericOp.isScalar(value))
+    return false;
+
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return false;
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise-Unary/Binary-OpInterface implementation
+//===----------------------------------------------------------------------===//
+static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+                                                   unsigned arity) {
+  // Check all loops are parallel, and have only tensor semantics.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+      genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+    return false;
+
+  // Check there are arity-inputs, 1-output and all are identity-maps.
+  if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
+      !llvm::all_of(genericOp.getIndexingMapsArray(),
+                    [](AffineMap map) { return map.isIdentity(); }))
+    return false;
+
+  // Init should not be referenced for elementwise operations.
+  if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return false;
+
+  // Expect two ops: first one possibly unary/binary op and the second one must
+  // yield the nary-op result.
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 2)
+    return false;
+
+  Operation *op = &body->front();
+  if (op->getNumOperands() != arity || op->getNumResults() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0).getDefiningOp() != op)
+    return false;
+  return true;
+}
+
+bool linalg::isaElementwiseUnaryOpInterface(linalg::GenericOp genericOp) {
+  // All basic elemwise checks.
+  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 1))
+    return false;
+
+  // Check input is actully used.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+    return false;
+  return true;
+}
+
+bool linalg::isaElementwiseBinaryOpInterface(linalg::GenericOp genericOp) {
+  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 2))
+    return false;
+
+  // Check both inputs are used (elementwise).
+  OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
+  OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
+  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
+      !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+    return false;
+
+  // Check that args are not swapped (all elemwise ops are not commutative).
+  Block *body = genericOp.getBody();
+  Operation *op = &body->front();
+  if (op->getOpOperand(0).get() != body->getArgument(0) ||
+      op->getOpOperand(1).get() != body->getArgument(1))
+    return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4c437b5db2c7b0..d3782287289a7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -12,12 +12,25 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "linalg-specialization"
 
+#define REPLACE_BINARY_OP(NEWOP)                                               \
+  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
+      genericOp,                                                               \
+      ValueRange{genericOp.getDpsInputs()[0], genericOp.getDpsInputs()[1]},    \
+      ValueRange{genericOp.getDpsInits()[0]}))
+
+#define REPLACE_UNARY_OP(NEWOP)                                                \
+  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
+      genericOp,                                                               \
+      ValueRange{genericOp.getDpsInputs()[0]},                                 \
+      ValueRange{genericOp.getDpsInits()[0]}))
+
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -28,5 +41,39 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
+
+  if (isaFillOpInterface(genericOp)) {
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+    return namedOp;
+  }
+
+  if (isaElementwiseUnaryOpInterface(genericOp)) {
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<math::ExpOp>(op)) {
+      LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
+      return namedOp;
+    }
+  }
+
+  if (isaElementwiseBinaryOpInterface(genericOp)) {
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<arith::AddFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp);
+      return namedOp;
+    }
+    if (isa<arith::SubFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp);
+      return namedOp;
+    }
+    if (isa<arith::MulFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp);
+      return namedOp;
+    }
+    if (isa<arith::DivFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp);
+      return namedOp;
+    }
+  }
   return failure();
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8a22c115f31170..21dd1fb56789f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -3,7 +3,6 @@
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
-
 func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {
@@ -141,3 +140,28 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<7x7xf32>
+  return %0 : tensor<7x7xf32>
+}
+// CHECK-LABEL: linalg_generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
new file mode 100644
index 00000000000000..7bd3b1a1a4a4ca
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.addf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_add
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.subf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_mul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.divf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
new file mode 100644
index 00000000000000..89a8baa453e905
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = math.exp %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

>From 27bc71ce07c0b582b181443c0c3b7850105c55fa Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sun, 5 May 2024 19:30:23 -0400
Subject: [PATCH 2/2] [MLIR][LINALG] Fix formatting error

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index e6611e496a4a2e..34093a22153221 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -76,14 +76,14 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
 bool linalg::isaFillOpInterface(GenericOp genericOp) {
   // Structural.
   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
-     return false;
+    return false;
 
   if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
     return false;
 
   // Input should be referenced and init should not.
   if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
-       genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+      genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
     return false;
 
   OpOperand *value = genericOp.getDpsInputOperand(0);



More information about the Mlir-commits mailing list