[Mlir-commits] [mlir] 018e048 - [MLIR][Linalg] Generic to category specialization for unary elementwise ops (#187217)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 2 01:50:25 PDT 2026


Author: Julian Oppermann
Date: 2026-04-02T10:50:21+02:00
New Revision: 018e048daf9dcc5e21f1372b85e9f0a6e4597c64

URL: https://github.com/llvm/llvm-project/commit/018e048daf9dcc5e21f1372b85e9f0a6e4597c64
DIFF: https://github.com/llvm/llvm-project/commit/018e048daf9dcc5e21f1372b85e9f0a6e4597c64.diff

LOG: [MLIR][Linalg] Generic to category specialization for unary elementwise ops (#187217)

Handle specialization of `linalg.generic` ops representing a unary
elementwise computation to the `linalg.elementwise` category op. This
implements a previously absent path in the linalg morphism.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
    mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
    mlir/test/Dialect/Linalg/specialize-generic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index e068c8a5002fc..c565781402e3c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -135,11 +135,16 @@ std::optional<SmallVector<int64_t>>
 isaTransposeOpInterface(GenericOp genericOp);
 
 /// Checks whether a given `genericOp` is semantically equivalent to a single
-/// linalgelementwise unary op. e.g. linalg.exp.
+/// linalg elementwise unary op, e.g. `linalg.exp` or
+/// `linalg.elementwise kind=#linalg.elementwise_kind<exp>`.
+/// If `allowNonIdentityMaps` is true, operations with custom indexing maps are
+/// included in the check. Note that these operations can only be represented by
+/// the category op.
 /// A linalg.generic body could be a series of unary elementwise ops e.g.
 /// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
 /// detecting cases where body is is a single computation op.
-bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
+bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp,
+                                       bool allowNonIdentityMaps = false);
 
 /// Checks whether `genericOp` is semantically equivalent to a single linalg
 /// elementwise binary op e.g. linalg.sub.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 1bddc286b3dd4..2ba77cea8f16e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -227,16 +227,19 @@ linalg::isaTransposeOpInterface(GenericOp op) {
 //===----------------------------------------------------------------------===//
 // Elementwise Single Unary/Binary-OpInterface implementation
 //===----------------------------------------------------------------------===//
-static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
-                                                      unsigned arity) {
+static bool
+isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity,
+                                          bool allowNonIdentityMaps) {
   // Check all loops are parallel.
   if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
     return false;
 
-  // Check there are arity-inputs, 1-output and all are identity-maps.
+  // Check there are arity-inputs, 1-output and all are identity-maps (unless
+  // requested otherwise).
   if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
-      !llvm::all_of(op.getIndexingMapsArray(),
-                    [](AffineMap map) { return map.isIdentity(); }))
+      (!allowNonIdentityMaps &&
+       !llvm::all_of(op.getIndexingMapsArray(),
+                     [](AffineMap map) { return map.isIdentity(); })))
     return false;
 
   // Init should not be referenced for elementwise operations.
@@ -264,19 +267,21 @@ static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
            yieldOp->getOperand(0).getDefiningOp() != oper);
 }
 
-bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
+bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
+                                               bool allowNonIdentityMaps) {
   // All basic elemwise checks.
-  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1, allowNonIdentityMaps))
     return false;
 
-  // Check input is actully used.
+  // Check input is actually used.
   if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
     return false;
   return true;
 }
 
 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
-  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(
+          op, 2, /*allowNonIdentityMaps=*/false))
     return false;
 
   // Check both inputs are used (elementwise).

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 60b18fb2e8d93..a764d1705e85c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -35,14 +35,13 @@ namespace mlir {
                  genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]},           \
       ValueRange{genericOp.getDpsInits()[0]}))
 
-#define REPLACE_UNARY_OP(NEWOP)                                                \
-  static_cast<LinalgOp>(rewriter.replaceOpWithNewOp<NEWOP>(                    \
-      genericOp, ValueRange{genericOp.getDpsInputs()[0]},                      \
-      ValueRange{genericOp.getDpsInits()[0]}))
-
 using namespace mlir;
 using namespace mlir::linalg;
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to elementwise ops.
+//===----------------------------------------------------------------------===//
+
 // Given a elementwise single binary linalg generic op, checks whether the
 // binary op accesses operands as swapped. e.g.
 // this 
diff erentiates between a linalg-generic body that contains:
@@ -67,6 +66,98 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+// Attempt to specialize linalg.generic to named elementwise ops or
+// linalg.elementwise.
+//
+// Example:
+//   %0 = linalg.generic {
+//       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+//                        affine_map<(d0, d1) -> (d0, d1)>],
+//       iterator_types = ["parallel", "parallel"]
+//     } ins(%In : tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+//     ^bb0(%in: f32, %out: f32):
+//       %1 = math.exp %in : f32
+//       linalg.yield %1 : f32
+//     } -> tensor<?x?xf32>
+//
+// is specialized to either
+//   linalg.exp ins(...) outs(...) -> ...
+// or
+//   linalg.elementwise kind=#linalg.elementwise_kind<exp> ...
+//
+// Only the category op can carry non-identity indexing maps; these are
+// transferred verbatim from the `genericOp`.
+static FailureOr<LinalgOp>
+specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
+                                 bool emitCategoryOp) {
+  bool hasNonIdentityMaps =
+      !llvm::all_of(genericOp.getIndexingMapsArray(),
+                    [](AffineMap map) { return map.isIdentity(); });
+
+  // Early exit: Named ops cannot carry user-defined maps.
+  if (hasNonIdentityMaps && !emitCategoryOp)
+    return rewriter.notifyMatchFailure(
+        genericOp,
+        "non-identity indexing maps prevent specialization to named op");
+
+  // Helper to dispatch between named op and `linalg.elementwise`.
+  // Lambdas with explicit template parameter list are a C++20 feature, hence
+  // the dummy op object.
+  auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
+    LinalgOp newOp;
+    if (!emitCategoryOp)
+      newOp = decltype(namedOp)::create(
+          rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
+          genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
+    else
+      newOp = ElementwiseOp::create(
+          rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
+          genericOp.getDpsInits(),
+          ElementwiseKindAttr::get(rewriter.getContext(), kind),
+          genericOp.getIndexingMaps());
+
+    rewriter.replaceOp(genericOp, newOp);
+    return newOp;
+  };
+
+  // Inspect body operation to determine named op or elementwise kind.
+  Operation *op = &genericOp.getBody()->front();
+
+  if (isa<math::ExpOp>(op))
+    return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
+  if (isa<math::LogOp>(op))
+    return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
+  if (isa<math::AbsFOp>(op))
+    return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
+  if (isa<math::CeilOp>(op))
+    return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
+  if (isa<math::FloorOp>(op))
+    return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
+  if (isa<arith::NegFOp>(op))
+    return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
+  if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+    if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+            divOp.getLhs().getDefiningOp()))
+      if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+        return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
+  }
+  if (isa<math::RoundOp>(op))
+    return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
+  if (isa<math::SqrtOp>(op))
+    return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
+  if (isa<math::RsqrtOp>(op))
+    return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
+  if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+      mulOp && mulOp.getLhs() == mulOp.getRhs())
+    return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
+  if (isa<math::TanhOp>(op))
+    return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
+  if (isa<math::ErfOp>(op))
+    return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // Specialize linalg generic to matmul variants.
 //===----------------------------------------------------------------------===//
@@ -455,6 +546,12 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
     RewriterBase &rewriter, GenericOp genericOp,
     const GenericOpSpecializationOptions &options) {
+  // Unary elementwise - e.g. exp
+  if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
+    return specializeLinalgUnaryElementwise(rewriter, genericOp,
+                                            options.emitCategoryOps);
+  }
+
   // Contraction - e.g. matmul
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp,
@@ -505,42 +602,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
     return namedOp;
   }
 
-  // Elementwise Unary
-  if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
-    Operation *op = &genericOp.getBody()->front();
-    if (isa<math::ExpOp>(op))
-      return REPLACE_UNARY_OP(ExpOp);
-    if (isa<math::LogOp>(op))
-      return REPLACE_UNARY_OP(LogOp);
-    if (isa<math::AbsFOp>(op))
-      return REPLACE_UNARY_OP(AbsOp);
-    if (isa<math::CeilOp>(op))
-      return REPLACE_UNARY_OP(CeilOp);
-    if (isa<math::FloorOp>(op))
-      return REPLACE_UNARY_OP(FloorOp);
-    if (isa<arith::NegFOp>(op))
-      return REPLACE_UNARY_OP(NegFOp);
-    if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
-      if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
-              divOp.getLhs().getDefiningOp()))
-        if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
-          return REPLACE_UNARY_OP(ReciprocalOp);
-    }
-    if (isa<math::RoundOp>(op))
-      return REPLACE_UNARY_OP(RoundOp);
-    if (isa<math::SqrtOp>(op))
-      return REPLACE_UNARY_OP(SqrtOp);
-    if (isa<math::RsqrtOp>(op))
-      return REPLACE_UNARY_OP(RsqrtOp);
-    if (auto mulOp = dyn_cast<arith::MulFOp>(op);
-        mulOp && mulOp.getLhs() == mulOp.getRhs())
-      return REPLACE_UNARY_OP(SquareOp);
-    if (isa<math::TanhOp>(op))
-      return REPLACE_UNARY_OP(TanhOp);
-    if (isa<math::ErfOp>(op))
-      return REPLACE_UNARY_OP(ErfOp);
-  }
-
   // Elementwise Binary
   if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
     bool swap = areBinOpsSwapped(genericOp);

diff  --git a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
index d5e49a866eaec..e7bdad8e56883 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
@@ -5,6 +5,101 @@
 // RUN: | mlir-opt -split-input-file -linalg-morph-ops=generic-to-category \
 // RUN: | FileCheck %s
 
+func.func @unary_ops(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
+  linalg.elementwise kind=#linalg.elementwise_kind<exp>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<log>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<abs>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<ceil>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<floor>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<negf>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<round>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<square>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<erf>
+    ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK-LABEL: unary_ops
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<log>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<abs>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<ceil>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<floor>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<negf>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<round>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<square>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<erf>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.elementwise
+    kind=#linalg.elementwise_kind<log>
+    indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>]
+    ins(%A : tensor<?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: unary_ops_non_identity
+// CHECK-SAME: %[[A:.+]]: tensor<?xf32>, %[[OUT:.+]]: tensor<?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<log>
+// CHECK-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]]]
+// CHECK-SAME: ins(%[[A]] : tensor<?xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

diff  --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 9cc24dd07ae47..37dec828687bd 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -147,8 +147,76 @@ func.func @unary_ops(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?
 // NAMED-SAME: ins(%[[RES11]] : tensor<?x?x?xf32>)
 // NAMED-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 
-// Not supported yet.
-// CATEGORY: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CATEGORY-SAME: ins(%[[A]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<log>
+// CATEGORY-SAME: ins(%[[RES0]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<abs>
+// CATEGORY-SAME: ins(%[[RES1]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<ceil>
+// CATEGORY-SAME: ins(%[[RES2]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES4:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<floor>
+// CATEGORY-SAME: ins(%[[RES3]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES5:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<negf>
+// CATEGORY-SAME: ins(%[[RES4]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES6:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
+// CATEGORY-SAME: ins(%[[RES5]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES7:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<round>
+// CATEGORY-SAME: ins(%[[RES6]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES8:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
+// CATEGORY-SAME: ins(%[[RES7]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES9:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
+// CATEGORY-SAME: ins(%[[RES8]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES10:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<square>
+// CATEGORY-SAME: ins(%[[RES9]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES11:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+// CATEGORY-SAME: ins(%[[RES10]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES12:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<erf>
+// CATEGORY-SAME: ins(%[[RES11]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+// -----
+
+func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%A : tensor<?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):  
+    %v = math.exp %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+
+// ALL-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// ALL-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// ALL: unary_ops_non_identity
+// ALL-SAME: %[[A:.+]]: tensor<?xf32>, %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// Named ops cannot carry user-defined indexing maps -> expect no change.
+// NAMED-NOT: linalg.exp
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]]]
+// CATEGORY-SAME: ins(%[[A]] : tensor<?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
 // -----
 


        


More information about the Mlir-commits mailing list