[Mlir-commits] [mlir] [MLIR][Linalg] Specialize more binary elementwise ops (PR #192290)
Julian Oppermann
llvmlistbot at llvm.org
Fri May 22 01:06:10 PDT 2026
https://github.com/jopperm updated https://github.com/llvm/llvm-project/pull/192290
>From 825ef6c0de39dc8fa2868538f40e34fce50bb49d Mon Sep 17 00:00:00 2001
From: Julian Oppermann <julian.oppermann at intel.com>
Date: Wed, 15 Apr 2026 09:30:13 -0700
Subject: [PATCH] [MLIR][Linalg] Specialize more binary elementwise ops
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 6 +-
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 7 +-
.../Dialect/Linalg/Transforms/Specialize.cpp | 252 +++++---
.../Linalg/linalg-morph-multi-step.mlir | 142 ++++-
...oundtrip-morphism-linalg-category-ops.mlir | 199 +++++++
.../roundtrip-morphism-linalg-named-ops.mlir | 136 ++++-
.../Linalg/specialize-generic-ops.mlir | 553 +++++++++++++++++-
...ansform-op-specialize-elemwise-binary.mlir | 239 +++++++-
8 files changed, 1406 insertions(+), 128 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index c565781402e3c..3c7ebd8277dbd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -148,7 +148,11 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp,
/// Checks whether `genericOp` is semantically equivalent to a single linalg
/// elementwise binary op e.g. linalg.sub.
-bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
+/// If `allowNonIdentityMaps` is true, operations with custom indexing maps are
+/// included in the check. Note that these operations can only be represented by
+/// the category op.
+bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp,
+ bool allowNonIdentityMaps = false);
/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
/// Supports two patterns:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 2ba77cea8f16e..238bddcb3b2bd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -279,9 +279,10 @@ bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
return true;
}
-bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
- if (!isaElemwiseSingleUnaryOrBinaryOpInterface(
- op, 2, /*allowNonIdentityMaps=*/false))
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
+ bool allowNonIdentityMaps) {
+ // All basic elemwise checks.
+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2, allowNonIdentityMaps))
return false;
// Check both inputs are used (elementwise).
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a7cd57cf4ed9e..cc8a69cef8e5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -28,13 +28,6 @@ namespace mlir {
#define DEBUG_TYPE "linalg-specialization"
-#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
- (rewriter.replaceOpWithNewOp<NEWOP>( \
- genericOp, \
- ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
- genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
- ValueRange{genericOp.getDpsInits()[0]}))
-
using namespace mlir;
using namespace mlir::linalg;
@@ -42,7 +35,7 @@ using namespace mlir::linalg;
// Specialize linalg generic to elementwise ops.
//===----------------------------------------------------------------------===//
-// Given a elementwise single binary linalg generic op, checks whether the
+// Given an elementwise single binary linalg generic op, checks whether the
// binary op accesses operands as swapped. e.g.
// this differentiates between a linalg-generic body that contains:
// ^bb0(%a: f32, %b: f32, %c : f32):
@@ -66,8 +59,39 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
-// Attempt to specialize linalg.generic to named elementwise ops or
-// linalg.elementwise.
+// Given an elementwise single unary linalg generic op whose body operation is a
+// binary operation, check if one of its operands is a scalar value defined
+// outside the generic op, set its index, and return true. Otherwise return
+// false. The index is unique because the block argument is used at
+// least by one operand, as checked in `isaElemwiseSingleUnaryOpInterface`.
+//
+// Example:
+// %cst = arith.constant 3.14 : f32
+// %0 = linalg.generic { indexing_maps = [#mapA, #mapRes], ... }
+// ins(%A : tensor<?xf32>) outs(...) {
+// ^bb0(%a: f32, %out : f32):
+// %0 = arith.mulf %a, %cst : f32
+// linalg.yield %0: f32
+// } -> tensor<?xf32>
+// Here, the returned index is 1, and the generic op can be represented as
+// %0 = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// indexing_maps = [#mapA, affine_map<(d0) -> ()>, #mapRes]
+// ins(%A, %cst : tensor<?xf32>, f32) outs(...) -> tensor<?xf32>
+static bool findIndexOfScalarOperand(GenericOp genericOp, int &index) {
+ Block *body = genericOp.getBody();
+ Operation *op = &body->front();
+ for (auto [i, v] : llvm::enumerate(op->getOperands())) {
+ if (auto blockArg = dyn_cast<BlockArgument>(v);
+ blockArg && blockArg.getOwner() == body)
+ continue; // not an outside value...
+ index = i;
+ return true;
+ }
+ return false;
+}
+
+// Attempt to specialize unary or binary linalg.generic ops to named elementwise
+// ops or linalg.elementwise.
//
// Example:
// %0 = linalg.generic {
@@ -87,9 +111,16 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
//
// Only the category op can carry non-identity indexing maps; these are
// transferred verbatim from the `genericOp`.
-static FailureOr<LinalgOp>
-specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
- bool emitCategoryOp) {
+//
+// In addition to the canonical forms used by the generalization path, this
+// function can handle the following variations:
+//
+// 1) Swapped operands in binary ops (see the `areBinOpsSwapped` helper)
+// 2) Unary generic ops with a binary body op (see the
+// `findIndexOfScalarOperand` helper)
+static FailureOr<LinalgOp> specializeLinalgElementwise(RewriterBase &rewriter,
+ GenericOp genericOp,
+ bool emitCategoryOp) {
bool hasNonIdentityMaps =
!llvm::all_of(genericOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isIdentity(); });
@@ -100,62 +131,142 @@ specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
genericOp,
"non-identity indexing maps prevent specialization to named op");
+ // Classify the generic op.
+ bool isUnary = genericOp.getNumDpsInputs() == 1;
+ bool isBinary = genericOp.getNumDpsInputs() == 2;
+
+ // Will inspect the body operation to determine named op or elementwise kind.
+ Operation *op = &genericOp.getBody()->front();
+
+ // Detect variations from canonical forms.
+ bool hasSwappedOperands = isBinary && areBinOpsSwapped(genericOp);
+ int scalarOprIdx = -1;
+ bool hasScalarOperand = isUnary && op->getNumOperands() == 2 &&
+ findIndexOfScalarOperand(genericOp, scalarOprIdx);
+
// Helper to dispatch between named op and `linalg.elementwise`.
// Lambdas with explicit template parameter list are a C++20 feature, hence
// the dummy op object.
- auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
+ auto replaceOp = [&](auto namedOp, ElementwiseKind kind,
+ bool mayHoistScalarOperand = true) -> LinalgOp {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ if (hasSwappedOperands)
+ std::swap(inputs[0], inputs[1]);
+
LinalgOp newOp;
- if (!emitCategoryOp)
- newOp = decltype(namedOp)::create(
- rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
- genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
- else
+ if (!emitCategoryOp) {
+ using NamedOpTy = decltype(namedOp);
+ if constexpr (!std::is_null_pointer_v<NamedOpTy>)
+ newOp = NamedOpTy::create(rewriter, genericOp.getLoc(), inputs,
+ genericOp.getDpsInits(),
+ ArrayRef<NamedAttribute>{});
+ else
+ llvm_unreachable("Missing named op type");
+ } else {
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ // Swap indexing maps, too.
+ if (hasSwappedOperands)
+ std::swap(indexingMaps[0], indexingMaps[1]);
+
+ // Represent unary generic op as a binary `linalg.elementwise` with a
+ // scalar operand and broadcasting map.
+ if (hasScalarOperand && mayHoistScalarOperand) {
+ // Adjust inputs and indexing maps accordingly.
+ inputs.insert(inputs.begin() + scalarOprIdx,
+ op->getOperand(scalarOprIdx));
+ auto scalarBroadcastMap =
+ AffineMap::get(genericOp.getNumParallelLoops(), /*symbolCount=*/0,
+ rewriter.getContext());
+ indexingMaps.insert(indexingMaps.begin() + scalarOprIdx,
+ scalarBroadcastMap);
+ }
newOp = ElementwiseOp::create(
- rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
- genericOp.getDpsInits(),
+ rewriter, genericOp.getLoc(), inputs, genericOp.getDpsInits(),
ElementwiseKindAttr::get(rewriter.getContext(), kind),
- genericOp.getIndexingMaps());
+ rewriter.getAffineMapArrayAttr(indexingMaps));
+ }
rewriter.replaceOp(genericOp, newOp);
return newOp;
};
- // Inspect body operation to determine named op or elementwise kind.
- Operation *op = &genericOp.getBody()->front();
+ if (isUnary) {
+ if (isa<math::ExpOp>(op))
+ return replaceOp(ExpOp{}, ElementwiseKind::exp);
+ if (isa<math::LogOp>(op))
+ return replaceOp(LogOp{}, ElementwiseKind::log);
+ if (isa<math::AbsFOp>(op))
+ return replaceOp(AbsOp{}, ElementwiseKind::abs);
+ if (isa<math::CeilOp>(op))
+ return replaceOp(CeilOp{}, ElementwiseKind::ceil);
+ if (isa<math::FloorOp>(op))
+ return replaceOp(FloorOp{}, ElementwiseKind::floor);
+ if (isa<arith::NegFOp>(op))
+ return replaceOp(NegFOp{}, ElementwiseKind::negf);
+ if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+ if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+ divOp.getLhs().getDefiningOp()))
+ if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+ return replaceOp(ReciprocalOp{}, ElementwiseKind::reciprocal,
+ /*mayHoistScalarOperand=*/false);
+ }
+ if (isa<math::RoundOp>(op))
+ return replaceOp(RoundOp{}, ElementwiseKind::round);
+ if (isa<math::SqrtOp>(op))
+ return replaceOp(SqrtOp{}, ElementwiseKind::sqrt);
+ if (isa<math::RsqrtOp>(op))
+ return replaceOp(RsqrtOp{}, ElementwiseKind::rsqrt);
+ if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+ mulOp && mulOp.getLhs() == mulOp.getRhs())
+ return replaceOp(SquareOp{}, ElementwiseKind::square);
+ if (isa<math::TanhOp>(op))
+ return replaceOp(TanhOp{}, ElementwiseKind::tanh);
+ if (isa<math::ErfOp>(op))
+ return replaceOp(ErfOp{}, ElementwiseKind::erf);
+
+ // At this point, we exhaustively checked the available unary named ops. The
+ // 1-input generic op might be representable as a `linalg.elementwise` that
+ // broadcasts a scalar operand. But if we can't emit the category op or
+ // don't have a scalar operand, exit now.
+ if (!emitCategoryOp || !hasScalarOperand)
+ return rewriter.notifyMatchFailure(
+ genericOp, "unary elementwise operation cannot be specialized to "
+ "named or category op");
+ }
- if (isa<math::ExpOp>(op))
- return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
- if (isa<math::LogOp>(op))
- return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
- if (isa<math::AbsFOp>(op))
- return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
- if (isa<math::CeilOp>(op))
- return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
- if (isa<math::FloorOp>(op))
- return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
- if (isa<arith::NegFOp>(op))
- return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
- if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
- if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
- divOp.getLhs().getDefiningOp()))
- if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
- return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
+ // Boolean-typed `linalg.add` and `linalg.mul` require special handling.
+ bool allBool = llvm::all_of(op->getOperands(),
+ [](Value v) { return v.getType().isInteger(1); });
+
+ if (isa<arith::AddIOp, arith::AddFOp, complex::AddOp>(op) ||
+ (allBool && isa<arith::OrIOp>(op)))
+ return replaceOp(AddOp{}, ElementwiseKind::add);
+ if (isa<arith::SubIOp, arith::SubFOp, complex::SubOp>(op))
+ return replaceOp(SubOp{}, ElementwiseKind::sub);
+ if (isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op) ||
+ (allBool && isa<arith::AndIOp>(op)))
+ return replaceOp(MulOp{}, ElementwiseKind::mul);
+ if (isa<arith::DivSIOp, arith::DivFOp, complex::DivOp>(op))
+ return replaceOp(DivOp{}, ElementwiseKind::div);
+ if (isa<arith::DivUIOp>(op))
+ return replaceOp(DivUnsignedOp{}, ElementwiseKind::div_unsigned);
+ if (isa<arith::MaxSIOp, arith::MaximumFOp>(op))
+ return replaceOp(MaxOp{}, ElementwiseKind::max_signed);
+ if (isa<arith::MinSIOp, arith::MinimumFOp>(op))
+ return replaceOp(MinOp{}, ElementwiseKind::min_signed);
+ if (emitCategoryOp) {
+ // No named ops for unsigned maximum/minimum.
+ if (isa<arith::MaxUIOp>(op))
+ return replaceOp(nullptr, ElementwiseKind::max_unsigned);
+ if (isa<arith::MinUIOp>(op))
+ return replaceOp(nullptr, ElementwiseKind::min_unsigned);
}
- if (isa<math::RoundOp>(op))
- return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
- if (isa<math::SqrtOp>(op))
- return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
- if (isa<math::RsqrtOp>(op))
- return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
- if (auto mulOp = dyn_cast<arith::MulFOp>(op);
- mulOp && mulOp.getLhs() == mulOp.getRhs())
- return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
- if (isa<math::TanhOp>(op))
- return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
- if (isa<math::ErfOp>(op))
- return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
+ if (isa<math::PowFOp>(op))
+ return replaceOp(PowFOp{}, ElementwiseKind::powf);
- return failure();
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "elementwise operation cannot be specialized to named or category op");
}
//===----------------------------------------------------------------------===//
@@ -580,10 +691,11 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
const GenericOpSpecializationOptions &options) {
- // Unary elementwise - e.g. exp
- if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
- return specializeLinalgUnaryElementwise(rewriter, genericOp,
- options.emitCategoryOps);
+ // Elementwise - e.g. exp, add
+ if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps) ||
+ isaElemwiseSingleBinaryOpInterface(genericOp, options.emitCategoryOps)) {
+ return specializeLinalgElementwise(rewriter, genericOp,
+ options.emitCategoryOps);
}
// Contraction - e.g. matmul
@@ -636,28 +748,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
return namedOp;
}
- // Elementwise Binary
- if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
- bool swap = areBinOpsSwapped(genericOp);
- Operation *op = &genericOp.getBody()->front();
- if (isa<arith::AddFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
- return namedOp;
- }
- if (isa<arith::SubFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
- return namedOp;
- }
- if (isa<arith::MulFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
- return namedOp;
- }
- if (isa<arith::DivFOp>(op)) {
- LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
- return namedOp;
- }
- }
-
// Convolution - e.g. *conv/pooling*
if (isaConvolutionOpInterface(genericOp))
return specializeLinalgConvolutions(rewriter, genericOp);
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
index 7bad1b7a44d92..0a0ddbcd85a0a 100644
--- a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -1,6 +1,8 @@
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | FileCheck %s --check-prefix=NAMED_TO_GENERIC
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | mlir-opt -linalg-morph-ops=generic-to-named | \
-// RUN: FileCheck %s --check-prefix=ROUND_TRIP
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN: FileCheck %s --check-prefix=ALL,NAMED_TO_GENERIC
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN: mlir-opt -linalg-morph-ops=generic-to-named -split-input-file | \
+// RUN: FileCheck %s --check-prefix=ALL,ROUND_TRIP
func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
@@ -19,6 +21,8 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16
return %erf : tensor<16x8xf32>
}
+// ALL-LABEL: unary_ops
+
// NAMED_TO_GENERIC-COUNT-13: linalg.generic
// NAMED_TO_GENERIC-NOT: linalg.exp
// NAMED_TO_GENERIC-NOT: linalg.log
@@ -48,3 +52,135 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16
// ROUND_TRIP: linalg.tanh
// ROUND_TRIP: linalg.erf
// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_int(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+ %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %1 = linalg.sub ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %2 = linalg.mul ins(%1, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %3 = linalg.div ins(%2, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %4 = linalg.div_unsigned ins(%3, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %5 = linalg.max ins(%4, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %6 = linalg.min ins(%5, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+ return %6 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_int
+
+// NAMED_TO_GENERIC-COUNT-7: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+// NAMED_TO_GENERIC-NOT: linalg.div_unsigned
+// NAMED_TO_GENERIC-NOT: linalg.max
+// NAMED_TO_GENERIC-NOT: linalg.min
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP: linalg.div_unsigned
+// ROUND_TRIP: linalg.max
+// ROUND_TRIP: linalg.min
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_float(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.sub ins(%0, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %2 = linalg.mul ins(%1, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.div ins(%2, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.max ins(%3, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %5 = linalg.min ins(%4, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %6 = linalg.powf ins(%5, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %6 : tensor<?x?xf32>
+}
+
+// ALL-LABEL: binary_ops_float
+
+// NAMED_TO_GENERIC-COUNT-7: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+// NAMED_TO_GENERIC-NOT: linalg.max
+// NAMED_TO_GENERIC-NOT: linalg.min
+// NAMED_TO_GENERIC-NOT: linalg.powf
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP: linalg.max
+// ROUND_TRIP: linalg.min
+// ROUND_TRIP: linalg.powf
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_complex(%A: tensor<?x?xcomplex<f32>>, %B: tensor<?x?xcomplex<f32>>,
+ %Out: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+ %1 = linalg.sub ins(%0, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+ %2 = linalg.mul ins(%1, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+ %3 = linalg.div ins(%2, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+ return %3 : tensor<?x?xcomplex<f32>>
+}
+
+// ALL-LABEL: binary_ops_complex
+
+// NAMED_TO_GENERIC-COUNT-4: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_bool(%A: tensor<?x?xi1>, %B: tensor<?x?xi1>,
+ %Out: tensor<?x?xi1>) -> tensor<?x?xi1> {
+ %0 = linalg.add ins(%A, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+ outs(%Out : tensor<?x?xi1>) -> tensor<?x?xi1>
+ %1 = linalg.mul ins(%0, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+ outs(%Out : tensor<?x?xi1>) -> tensor<?x?xi1>
+ return %1 : tensor<?x?xi1>
+}
+
+// ALL-LABEL: binary_ops_bool
+
+// NAMED_TO_GENERIC-COUNT-2: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.mul
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP-NOT: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
index e7bdad8e56883..7fd7bebdb5249 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
@@ -98,6 +98,205 @@ func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> t
// CHECK-SAME: ins(%[[A]] : tensor<?xf32>)
// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// -----
+
+func.func @binary_ops_int(%A: memref<10xi32>, %B: memref<10xi32>,
+ %Out: memref<10xi32>) {
+ linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<sub>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<div>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<div_unsigned>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ return
+}
+
+// CHECK-LABEL: binary_ops_int
+// CHECK-SAME: %[[A:.+]]: memref<10xi32>, %[[B:.+]]: memref<10xi32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div_unsigned>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+
+// -----
+
+func.func @binary_ops_float(%A: memref<10xf32>, %B: memref<10xf32>,
+ %Out: memref<10xf32>) {
+ linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<sub>
+ ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<div>
+ ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+ ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+ ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<powf>
+ ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ return
+}
+
+// CHECK-LABEL: binary_ops_float
+// CHECK-SAME: %[[A:.+]]: memref<10xf32>, %[[B:.+]]: memref<10xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<powf>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+
+// -----
+
+func.func @binary_ops_complex(%A: memref<10xcomplex<f32>>, %B: memref<10xcomplex<f32>>,
+ %Out: memref<10xcomplex<f32>>) {
+ linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ linalg.elementwise kind=#linalg.elementwise_kind<sub>
+ ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ linalg.elementwise kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ linalg.elementwise kind=#linalg.elementwise_kind<div>
+ ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ return
+}
+
+// CHECK-LABEL: binary_ops_complex
+// CHECK-SAME: %[[A:.+]]: memref<10xcomplex<f32>>, %[[B:.+]]: memref<10xcomplex<f32>>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xcomplex<f32>>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+
+// -----
+
+func.func @binary_ops_bool(%A: memref<10xi1>, %B: memref<10xi1>,
+ %Out: memref<10xi1>) {
+ linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+ linalg.elementwise kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+ return
+}
+
+// CHECK-LABEL: binary_ops_bool
+// CHECK-SAME: %[[A:.+]]: memref<10xi1>, %[[B:.+]]: memref<10xi1>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi1>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
+
+// -----
+
+func.func @binary_ops_uint(%A: memref<10xi32>, %B: memref<10xi32>,
+ %Out: memref<10xi32>) {
+ linalg.elementwise kind=#linalg.elementwise_kind<max_unsigned>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<min_unsigned>
+ ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ return
+}
+
+// CHECK-LABEL: binary_ops_uint
+// CHECK-SAME: %[[A:.+]]: memref<10xi32>, %[[B:.+]]: memref<10xi32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<max_unsigned>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<min_unsigned>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+
+// -----
+
+func.func @binary_ops_non_identity(%A: tensor<?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.elementwise
+ kind=#linalg.elementwise_kind<add>
+ indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>]
+ ins(%A, %B : tensor<?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[MAP_ID:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: binary_ops_non_identity
+// CHECK-SAME: %[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]], #[[MAP_ID]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
// -----
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
diff --git a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir
index 69a1a7f650810..604077f6f7834 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir
@@ -66,22 +66,136 @@ func.func @unary_ops(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
// -----
+
+func.func @binary_ops_int(%A: memref<10xi32>, %B: memref<10xi32>,
+ %Out: memref<10xi32>) {
+ linalg.add ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.sub ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.mul ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.div ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.div_unsigned ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.max ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ linalg.min ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+ return
+}
-func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
- %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.add
- ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
+// CHECK-LABEL: binary_ops_int
+// CHECK-SAME: %[[A:.+]]: memref<10xi32>, %[[B:.+]]: memref<10xi32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.sub
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.div
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.div_unsigned
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.max
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.min
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+
+// -----
+
+func.func @binary_ops_float(%A: memref<10xf32>, %B: memref<10xf32>,
+ %Out: memref<10xf32>) {
+ linalg.add ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.sub ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.mul ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.div ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.max ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.min ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ linalg.powf ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+ return
}
-// CHECK-LABEL: binary_add
-// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-LABEL: binary_ops_float
+// CHECK-SAME: %[[A:.+]]: memref<10xf32>, %[[B:.+]]: memref<10xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xf32>)
// CHECK-NOT: linalg.generic
// CHECK: linalg.add
-// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.sub
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.div
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.max
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.min
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.powf
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+
+// -----
+
+func.func @binary_ops_complex(%A: memref<10xcomplex<f32>>, %B: memref<10xcomplex<f32>>,
+ %Out: memref<10xcomplex<f32>>) {
+ linalg.add ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ linalg.sub ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ linalg.mul ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ linalg.div ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+ outs(%Out : memref<10xcomplex<f32>>)
+ return
+}
+
+// CHECK-LABEL: binary_ops_complex
+// CHECK-SAME: %[[A:.+]]: memref<10xcomplex<f32>>, %[[B:.+]]: memref<10xcomplex<f32>>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xcomplex<f32>>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.sub
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.div
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+
+// -----
+
+func.func @binary_ops_bool(%A: memref<10xi1>, %B: memref<10xi1>,
+ %Out: memref<10xi1>) {
+ linalg.add ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+ linalg.mul ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+ return
+}
+
+// CHECK-LABEL: binary_ops_bool
+// CHECK-SAME: %[[A:.+]]: memref<10xi1>, %[[B:.+]]: memref<10xi1>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi1>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
// -----
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 1cca2b86ddc25..3d6c2962731c9 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -221,31 +221,558 @@ func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> t
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
- %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @binary_ops_int(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+ %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.addi %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.subi %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ %2 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%1, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.muli %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ %3 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%2, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.divsi %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ %4 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%3, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.divui %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ %5 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%4, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.maxsi %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ %6 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%5, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.minsi %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ return %6 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_int
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xi32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED-NOT: linalg.generic
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.sub
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES2:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES3:.+]] = linalg.div
+// NAMED-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES4:.+]] = linalg.div_unsigned
+// NAMED-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES5:.+]] = linalg.max
+// NAMED-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES6:.+]] = linalg.min
+// NAMED-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CATEGORY-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES4:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div_unsigned>
+// CATEGORY-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES5:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CATEGORY-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES6:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CATEGORY-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_float(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic
{indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%Out : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.divf %in, %in_0 : f32
- linalg.yield %1 : f32
+ %v = arith.addf %in, %in_0 : f32
+ linalg.yield %v : f32
} -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.subf %in, %in_0 : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ %2 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%1, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.mulf %in, %in_0 : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ %3 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%2, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.divf %in, %in_0 : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ %4 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%3, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.maximumf %in, %in_0 : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ %5 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%4, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.minimumf %in, %in_0 : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ %6 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%5, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = math.powf %in, %in_0 : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ return %6 : tensor<?x?xf32>
}
-// ALL-LABEL: binary_op_div
-// ALL-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,
-// ALL-SAME: %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// ALL-LABEL: binary_ops_float
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xf32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
// NAMED-NOT: linalg.generic
-// NAMED: linalg.div
-// NAMED-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// NAMED-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.sub
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES2:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES3:.+]] = linalg.div
+// NAMED-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES4:.+]] = linalg.max
+// NAMED-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES5:.+]] = linalg.min
+// NAMED-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES6:.+]] = linalg.powf
+// NAMED-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
-// Not supported yet.
-// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CATEGORY-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES4:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CATEGORY-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES5:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CATEGORY-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES6:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<powf>
+// CATEGORY-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_complex(%A: tensor<?x?xcomplex<f32>>,
+ %B: tensor<?x?xcomplex<f32>>,
+ %Out: tensor<?x?xcomplex<f32>>)
+ -> tensor<?x?xcomplex<f32>> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %v = complex.add %in, %in_0 : complex<f32>
+ linalg.yield %v : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %v = complex.sub %in, %in_0 : complex<f32>
+ linalg.yield %v : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ %2 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%1, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %v = complex.mul %in, %in_0 : complex<f32>
+ linalg.yield %v : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ %3 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%2, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+ outs(%Out : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %v = complex.div %in, %in_0 : complex<f32>
+ linalg.yield %v : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ return %3 : tensor<?x?xcomplex<f32>>
+}
+
+// ALL-LABEL: binary_ops_complex
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xcomplex<f32>>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED-NOT: linalg.generic
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.sub
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES2:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES3:.+]] = linalg.div
+// NAMED-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CATEGORY-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_bool(%A: tensor<?x?xi1>, %B: tensor<?x?xi1>,
+ %Out: tensor<?x?xi1>) -> tensor<?x?xi1> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+ outs(%Out : tensor<?x?xi1>) {
+ ^bb0(%in: i1, %in_0: i1, %out: i1):
+ %v = arith.ori %in, %in_0 : i1
+ linalg.yield %v : i1
+ } -> tensor<?x?xi1>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+ outs(%Out : tensor<?x?xi1>) {
+ ^bb0(%in: i1, %in_0: i1, %out: i1):
+ %v = arith.andi %in, %in_0 : i1
+ linalg.yield %v : i1
+ } -> tensor<?x?xi1>
+ return %1 : tensor<?x?xi1>
+}
+
+// ALL-LABEL: binary_ops_bool
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xi1>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED-NOT: linalg.generic
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_uint(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+ %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.maxui %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+ outs(%Out : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %v = arith.minui %in, %in_0 : i32
+ linalg.yield %v : i32
+ } -> tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_uint
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xi32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// No named ops yet for unsigned max/min -> expect no change.
+// NAMED-NOT: linalg.{{max|min}}
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<max_unsigned>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<min_unsigned>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+func.func @binary_ops_non_identity(%A: tensor<?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.addf %in, %in_0 : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// ALL-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// ALL-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// ALL-DAG: #[[MAP_ID:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// ALL: binary_ops_non_identity
+// ALL-SAME: %[[A:.+]]: [[TTY1D:tensor<\?xf32>]], %[[B:.+]]: [[TTY:tensor<\?x\?xf32>]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// Named ops cannot carry user-defined indexing maps -> expect no change.
+// NAMED-NOT: linalg.add
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]], #[[MAP_ID]]]
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY1D]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#bcast = affine_map<(d0, d1) -> (d0)>
+func.func @binary_ops_swapped(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?xf32>, %Out: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.addf %in_0, %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #bcast, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %C : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %v = arith.subf %in_0, %in : f32
+ linalg.yield %v : f32
+ } -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// ALL-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d0)>
+// ALL-DAG: #[[MAP_ID:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// ALL: binary_ops_swapped
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xf32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[C:.+]]: [[TTY1D:tensor<\?xf32>]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[B]], %[[A]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED-NOT: linalg.sub
+// NAMED: linalg.generic
+// NAMED-SAME: ins(%[[RES0]], %[[C]] : [[TTY]], [[TTY1D]])
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[B]], %[[A]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_ID]], #[[MAP_ID]]]
+// CATEGORY-SAME: ins(%[[C]], %[[RES0]] : [[TTY1D]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @unary_op_with_scalar(%A: tensor<?xi32>, %Out: tensor<?xi32>)
+ -> tensor<?xi32> {
+ %cst = arith.constant 123 : i32
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]}
+ ins(%A : tensor<?xi32>)
+ outs(%Out : tensor<?xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ %v = arith.addi %cst, %in : i32
+ linalg.yield %v : i32
+ } -> tensor<?xi32>
+ return %0 : tensor<?xi32>
+}
+
+// CATEGORY-DAG: #[[MAP_ID:.+]] = affine_map<(d0) -> (d0)>
+// CATEGORY-DAG: #[[MAP_BC:.+]] = affine_map<(d0) -> ()>
+// ALL: unary_op_with_scalar
+// CATEGORY-SAME: %[[A:.+]]: [[TTY:tensor<\?xi32>]],
+// CATEGORY-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// Named ops cannot broadcast from a scalar operand -> expect no change.
+// NAMED-NOT: linalg.add
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[CST:.+]] = arith.constant 123 : i32
+// CATEGORY: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_ID]], #[[MAP_ID]]]
+// CATEGORY-SAME: ins(%[[CST]], %[[A]] : i32, [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @negative_unary_op_using_block_arg_twice(%A: tensor<?xi32>,
+ %Out: tensor<?xi32>)
+ -> tensor<?xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]}
+ ins(%A : tensor<?xi32>)
+ outs(%Out : tensor<?xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ %v = arith.addi %in, %in : i32
+ linalg.yield %v : i32
+ } -> tensor<?xi32>
+ return %0 : tensor<?xi32>
+}
+
+// ALL-LABEL: negative_unary_op_using_block_arg_twice
+
+// Named ops cannot broadcast from a scalar operand -> expect no change.
+// NAMED-NOT: linalg.add
+
+// There is no scalar operand to hoist -> expect no change.
+// CATEGORY-NOT: linalg.elementwise kind=#linalg.elementwise_kind<add>
+
+// ALL: linalg.generic
// -----
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
index d45025de931cd..0f444a2d6a71b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
@@ -1,7 +1,98 @@
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_add_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %1 = arith.addi %in, %in_0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_add_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>, %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_sub_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %1 = arith.subi %in, %in_0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_sub_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>, %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_mul_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %1 = arith.muli %in, %in_0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_mul_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>, %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_div_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %1 = arith.divsi %in, %in_0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_div_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>, %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_div_unsigned_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %1 = arith.divui %in, %in_0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_div_unsigned_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>, %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div_unsigned ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_max_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %1 = arith.maxsi %in, %in_0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_max_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>, %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.max ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_min_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %1 = arith.minsi %in, %in_0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_min_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>, %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.min ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_add_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.addf %in, %in_0 : f32
@@ -9,12 +100,12 @@ func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: specialize_add
+// CHECK-LABEL: specialize_add_float
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_sub_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.subf %in, %in_0 : f32
@@ -22,50 +113,166 @@ func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: specialize_sub
+// CHECK-LABEL: specialize_sub_float
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_mul_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.subf %in_0, %in : f32
+ %1 = arith.mulf %in, %in_0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: specialize_sub
+// CHECK-LABEL: specialize_mul_float
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_div_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.mulf %in, %in_0 : f32
+ %1 = arith.divf %in, %in_0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: specialize_mul
+// CHECK-LABEL: specialize_div_float
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_max_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.divf %in, %in_0 : f32
+ %1 = arith.maximumf %in, %in_0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: specialize_div
+// CHECK-LABEL: specialize_max_float
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
-// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.max ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+func.func @specialize_min_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.minimumf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_min_float
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.min ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_powf_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = math.powf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_powf_float
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.powf ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_add_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %1 = complex.add %in, %in_0 : complex<f32>
+ linalg.yield %1 : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_add_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_sub_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %1 = complex.sub %in, %in_0 : complex<f32>
+ linalg.yield %1 : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_sub_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_mul_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %1 = complex.mul %in, %in_0 : complex<f32>
+ linalg.yield %1 : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_mul_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_div_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+ ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+ %1 = complex.div %in, %in_0 : complex<f32>
+ linalg.yield %1 : complex<f32>
+ } -> tensor<?x?xcomplex<f32>>
+ return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_div_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_add_bool(%arg0: tensor<?x?xi1>, %arg1: tensor<?x?xi1>, %arg2: tensor<?x?xi1>) -> tensor<?x?xi1> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi1>, tensor<?x?xi1>) outs(%arg2 : tensor<?x?xi1>) {
+ ^bb0(%in: i1, %in_0: i1, %out: i1):
+ %1 = arith.ori %in, %in_0 : i1
+ linalg.yield %1 : i1
+ } -> tensor<?x?xi1>
+ return %0 : tensor<?x?xi1>
+}
+// CHECK-LABEL: specialize_add_bool
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi1>, %[[ARG1:.+]]: tensor<?x?xi1>, %[[ARG2:.+]]: tensor<?x?xi1>) -> tensor<?x?xi1>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi1>, tensor<?x?xi1>) outs(%[[ARG2]] : tensor<?x?xi1>) -> tensor<?x?xi1>
+
+func.func @specialize_mul_bool(%arg0: tensor<?x?xi1>, %arg1: tensor<?x?xi1>, %arg2: tensor<?x?xi1>) -> tensor<?x?xi1> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi1>, tensor<?x?xi1>) outs(%arg2 : tensor<?x?xi1>) {
+ ^bb0(%in: i1, %in_0: i1, %out: i1):
+ %1 = arith.andi %in, %in_0 : i1
+ linalg.yield %1 : i1
+ } -> tensor<?x?xi1>
+ return %0 : tensor<?x?xi1>
+}
+// CHECK-LABEL: specialize_mul_bool
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi1>, %[[ARG1:.+]]: tensor<?x?xi1>, %[[ARG2:.+]]: tensor<?x?xi1>) -> tensor<?x?xi1>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi1>, tensor<?x?xi1>) outs(%[[ARG2]] : tensor<?x?xi1>) -> tensor<?x?xi1>
+
+func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.subf %in_0, %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub_swapped_operands
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
More information about the Mlir-commits
mailing list