[Mlir-commits] [mlir] [MLIR][LINALG] Add more specialize patterns (PR #91153)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 5 16:25:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Javed Absar (javedabsar1)
<details>
<summary>Changes</summary>
Currently only linalg.copy is recognized when trying to specialize linalg.generics back to named op. This diff enables recognition of more generic to named op e.g. linalg.fill, elemwise unary/binary.
---
Full diff: https://github.com/llvm/llvm-project/pull/91153.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+12)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+99)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+47)
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize.mlir (+25-1)
- (added) mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir (+63)
- (added) mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir (+25)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f92843a1dcb987..7a67525c1ba674 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
+class GenericOp;
namespace detail {
/// Implementation of the method that check if given operands
@@ -115,6 +116,17 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise unary op e.g. linalg.exp.
+bool isaElementwiseUnaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise binary op e.g. linalg.sub.
+bool isaElementwiseBinaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+bool isaFillOpInterface(GenericOp genericOp);
+
namespace detail {
/// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..e6611e496a4a2e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -70,6 +70,105 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+bool linalg::isaFillOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+ return false;
+
+ if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ return false;
+
+ // Input should be referenced and init should not.
+ if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
+ genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+ return false;
+
+ OpOperand *value = genericOp.getDpsInputOperand(0);
+ if (!genericOp.isScalar(value))
+ return false;
+
+ Block *body = genericOp.getBody();
+ if (body->getOperations().size() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0) != body->getArgument(0))
+ return false;
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise-Unary/Binary-OpInterface implementation
+//===----------------------------------------------------------------------===//
+static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+ unsigned arity) {
+ // Check all loops are parallel, and have only tensor semantics.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+ genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+ return false;
+
+ // Check there are arity-inputs, 1-output and all are identity-maps.
+ if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
+ !llvm::all_of(genericOp.getIndexingMapsArray(),
+ [](AffineMap map) { return map.isIdentity(); }))
+ return false;
+
+ // Init should not be referenced for elementwise operations.
+ if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+ return false;
+
+ // Expect two ops: first one possibly unary/binary op and the second one must
+ // yield the nary-op result.
+ Block *body = genericOp.getBody();
+ if (body->getOperations().size() != 2)
+ return false;
+
+ Operation *op = &body->front();
+ if (op->getNumOperands() != arity || op->getNumResults() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0).getDefiningOp() != op)
+ return false;
+ return true;
+}
+
+bool linalg::isaElementwiseUnaryOpInterface(linalg::GenericOp genericOp) {
+ // All basic elemwise checks.
+ if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 1))
+ return false;
+
+ // Check input is actully used.
+ if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+ return false;
+ return true;
+}
+
+bool linalg::isaElementwiseBinaryOpInterface(linalg::GenericOp genericOp) {
+ if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 2))
+ return false;
+
+ // Check both inputs are used (elementwise).
+ OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
+ OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
+ !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+ return false;
+
+ // Check that args are not swapped (all elemwise ops are not commutative).
+ Block *body = genericOp.getBody();
+ Operation *op = &body->front();
+ if (op->getOpOperand(0).get() != body->getArgument(0) ||
+ op->getOpOperand(1).get() != body->getArgument(1))
+ return false;
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4c437b5db2c7b0..d3782287289a7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -12,12 +12,25 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-specialization"
+#define REPLACE_BINARY_OP(NEWOP) \
+ (rewriter.replaceOpWithNewOp<NEWOP>( \
+ genericOp, \
+ ValueRange{genericOp.getDpsInputs()[0], genericOp.getDpsInputs()[1]}, \
+ ValueRange{genericOp.getDpsInits()[0]}))
+
+#define REPLACE_UNARY_OP(NEWOP) \
+ (rewriter.replaceOpWithNewOp<NEWOP>( \
+ genericOp, \
+ ValueRange{genericOp.getDpsInputs()[0]}, \
+ ValueRange{genericOp.getDpsInits()[0]}))
+
using namespace mlir;
using namespace mlir::linalg;
@@ -28,5 +41,39 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+
+ if (isaFillOpInterface(genericOp)) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+ return namedOp;
+ }
+
+ if (isaElementwiseUnaryOpInterface(genericOp)) {
+ Operation *op = &genericOp.getBody()->front();
+ if (isa<math::ExpOp>(op)) {
+ LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
+ return namedOp;
+ }
+ }
+
+ if (isaElementwiseBinaryOpInterface(genericOp)) {
+ Operation *op = &genericOp.getBody()->front();
+ if (isa<arith::AddFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(AddOp);
+ return namedOp;
+ }
+ if (isa<arith::SubFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(SubOp);
+ return namedOp;
+ }
+ if (isa<arith::MulFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(MulOp);
+ return namedOp;
+ }
+ if (isa<arith::DivFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(DivOp);
+ return namedOp;
+ }
+ }
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8a22c115f31170..21dd1fb56789f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -3,7 +3,6 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>
-
func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
@@ -141,3 +140,28 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<7x7xf32>
+ return %0 : tensor<7x7xf32>
+}
+// CHECK-LABEL: linalg_generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
new file mode 100644
index 00000000000000..7bd3b1a1a4a4ca
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.addf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_add
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.subf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_mul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
new file mode 100644
index 00000000000000..89a8baa453e905
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/91153
More information about the Mlir-commits
mailing list