[Mlir-commits] [mlir] 018e048 - [MLIR][Linalg] Generic to category specialization for unary elementwise ops (#187217)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 2 01:50:25 PDT 2026
Author: Julian Oppermann
Date: 2026-04-02T10:50:21+02:00
New Revision: 018e048daf9dcc5e21f1372b85e9f0a6e4597c64
URL: https://github.com/llvm/llvm-project/commit/018e048daf9dcc5e21f1372b85e9f0a6e4597c64
DIFF: https://github.com/llvm/llvm-project/commit/018e048daf9dcc5e21f1372b85e9f0a6e4597c64.diff
LOG: [MLIR][Linalg] Generic to category specialization for unary elementwise ops (#187217)
Handle specialization of `linalg.generic` ops representing a unary
elementwise computation to the `linalg.elementwise` category op. This
implements a previously absent path in the linalg morphism.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index e068c8a5002fc..c565781402e3c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -135,11 +135,16 @@ std::optional<SmallVector<int64_t>>
isaTransposeOpInterface(GenericOp genericOp);
/// Checks whether a given `genericOp` is semantically equivalent to a single
-/// linalgelementwise unary op. e.g. linalg.exp.
+/// linalg elementwise unary op, e.g. `linalg.exp` or
+/// `linalg.elementwise kind=#linalg.elementwise_kind<exp>`.
+/// If `allowNonIdentityMaps` is true, operations with custom indexing maps are
+/// included in the check. Note that these operations can only be represented by
+/// the category op.
/// A linalg.generic body could be a series of unary elementwise ops e.g.
/// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
/// detecting cases where body is is a single computation op.
-bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
+bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp,
+ bool allowNonIdentityMaps = false);
/// Checks whether `genericOp` is semantically equivalent to a single linalg
/// elementwise binary op e.g. linalg.sub.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 1bddc286b3dd4..2ba77cea8f16e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -227,16 +227,19 @@ linalg::isaTransposeOpInterface(GenericOp op) {
//===----------------------------------------------------------------------===//
// Elementwise Single Unary/Binary-OpInterface implementation
//===----------------------------------------------------------------------===//
-static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
- unsigned arity) {
+static bool
+isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity,
+ bool allowNonIdentityMaps) {
// Check all loops are parallel.
if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
return false;
- // Check there are arity-inputs, 1-output and all are identity-maps.
+ // Check there are arity-inputs, 1-output and all are identity-maps (unless
+ // requested otherwise).
if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
- !llvm::all_of(op.getIndexingMapsArray(),
- [](AffineMap map) { return map.isIdentity(); }))
+ (!allowNonIdentityMaps &&
+ !llvm::all_of(op.getIndexingMapsArray(),
+ [](AffineMap map) { return map.isIdentity(); })))
return false;
// Init should not be referenced for elementwise operations.
@@ -264,19 +267,21 @@ static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
yieldOp->getOperand(0).getDefiningOp() != oper);
}
-bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
+bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
+ bool allowNonIdentityMaps) {
// All basic elemwise checks.
- if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1, allowNonIdentityMaps))
return false;
- // Check input is actully used.
+ // Check input is actually used.
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
return false;
return true;
}
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
- if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface(
+ op, 2, /*allowNonIdentityMaps=*/false))
return false;
// Check both inputs are used (elementwise).
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 60b18fb2e8d93..a764d1705e85c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -35,14 +35,13 @@ namespace mlir {
genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
ValueRange{genericOp.getDpsInits()[0]}))
-#define REPLACE_UNARY_OP(NEWOP) \
- static_cast<LinalgOp>(rewriter.replaceOpWithNewOp<NEWOP>( \
- genericOp, ValueRange{genericOp.getDpsInputs()[0]}, \
- ValueRange{genericOp.getDpsInits()[0]}))
-
using namespace mlir;
using namespace mlir::linalg;
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to elementwise ops.
+//===----------------------------------------------------------------------===//
+
// Given a elementwise single binary linalg generic op, checks whether the
// binary op accesses operands as swapped. e.g.
// this
diff erentiates between a linalg-generic body that contains:
@@ -67,6 +66,98 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
+// Attempt to specialize linalg.generic to named elementwise ops or
+// linalg.elementwise.
+//
+// Example:
+// %0 = linalg.generic {
+// indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+// affine_map<(d0, d1) -> (d0, d1)>],
+// iterator_types = ["parallel", "parallel"]
+// } ins(%In : tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+// ^bb0(%in: f32, %out: f32):
+// %1 = math.exp %in : f32
+// linalg.yield %1 : f32
+// } -> tensor<?x?xf32>
+//
+// is specialized to either
+// linalg.exp ins(...) outs(...) -> ...
+// or
+// linalg.elementwise kind=#linalg.elementwise_kind<exp> ...
+//
+// Only the category op can carry non-identity indexing maps; these are
+// transferred verbatim from the `genericOp`.
+static FailureOr<LinalgOp>
+specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
+ bool emitCategoryOp) {
+ bool hasNonIdentityMaps =
+ !llvm::all_of(genericOp.getIndexingMapsArray(),
+ [](AffineMap map) { return map.isIdentity(); });
+
+ // Early exit: Named ops cannot carry user-defined maps.
+ if (hasNonIdentityMaps && !emitCategoryOp)
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "non-identity indexing maps prevent specialization to named op");
+
+ // Helper to dispatch between named op and `linalg.elementwise`.
+ // Lambdas with explicit template parameter list are a C++20 feature, hence
+ // the dummy op object.
+ auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
+ LinalgOp newOp;
+ if (!emitCategoryOp)
+ newOp = decltype(namedOp)::create(
+ rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
+ genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
+ else
+ newOp = ElementwiseOp::create(
+ rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
+ genericOp.getDpsInits(),
+ ElementwiseKindAttr::get(rewriter.getContext(), kind),
+ genericOp.getIndexingMaps());
+
+ rewriter.replaceOp(genericOp, newOp);
+ return newOp;
+ };
+
+ // Inspect body operation to determine named op or elementwise kind.
+ Operation *op = &genericOp.getBody()->front();
+
+ if (isa<math::ExpOp>(op))
+ return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
+ if (isa<math::LogOp>(op))
+ return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
+ if (isa<math::AbsFOp>(op))
+ return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
+ if (isa<math::CeilOp>(op))
+ return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
+ if (isa<math::FloorOp>(op))
+ return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
+ if (isa<arith::NegFOp>(op))
+ return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
+ if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+ if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+ divOp.getLhs().getDefiningOp()))
+ if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+ return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
+ }
+ if (isa<math::RoundOp>(op))
+ return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
+ if (isa<math::SqrtOp>(op))
+ return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
+ if (isa<math::RsqrtOp>(op))
+ return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
+ if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+ mulOp && mulOp.getLhs() == mulOp.getRhs())
+ return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
+ if (isa<math::TanhOp>(op))
+ return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
+ if (isa<math::ErfOp>(op))
+ return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
+
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// Specialize linalg generic to matmul variants.
//===----------------------------------------------------------------------===//
@@ -455,6 +546,12 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
const GenericOpSpecializationOptions &options) {
+ // Unary elementwise - e.g. exp
+ if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
+ return specializeLinalgUnaryElementwise(rewriter, genericOp,
+ options.emitCategoryOps);
+ }
+
// Contraction - e.g. matmul
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp,
@@ -505,42 +602,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
return namedOp;
}
- // Elementwise Unary
- if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
- Operation *op = &genericOp.getBody()->front();
- if (isa<math::ExpOp>(op))
- return REPLACE_UNARY_OP(ExpOp);
- if (isa<math::LogOp>(op))
- return REPLACE_UNARY_OP(LogOp);
- if (isa<math::AbsFOp>(op))
- return REPLACE_UNARY_OP(AbsOp);
- if (isa<math::CeilOp>(op))
- return REPLACE_UNARY_OP(CeilOp);
- if (isa<math::FloorOp>(op))
- return REPLACE_UNARY_OP(FloorOp);
- if (isa<arith::NegFOp>(op))
- return REPLACE_UNARY_OP(NegFOp);
- if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
- if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
- divOp.getLhs().getDefiningOp()))
- if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
- return REPLACE_UNARY_OP(ReciprocalOp);
- }
- if (isa<math::RoundOp>(op))
- return REPLACE_UNARY_OP(RoundOp);
- if (isa<math::SqrtOp>(op))
- return REPLACE_UNARY_OP(SqrtOp);
- if (isa<math::RsqrtOp>(op))
- return REPLACE_UNARY_OP(RsqrtOp);
- if (auto mulOp = dyn_cast<arith::MulFOp>(op);
- mulOp && mulOp.getLhs() == mulOp.getRhs())
- return REPLACE_UNARY_OP(SquareOp);
- if (isa<math::TanhOp>(op))
- return REPLACE_UNARY_OP(TanhOp);
- if (isa<math::ErfOp>(op))
- return REPLACE_UNARY_OP(ErfOp);
- }
-
// Elementwise Binary
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
diff --git a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
index d5e49a866eaec..e7bdad8e56883 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
@@ -5,6 +5,101 @@
// RUN: | mlir-opt -split-input-file -linalg-morph-ops=generic-to-category \
// RUN: | FileCheck %s
+func.func @unary_ops(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
+ linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<log>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<abs>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<ceil>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<floor>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<negf>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<round>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<square>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<erf>
+ ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK-LABEL: unary_ops
+// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<log>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<abs>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<ceil>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<floor>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<negf>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<round>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<square>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<erf>
+// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// -----
+
+func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.elementwise
+ kind=#linalg.elementwise_kind<log>
+ indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>]
+ ins(%A : tensor<?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: unary_ops_non_identity
+// CHECK-SAME: %[[A:.+]]: tensor<?xf32>, %[[OUT:.+]]: tensor<?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<log>
+// CHECK-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]]]
+// CHECK-SAME: ins(%[[A]] : tensor<?xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 9cc24dd07ae47..37dec828687bd 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -147,8 +147,76 @@ func.func @unary_ops(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?
// NAMED-SAME: ins(%[[RES11]] : tensor<?x?x?xf32>)
// NAMED-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// Not supported yet.
-// CATEGORY: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CATEGORY-SAME: ins(%[[A]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<log>
+// CATEGORY-SAME: ins(%[[RES0]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<abs>
+// CATEGORY-SAME: ins(%[[RES1]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<ceil>
+// CATEGORY-SAME: ins(%[[RES2]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES4:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<floor>
+// CATEGORY-SAME: ins(%[[RES3]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES5:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<negf>
+// CATEGORY-SAME: ins(%[[RES4]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES6:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
+// CATEGORY-SAME: ins(%[[RES5]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES7:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<round>
+// CATEGORY-SAME: ins(%[[RES6]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES8:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
+// CATEGORY-SAME: ins(%[[RES7]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES9:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
+// CATEGORY-SAME: ins(%[[RES8]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES10:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<square>
+// CATEGORY-SAME: ins(%[[RES9]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES11:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+// CATEGORY-SAME: ins(%[[RES10]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CATEGORY: %[[RES12:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<erf>
+// CATEGORY-SAME: ins(%[[RES11]] : tensor<?x?x?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+// -----
+
+func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%A : tensor<?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %v = math.exp %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+
+// ALL-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// ALL-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// ALL: unary_ops_non_identity
+// ALL-SAME: %[[A:.+]]: tensor<?xf32>, %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// Named ops cannot carry user-defined indexing maps -> expect no change.
+// NAMED-NOT: linalg.exp
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]]]
+// CATEGORY-SAME: ins(%[[A]] : tensor<?xf32>)
+// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
More information about the Mlir-commits
mailing list