[Mlir-commits] [mlir] [MLIR][Linalg] Specialize more binary elementwise ops (PR #192290)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 15 11:06:44 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Julian Oppermann (jopperm)
<details>
<summary>Changes</summary>
Extends the matching logic for `linalg.generic` ops that can be represented as named op or `linalg.elementwise` to cover all variants currently supported by `RegionBuilderHelper::buildBinaryFn`. We previously detected only `add`, `sub`, `mul` and `div` for floating point types.
I combined the detection for unary and binary functions to make it tractable to morph operations such as
```mlir
#map = affine_map<(d0) -> (d0)>
// ...
%c123_i32 = arith.constant 123 : i32
%0 = linalg.generic
{indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%A : tensor<?xi32>) outs(%Out : tensor<?xi32>) {
^bb0(%in: i32, %out: i32):
%v = arith.addi %c123_i32 , %in : i32
linalg.yield %v : i32
} -> tensor<?xi32>
```
to
```mlir
#map = affine_map<(d0) -> ()>
#map1 = affine_map<(d0) -> (d0)>
// ...
%0 = linalg.elementwise kind=#linalg.elementwise_kind<add>
indexing_maps = [#map, #map1, #map1]
ins(%c123_i32, %A: i32, tensor<?xi32>) outs(%Out: tensor<?xi32>) -> tensor<?xi32>
```
---
Patch is 83.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192290.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+5-1)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+4-3)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+171-81)
- (modified) mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir (+139-3)
- (modified) mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir (+199)
- (modified) mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir (+125-11)
- (modified) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+540-13)
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir (+223-16)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index c565781402e3c..3c7ebd8277dbd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -148,7 +148,11 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp,
/// Checks whether `genericOp` is semantically equivalent to a single linalg
/// elementwise binary op e.g. linalg.sub.
-bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
+/// If `allowNonIdentityMaps` is true, operations with custom indexing maps are
+/// included in the check. Note that these operations can only be represented by
+/// the category op.
+bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp,
+ bool allowNonIdentityMaps = false);
/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
/// Supports two patterns:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 2ba77cea8f16e..238bddcb3b2bd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -279,9 +279,10 @@ bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
return true;
}
-bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
- if (!isaElemwiseSingleUnaryOrBinaryOpInterface(
- op, 2, /*allowNonIdentityMaps=*/false))
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
+ bool allowNonIdentityMaps) {
+ // All basic elemwise checks.
+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2, allowNonIdentityMaps))
return false;
// Check both inputs are used (elementwise).
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a7cd57cf4ed9e..cc8a69cef8e5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -28,13 +28,6 @@ namespace mlir {
#define DEBUG_TYPE "linalg-specialization"
-#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
- (rewriter.replaceOpWithNewOp<NEWOP>( \
- genericOp, \
- ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
- genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
- ValueRange{genericOp.getDpsInits()[0]}))
-
using namespace mlir;
using namespace mlir::linalg;
@@ -42,7 +35,7 @@ using namespace mlir::linalg;
// Specialize linalg generic to elementwise ops.
//===----------------------------------------------------------------------===//
-// Given a elementwise single binary linalg generic op, checks whether the
+// Given an elementwise single binary linalg generic op, checks whether the
// binary op accesses operands as swapped. e.g.
// this differentiates between a linalg-generic body that contains:
// ^bb0(%a: f32, %b: f32, %c : f32):
@@ -66,8 +59,39 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
-// Attempt to specialize linalg.generic to named elementwise ops or
-// linalg.elementwise.
+// Given an elementwise single unary linalg generic op whose body operation is a
+// binary operation, check if one of its operands is a scalar value defined
+// outside the generic op, set its index, and return true. Otherwise return
+// false. The index is unique because the block argument is used at
+// least by one operand, as checked in `isaElemwiseSingleUnaryOpInterface`.
+//
+// Example:
+// %cst = arith.constant 3.14 : f32
+// %0 = linalg.generic { indexing_maps = [#mapA, #mapRes], ... }
+// ins(%A : tensor<?xf32>) outs(...) {
+// ^bb0(%a: f32, %out : f32):
+// %0 = arith.mulf %a, %cst : f32
+// linalg.yield %0: f32
+// } -> tensor<?xf32>
+// Here, the returned index is 1, and the generic op can be represented as
+// %0 = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// indexing_maps = [#mapA, affine_map<(d0) -> ()>, #mapRes]
+// ins(%A, %cst : tensor<?xf32>, f32) outs(...) -> tensor<?xf32>
+static bool findIndexOfScalarOperand(GenericOp genericOp, int &index) {
+ Block *body = genericOp.getBody();
+ Operation *op = &body->front();
+ for (auto [i, v] : llvm::enumerate(op->getOperands())) {
+ if (auto blockArg = dyn_cast<BlockArgument>(v);
+ blockArg && blockArg.getOwner() == body)
+ continue; // not an outside value...
+ index = i;
+ return true;
+ }
+ return false;
+}
+
+// Attempt to specialize unary or binary linalg.generic ops to named elementwise
+// ops or linalg.elementwise.
//
// Example:
// %0 = linalg.generic {
@@ -87,9 +111,16 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
//
// Only the category op can carry non-identity indexing maps; these are
// transferred verbatim from the `genericOp`.
-static FailureOr<LinalgOp>
-specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
- bool emitCategoryOp) {
+//
+// In addition to the canonical forms used by the generalization path, this
+// function can handle the following variations:
+//
+// 1) Swapped operands in binary ops (see the `areBinOpsSwapped` helper)
+// 2) Unary generic ops with a binary body op (see the
+// `findIndexOfScalarOperand` helper)
+static FailureOr<LinalgOp> specializeLinalgElementwise(RewriterBase &rewriter,
+ GenericOp genericOp,
+ bool emitCategoryOp) {
bool hasNonIdentityMaps =
!llvm::all_of(genericOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isIdentity(); });
@@ -100,62 +131,142 @@ specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
genericOp,
"non-identity indexing maps prevent specialization to named op");
+ // Classify the generic op.
+ bool isUnary = genericOp.getNumDpsInputs() == 1;
+ bool isBinary = genericOp.getNumDpsInputs() == 2;
+
+ // Will inspect the body operation to determine named op or elementwise kind.
+ Operation *op = &genericOp.getBody()->front();
+
+ // Detect variations from canonical forms.
+ bool hasSwappedOperands = isBinary && areBinOpsSwapped(genericOp);
+ int scalarOprIdx = -1;
+ bool hasScalarOperand = isUnary && op->getNumOperands() == 2 &&
+ findIndexOfScalarOperand(genericOp, scalarOprIdx);
+
// Helper to dispatch between named op and `linalg.elementwise`.
// Lambdas with explicit template parameter list are a C++20 feature, hence
// the dummy op object.
- auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
+ auto replaceOp = [&](auto namedOp, ElementwiseKind kind,
+ bool mayHoistScalarOperand = true) -> LinalgOp {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ if (hasSwappedOperands)
+ std::swap(inputs[0], inputs[1]);
+
LinalgOp newOp;
- if (!emitCategoryOp)
- newOp = decltype(namedOp)::create(
- rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
- genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
- else
+ if (!emitCategoryOp) {
+ using NamedOpTy = decltype(namedOp);
+ if constexpr (!std::is_null_pointer_v<NamedOpTy>)
+ newOp = NamedOpTy::create(rewriter, genericOp.getLoc(), inputs,
+ genericOp.getDpsInits(),
+ ArrayRef<NamedAttribute>{});
+ else
+ llvm_unreachable("Missing named op type");
+ } else {
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ // Swap indexing maps, too.
+ if (hasSwappedOperands)
+ std::swap(indexingMaps[0], indexingMaps[1]);
+
+ // Represent unary generic op as a binary `linalg.elementwise` with a
+ // scalar operand and broadcasting map.
+ if (hasScalarOperand && mayHoistScalarOperand) {
+ // Adjust inputs and indexing maps accordingly.
+ inputs.insert(inputs.begin() + scalarOprIdx,
+ op->getOperand(scalarOprIdx));
+ auto scalarBroadcastMap =
+ AffineMap::get(genericOp.getNumParallelLoops(), /*symbolCount=*/0,
+ rewriter.getContext());
+ indexingMaps.insert(indexingMaps.begin() + scalarOprIdx,
+ scalarBroadcastMap);
+ }
newOp = ElementwiseOp::create(
- rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
- genericOp.getDpsInits(),
+ rewriter, genericOp.getLoc(), inputs, genericOp.getDpsInits(),
ElementwiseKindAttr::get(rewriter.getContext(), kind),
- genericOp.getIndexingMaps());
+ rewriter.getAffineMapArrayAttr(indexingMaps));
+ }
rewriter.replaceOp(genericOp, newOp);
return newOp;
};
- // Inspect body operation to determine named op or elementwise kind.
- Operation *op = &genericOp.getBody()->front();
+ if (isUnary) {
+ if (isa<math::ExpOp>(op))
+ return replaceOp(ExpOp{}, ElementwiseKind::exp);
+ if (isa<math::LogOp>(op))
+ return replaceOp(LogOp{}, ElementwiseKind::log);
+ if (isa<math::AbsFOp>(op))
+ return replaceOp(AbsOp{}, ElementwiseKind::abs);
+ if (isa<math::CeilOp>(op))
+ return replaceOp(CeilOp{}, ElementwiseKind::ceil);
+ if (isa<math::FloorOp>(op))
+ return replaceOp(FloorOp{}, ElementwiseKind::floor);
+ if (isa<arith::NegFOp>(op))
+ return replaceOp(NegFOp{}, ElementwiseKind::negf);
+ if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+ if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+ divOp.getLhs().getDefiningOp()))
+ if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+ return replaceOp(ReciprocalOp{}, ElementwiseKind::reciprocal,
+ /*mayHoistScalarOperand=*/false);
+ }
+ if (isa<math::RoundOp>(op))
+ return replaceOp(RoundOp{}, ElementwiseKind::round);
+ if (isa<math::SqrtOp>(op))
+ return replaceOp(SqrtOp{}, ElementwiseKind::sqrt);
+ if (isa<math::RsqrtOp>(op))
+ return replaceOp(RsqrtOp{}, ElementwiseKind::rsqrt);
+ if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+ mulOp && mulOp.getLhs() == mulOp.getRhs())
+ return replaceOp(SquareOp{}, ElementwiseKind::square);
+ if (isa<math::TanhOp>(op))
+ return replaceOp(TanhOp{}, ElementwiseKind::tanh);
+ if (isa<math::ErfOp>(op))
+ return replaceOp(ErfOp{}, ElementwiseKind::erf);
+
+ // At this point, we exhaustively checked the available unary named ops. The
+ // 1-input generic op might be representable as a `linalg.elementwise` that
+ // broadcasts a scalar operand. But if we can't emit the category op or
+ // don't have a scalar operand, exit now.
+ if (!emitCategoryOp || !hasScalarOperand)
+ return rewriter.notifyMatchFailure(
+ genericOp, "unary elementwise operation cannot be specialized to "
+ "named or category op");
+ }
- if (isa<math::ExpOp>(op))
- return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
- if (isa<math::LogOp>(op))
- return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
- if (isa<math::AbsFOp>(op))
- return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
- if (isa<math::CeilOp>(op))
- return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
- if (isa<math::FloorOp>(op))
- return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
- if (isa<arith::NegFOp>(op))
- return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
- if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
- if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
- divOp.getLhs().getDefiningOp()))
- if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
- return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
+ // Boolean-typed `linalg.add` and `linalg.mul` require special handling.
+ bool allBool = llvm::all_of(op->getOperands(),
+ [](Value v) { return v.getType().isInteger(1); });
+
+ if (isa<arith::AddIOp, arith::AddFOp, complex::AddOp>(op) ||
+ (allBool && isa<arith::OrIOp>(op)))
+ return replaceOp(AddOp{}, ElementwiseKind::add);
+ if (isa<arith::SubIOp, arith::SubFOp, complex::SubOp>(op))
+ return replaceOp(SubOp{}, ElementwiseKind::sub);
+ if (isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op) ||
+ (allBool && isa<arith::AndIOp>(op)))
+ return replaceOp(MulOp{}, ElementwiseKind::mul);
+ if (isa<arith::DivSIOp, arith::DivFOp, complex::DivOp>(op))
+ return replaceOp(DivOp{}, ElementwiseKind::div);
+ if (isa<arith::DivUIOp>(op))
+ return replaceOp(DivUnsignedOp{}, ElementwiseKind::div_unsigned);
+ if (isa<arith::MaxSIOp, arith::MaximumFOp>(op))
+ return replaceOp(MaxOp{}, ElementwiseKind::max_signed);
+ if (isa<arith::MinSIOp, arith::MinimumFOp>(op))
+ return replaceOp(MinOp{}, ElementwiseKind::min_signed);
+ if (emitCategoryOp) {
+ // No named ops for unsigned maximum/minimum.
+ if (isa<arith::MaxUIOp>(op))
+ return replaceOp(nullptr, ElementwiseKind::max_unsigned);
+ if (isa<arith::MinUIOp>(op))
+ return replaceOp(nullptr, ElementwiseKind::min_unsigned);
}
- if (isa<math::RoundOp>(op))
- return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
- if (isa<math::SqrtOp>(op))
- return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
- if (isa<math::RsqrtOp>(op))
- return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
- if (auto mulOp = dyn_cast<arith::MulFOp>(op);
- mulOp && mulOp.getLhs() == mulOp.getRhs())
- return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
- if (isa<math::TanhOp>(op))
- return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
- if (isa<math::ErfOp>(op))
- return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
+ if (isa<math::PowFOp>(op))
+ return replaceOp(PowFOp{}, ElementwiseKind::powf);
- return failure();
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "elementwise operation cannot be specialized to named or category op");
}
//===----------------------------------------------------------------------===//
@@ -580,10 +691,11 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
const GenericOpSpecializationOptions &options) {
- // Unary elementwise - e.g. exp
- if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
- return specializeLinalgUnaryElementwise(rewriter, genericOp,
- options.emitCategoryOps);
+ // Elementwise - e.g. exp, add
+ if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps) ||
+ isaElemwiseSingleBinaryOpInterface(genericOp, options.emitCategoryOps)) {
+ return specializeLinalgElementwise(rewriter, genericOp,
+ options.emitCategoryOps);
}
// Contraction - e.g. matmul
@@ -636,28 +748,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
return namedOp;
}
- // Elementwise Binary
- if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
- bool swap = areBinOpsSwapped(genericOp);
- Operation *op = &genericOp.getBody()->front();
- if (isa<arith::AddFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
- return namedOp;
- }
- if (isa<arith::SubFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
- return namedOp;
- }
- if (isa<arith::MulFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
- return namedOp;
- }
- if (isa<arith::DivFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
- return namedOp;
- }
- }
-
// Convolution - e.g. *conv/pooling*
if (isaConvolutionOpInterface(genericOp))
return specializeLinalgConvolutions(rewriter, genericOp);
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
index 7bad1b7a44d92..0a0ddbcd85a0a 100644
--- a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -1,6 +1,8 @@
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | FileCheck %s --check-prefix=NAMED_TO_GENERIC
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | mlir-opt -linalg-morph-ops=generic-to-named | \
-// RUN: FileCheck %s --check-prefix=ROUND_TRIP
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN: FileCheck %s --check-prefix=ALL,NAMED_TO_GENERIC
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN: mlir-opt -linalg-morph-ops=generic-to-named -split-input-file | \
+// RUN: FileCheck %s --check-prefix=ALL,ROUND_TRIP
func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
@@ -19,6 +21,8 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16
return %erf : tensor<16x8xf32>
}
+// ALL-LABEL: unary_ops
+
// NAMED_TO_GENERIC-COUNT-13: linalg.generic
// NAMED_TO_GENERIC-NOT: linalg.exp
// NAMED_TO_GENERIC-NOT: linalg.log
@@ -48,3 +52,135 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16
// ROUND_TRIP: linalg.tanh
// ROUND_TRIP: linalg.erf
// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_int(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+ %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %1 = linalg.sub ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %2 = linalg.mul ins(%1, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %3 = linalg.div ins(%2, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %4 = linalg.div_unsigned ins(%3, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %5 = linalg.max ins(%4, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %6 = linalg.min ins(%5, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ return %6 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_int
+
+// NAMED_TO_GENERIC-COUNT-7: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+// NAMED_TO_GENERIC-NOT: linalg.div_unsigned
+// NAMED_TO_GENERIC-NOT: linalg.max
+// NAMED_TO_GENERIC-NOT: linalg.min
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP: linalg.div_unsigned
+// ROUND_TRIP: linalg.max
+// ROUND_TRIP: linalg.min
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_float(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.sub ins(%0, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %2 = linalg.mul ins(%1, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.div ins(%2, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.max ins(%3, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/192290
More information about the Mlir-commits
mailing list