[Mlir-commits] [mlir] [MLIR][Linalg] Specialize more binary elementwise ops (PR #192290)

Julian Oppermann llvmlistbot at llvm.org
Fri May 22 01:06:10 PDT 2026


https://github.com/jopperm updated https://github.com/llvm/llvm-project/pull/192290

>From 825ef6c0de39dc8fa2868538f40e34fce50bb49d Mon Sep 17 00:00:00 2001
From: Julian Oppermann <julian.oppermann at intel.com>
Date: Wed, 15 Apr 2026 09:30:13 -0700
Subject: [PATCH] [MLIR][Linalg] Specialize more binary elementwise ops

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |   6 +-
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |   7 +-
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 252 +++++---
 .../Linalg/linalg-morph-multi-step.mlir       | 142 ++++-
 ...oundtrip-morphism-linalg-category-ops.mlir | 199 +++++++
 .../roundtrip-morphism-linalg-named-ops.mlir  | 136 ++++-
 .../Linalg/specialize-generic-ops.mlir        | 553 +++++++++++++++++-
 ...ansform-op-specialize-elemwise-binary.mlir | 239 +++++++-
 8 files changed, 1406 insertions(+), 128 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index c565781402e3c..3c7ebd8277dbd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -148,7 +148,11 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp,
 
 /// Checks whether `genericOp` is semantically equivalent to a single linalg
 /// elementwise binary op e.g. linalg.sub.
-bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
+/// If `allowNonIdentityMaps` is true, operations with custom indexing maps are
+/// included in the check. Note that these operations can only be represented by
+/// the category op.
+bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp,
+                                        bool allowNonIdentityMaps = false);
 
 /// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
 /// Supports two patterns:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 2ba77cea8f16e..238bddcb3b2bd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -279,9 +279,10 @@ bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
   return true;
 }
 
-bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
-  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(
-          op, 2, /*allowNonIdentityMaps=*/false))
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
+                                                bool allowNonIdentityMaps) {
+  // All basic elemwise checks.
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2, allowNonIdentityMaps))
     return false;
 
   // Check both inputs are used (elementwise).
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a7cd57cf4ed9e..cc8a69cef8e5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -28,13 +28,6 @@ namespace mlir {
 
 #define DEBUG_TYPE "linalg-specialization"
 
-#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
-  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
-      genericOp,                                                               \
-      ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0],            \
-                 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]},           \
-      ValueRange{genericOp.getDpsInits()[0]}))
-
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -42,7 +35,7 @@ using namespace mlir::linalg;
 // Specialize linalg generic to elementwise ops.
 //===----------------------------------------------------------------------===//
 
-// Given a elementwise single binary linalg generic op, checks whether the
+// Given an elementwise single binary linalg generic op, checks whether the
 // binary op accesses operands as swapped. e.g.
 // this differentiates between a linalg-generic body that contains:
 //    ^bb0(%a: f32, %b: f32, %c : f32):
@@ -66,8 +59,39 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
-// Attempt to specialize linalg.generic to named elementwise ops or
-// linalg.elementwise.
+// Given an elementwise single unary linalg generic op whose body operation is a
+// binary operation, check if one of its operands is a scalar value defined
+// outside the generic op, set its index, and return true. Otherwise return
+// false. The index is unique because the block argument is used at
+// least by one operand, as checked in `isaElemwiseSingleUnaryOpInterface`.
+//
+// Example:
+//   %cst = arith.constant 3.14 : f32
+//   %0 = linalg.generic { indexing_maps = [#mapA, #mapRes], ... }
+//          ins(%A : tensor<?xf32>) outs(...) {
+//   ^bb0(%a: f32, %out : f32):
+//     %0 = arith.mulf %a, %cst : f32
+//     linalg.yield %0: f32
+//   } -> tensor<?xf32>
+// Here, the returned index is 1, and the generic op can be represented as
+//   %0 = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+//          indexing_maps = [#mapA, affine_map<(d0) -> ()>, #mapRes]
+//          ins(%A, %cst : tensor<?xf32>, f32) outs(...) -> tensor<?xf32>
+static bool findIndexOfScalarOperand(GenericOp genericOp, int &index) {
+  Block *body = genericOp.getBody();
+  Operation *op = &body->front();
+  for (auto [i, v] : llvm::enumerate(op->getOperands())) {
+    if (auto blockArg = dyn_cast<BlockArgument>(v);
+        blockArg && blockArg.getOwner() == body)
+      continue; // not an outside value...
+    index = i;
+    return true;
+  }
+  return false;
+}
+
+// Attempt to specialize unary or binary linalg.generic ops to named elementwise
+// ops or linalg.elementwise.
 //
 // Example:
 //   %0 = linalg.generic {
@@ -87,9 +111,16 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
 //
 // Only the category op can carry non-identity indexing maps; these are
 // transferred verbatim from the `genericOp`.
-static FailureOr<LinalgOp>
-specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
-                                 bool emitCategoryOp) {
+//
+// In addition to the canonical forms used by the generalization path, this
+// function can handle the following variations:
+//
+// 1) Swapped operands in binary ops (see the `areBinOpsSwapped` helper)
+// 2) Unary generic ops with a binary body op (see the
+//    `findIndexOfScalarOperand` helper)
+static FailureOr<LinalgOp> specializeLinalgElementwise(RewriterBase &rewriter,
+                                                       GenericOp genericOp,
+                                                       bool emitCategoryOp) {
   bool hasNonIdentityMaps =
       !llvm::all_of(genericOp.getIndexingMapsArray(),
                     [](AffineMap map) { return map.isIdentity(); });
@@ -100,62 +131,142 @@ specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
         genericOp,
         "non-identity indexing maps prevent specialization to named op");
 
+  // Classify the generic op.
+  bool isUnary = genericOp.getNumDpsInputs() == 1;
+  bool isBinary = genericOp.getNumDpsInputs() == 2;
+
+  // Will inspect the body operation to determine named op or elementwise kind.
+  Operation *op = &genericOp.getBody()->front();
+
+  // Detect variations from canonical forms.
+  bool hasSwappedOperands = isBinary && areBinOpsSwapped(genericOp);
+  int scalarOprIdx = -1;
+  bool hasScalarOperand = isUnary && op->getNumOperands() == 2 &&
+                          findIndexOfScalarOperand(genericOp, scalarOprIdx);
+
   // Helper to dispatch between named op and `linalg.elementwise`.
   // Lambdas with explicit template parameter list are a C++20 feature, hence
   // the dummy op object.
-  auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
+  auto replaceOp = [&](auto namedOp, ElementwiseKind kind,
+                       bool mayHoistScalarOperand = true) -> LinalgOp {
+    SmallVector<Value> inputs = genericOp.getDpsInputs();
+    if (hasSwappedOperands)
+      std::swap(inputs[0], inputs[1]);
+
     LinalgOp newOp;
-    if (!emitCategoryOp)
-      newOp = decltype(namedOp)::create(
-          rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
-          genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
-    else
+    if (!emitCategoryOp) {
+      using NamedOpTy = decltype(namedOp);
+      if constexpr (!std::is_null_pointer_v<NamedOpTy>)
+        newOp = NamedOpTy::create(rewriter, genericOp.getLoc(), inputs,
+                                  genericOp.getDpsInits(),
+                                  ArrayRef<NamedAttribute>{});
+      else
+        llvm_unreachable("Missing named op type");
+    } else {
+      SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+      // Swap indexing maps, too.
+      if (hasSwappedOperands)
+        std::swap(indexingMaps[0], indexingMaps[1]);
+
+      // Represent unary generic op as a binary `linalg.elementwise` with a
+      // scalar operand and broadcasting map.
+      if (hasScalarOperand && mayHoistScalarOperand) {
+        // Adjust inputs and indexing maps accordingly.
+        inputs.insert(inputs.begin() + scalarOprIdx,
+                      op->getOperand(scalarOprIdx));
+        auto scalarBroadcastMap =
+            AffineMap::get(genericOp.getNumParallelLoops(), /*symbolCount=*/0,
+                           rewriter.getContext());
+        indexingMaps.insert(indexingMaps.begin() + scalarOprIdx,
+                            scalarBroadcastMap);
+      }
       newOp = ElementwiseOp::create(
-          rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
-          genericOp.getDpsInits(),
+          rewriter, genericOp.getLoc(), inputs, genericOp.getDpsInits(),
           ElementwiseKindAttr::get(rewriter.getContext(), kind),
-          genericOp.getIndexingMaps());
+          rewriter.getAffineMapArrayAttr(indexingMaps));
+    }
 
     rewriter.replaceOp(genericOp, newOp);
     return newOp;
   };
 
-  // Inspect body operation to determine named op or elementwise kind.
-  Operation *op = &genericOp.getBody()->front();
+  if (isUnary) {
+    if (isa<math::ExpOp>(op))
+      return replaceOp(ExpOp{}, ElementwiseKind::exp);
+    if (isa<math::LogOp>(op))
+      return replaceOp(LogOp{}, ElementwiseKind::log);
+    if (isa<math::AbsFOp>(op))
+      return replaceOp(AbsOp{}, ElementwiseKind::abs);
+    if (isa<math::CeilOp>(op))
+      return replaceOp(CeilOp{}, ElementwiseKind::ceil);
+    if (isa<math::FloorOp>(op))
+      return replaceOp(FloorOp{}, ElementwiseKind::floor);
+    if (isa<arith::NegFOp>(op))
+      return replaceOp(NegFOp{}, ElementwiseKind::negf);
+    if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+      if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+              divOp.getLhs().getDefiningOp()))
+        if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+          return replaceOp(ReciprocalOp{}, ElementwiseKind::reciprocal,
+                           /*mayHoistScalarOperand=*/false);
+    }
+    if (isa<math::RoundOp>(op))
+      return replaceOp(RoundOp{}, ElementwiseKind::round);
+    if (isa<math::SqrtOp>(op))
+      return replaceOp(SqrtOp{}, ElementwiseKind::sqrt);
+    if (isa<math::RsqrtOp>(op))
+      return replaceOp(RsqrtOp{}, ElementwiseKind::rsqrt);
+    if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+        mulOp && mulOp.getLhs() == mulOp.getRhs())
+      return replaceOp(SquareOp{}, ElementwiseKind::square);
+    if (isa<math::TanhOp>(op))
+      return replaceOp(TanhOp{}, ElementwiseKind::tanh);
+    if (isa<math::ErfOp>(op))
+      return replaceOp(ErfOp{}, ElementwiseKind::erf);
+
+    // At this point, we exhaustively checked the available unary named ops. The
+    // 1-input generic op might be representable as a `linalg.elementwise` that
+    // broadcasts a scalar operand. But if we can't emit the category op or
+    // don't have a scalar operand, exit now.
+    if (!emitCategoryOp || !hasScalarOperand)
+      return rewriter.notifyMatchFailure(
+          genericOp, "unary elementwise operation cannot be specialized to "
+                     "named or category op");
+  }
 
-  if (isa<math::ExpOp>(op))
-    return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
-  if (isa<math::LogOp>(op))
-    return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
-  if (isa<math::AbsFOp>(op))
-    return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
-  if (isa<math::CeilOp>(op))
-    return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
-  if (isa<math::FloorOp>(op))
-    return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
-  if (isa<arith::NegFOp>(op))
-    return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
-  if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
-    if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
-            divOp.getLhs().getDefiningOp()))
-      if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
-        return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
+  // Boolean-typed `linalg.add` and `linalg.mul` require special handling.
+  bool allBool = llvm::all_of(op->getOperands(),
+                              [](Value v) { return v.getType().isInteger(1); });
+
+  if (isa<arith::AddIOp, arith::AddFOp, complex::AddOp>(op) ||
+      (allBool && isa<arith::OrIOp>(op)))
+    return replaceOp(AddOp{}, ElementwiseKind::add);
+  if (isa<arith::SubIOp, arith::SubFOp, complex::SubOp>(op))
+    return replaceOp(SubOp{}, ElementwiseKind::sub);
+  if (isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op) ||
+      (allBool && isa<arith::AndIOp>(op)))
+    return replaceOp(MulOp{}, ElementwiseKind::mul);
+  if (isa<arith::DivSIOp, arith::DivFOp, complex::DivOp>(op))
+    return replaceOp(DivOp{}, ElementwiseKind::div);
+  if (isa<arith::DivUIOp>(op))
+    return replaceOp(DivUnsignedOp{}, ElementwiseKind::div_unsigned);
+  if (isa<arith::MaxSIOp, arith::MaximumFOp>(op))
+    return replaceOp(MaxOp{}, ElementwiseKind::max_signed);
+  if (isa<arith::MinSIOp, arith::MinimumFOp>(op))
+    return replaceOp(MinOp{}, ElementwiseKind::min_signed);
+  if (emitCategoryOp) {
+    // No named ops for unsigned maximum/minimum.
+    if (isa<arith::MaxUIOp>(op))
+      return replaceOp(nullptr, ElementwiseKind::max_unsigned);
+    if (isa<arith::MinUIOp>(op))
+      return replaceOp(nullptr, ElementwiseKind::min_unsigned);
   }
-  if (isa<math::RoundOp>(op))
-    return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
-  if (isa<math::SqrtOp>(op))
-    return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
-  if (isa<math::RsqrtOp>(op))
-    return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
-  if (auto mulOp = dyn_cast<arith::MulFOp>(op);
-      mulOp && mulOp.getLhs() == mulOp.getRhs())
-    return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
-  if (isa<math::TanhOp>(op))
-    return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
-  if (isa<math::ErfOp>(op))
-    return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
+  if (isa<math::PowFOp>(op))
+    return replaceOp(PowFOp{}, ElementwiseKind::powf);
 
-  return failure();
+  return rewriter.notifyMatchFailure(
+      genericOp,
+      "elementwise operation cannot be specialized to named or category op");
 }
 
 //===----------------------------------------------------------------------===//
@@ -580,10 +691,11 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
     RewriterBase &rewriter, GenericOp genericOp,
     const GenericOpSpecializationOptions &options) {
-  // Unary elementwise - e.g. exp
-  if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
-    return specializeLinalgUnaryElementwise(rewriter, genericOp,
-                                            options.emitCategoryOps);
+  // Elementwise - e.g. exp, add
+  if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps) ||
+      isaElemwiseSingleBinaryOpInterface(genericOp, options.emitCategoryOps)) {
+    return specializeLinalgElementwise(rewriter, genericOp,
+                                       options.emitCategoryOps);
   }
 
   // Contraction - e.g. matmul
@@ -636,28 +748,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
     return namedOp;
   }
 
-  // Elementwise Binary
-  if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
-    bool swap = areBinOpsSwapped(genericOp);
-    Operation *op = &genericOp.getBody()->front();
-    if (isa<arith::AddFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
-      return namedOp;
-    }
-    if (isa<arith::SubFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
-      return namedOp;
-    }
-    if (isa<arith::MulFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
-      return namedOp;
-    }
-    if (isa<arith::DivFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
-      return namedOp;
-    }
-  }
-
   // Convolution - e.g. *conv/pooling*
   if (isaConvolutionOpInterface(genericOp))
     return specializeLinalgConvolutions(rewriter, genericOp);
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
index 7bad1b7a44d92..0a0ddbcd85a0a 100644
--- a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -1,6 +1,8 @@
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic |  FileCheck %s  --check-prefix=NAMED_TO_GENERIC
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic |  mlir-opt -linalg-morph-ops=generic-to-named | \
-// RUN:   FileCheck %s  --check-prefix=ROUND_TRIP
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN:   FileCheck %s  --check-prefix=ALL,NAMED_TO_GENERIC
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN:   mlir-opt -linalg-morph-ops=generic-to-named -split-input-file | \
+// RUN:     FileCheck %s  --check-prefix=ALL,ROUND_TRIP
 
 func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
   %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
@@ -19,6 +21,8 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16
   return %erf :  tensor<16x8xf32>
 }
 
+// ALL-LABEL: unary_ops
+
 // NAMED_TO_GENERIC-COUNT-13: linalg.generic
 // NAMED_TO_GENERIC-NOT: linalg.exp
 // NAMED_TO_GENERIC-NOT: linalg.log
@@ -48,3 +52,135 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16
 // ROUND_TRIP: linalg.tanh
 // ROUND_TRIP: linalg.erf
 // ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_int(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+                          %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.add ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %1 = linalg.sub ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %2 = linalg.mul ins(%1, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %3 = linalg.div ins(%2, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %4 = linalg.div_unsigned ins(%3, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %5 = linalg.max ins(%4, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %6 = linalg.min ins(%5, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  return %6 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_int
+
+// NAMED_TO_GENERIC-COUNT-7: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+// NAMED_TO_GENERIC-NOT: linalg.div_unsigned
+// NAMED_TO_GENERIC-NOT: linalg.max
+// NAMED_TO_GENERIC-NOT: linalg.min
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP: linalg.div_unsigned
+// ROUND_TRIP: linalg.max
+// ROUND_TRIP: linalg.min
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_float(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                            %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.sub ins(%0, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %2 = linalg.mul ins(%1, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %3 = linalg.div ins(%2, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %4 = linalg.max ins(%3, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %5 = linalg.min ins(%4, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %6 = linalg.powf ins(%5, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %6 : tensor<?x?xf32>
+}
+
+// ALL-LABEL: binary_ops_float
+
+// NAMED_TO_GENERIC-COUNT-7: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+// NAMED_TO_GENERIC-NOT: linalg.max
+// NAMED_TO_GENERIC-NOT: linalg.min
+// NAMED_TO_GENERIC-NOT: linalg.powf
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP: linalg.max
+// ROUND_TRIP: linalg.min
+// ROUND_TRIP: linalg.powf
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_complex(%A: tensor<?x?xcomplex<f32>>, %B: tensor<?x?xcomplex<f32>>,
+                              %Out: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+  %0 = linalg.add ins(%A, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+                  outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+  %1 = linalg.sub ins(%0, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+                  outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+  %2 = linalg.mul ins(%1, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+                  outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+  %3 = linalg.div ins(%2, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+                  outs(%Out : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+  return %3 : tensor<?x?xcomplex<f32>>
+}
+
+// ALL-LABEL: binary_ops_complex
+
+// NAMED_TO_GENERIC-COUNT-4: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_bool(%A: tensor<?x?xi1>, %B: tensor<?x?xi1>,
+                           %Out: tensor<?x?xi1>) -> tensor<?x?xi1> {
+  %0 = linalg.add ins(%A, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+                  outs(%Out : tensor<?x?xi1>) -> tensor<?x?xi1>
+  %1 = linalg.mul ins(%0, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+                  outs(%Out : tensor<?x?xi1>) -> tensor<?x?xi1>
+  return %1 : tensor<?x?xi1>
+}
+
+// ALL-LABEL: binary_ops_bool
+
+// NAMED_TO_GENERIC-COUNT-2: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.mul
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP-NOT: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
index e7bdad8e56883..7fd7bebdb5249 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir
@@ -98,6 +98,205 @@ func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> t
 // CHECK-SAME: ins(%[[A]] : tensor<?xf32>)
 // CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// -----
+              
+func.func @binary_ops_int(%A: memref<10xi32>, %B: memref<10xi32>,
+                          %Out: memref<10xi32>) {
+  linalg.elementwise kind=#linalg.elementwise_kind<add>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<sub>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<mul>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<div>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<div_unsigned>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  return
+}
+
+// CHECK-LABEL: binary_ops_int
+// CHECK-SAME: %[[A:.+]]: memref<10xi32>, %[[B:.+]]: memref<10xi32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div_unsigned>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+
+// -----
+              
+func.func @binary_ops_float(%A: memref<10xf32>, %B: memref<10xf32>,
+                            %Out: memref<10xf32>) {
+  linalg.elementwise kind=#linalg.elementwise_kind<add>
+    ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<sub>
+    ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<mul>
+    ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<div>
+    ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+    ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+    ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<powf>
+    ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  return
+}
+
+// CHECK-LABEL: binary_ops_float
+// CHECK-SAME: %[[A:.+]]: memref<10xf32>, %[[B:.+]]: memref<10xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<powf>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+
+// -----
+              
+func.func @binary_ops_complex(%A: memref<10xcomplex<f32>>, %B: memref<10xcomplex<f32>>,
+                              %Out: memref<10xcomplex<f32>>) {
+  linalg.elementwise kind=#linalg.elementwise_kind<add>
+    ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+    outs(%Out : memref<10xcomplex<f32>>)
+  linalg.elementwise kind=#linalg.elementwise_kind<sub>
+    ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+    outs(%Out : memref<10xcomplex<f32>>)
+  linalg.elementwise kind=#linalg.elementwise_kind<mul>
+    ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+    outs(%Out : memref<10xcomplex<f32>>)
+  linalg.elementwise kind=#linalg.elementwise_kind<div>
+    ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+    outs(%Out : memref<10xcomplex<f32>>)
+  return
+}
+
+// CHECK-LABEL: binary_ops_complex
+// CHECK-SAME: %[[A:.+]]: memref<10xcomplex<f32>>, %[[B:.+]]: memref<10xcomplex<f32>>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xcomplex<f32>>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+
+// -----
+              
+func.func @binary_ops_bool(%A: memref<10xi1>, %B: memref<10xi1>,
+                           %Out: memref<10xi1>) {
+  linalg.elementwise kind=#linalg.elementwise_kind<add>
+    ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+  linalg.elementwise kind=#linalg.elementwise_kind<mul>
+    ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+  return
+}
+
+// CHECK-LABEL: binary_ops_bool
+// CHECK-SAME: %[[A:.+]]: memref<10xi1>, %[[B:.+]]: memref<10xi1>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi1>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
+
+// -----
+              
+func.func @binary_ops_uint(%A: memref<10xi32>, %B: memref<10xi32>,
+                           %Out: memref<10xi32>) {
+  linalg.elementwise kind=#linalg.elementwise_kind<max_unsigned>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.elementwise kind=#linalg.elementwise_kind<min_unsigned>
+    ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  return
+}
+
+// CHECK-LABEL: binary_ops_uint
+// CHECK-SAME: %[[A:.+]]: memref<10xi32>, %[[B:.+]]: memref<10xi32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<max_unsigned>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<min_unsigned>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+
+// -----
+
+func.func @binary_ops_non_identity(%A: tensor<?xf32>, %B: tensor<?x?xf32>,
+                                   %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.elementwise
+    kind=#linalg.elementwise_kind<add>
+    indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>,
+                     affine_map<(d0, d1) -> (d0, d1)>]
+    ins(%A, %B : tensor<?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[MAP_ID:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: binary_ops_non_identity
+// CHECK-SAME: %[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUT:.+]]: tensor<?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]], #[[MAP_ID]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
 // -----
 
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
diff --git a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir
index 69a1a7f650810..604077f6f7834 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir
@@ -66,22 +66,136 @@ func.func @unary_ops(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
 // CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
 
 // -----
+              
+func.func @binary_ops_int(%A: memref<10xi32>, %B: memref<10xi32>,
+                          %Out: memref<10xi32>) {
+  linalg.add ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.sub ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.mul ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.div ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.div_unsigned ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.max ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  linalg.min ins(%A, %B : memref<10xi32>, memref<10xi32>) outs(%Out : memref<10xi32>)
+  return
+}
 
-func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
-    %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.add
-    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
-    outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
-  return %0 : tensor<?x?xf32>
+// CHECK-LABEL: binary_ops_int
+// CHECK-SAME: %[[A:.+]]: memref<10xi32>, %[[B:.+]]: memref<10xi32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi32>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.sub
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.div
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.div_unsigned
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.max
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+// CHECK: linalg.min
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi32>, memref<10xi32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi32>)
+
+// -----
+              
+func.func @binary_ops_float(%A: memref<10xf32>, %B: memref<10xf32>,
+                            %Out: memref<10xf32>) {
+  linalg.add ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.sub ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.mul ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.div ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.max ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.min ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  linalg.powf ins(%A, %B : memref<10xf32>, memref<10xf32>) outs(%Out : memref<10xf32>)
+  return
 }
 
-// CHECK-LABEL: binary_add
-// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-LABEL: binary_ops_float
+// CHECK-SAME: %[[A:.+]]: memref<10xf32>, %[[B:.+]]: memref<10xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xf32>)
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.add
-// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.sub
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.div
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.max
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.min
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+// CHECK: linalg.powf
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xf32>, memref<10xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xf32>)
+
+// -----
+              
+func.func @binary_ops_complex(%A: memref<10xcomplex<f32>>, %B: memref<10xcomplex<f32>>,
+                              %Out: memref<10xcomplex<f32>>) {
+  linalg.add ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+             outs(%Out : memref<10xcomplex<f32>>)
+  linalg.sub ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+             outs(%Out : memref<10xcomplex<f32>>)
+  linalg.mul ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+             outs(%Out : memref<10xcomplex<f32>>)
+  linalg.div ins(%A, %B : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+             outs(%Out : memref<10xcomplex<f32>>)
+  return
+}
+
+// CHECK-LABEL: binary_ops_complex
+// CHECK-SAME: %[[A:.+]]: memref<10xcomplex<f32>>, %[[B:.+]]: memref<10xcomplex<f32>>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xcomplex<f32>>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.sub
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+// CHECK: linalg.div
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xcomplex<f32>>, memref<10xcomplex<f32>>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xcomplex<f32>>)
+
+// -----
+              
+func.func @binary_ops_bool(%A: memref<10xi1>, %B: memref<10xi1>,
+                           %Out: memref<10xi1>) {
+  linalg.add ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+  linalg.mul ins(%A, %B : memref<10xi1>, memref<10xi1>) outs(%Out : memref<10xi1>)
+  return
+}
+
+// CHECK-LABEL: binary_ops_bool
+// CHECK-SAME: %[[A:.+]]: memref<10xi1>, %[[B:.+]]: memref<10xi1>,
+// CHECK-SAME: %[[OUT:.+]]: memref<10xi1>)
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
+// CHECK: linalg.mul
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<10xi1>, memref<10xi1>)
+// CHECK-SAME: outs(%[[OUT]] : memref<10xi1>)
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 1cca2b86ddc25..3d6c2962731c9 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -221,31 +221,558 @@ func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> t
 // -----
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
-                         %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @binary_ops_int(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+                          %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.addi %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  %1 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.subi %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  %2 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%1, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.muli %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  %3 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%2, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.divsi %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  %4 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%3, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.divui %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  %5 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%4, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.maxsi %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  %6 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%5, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.minsi %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  return %6 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_int
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xi32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED-NOT: linalg.generic
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.sub
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES2:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES3:.+]] = linalg.div
+// NAMED-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES4:.+]] = linalg.div_unsigned
+// NAMED-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES5:.+]] = linalg.max
+// NAMED-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES6:.+]] = linalg.min
+// NAMED-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CATEGORY-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES4:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div_unsigned>
+// CATEGORY-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES5:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CATEGORY-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES6:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CATEGORY-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_float(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                            %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic
     {indexing_maps = [#map, #map, #map],
     iterator_types = ["parallel", "parallel"]}
     ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%Out : tensor<?x?xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %1 = arith.divf %in, %in_0 : f32
-    linalg.yield %1 : f32
+    %v = arith.addf %in, %in_0 : f32
+    linalg.yield %v : f32
   } -> tensor<?x?xf32>
-  return %0 : tensor<?x?xf32>
+  %1 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%0, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = arith.subf %in, %in_0 : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  %2 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%1, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = arith.mulf %in, %in_0 : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  %3 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%2, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = arith.divf %in, %in_0 : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  %4 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%3, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = arith.maximumf %in, %in_0 : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  %5 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%4, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = arith.minimumf %in, %in_0 : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  %6 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%5, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = math.powf %in, %in_0 : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  return %6 : tensor<?x?xf32>
 }
 
-// ALL-LABEL: binary_op_div
-// ALL-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>,
-// ALL-SAME: %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// ALL-LABEL: binary_ops_float
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xf32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
 
 // NAMED-NOT: linalg.generic
-// NAMED: linalg.div
-// NAMED-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// NAMED-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.sub
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES2:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES3:.+]] = linalg.div
+// NAMED-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES4:.+]] = linalg.max
+// NAMED-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES5:.+]] = linalg.min
+// NAMED-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES6:.+]] = linalg.powf
+// NAMED-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
 
-// Not supported yet.
-// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CATEGORY-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES4:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
+// CATEGORY-SAME: ins(%[[RES3]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES5:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<min_signed>
+// CATEGORY-SAME: ins(%[[RES4]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES6:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<powf>
+// CATEGORY-SAME: ins(%[[RES5]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_complex(%A: tensor<?x?xcomplex<f32>>,
+                              %B: tensor<?x?xcomplex<f32>>,
+                              %Out: tensor<?x?xcomplex<f32>>)
+                               -> tensor<?x?xcomplex<f32>> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%A, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+    outs(%Out : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %v = complex.add %in, %in_0 : complex<f32>
+    linalg.yield %v : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  %1 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%0, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+    outs(%Out : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %v = complex.sub %in, %in_0 : complex<f32>
+    linalg.yield %v : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  %2 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%1, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+    outs(%Out : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %v = complex.mul %in, %in_0 : complex<f32>
+    linalg.yield %v : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  %3 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%2, %B : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>)
+    outs(%Out : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %v = complex.div %in, %in_0 : complex<f32>
+    linalg.yield %v : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  return %3 : tensor<?x?xcomplex<f32>>
+}
+
+// ALL-LABEL: binary_ops_complex
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xcomplex<f32>>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED-NOT: linalg.generic
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.sub
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES2:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES3:.+]] = linalg.div
+// NAMED-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES1]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<div>
+// CATEGORY-SAME: ins(%[[RES2]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_bool(%A: tensor<?x?xi1>, %B: tensor<?x?xi1>,
+                           %Out: tensor<?x?xi1>) -> tensor<?x?xi1> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%A, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+    outs(%Out : tensor<?x?xi1>) {
+  ^bb0(%in: i1, %in_0: i1, %out: i1):
+    %v = arith.ori %in, %in_0 : i1
+    linalg.yield %v : i1
+  } -> tensor<?x?xi1>
+  %1 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%0, %B : tensor<?x?xi1>, tensor<?x?xi1>)
+    outs(%Out : tensor<?x?xi1>) {
+  ^bb0(%in: i1, %in_0: i1, %out: i1):
+    %v = arith.andi %in, %in_0 : i1
+    linalg.yield %v : i1
+  } -> tensor<?x?xi1>
+  return %1 : tensor<?x?xi1>
+}
+
+// ALL-LABEL: binary_ops_bool
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xi1>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED-NOT: linalg.generic
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED: %[[RES1:.+]] = linalg.mul
+// NAMED-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @binary_ops_uint(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+                          %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.maxui %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  %1 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%Out : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %v = arith.minui %in, %in_0 : i32
+    linalg.yield %v : i32
+  } -> tensor<?x?xi32>
+  return %1 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_uint
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xi32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// No named ops yet for unsigned max/min -> expect no change.
+// NAMED-NOT: linalg.{{max|min}}
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<max_unsigned>
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<min_unsigned>
+// CATEGORY-SAME: ins(%[[RES0]], %[[B]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+func.func @binary_ops_non_identity(%A: tensor<?xf32>, %B: tensor<?x?xf32>,
+                                   %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>,
+                      affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%A, %B : tensor<?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):  
+    %v = arith.addf %in, %in_0 : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// ALL-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
+// ALL-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// ALL-DAG: #[[MAP_ID:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// ALL: binary_ops_non_identity
+// ALL-SAME: %[[A:.+]]: [[TTY1D:tensor<\?xf32>]], %[[B:.+]]: [[TTY:tensor<\?x\?xf32>]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// Named ops cannot carry user-defined indexing maps -> expect no change.
+// NAMED-NOT: linalg.add
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]], #[[MAP_ID]]]
+// CATEGORY-SAME: ins(%[[A]], %[[B]] : [[TTY1D]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#bcast = affine_map<(d0, d1) -> (d0)>
+func.func @binary_ops_swapped(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                              %C: tensor<?xf32>, %Out: tensor<?x?xf32>)
+                               -> tensor<?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = arith.addf %in_0, %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  %1 = linalg.generic
+    {indexing_maps = [#map, #bcast, #map],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%0, %C : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%Out : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %v = arith.subf %in_0, %in : f32
+    linalg.yield %v : f32
+  } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// ALL-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d0)>
+// ALL-DAG: #[[MAP_ID:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// ALL: binary_ops_swapped
+// ALL-SAME: %[[A:.+]]: [[TTY:tensor<\?x\?xf32>]], %[[B:.+]]: [[TTY]],
+// ALL-SAME: %[[C:.+]]: [[TTY1D:tensor<\?xf32>]],
+// ALL-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// NAMED: %[[RES0:.+]] = linalg.add
+// NAMED-SAME: ins(%[[B]], %[[A]] : [[TTY]], [[TTY]])
+// NAMED-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// NAMED-NOT: linalg.sub
+// NAMED: linalg.generic
+// NAMED-SAME: ins(%[[RES0]], %[[C]] : [[TTY]], [[TTY1D]])
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: ins(%[[B]], %[[A]] : [[TTY]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sub>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_ID]], #[[MAP_ID]]]
+// CATEGORY-SAME: ins(%[[C]], %[[RES0]] : [[TTY1D]], [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @unary_op_with_scalar(%A: tensor<?xi32>, %Out: tensor<?xi32>)
+                                 -> tensor<?xi32> {
+  %cst = arith.constant 123 : i32
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map],
+    iterator_types = ["parallel"]}
+    ins(%A : tensor<?xi32>)
+    outs(%Out : tensor<?xi32>) {
+  ^bb0(%in: i32, %out: i32):
+    %v = arith.addi %cst, %in : i32
+    linalg.yield %v : i32
+  } -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// CATEGORY-DAG: #[[MAP_ID:.+]] = affine_map<(d0) -> (d0)>
+// CATEGORY-DAG: #[[MAP_BC:.+]] = affine_map<(d0) -> ()>
+// ALL: unary_op_with_scalar
+// CATEGORY-SAME: %[[A:.+]]: [[TTY:tensor<\?xi32>]],
+// CATEGORY-SAME: %[[OUT:.+]]: [[TTY]]) -> [[TTY]]
+
+// Named ops cannot broadcast from a scalar operand -> expect no change.
+// NAMED-NOT: linalg.add
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: %[[CST:.+]] = arith.constant 123 : i32
+// CATEGORY: linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_ID]], #[[MAP_ID]]]
+// CATEGORY-SAME: ins(%[[CST]], %[[A]] : i32, [[TTY]])
+// CATEGORY-SAME: outs(%[[OUT]] : [[TTY]]) -> [[TTY]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @negative_unary_op_using_block_arg_twice(%A: tensor<?xi32>,
+                                                   %Out: tensor<?xi32>)
+                                                    -> tensor<?xi32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map],
+    iterator_types = ["parallel"]}
+    ins(%A : tensor<?xi32>)
+    outs(%Out : tensor<?xi32>) {
+  ^bb0(%in: i32, %out: i32):
+    %v = arith.addi %in, %in : i32
+    linalg.yield %v : i32
+  } -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// ALL-LABEL: negative_unary_op_using_block_arg_twice
+
+// Named ops cannot broadcast from a scalar operand -> expect no change.
+// NAMED-NOT: linalg.add
+
+// There is no scalar operand to hoist -> expect no change.
+// CATEGORY-NOT: linalg.elementwise kind=#linalg.elementwise_kind<add>
+
+// ALL: linalg.generic
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
index d45025de931cd..0f444a2d6a71b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir
@@ -1,7 +1,98 @@
 // RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_add_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %1 = arith.addi %in, %in_0 : i32
+    linalg.yield %1 : i32
+  } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_add_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>,  %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_sub_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %1 = arith.subi %in, %in_0 : i32
+    linalg.yield %1 : i32
+  } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_sub_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>,  %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_mul_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %1 = arith.muli %in, %in_0 : i32
+    linalg.yield %1 : i32
+  } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_mul_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>,  %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_div_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %1 = arith.divsi %in, %in_0 : i32
+    linalg.yield %1 : i32
+  } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_div_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>,  %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_div_unsigned_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %1 = arith.divui %in, %in_0 : i32
+    linalg.yield %1 : i32
+  } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_div_unsigned_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>,  %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div_unsigned ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_max_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %1 = arith.maxsi %in, %in_0 : i32
+    linalg.yield %1 : i32
+  } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_max_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>,  %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.max ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_min_int(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>, %arg2: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%arg2 : tensor<?x?xi32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %1 = arith.minsi %in, %in_0 : i32
+    linalg.yield %1 : i32
+  } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: specialize_min_int
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?x?xi32>,  %[[ARG2:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.min ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi32>, tensor<?x?xi32>) outs(%[[ARG2]] : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+func.func @specialize_add_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):
     %1 = arith.addf %in, %in_0 : f32
@@ -9,12 +100,12 @@ func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
   } -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
-// CHECK-LABEL: specialize_add
+// CHECK-LABEL: specialize_add_float
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
-func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_sub_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):
     %1 = arith.subf %in, %in_0 : f32
@@ -22,50 +113,166 @@ func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
   } -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
-// CHECK-LABEL: specialize_sub
+// CHECK-LABEL: specialize_sub_float
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
-func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_mul_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %1 = arith.subf %in_0, %in : f32
+    %1 = arith.mulf %in, %in_0 : f32
     linalg.yield %1 : f32
   } -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
-// CHECK-LABEL: specialize_sub
+// CHECK-LABEL: specialize_mul_float
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
-func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_div_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %1 = arith.mulf %in, %in_0 : f32
+    %1 = arith.divf %in, %in_0 : f32
     linalg.yield %1 : f32
   } -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
-// CHECK-LABEL: specialize_mul
+// CHECK-LABEL: specialize_div_float
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
-func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @specialize_max_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
   ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %1 = arith.divf %in, %in_0 : f32
+    %1 = arith.maximumf %in, %in_0 : f32
     linalg.yield %1 : f32
   } -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
-// CHECK-LABEL: specialize_div
+// CHECK-LABEL: specialize_max_float
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-NOT: linalg.generic
-// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: linalg.max ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+func.func @specialize_min_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.minimumf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_min_float
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.min ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_powf_float(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = math.powf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_powf_float
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.powf ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_add_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %1 = complex.add %in, %in_0 : complex<f32>
+    linalg.yield %1 : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_add_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>,  %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_sub_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %1 = complex.sub %in, %in_0 : complex<f32>
+    linalg.yield %1 : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_sub_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>,  %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_mul_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %1 = complex.mul %in, %in_0 : complex<f32>
+    linalg.yield %1 : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_mul_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>,  %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_div_complex(%arg0: tensor<?x?xcomplex<f32>>, %arg1: tensor<?x?xcomplex<f32>>, %arg2: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%arg2 : tensor<?x?xcomplex<f32>>) {
+  ^bb0(%in: complex<f32>, %in_0: complex<f32>, %out: complex<f32>):
+    %1 = complex.div %in, %in_0 : complex<f32>
+    linalg.yield %1 : complex<f32>
+  } -> tensor<?x?xcomplex<f32>>
+  return %0 : tensor<?x?xcomplex<f32>>
+}
+// CHECK-LABEL: specialize_div_complex
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xcomplex<f32>>, %[[ARG1:.+]]: tensor<?x?xcomplex<f32>>,  %[[ARG2:.+]]: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xcomplex<f32>>, tensor<?x?xcomplex<f32>>) outs(%[[ARG2]] : tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
+
+func.func @specialize_add_bool(%arg0: tensor<?x?xi1>, %arg1: tensor<?x?xi1>, %arg2: tensor<?x?xi1>) -> tensor<?x?xi1> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi1>, tensor<?x?xi1>) outs(%arg2 : tensor<?x?xi1>) {
+  ^bb0(%in: i1, %in_0: i1, %out: i1):
+    %1 = arith.ori %in, %in_0 : i1
+    linalg.yield %1 : i1
+  } -> tensor<?x?xi1>
+  return %0 : tensor<?x?xi1>
+}
+// CHECK-LABEL: specialize_add_bool
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi1>, %[[ARG1:.+]]: tensor<?x?xi1>,  %[[ARG2:.+]]: tensor<?x?xi1>) -> tensor<?x?xi1>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi1>, tensor<?x?xi1>) outs(%[[ARG2]] : tensor<?x?xi1>) -> tensor<?x?xi1>
+
+func.func @specialize_mul_bool(%arg0: tensor<?x?xi1>, %arg1: tensor<?x?xi1>, %arg2: tensor<?x?xi1>) -> tensor<?x?xi1> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xi1>, tensor<?x?xi1>) outs(%arg2 : tensor<?x?xi1>) {
+  ^bb0(%in: i1, %in_0: i1, %out: i1):
+    %1 = arith.andi %in, %in_0 : i1
+    linalg.yield %1 : i1
+  } -> tensor<?x?xi1>
+  return %0 : tensor<?x?xi1>
+}
+// CHECK-LABEL: specialize_mul_bool
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi1>, %[[ARG1:.+]]: tensor<?x?xi1>,  %[[ARG2:.+]]: tensor<?x?xi1>) -> tensor<?x?xi1>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xi1>, tensor<?x?xi1>) outs(%[[ARG2]] : tensor<?x?xi1>) -> tensor<?x?xi1>
+
+func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.subf %in_0, %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub_swapped_operands
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {



More information about the Mlir-commits mailing list