[Mlir-commits] [mlir] [MLIR][Linalg] Add matchers to specialize more unary ops (PR #183259)

Julian Oppermann llvmlistbot at llvm.org
Mon Mar 16 00:10:52 PDT 2026


https://github.com/jopperm updated https://github.com/llvm/llvm-project/pull/183259

>From d337be027da565f1d81fe89b7b81e30f41216c04 Mon Sep 17 00:00:00 2001
From: Julian Oppermann <julian.oppermann at intel.com>
Date: Wed, 25 Feb 2026 00:13:34 -0800
Subject: [PATCH] [MLIR][Linalg] Add matchers to specialize more unary ops

---
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |   6 +-
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  39 ++++++-
 .../Linalg/linalg-morph-multi-step.mlir       |  40 ++++++-
 .../Linalg/roundtrip-linalg-named-ops.mlir    |  28 ++++-
 .../Linalg/specialize-generic-ops.mlir        | 109 +++++++++++++++++-
 ...ransform-op-specialize-elemwise-unary.mlir | 105 ++++++++++++++++-
 6 files changed, 306 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c27749c36887d..1bddc286b3dd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -251,8 +251,12 @@ static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
   if (body->getOperations().size() != 2)
     return false;
 
+  // The payload op must have one result and at least arity-many operands
+  // (otherwise not all inputs can be used). It can have additional operands
+  // from outside of the generic op (e.g. div(1, x) for linalg.reciprocal) or
+  // use an input more than once (e.g. mul(x, x) for linalg.square).
   Operation *oper = &body->front();
-  if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
+  if (oper->getNumOperands() < arity || oper->getNumResults() != 1)
     return false;
 
   auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a71f84dee3bb0..25ea1293c6540 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -36,9 +36,9 @@ namespace mlir {
       ValueRange{genericOp.getDpsInits()[0]}))
 
 #define REPLACE_UNARY_OP(NEWOP)                                                \
-  (rewriter.replaceOpWithNewOp<NEWOP>(genericOp,                               \
-                                      ValueRange{genericOp.getDpsInputs()[0]}, \
-                                      ValueRange{genericOp.getDpsInits()[0]}))
+  static_cast<LinalgOp>(rewriter.replaceOpWithNewOp<NEWOP>(                    \
+      genericOp, ValueRange{genericOp.getDpsInputs()[0]},                      \
+      ValueRange{genericOp.getDpsInits()[0]}))
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -446,10 +446,37 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   // Elementwise Unary
   if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
     Operation *op = &genericOp.getBody()->front();
-    if (isa<math::ExpOp>(op)) {
-      LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
-      return namedOp;
+    if (isa<math::ExpOp>(op))
+      return REPLACE_UNARY_OP(ExpOp);
+    if (isa<math::LogOp>(op))
+      return REPLACE_UNARY_OP(LogOp);
+    if (isa<math::AbsFOp>(op))
+      return REPLACE_UNARY_OP(AbsOp);
+    if (isa<math::CeilOp>(op))
+      return REPLACE_UNARY_OP(CeilOp);
+    if (isa<math::FloorOp>(op))
+      return REPLACE_UNARY_OP(FloorOp);
+    if (isa<arith::NegFOp>(op))
+      return REPLACE_UNARY_OP(NegFOp);
+    if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+      if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+              divOp.getLhs().getDefiningOp()))
+        if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+          return REPLACE_UNARY_OP(ReciprocalOp);
     }
+    if (isa<math::RoundOp>(op))
+      return REPLACE_UNARY_OP(RoundOp);
+    if (isa<math::SqrtOp>(op))
+      return REPLACE_UNARY_OP(SqrtOp);
+    if (isa<math::RsqrtOp>(op))
+      return REPLACE_UNARY_OP(RsqrtOp);
+    if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+        mulOp && mulOp.getLhs() == mulOp.getRhs())
+      return REPLACE_UNARY_OP(SquareOp);
+    if (isa<math::TanhOp>(op))
+      return REPLACE_UNARY_OP(TanhOp);
+    if (isa<math::ErfOp>(op))
+      return REPLACE_UNARY_OP(ErfOp);
   }
 
   // Elementwise Binary
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
index bdd29b96346e1..bfaeee4717874 100644
--- a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -4,11 +4,47 @@
 
 func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
   %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
-  return %exp :  tensor<16x8xf32>
+  %log = linalg.log ins(%exp : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %abs = linalg.abs ins(%log : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %ceil = linalg.ceil ins(%abs : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %floor = linalg.floor ins(%ceil : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %negf = linalg.negf ins(%floor : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %recip = linalg.reciprocal ins(%negf : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %round = linalg.round ins(%recip : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %sqrt = linalg.sqrt ins(%round : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %rsqrt = linalg.rsqrt ins(%sqrt : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %square = linalg.square ins(%rsqrt : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %tanh = linalg.tanh ins(%square : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  %erf = linalg.erf ins(%tanh : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %erf :  tensor<16x8xf32>
 }
 
-// NAMED_TO_GENERIC: linalg.generic
+// NAMED_TO_GENERIC-COUNT-13: linalg.generic
 // NAMED_TO_GENERIC-NOT: linalg.exp
+// NAMED_TO_GENERIC-NOT: linalg.log
+// NAMED_TO_GENERIC-NOT: linalg.abs
+// NAMED_TO_GENERIC-NOT: linalg.ceil
+// NAMED_TO_GENERIC-NOT: linalg.floor
+// NAMED_TO_GENERIC-NOT: linalg.negf
+// NAMED_TO_GENERIC-NOT: linalg.reciprocal
+// NAMED_TO_GENERIC-NOT: linalg.round
+// NAMED_TO_GENERIC-NOT: linalg.sqrt
+// NAMED_TO_GENERIC-NOT: linalg.rsqrt
+// NAMED_TO_GENERIC-NOT: linalg.square
+// NAMED_TO_GENERIC-NOT: linalg.tanh
+// NAMED_TO_GENERIC-NOT: linalg.erf
 
 // ROUND_TRIP: linalg.exp
+// ROUND_TRIP: linalg.log
+// ROUND_TRIP: linalg.abs
+// ROUND_TRIP: linalg.ceil
+// ROUND_TRIP: linalg.floor
+// ROUND_TRIP: linalg.negf
+// ROUND_TRIP: linalg.reciprocal
+// ROUND_TRIP: linalg.round
+// ROUND_TRIP: linalg.sqrt
+// ROUND_TRIP: linalg.rsqrt
+// ROUND_TRIP: linalg.square
+// ROUND_TRIP: linalg.tanh
+// ROUND_TRIP: linalg.erf
 // ROUND_TRIP-NOT: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
index f15ae646e5765..29083dfe0ae1e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -2,15 +2,39 @@
 // lifted back up to named op.
 // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
 
-func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
+func.func @unary(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
   linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.log ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.abs ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.ceil ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.floor ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.negf ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.reciprocal ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.round ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.sqrt ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.rsqrt ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.square ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.tanh ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.erf ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
   return
 }
 
-// CHECK-LABEL: unary_exp
+// CHECK-LABEL: unary
 // CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.log ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.abs ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.ceil ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.floor ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.negf ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.reciprocal ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.round ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.sqrt ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.rsqrt ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.square ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.tanh ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
+// CHECK: linalg.erf ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 6acf1ca0d4e30..5c7120f3fe41c 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1,21 +1,118 @@
 // RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
 
 #umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @unary_op_exp(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+func.func @unary_ops(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %0 = linalg.generic
           {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
           ins(%A : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
   ^bb0(%in: f32, %out: f32):
-    %1 = math.exp %in : f32
-    linalg.yield %1 : f32
+    %v = math.exp %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %1 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%0 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.log %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %2 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%1 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.absf %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %3 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%2 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.ceil %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %4 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%3 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.floor %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %5 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%4 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = arith.negf %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %cst_1 = arith.constant 1.0 : f32
+  %6 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%5 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = arith.divf %cst_1, %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %7 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%6 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.round %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %8 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%7 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.sqrt %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %9 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%8 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.rsqrt %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %10 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%9 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = arith.mulf %in, %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %11 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%10 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.tanh %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %12 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%11 : tensor<?x?x?xf32>) outs(%Out : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.erf %in : f32
+    linalg.yield %v : f32
   } -> tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
+  return %12 : tensor<?x?x?xf32>
 }
 
-// CHECK-LABEL: unary_op_exp
+// CHECK-LABEL: unary_ops
 // CHECK-SAME: %[[A:.+]]: tensor<?x?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES0:.+]] = linalg.exp ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES1:.+]] = linalg.log ins(%[[RES0]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES2:.+]] = linalg.abs ins(%[[RES1]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES3:.+]] = linalg.ceil ins(%[[RES2]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES4:.+]] = linalg.floor ins(%[[RES3]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES5:.+]] = linalg.negf ins(%[[RES4]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES6:.+]] = linalg.reciprocal ins(%[[RES5]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES7:.+]] = linalg.round ins(%[[RES6]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES8:.+]] = linalg.sqrt ins(%[[RES7]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES9:.+]] = linalg.rsqrt ins(%[[RES8]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES10:.+]] = linalg.square ins(%[[RES9]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES11:.+]] = linalg.tanh ins(%[[RES10]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES12:.+]] = linalg.erf ins(%[[RES11]] : tensor<?x?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
index 89a8baa453e90..3a2c7c9965287 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-unary.mlir
@@ -6,15 +6,112 @@ func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) ->
           {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
           ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
   ^bb0(%in: f32, %out: f32):
-    %1 = math.exp %in : f32
-    linalg.yield %1 : f32
+    %v = math.exp %in : f32
+    linalg.yield %v : f32
   } -> tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
+  %1 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.log %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %2 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%1 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.absf %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %3 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%2 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.ceil %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %4 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%3 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.floor %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %5 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%4 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = arith.negf %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %cst_1 = arith.constant 1.0 : f32
+  %6 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%5 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = arith.divf %cst_1, %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %7 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%6 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.round %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %8 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%7 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.sqrt %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %9 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%8 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.rsqrt %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %10 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%9 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = arith.mulf %in, %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %11 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%10 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.tanh %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  %12 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%11 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %v = math.erf %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?x?xf32>
+  return %12 : tensor<?x?x?xf32>
 }
 // CHECK-LABEL: specialize_exp
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES0:.+]] = linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES1:.+]] = linalg.log ins(%[[RES0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES2:.+]] = linalg.abs ins(%[[RES1]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES3:.+]] = linalg.ceil ins(%[[RES2]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES4:.+]] = linalg.floor ins(%[[RES3]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES5:.+]] = linalg.negf ins(%[[RES4]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES6:.+]] = linalg.reciprocal ins(%[[RES5]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES7:.+]] = linalg.round ins(%[[RES6]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES8:.+]] = linalg.sqrt ins(%[[RES7]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES9:.+]] = linalg.rsqrt ins(%[[RES8]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES10:.+]] = linalg.square ins(%[[RES9]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES11:.+]] = linalg.tanh ins(%[[RES10]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[RES12:.+]] = linalg.erf ins(%[[RES11]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {



More information about the Mlir-commits mailing list