[Mlir-commits] [mlir] 33b7833 - [MLIR][Linalg] Add more specialize patterns (#91153)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 20 15:10:55 PDT 2024


Author: Javed Absar
Date: 2024-05-20T23:10:51+01:00
New Revision: 33b7833891dcacf9e81e911ed59932fd55113fff

URL: https://github.com/llvm/llvm-project/commit/33b7833891dcacf9e81e911ed59932fd55113fff
DIFF: https://github.com/llvm/llvm-project/commit/33b7833891dcacf9e81e911ed59932fd55113fff.diff

LOG: [MLIR][Linalg] Add more specialize patterns (#91153)

Currently only linalg.copy is recognized when trying to specialize
linalg.generics back to named op. This diff enables recognition of more
generic to named op e.g. linalg.fill, elemwise unary/binary.

Added: 
    mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
    mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
    mlir/test/Dialect/Linalg/transform-op-specialize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f92843a1dcb98..08afdf373f014 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class GenericOp;
 
 namespace detail {
 /// Implementation of the method that check if given operands
@@ -115,6 +116,21 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
+/// Checks whether a given `genericOp` is semantically equivalent to a single
+/// linalgelementwise unary op. e.g. linalg.exp.
+/// A linalg.generic body could be a series of unary elementwise ops e.g.
+/// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
+/// detecting cases where body is is a single computation op.
+bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a single linalg
+/// elementwise binary op e.g. linalg.sub.
+bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+/// Returns the scalar fill value if true.
+std::optional<Value> isaFillOpInterface(GenericOp genericOp);
+
 namespace detail {
 
 /// Returns true if the block contains a contraction of the following form:

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda..f35ab3b856b4e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -70,6 +70,99 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
   return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
 }
 
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+      genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+    return std::nullopt;
+
+  // Input should be referenced and init should not.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
+      genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return std::nullopt;
+
+  OpOperand *value = genericOp.getDpsInputOperand(0);
+  if (!genericOp.isScalar(value))
+    return std::nullopt;
+
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 1)
+    return std::nullopt;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return std::nullopt;
+  return value->get();
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise Single Unary/Binary-OpInterface implementation
+//===----------------------------------------------------------------------===//
+static bool
+isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+                                          unsigned arity) {
+  // Check all loops are parallel, and have only tensor semantics.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+      genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+    return false;
+
+  // Check there are arity-inputs, 1-output and all are identity-maps.
+  if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
+      !llvm::all_of(genericOp.getIndexingMapsArray(),
+                    [](AffineMap map) { return map.isIdentity(); }))
+    return false;
+
+  // Init should not be referenced for elementwise operations.
+  if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return false;
+
+  // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
+  // as resulting from producer-consumer fusion. Here, we restrict to two ops in
+  // the body, where the first is the elementwise single op and the second a
+  // yield.
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 2)
+    return false;
+
+  Operation *op = &body->front();
+  if (op->getNumOperands() != arity || op->getNumResults() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0).getDefiningOp() != op)
+    return false;
+  return true;
+}
+
+bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
+  // All basic elemwise checks.
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
+    return false;
+
+  // Check input is actully used.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+    return false;
+  return true;
+}
+
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
+    return false;
+
+  // Check both inputs are used (elementwise).
+  OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
+  OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
+  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
+      !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+    return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4c437b5db2c7b..2bc4d7fbfadcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -14,13 +14,50 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "linalg-specialization"
 
+#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
+  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
+      genericOp,                                                               \
+      ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0],            \
+                 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]},           \
+      ValueRange{genericOp.getDpsInits()[0]}))
+
+#define REPLACE_UNARY_OP(NEWOP)                                                \
+  (rewriter.replaceOpWithNewOp<NEWOP>(genericOp,                               \
+                                      ValueRange{genericOp.getDpsInputs()[0]}, \
+                                      ValueRange{genericOp.getDpsInits()[0]}))
+
 using namespace mlir;
 using namespace mlir::linalg;
 
+// Given a elementwise single binary linalg generic op, checks whether the
+// binary op accesses operands as swapped. e.g.
+// this 
diff erentiates between a linalg-generic body that contains:
+//    ^bb0(%a: f32, %b: f32, %c : f32):
+//         %0 = arith.subf %a, %b : f32
+//         linalg.yield %0: f32
+// against:
+//    ^bb0(%a: f32, %b: f32, %c : f32):
+//         %0 = arith.subf %b, %a : f32
+//         linalg.yield %0: f32
+// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
+static bool areBinOpsSwapped(GenericOp genericOp) {
+  Block *body = genericOp.getBody();
+  Operation *op = &body->front();
+  bool swapped = false;
+  if (op->getOpOperand(0).get() != body->getArgument(0)) {
+    swapped = true;
+    assert(op->getOpOperand(0).get() == body->getArgument(1) &&
+           op->getOpOperand(1).get() == body->getArgument(0) &&
+           "binary op uses just one block arg");
+  }
+  return swapped;
+}
+
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
   if (isaCopyOpInterface(genericOp)) {
@@ -28,5 +65,40 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
+
+  if (isaFillOpInterface(genericOp)) {
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+    return namedOp;
+  }
+
+  if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<math::ExpOp>(op)) {
+      LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
+      return namedOp;
+    }
+  }
+
+  if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
+    bool swap = areBinOpsSwapped(genericOp);
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<arith::AddFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
+      return namedOp;
+    }
+    if (isa<arith::SubFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
+      return namedOp;
+    }
+    if (isa<arith::MulFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
+      return namedOp;
+    }
+    if (isa<arith::DivFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
+      return namedOp;
+    }
+  }
   return failure();
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8a22c115f3117..35679db7412f3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -141,3 +141,28 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<7x7xf32>
+  return %0 : tensor<7x7xf32>
+}
+// CHECK-LABEL: linalg_generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
new file mode 100644
index 0000000000000..d45025de931cd
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.addf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_add
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.subf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.subf %in_0, %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_mul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.divf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
new file mode 100644
index 0000000000000..89a8baa453e90
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = math.exp %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list