[Mlir-commits] [mlir] 33b7833 - [MLIR][Linalg] Add more specialize patterns (#91153)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 20 15:10:55 PDT 2024
Author: Javed Absar
Date: 2024-05-20T23:10:51+01:00
New Revision: 33b7833891dcacf9e81e911ed59932fd55113fff
URL: https://github.com/llvm/llvm-project/commit/33b7833891dcacf9e81e911ed59932fd55113fff
DIFF: https://github.com/llvm/llvm-project/commit/33b7833891dcacf9e81e911ed59932fd55113fff.diff
LOG: [MLIR][Linalg] Add more specialize patterns (#91153)
Currently only linalg.copy is recognized when trying to specialize
linalg.generics back to named op. This diff enables recognition of more
generic to named op e.g. linalg.fill, elemwise unary/binary.
Added:
mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
mlir/test/Dialect/Linalg/transform-op-specialize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f92843a1dcb98..08afdf373f014 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
+class GenericOp;
namespace detail {
/// Implementation of the method that check if given operands
@@ -115,6 +116,21 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether a given `genericOp` is semantically equivalent to a single
+/// linalgelementwise unary op. e.g. linalg.exp.
+/// A linalg.generic body could be a series of unary elementwise ops e.g.
+/// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
+/// detecting cases where body is is a single computation op.
+bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a single linalg
+/// elementwise binary op e.g. linalg.sub.
+bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+/// Returns the scalar fill value if true.
+std::optional<Value> isaFillOpInterface(GenericOp genericOp);
+
namespace detail {
/// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda..f35ab3b856b4e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -70,6 +70,99 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+ genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ return std::nullopt;
+
+ // Input should be referenced and init should not.
+ if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
+ genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+ return std::nullopt;
+
+ OpOperand *value = genericOp.getDpsInputOperand(0);
+ if (!genericOp.isScalar(value))
+ return std::nullopt;
+
+ Block *body = genericOp.getBody();
+ if (body->getOperations().size() != 1)
+ return std::nullopt;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0) != body->getArgument(0))
+ return std::nullopt;
+ return value->get();
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise Single Unary/Binary-OpInterface implementation
+//===----------------------------------------------------------------------===//
+static bool
+isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+ unsigned arity) {
+ // Check all loops are parallel, and have only tensor semantics.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+ genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+ return false;
+
+ // Check there are arity-inputs, 1-output and all are identity-maps.
+ if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
+ !llvm::all_of(genericOp.getIndexingMapsArray(),
+ [](AffineMap map) { return map.isIdentity(); }))
+ return false;
+
+ // Init should not be referenced for elementwise operations.
+ if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+ return false;
+
+ // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
+ // as resulting from producer-consumer fusion. Here, we restrict to two ops in
+ // the body, where the first is the elementwise single op and the second a
+ // yield.
+ Block *body = genericOp.getBody();
+ if (body->getOperations().size() != 2)
+ return false;
+
+ Operation *op = &body->front();
+ if (op->getNumOperands() != arity || op->getNumResults() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0).getDefiningOp() != op)
+ return false;
+ return true;
+}
+
+bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
+ // All basic elemwise checks.
+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
+ return false;
+
+ // Check input is actully used.
+ if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+ return false;
+ return true;
+}
+
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
+ return false;
+
+ // Check both inputs are used (elementwise).
+ OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
+ OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
+ !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+ return false;
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4c437b5db2c7b..2bc4d7fbfadcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -14,13 +14,50 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-specialization"
+#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
+ (rewriter.replaceOpWithNewOp<NEWOP>( \
+ genericOp, \
+ ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
+ genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
+ ValueRange{genericOp.getDpsInits()[0]}))
+
+#define REPLACE_UNARY_OP(NEWOP) \
+ (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
+ ValueRange{genericOp.getDpsInputs()[0]}, \
+ ValueRange{genericOp.getDpsInits()[0]}))
+
using namespace mlir;
using namespace mlir::linalg;
+// Given a elementwise single binary linalg generic op, checks whether the
+// binary op accesses operands as swapped. e.g.
+// this
diff erentiates between a linalg-generic body that contains:
+// ^bb0(%a: f32, %b: f32, %c : f32):
+// %0 = arith.subf %a, %b : f32
+// linalg.yield %0: f32
+// against:
+// ^bb0(%a: f32, %b: f32, %c : f32):
+// %0 = arith.subf %b, %a : f32
+// linalg.yield %0: f32
+// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
+static bool areBinOpsSwapped(GenericOp genericOp) {
+ Block *body = genericOp.getBody();
+ Operation *op = &body->front();
+ bool swapped = false;
+ if (op->getOpOperand(0).get() != body->getArgument(0)) {
+ swapped = true;
+ assert(op->getOpOperand(0).get() == body->getArgument(1) &&
+ op->getOpOperand(1).get() == body->getArgument(0) &&
+ "binary op uses just one block arg");
+ }
+ return swapped;
+}
+
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
@@ -28,5 +65,40 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+
+ if (isaFillOpInterface(genericOp)) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+ return namedOp;
+ }
+
+ if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
+ Operation *op = &genericOp.getBody()->front();
+ if (isa<math::ExpOp>(op)) {
+ LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
+ return namedOp;
+ }
+ }
+
+ if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
+ bool swap = areBinOpsSwapped(genericOp);
+ Operation *op = &genericOp.getBody()->front();
+ if (isa<arith::AddFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
+ return namedOp;
+ }
+ if (isa<arith::SubFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
+ return namedOp;
+ }
+ if (isa<arith::MulFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
+ return namedOp;
+ }
+ if (isa<arith::DivFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
+ return namedOp;
+ }
+ }
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8a22c115f3117..35679db7412f3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -141,3 +141,28 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<7x7xf32>
+ return %0 : tensor<7x7xf32>
+}
+// CHECK-LABEL: linalg_generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
new file mode 100644
index 0000000000000..d45025de931cd
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.addf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_add
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.subf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.subf %in_0, %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_mul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
new file mode 100644
index 0000000000000..89a8baa453e90
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list