[Mlir-commits] [mlir] [MLIR][Linalg] Add more specialize patterns (PR #91153)

Javed Absar llvmlistbot at llvm.org
Mon May 20 14:59:28 PDT 2024


https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/91153

>From 468bdd0a08fa3afb03eb388ba938d5fac9f9f591 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 4 May 2024 17:31:31 -0400
Subject: [PATCH 1/3] [MLIR][LINALG] Add more specialize patterns

Currently only linalg.copy is recognized when trying to specialize
linalg.generics back to named op. This diff enables recognition
of more generic to named op e.g. linalg.fill, elemwise unary/binary.
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 12 +++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 99 +++++++++++++++++++
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 47 +++++++++
 .../Linalg/transform-op-specialize.mlir       | 26 ++++-
 ...ansform-op-specialize_elemwise_binary.mlir | 63 ++++++++++++
 ...ransform-op-specialize_elemwise_unary.mlir | 25 +++++
 6 files changed, 271 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f92843a1dcb98..7a67525c1ba67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class GenericOp;
 
 namespace detail {
 /// Implementation of the method that check if given operands
@@ -115,6 +116,17 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise unary op e.g. linalg.exp.
+bool isaElementwiseUnaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise binary op e.g. linalg.sub.
+bool isaElementwiseBinaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+bool isaFillOpInterface(GenericOp genericOp);
+
 namespace detail {
 
 /// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda..e6611e496a4a2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -70,6 +70,105 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
   return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
 }
 
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+bool linalg::isaFillOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+     return false;
+
+  if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+    return false;
+
+  // Input should be referenced and init should not.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
+       genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return false;
+
+  OpOperand *value = genericOp.getDpsInputOperand(0);
+  if (!genericOp.isScalar(value))
+    return false;
+
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return false;
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise-Unary/Binary-OpInterface implementation
+//===----------------------------------------------------------------------===//
+static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+                                                   unsigned arity) {
+  // Check all loops are parallel, and have only tensor semantics.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+      genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+    return false;
+
+  // Check there are arity-inputs, 1-output and all are identity-maps.
+  if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
+      !llvm::all_of(genericOp.getIndexingMapsArray(),
+                    [](AffineMap map) { return map.isIdentity(); }))
+    return false;
+
+  // Init should not be referenced for elementwise operations.
+  if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return false;
+
+  // Expect two ops: first one possibly unary/binary op and the second one must
+  // yield the nary-op result.
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 2)
+    return false;
+
+  Operation *op = &body->front();
+  if (op->getNumOperands() != arity || op->getNumResults() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0).getDefiningOp() != op)
+    return false;
+  return true;
+}
+
+bool linalg::isaElementwiseUnaryOpInterface(linalg::GenericOp genericOp) {
+  // All basic elemwise checks.
+  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 1))
+    return false;
+
+  // Check input is actully used.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+    return false;
+  return true;
+}
+
+bool linalg::isaElementwiseBinaryOpInterface(linalg::GenericOp genericOp) {
+  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 2))
+    return false;
+
+  // Check both inputs are used (elementwise).
+  OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
+  OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
+  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
+      !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+    return false;
+
+  // Check that args are not swapped (all elemwise ops are not commutative).
+  Block *body = genericOp.getBody();
+  Operation *op = &body->front();
+  if (op->getOpOperand(0).get() != body->getArgument(0) ||
+      op->getOpOperand(1).get() != body->getArgument(1))
+    return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4c437b5db2c7b..d3782287289a7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -12,12 +12,25 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "linalg-specialization"
 
+#define REPLACE_BINARY_OP(NEWOP)                                               \
+  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
+      genericOp,                                                               \
+      ValueRange{genericOp.getDpsInputs()[0], genericOp.getDpsInputs()[1]},    \
+      ValueRange{genericOp.getDpsInits()[0]}))
+
+#define REPLACE_UNARY_OP(NEWOP)                                                \
+  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
+      genericOp,                                                               \
+      ValueRange{genericOp.getDpsInputs()[0]},                                 \
+      ValueRange{genericOp.getDpsInits()[0]}))
+
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -28,5 +41,39 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
+
+  if (isaFillOpInterface(genericOp)) {
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+    return namedOp;
+  }
+
+  if (isaElementwiseUnaryOpInterface(genericOp)) {
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<math::ExpOp>(op)) {
+      LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
+      return namedOp;
+    }
+  }
+
+  if (isaElementwiseBinaryOpInterface(genericOp)) {
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<arith::AddFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp);
+      return namedOp;
+    }
+    if (isa<arith::SubFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp);
+      return namedOp;
+    }
+    if (isa<arith::MulFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp);
+      return namedOp;
+    }
+    if (isa<arith::DivFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp);
+      return namedOp;
+    }
+  }
   return failure();
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8a22c115f3117..21dd1fb56789f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -3,7 +3,6 @@
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
-
 func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {
@@ -141,3 +140,28 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<7x7xf32>
+  return %0 : tensor<7x7xf32>
+}
+// CHECK-LABEL: linalg_generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
new file mode 100644
index 0000000000000..7bd3b1a1a4a4c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.addf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_add
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.subf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_mul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.divf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
new file mode 100644
index 0000000000000..89a8baa453e90
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = math.exp %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

>From 27bc71ce07c0b582b181443c0c3b7850105c55fa Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sun, 5 May 2024 19:30:23 -0400
Subject: [PATCH 2/3] [MLIR][LINALG] Fix formatting error

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index e6611e496a4a2..34093a2215322 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -76,14 +76,14 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
 bool linalg::isaFillOpInterface(GenericOp genericOp) {
   // Structural.
   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
-     return false;
+    return false;
 
   if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
     return false;
 
   // Input should be referenced and init should not.
   if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
-       genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+      genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
     return false;
 
   OpOperand *value = genericOp.getDpsInputOperand(0);

>From d4aa1ec1298ac56e4703cc0dd55db6eb9222c38c Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Mon, 20 May 2024 16:26:54 -0400
Subject: [PATCH 3/3] [MLIR][Linalg] Address review comments.

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 18 ++++---
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 48 ++++++++---------
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 51 ++++++++++++++-----
 .../Linalg/transform-op-specialize.mlir       |  1 +
 ...ansform-op-specialize_elemwise_binary.mlir | 13 +++++
 5 files changed, 84 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 7a67525c1ba67..08afdf373f014 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -116,16 +116,20 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
-/// Checks whether `genericOp` is semantically equivalent to a linalg
-// elementwise unary op e.g. linalg.exp.
-bool isaElementwiseUnaryOpInterface(GenericOp genericOp);
+/// Checks whether a given `genericOp` is semantically equivalent to a single
+/// linalgelementwise unary op. e.g. linalg.exp.
+/// A linalg.generic body could be a series of unary elementwise ops e.g.
+/// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
+/// detecting cases where body is is a single computation op.
+bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
 
-/// Checks whether `genericOp` is semantically equivalent to a linalg
-// elementwise binary op e.g. linalg.sub.
-bool isaElementwiseBinaryOpInterface(GenericOp genericOp);
+/// Checks whether `genericOp` is semantically equivalent to a single linalg
+/// elementwise binary op e.g. linalg.sub.
+bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
 
 /// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
-bool isaFillOpInterface(GenericOp genericOp);
+/// Returns the scalar fill value if true.
+std::optional<Value> isaFillOpInterface(GenericOp genericOp);
 
 namespace detail {
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 34093a2215322..f35ab3b856b4e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -73,39 +73,38 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
 //===----------------------------------------------------------------------===//
 // FillOpInterface implementation
 //===----------------------------------------------------------------------===//
-bool linalg::isaFillOpInterface(GenericOp genericOp) {
+std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   // Structural.
-  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
-    return false;
-
-  if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
-    return false;
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+      genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+    return std::nullopt;
 
   // Input should be referenced and init should not.
   if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
       genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
-    return false;
+    return std::nullopt;
 
   OpOperand *value = genericOp.getDpsInputOperand(0);
   if (!genericOp.isScalar(value))
-    return false;
+    return std::nullopt;
 
   Block *body = genericOp.getBody();
   if (body->getOperations().size() != 1)
-    return false;
+    return std::nullopt;
 
   auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
   if (!yieldOp || yieldOp.getNumOperands() != 1 ||
       yieldOp->getOperand(0) != body->getArgument(0))
-    return false;
-  return true;
+    return std::nullopt;
+  return value->get();
 }
 
 //===----------------------------------------------------------------------===//
-// Elementwise-Unary/Binary-OpInterface implementation
+// Elementwise Single Unary/Binary-OpInterface implementation
 //===----------------------------------------------------------------------===//
-static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
-                                                   unsigned arity) {
+static bool
+isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+                                          unsigned arity) {
   // Check all loops are parallel, and have only tensor semantics.
   if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
       genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
@@ -121,8 +120,10 @@ static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
   if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
     return false;
 
-  // Expect two ops: first one possibly unary/binary op and the second one must
-  // yield the nary-op result.
+  // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
+  // as resulting from producer-consumer fusion. Here, we restrict to two ops in
+  // the body, where the first is the elementwise single op and the second a
+  // yield.
   Block *body = genericOp.getBody();
   if (body->getOperations().size() != 2)
     return false;
@@ -138,9 +139,9 @@ static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
   return true;
 }
 
-bool linalg::isaElementwiseUnaryOpInterface(linalg::GenericOp genericOp) {
+bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
   // All basic elemwise checks.
-  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 1))
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
     return false;
 
   // Check input is actully used.
@@ -149,8 +150,8 @@ bool linalg::isaElementwiseUnaryOpInterface(linalg::GenericOp genericOp) {
   return true;
 }
 
-bool linalg::isaElementwiseBinaryOpInterface(linalg::GenericOp genericOp) {
-  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 2))
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
     return false;
 
   // Check both inputs are used (elementwise).
@@ -159,13 +160,6 @@ bool linalg::isaElementwiseBinaryOpInterface(linalg::GenericOp genericOp) {
   if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
       !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
     return false;
-
-  // Check that args are not swapped (all elemwise ops are not commutative).
-  Block *body = genericOp.getBody();
-  Operation *op = &body->front();
-  if (op->getOpOperand(0).get() != body->getArgument(0) ||
-      op->getOpOperand(1).get() != body->getArgument(1))
-    return false;
   return true;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index d3782287289a7..2bc4d7fbfadcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -12,28 +12,52 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "linalg-specialization"
 
-#define REPLACE_BINARY_OP(NEWOP)                                               \
+#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
   (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
       genericOp,                                                               \
-      ValueRange{genericOp.getDpsInputs()[0], genericOp.getDpsInputs()[1]},    \
+      ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0],            \
+                 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]},           \
       ValueRange{genericOp.getDpsInits()[0]}))
 
 #define REPLACE_UNARY_OP(NEWOP)                                                \
-  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
-      genericOp,                                                               \
-      ValueRange{genericOp.getDpsInputs()[0]},                                 \
-      ValueRange{genericOp.getDpsInits()[0]}))
+  (rewriter.replaceOpWithNewOp<NEWOP>(genericOp,                               \
+                                      ValueRange{genericOp.getDpsInputs()[0]}, \
+                                      ValueRange{genericOp.getDpsInits()[0]}))
 
 using namespace mlir;
 using namespace mlir::linalg;
 
+// Given a elementwise single binary linalg generic op, checks whether the
+// binary op accesses operands as swapped. e.g.
+// this differentiates between a linalg-generic body that contains:
+//    ^bb0(%a: f32, %b: f32, %c : f32):
+//         %0 = arith.subf %a, %b : f32
+//         linalg.yield %0: f32
+// against:
+//    ^bb0(%a: f32, %b: f32, %c : f32):
+//         %0 = arith.subf %b, %a : f32
+//         linalg.yield %0: f32
+// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
+static bool areBinOpsSwapped(GenericOp genericOp) {
+  Block *body = genericOp.getBody();
+  Operation *op = &body->front();
+  bool swapped = false;
+  if (op->getOpOperand(0).get() != body->getArgument(0)) {
+    swapped = true;
+    assert(op->getOpOperand(0).get() == body->getArgument(1) &&
+           op->getOpOperand(1).get() == body->getArgument(0) &&
+           "binary op uses just one block arg");
+  }
+  return swapped;
+}
+
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
   if (isaCopyOpInterface(genericOp)) {
@@ -48,7 +72,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     return namedOp;
   }
 
-  if (isaElementwiseUnaryOpInterface(genericOp)) {
+  if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
     Operation *op = &genericOp.getBody()->front();
     if (isa<math::ExpOp>(op)) {
       LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
@@ -56,22 +80,23 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
-  if (isaElementwiseBinaryOpInterface(genericOp)) {
+  if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
+    bool swap = areBinOpsSwapped(genericOp);
     Operation *op = &genericOp.getBody()->front();
     if (isa<arith::AddFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp);
+      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
       return namedOp;
     }
     if (isa<arith::SubFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp);
+      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
       return namedOp;
     }
     if (isa<arith::MulFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp);
+      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
       return namedOp;
     }
     if (isa<arith::DivFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp);
+      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
       return namedOp;
     }
   }
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 21dd1fb56789f..35679db7412f3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -3,6 +3,7 @@
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
+
 func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
index 7bd3b1a1a4a4c..d45025de931cd 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -27,6 +27,19 @@ func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.subf %in_0, %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
 func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):



More information about the Mlir-commits mailing list