[Mlir-commits] [mlir] [MLIR][Linalg] Add matchers to specialize more unary ops (PR #183259)
Julian Oppermann
llvmlistbot at llvm.org
Mon Mar 16 00:10:52 PDT 2026
https://github.com/jopperm updated https://github.com/llvm/llvm-project/pull/183259
>From d337be027da565f1d81fe89b7b81e30f41216c04 Mon Sep 17 00:00:00 2001
From: Julian Oppermann <julian.oppermann at intel.com>
Date: Wed, 25 Feb 2026 00:13:34 -0800
Subject: [PATCH] [MLIR][Linalg] Add matchers to specialize more unary ops
---
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 6 +-
.../Dialect/Linalg/Transforms/Specialize.cpp | 39 ++++++-
.../Linalg/linalg-morph-multi-step.mlir | 40 ++++++-
.../Linalg/roundtrip-linalg-named-ops.mlir | 28 ++++-
.../Linalg/specialize-generic-ops.mlir | 109 +++++++++++++++++-
...ransform-op-specialize-elemwise-unary.mlir | 105 ++++++++++++++++-
6 files changed, 306 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c27749c36887d..1bddc286b3dd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -251,8 +251,12 @@ static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
if (body->getOperations().size() != 2)
return false;
+ // The payload op must have one result and at least arity-many operands
+ // (otherwise not all inputs can be used). It can have additional operands
+ // from outside of the generic op (e.g. div(1, x) for linalg.reciprocal) or
+ // use an input more than once (e.g. mul(x, x) for linalg.square).
Operation *oper = &body->front();
- if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
+ if (oper->getNumOperands() < arity || oper->getNumResults() != 1)
return false;
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a71f84dee3bb0..25ea1293c6540 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -36,9 +36,9 @@ namespace mlir {
ValueRange{genericOp.getDpsInits()[0]}))
#define REPLACE_UNARY_OP(NEWOP) \
- (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
- ValueRange{genericOp.getDpsInputs()[0]}, \
- ValueRange{genericOp.getDpsInits()[0]}))
+ static_cast<LinalgOp>(rewriter.replaceOpWithNewOp<NEWOP>( \
+ genericOp, ValueRange{genericOp.getDpsInputs()[0]}, \
+ ValueRange{genericOp.getDpsInits()[0]}))
using namespace mlir;
using namespace mlir::linalg;
@@ -446,10 +446,37 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
// Elementwise Unary
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
- if (isa<math::ExpOp>(op)) {
- LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
- return namedOp;
+ if (isa<math::ExpOp>(op))
+ return REPLACE_UNARY_OP(ExpOp);
+ if (isa<math::LogOp>(op))
+ return REPLACE_UNARY_OP(LogOp);
+ if (isa<math::AbsFOp>(op))
+ return REPLACE_UNARY_OP(AbsOp);
+ if (isa<math::CeilOp>(op))
+ return REPLACE_UNARY_OP(CeilOp);
+ if (isa<math::FloorOp>(op))
+ return REPLACE_UNARY_OP(FloorOp);
+ if (isa<arith::NegFOp>(op))
+ return REPLACE_UNARY_OP(NegFOp);
+ if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+ if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+ divOp.getLhs().getDefiningOp()))
+ if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+ return REPLACE_UNARY_OP(ReciprocalOp);
}
+ if (isa<math::RoundOp>(op))
+ return REPLACE_UNARY_OP(RoundOp);
+ if (isa<math::SqrtOp>(op))
+ return REPLACE_UNARY_OP(SqrtOp);
+ if (isa<math::RsqrtOp>(op))
+ return REPLACE_UNARY_OP(RsqrtOp);
+ if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+ mulOp && mulOp.getLhs() == mulOp.getRhs())
+ return REPLACE_UNARY_OP(SquareOp);
+ if (isa<math::TanhOp>(op))
+ return REPLACE_UNARY_OP(TanhOp);
+ if (isa<math::ErfOp>(op))
+ return REPLACE_UNARY_OP(ErfOp);
}
// Elementwise Binary
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
index bdd29b96346e1..bfaeee4717874 100644
--- a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -4,11 +4,47 @@
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
- return %exp : tensor<16x8xf32>
+ %log = linalg.log ins(%exp : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %abs = linalg.abs ins(%log : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %ceil = linalg.ceil ins(%abs : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %floor = linalg.floor ins(%ceil : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %negf = linalg.negf ins(%floor : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %recip = linalg.reciprocal ins(%negf : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %round = linalg.round ins(%recip : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %sqrt = linalg.sqrt ins(%round : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %rsqrt = linalg.rsqrt ins(%sqrt : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %square = linalg.square ins(%rsqrt : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %tanh = linalg.tanh ins(%square : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ %erf = linalg.erf ins(%tanh : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %erf : tensor<16x8xf32>
}
-// NAMED_TO_GENERIC: linalg.generic
+// NAMED_TO_GENERIC-COUNT-13: linalg.generic
// NAMED_TO_GENERIC-NOT: linalg.exp
+// NAMED_TO_GENERIC-NOT: linalg.log
+// NAMED_TO_GENERIC-NOT: linalg.abs
+// NAMED_TO_GENERIC-NOT: linalg.ceil
+// NAMED_TO_GENERIC-NOT: linalg.floor
+// NAMED_TO_GENERIC-NOT: linalg.negf
+// NAMED_TO_GENERIC-NOT: linalg.reciprocal
+// NAMED_TO_GENERIC-NOT: linalg.round
+// NAMED_TO_GENERIC-NOT: linalg.sqrt
+// NAMED_TO_GENERIC-NOT: linalg.rsqrt
+// NAMED_TO_GENERIC-NOT: linalg.square
+// NAMED_TO_GENERIC-NOT: linalg.tanh
+// NAMED_TO_GENERIC-NOT: linalg.erf
// ROUND_TRIP: linalg.exp
+// ROUND_TRIP: linalg.log
+// ROUND_TRIP: linalg.abs
+// ROUND_TRIP: linalg.ceil
+// ROUND_TRIP: linalg.floor
+// ROUND_TRIP: linalg.negf
+// ROUND_TRIP: linalg.reciprocal
+// ROUND_TRIP: linalg.round
+// ROUND_TRIP: linalg.sqrt
+// ROUND_TRIP: linalg.rsqrt
+// ROUND_TRIP: linalg.square
+// ROUND_TRIP: linalg.tanh
+// ROUND_TRIP: linalg.erf
// ROUND_TRIP-NOT: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
index f15ae646e5765..29083dfe0ae1e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -2,15 +2,39 @@
// lifted back up to named op.
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
-func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
+func.func @unary(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.log ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.abs ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.ceil ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.floor ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.negf ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.reciprocal ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.round ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.sqrt ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.rsqrt ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.square ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.tanh ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.erf ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
return
}
-// CHECK-LABEL: unary_exp
+// CHECK-LABEL: unary
// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
// CHECK-NOT: linalg.generic
// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.log ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.abs ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.ceil ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.floor ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.negf ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.reciprocal ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.round ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.sqrt ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.rsqrt ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.square ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.tanh ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.erf ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
// -----
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 6acf1ca0d4e30..5c7120f3fe41c 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1,21 +1,118 @@
// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+func.func @unary_ops(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = linalg.generic
{indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
ins(%A : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
- %1 = math.exp %in : f32
- linalg.yield %1 : f32
+ %v = math.exp %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%0 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.log %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %2 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%1 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.absf %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %3 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%2 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.ceil %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %4 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%3 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.floor %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %5 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%4 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = arith.negf %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %cst_1 = arith.constant 1.0 : f32
+ %6 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%5 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = arith.divf %cst_1, %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %7 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%6 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.round %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %8 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%7 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.sqrt %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %9 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%8 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.rsqrt %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %10 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%9 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = arith.mulf %in, %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %11 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%10 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.tanh %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %12 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%11 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.erf %in : f32
+ linalg.yield %v : f32
} -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
+ return %12 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: unary_op_exp
+// CHECK-LABEL: unary_ops
// CHECK-SAME: %[[A:.+]]: tensor<?x?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES0:.+]] = linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES1:.+]] = linalg.log ins(%[[RES0]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES2:.+]] = linalg.abs ins(%[[RES1]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES3:.+]] = linalg.ceil ins(%[[RES2]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES4:.+]] = linalg.floor ins(%[[RES3]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES5:.+]] = linalg.negf ins(%[[RES4]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES6:.+]] = linalg.reciprocal ins(%[[RES5]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES7:.+]] = linalg.round ins(%[[RES6]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES8:.+]] = linalg.sqrt ins(%[[RES7]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES9:.+]] = linalg.rsqrt ins(%[[RES8]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES10:.+]] = linalg.square ins(%[[RES9]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES11:.+]] = linalg.tanh ins(%[[RES10]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES12:.+]] = linalg.erf ins(%[[RES11]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// -----
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
index 89a8baa453e90..3a2c7c9965287 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
@@ -6,15 +6,112 @@ func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) ->
{indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
- %1 = math.exp %in : f32
- linalg.yield %1 : f32
+ %v = math.exp %in : f32
+ linalg.yield %v : f32
} -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.log %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %2 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%1 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.absf %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %3 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%2 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.ceil %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %4 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%3 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.floor %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %5 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%4 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = arith.negf %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %cst_1 = arith.constant 1.0 : f32
+ %6 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%5 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = arith.divf %cst_1, %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %7 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%6 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.round %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %8 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%7 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.sqrt %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %9 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%8 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.rsqrt %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %10 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%9 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = arith.mulf %in, %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %11 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%10 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.tanh %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ %12 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%11 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.erf %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?x?xf32>
+ return %12 : tensor<?x?x?xf32>
}
// CHECK-LABEL: specialize_exp
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES0:.+]] = linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES1:.+]] = linalg.log ins(%[[RES0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES2:.+]] = linalg.abs ins(%[[RES1]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES3:.+]] = linalg.ceil ins(%[[RES2]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES4:.+]] = linalg.floor ins(%[[RES3]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES5:.+]] = linalg.negf ins(%[[RES4]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES6:.+]] = linalg.reciprocal ins(%[[RES5]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES7:.+]] = linalg.round ins(%[[RES6]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES8:.+]] = linalg.sqrt ins(%[[RES7]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES9:.+]] = linalg.rsqrt ins(%[[RES8]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES10:.+]] = linalg.square ins(%[[RES9]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES11:.+]] = linalg.tanh ins(%[[RES10]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES12:.+]] = linalg.erf ins(%[[RES11]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
More information about the Mlir-commits
mailing list