[Mlir-commits] [mlir] [MLIR][Linalg] Specialize more binary elementwise ops (PR #192290)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 11:06:44 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Julian Oppermann (jopperm)

<details>
<summary>Changes</summary>

Extends the matching logic for `linalg.generic` ops that can be represented as named op or `linalg.elementwise` to cover all variants currently supported by `RegionBuilderHelper::buildBinaryFn`. We previously detected only `add`, `sub`, `mul` and `div` for floating point types.

I combined the detection for unary and binary functions to make it tractable to morph operations such as
```mlir
#map = affine_map<(d0) -> (d0)>
// ...
%c123_i32 = arith.constant 123 : i32
%0 = linalg.generic
  {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
  ins(%A : tensor<?xi32>) outs(%Out : tensor<?xi32>) {
^bb0(%in: i32, %out: i32):
  %v = arith.addi %c123_i32 , %in : i32
  linalg.yield %v : i32
} -> tensor<?xi32>
 ```
to 
```mlir
#map = affine_map<(d0) -> ()>
#map1 = affine_map<(d0) -> (d0)>
// ...
%0 = linalg.elementwise kind=#linalg.elementwise_kind<add>
  indexing_maps = [#map, #map1, #map1]
  ins(%c123_i32, %A: i32, tensor<?xi32>) outs(%Out: tensor<?xi32>) -> tensor<?xi32>
```

---

Patch is 83.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192290.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+5-1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+4-3) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+171-81) 
- (modified) mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir (+139-3) 
- (modified) mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-category-ops.mlir (+199) 
- (modified) mlir/test/Dialect/Linalg/roundtrip-morphism-linalg-named-ops.mlir (+125-11) 
- (modified) mlir/test/Dialect/Linalg/specialize-generic-ops.mlir (+540-13) 
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize-elemwise-binary.mlir (+223-16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index c565781402e3c..3c7ebd8277dbd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -148,7 +148,11 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp,
 
 /// Checks whether `genericOp` is semantically equivalent to a single linalg
 /// elementwise binary op e.g. linalg.sub.
-bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
+/// If `allowNonIdentityMaps` is true, operations with custom indexing maps are
+/// included in the check. Note that these operations can only be represented by
+/// the category op.
+bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp,
+                                        bool allowNonIdentityMaps = false);
 
 /// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
 /// Supports two patterns:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 2ba77cea8f16e..238bddcb3b2bd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -279,9 +279,10 @@ bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
   return true;
 }
 
-bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
-  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(
-          op, 2, /*allowNonIdentityMaps=*/false))
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
+                                                bool allowNonIdentityMaps) {
+  // All basic elemwise checks.
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2, allowNonIdentityMaps))
     return false;
 
   // Check both inputs are used (elementwise).
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a7cd57cf4ed9e..cc8a69cef8e5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -28,13 +28,6 @@ namespace mlir {
 
 #define DEBUG_TYPE "linalg-specialization"
 
-#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
-  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
-      genericOp,                                                               \
-      ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0],            \
-                 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]},           \
-      ValueRange{genericOp.getDpsInits()[0]}))
-
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -42,7 +35,7 @@ using namespace mlir::linalg;
 // Specialize linalg generic to elementwise ops.
 //===----------------------------------------------------------------------===//
 
-// Given a elementwise single binary linalg generic op, checks whether the
+// Given an elementwise single binary linalg generic op, checks whether the
 // binary op accesses operands as swapped. e.g.
 // this differentiates between a linalg-generic body that contains:
 //    ^bb0(%a: f32, %b: f32, %c : f32):
@@ -66,8 +59,39 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
-// Attempt to specialize linalg.generic to named elementwise ops or
-// linalg.elementwise.
+// Given an elementwise single unary linalg generic op whose body operation is a
+// binary operation, check if one of its operands is a scalar value defined
+// outside the generic op, set its index, and return true. Otherwise return
+// false. The index is unique because the block argument is used at
+// least by one operand, as checked in `isaElemwiseSingleUnaryOpInterface`.
+//
+// Example:
+//   %cst = arith.constant 3.14 : f32
+//   %0 = linalg.generic { indexing_maps = [#mapA, #mapRes], ... }
+//          ins(%A : tensor<?xf32>) outs(...) {
+//   ^bb0(%a: f32, %out : f32):
+//     %0 = arith.mulf %a, %cst : f32
+//     linalg.yield %0: f32
+//   } -> tensor<?xf32>
+// Here, the returned index is 1, and the generic op can be represented as
+//   %0 = linalg.elementwise kind=#linalg.elementwise_kind<mul>
+//          indexing_maps = [#mapA, affine_map<(d0) -> ()>, #mapRes]
+//          ins(%A, %cst : tensor<?xf32>, f32) outs(...) -> tensor<?xf32>
+static bool findIndexOfScalarOperand(GenericOp genericOp, int &index) {
+  Block *body = genericOp.getBody();
+  Operation *op = &body->front();
+  for (auto [i, v] : llvm::enumerate(op->getOperands())) {
+    if (auto blockArg = dyn_cast<BlockArgument>(v);
+        blockArg && blockArg.getOwner() == body)
+      continue; // not an outside value...
+    index = i;
+    return true;
+  }
+  return false;
+}
+
+// Attempt to specialize unary or binary linalg.generic ops to named elementwise
+// ops or linalg.elementwise.
 //
 // Example:
 //   %0 = linalg.generic {
@@ -87,9 +111,16 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
 //
 // Only the category op can carry non-identity indexing maps; these are
 // transferred verbatim from the `genericOp`.
-static FailureOr<LinalgOp>
-specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
-                                 bool emitCategoryOp) {
+//
+// In addition to the canonical forms used by the generalization path, this
+// function can handle the following variations:
+//
+// 1) Swapped operands in binary ops (see the `areBinOpsSwapped` helper)
+// 2) Unary generic ops with a binary body op (see the
+//    `findIndexOfScalarOperand` helper)
+static FailureOr<LinalgOp> specializeLinalgElementwise(RewriterBase &rewriter,
+                                                       GenericOp genericOp,
+                                                       bool emitCategoryOp) {
   bool hasNonIdentityMaps =
       !llvm::all_of(genericOp.getIndexingMapsArray(),
                     [](AffineMap map) { return map.isIdentity(); });
@@ -100,62 +131,142 @@ specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
         genericOp,
         "non-identity indexing maps prevent specialization to named op");
 
+  // Classify the generic op.
+  bool isUnary = genericOp.getNumDpsInputs() == 1;
+  bool isBinary = genericOp.getNumDpsInputs() == 2;
+
+  // Will inspect the body operation to determine named op or elementwise kind.
+  Operation *op = &genericOp.getBody()->front();
+
+  // Detect variations from canonical forms.
+  bool hasSwappedOperands = isBinary && areBinOpsSwapped(genericOp);
+  int scalarOprIdx = -1;
+  bool hasScalarOperand = isUnary && op->getNumOperands() == 2 &&
+                          findIndexOfScalarOperand(genericOp, scalarOprIdx);
+
   // Helper to dispatch between named op and `linalg.elementwise`.
   // Lambdas with explicit template parameter list are a C++20 feature, hence
   // the dummy op object.
-  auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
+  auto replaceOp = [&](auto namedOp, ElementwiseKind kind,
+                       bool mayHoistScalarOperand = true) -> LinalgOp {
+    SmallVector<Value> inputs = genericOp.getDpsInputs();
+    if (hasSwappedOperands)
+      std::swap(inputs[0], inputs[1]);
+
     LinalgOp newOp;
-    if (!emitCategoryOp)
-      newOp = decltype(namedOp)::create(
-          rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
-          genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
-    else
+    if (!emitCategoryOp) {
+      using NamedOpTy = decltype(namedOp);
+      if constexpr (!std::is_null_pointer_v<NamedOpTy>)
+        newOp = NamedOpTy::create(rewriter, genericOp.getLoc(), inputs,
+                                  genericOp.getDpsInits(),
+                                  ArrayRef<NamedAttribute>{});
+      else
+        llvm_unreachable("Missing named op type");
+    } else {
+      SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+      // Swap indexing maps, too.
+      if (hasSwappedOperands)
+        std::swap(indexingMaps[0], indexingMaps[1]);
+
+      // Represent unary generic op as a binary `linalg.elementwise` with a
+      // scalar operand and broadcasting map.
+      if (hasScalarOperand && mayHoistScalarOperand) {
+        // Adjust inputs and indexing maps accordingly.
+        inputs.insert(inputs.begin() + scalarOprIdx,
+                      op->getOperand(scalarOprIdx));
+        auto scalarBroadcastMap =
+            AffineMap::get(genericOp.getNumParallelLoops(), /*symbolCount=*/0,
+                           rewriter.getContext());
+        indexingMaps.insert(indexingMaps.begin() + scalarOprIdx,
+                            scalarBroadcastMap);
+      }
       newOp = ElementwiseOp::create(
-          rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
-          genericOp.getDpsInits(),
+          rewriter, genericOp.getLoc(), inputs, genericOp.getDpsInits(),
           ElementwiseKindAttr::get(rewriter.getContext(), kind),
-          genericOp.getIndexingMaps());
+          rewriter.getAffineMapArrayAttr(indexingMaps));
+    }
 
     rewriter.replaceOp(genericOp, newOp);
     return newOp;
   };
 
-  // Inspect body operation to determine named op or elementwise kind.
-  Operation *op = &genericOp.getBody()->front();
+  if (isUnary) {
+    if (isa<math::ExpOp>(op))
+      return replaceOp(ExpOp{}, ElementwiseKind::exp);
+    if (isa<math::LogOp>(op))
+      return replaceOp(LogOp{}, ElementwiseKind::log);
+    if (isa<math::AbsFOp>(op))
+      return replaceOp(AbsOp{}, ElementwiseKind::abs);
+    if (isa<math::CeilOp>(op))
+      return replaceOp(CeilOp{}, ElementwiseKind::ceil);
+    if (isa<math::FloorOp>(op))
+      return replaceOp(FloorOp{}, ElementwiseKind::floor);
+    if (isa<arith::NegFOp>(op))
+      return replaceOp(NegFOp{}, ElementwiseKind::negf);
+    if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
+      if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
+              divOp.getLhs().getDefiningOp()))
+        if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
+          return replaceOp(ReciprocalOp{}, ElementwiseKind::reciprocal,
+                           /*mayHoistScalarOperand=*/false);
+    }
+    if (isa<math::RoundOp>(op))
+      return replaceOp(RoundOp{}, ElementwiseKind::round);
+    if (isa<math::SqrtOp>(op))
+      return replaceOp(SqrtOp{}, ElementwiseKind::sqrt);
+    if (isa<math::RsqrtOp>(op))
+      return replaceOp(RsqrtOp{}, ElementwiseKind::rsqrt);
+    if (auto mulOp = dyn_cast<arith::MulFOp>(op);
+        mulOp && mulOp.getLhs() == mulOp.getRhs())
+      return replaceOp(SquareOp{}, ElementwiseKind::square);
+    if (isa<math::TanhOp>(op))
+      return replaceOp(TanhOp{}, ElementwiseKind::tanh);
+    if (isa<math::ErfOp>(op))
+      return replaceOp(ErfOp{}, ElementwiseKind::erf);
+
+    // At this point, we exhaustively checked the available unary named ops. The
+    // 1-input generic op might be representable as a `linalg.elementwise` that
+    // broadcasts a scalar operand. But if we can't emit the category op or
+    // don't have a scalar operand, exit now.
+    if (!emitCategoryOp || !hasScalarOperand)
+      return rewriter.notifyMatchFailure(
+          genericOp, "unary elementwise operation cannot be specialized to "
+                     "named or category op");
+  }
 
-  if (isa<math::ExpOp>(op))
-    return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
-  if (isa<math::LogOp>(op))
-    return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
-  if (isa<math::AbsFOp>(op))
-    return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
-  if (isa<math::CeilOp>(op))
-    return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
-  if (isa<math::FloorOp>(op))
-    return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
-  if (isa<arith::NegFOp>(op))
-    return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
-  if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
-    if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
-            divOp.getLhs().getDefiningOp()))
-      if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
-        return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
+  // Boolean-typed `linalg.add` and `linalg.mul` require special handling.
+  bool allBool = llvm::all_of(op->getOperands(),
+                              [](Value v) { return v.getType().isInteger(1); });
+
+  if (isa<arith::AddIOp, arith::AddFOp, complex::AddOp>(op) ||
+      (allBool && isa<arith::OrIOp>(op)))
+    return replaceOp(AddOp{}, ElementwiseKind::add);
+  if (isa<arith::SubIOp, arith::SubFOp, complex::SubOp>(op))
+    return replaceOp(SubOp{}, ElementwiseKind::sub);
+  if (isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op) ||
+      (allBool && isa<arith::AndIOp>(op)))
+    return replaceOp(MulOp{}, ElementwiseKind::mul);
+  if (isa<arith::DivSIOp, arith::DivFOp, complex::DivOp>(op))
+    return replaceOp(DivOp{}, ElementwiseKind::div);
+  if (isa<arith::DivUIOp>(op))
+    return replaceOp(DivUnsignedOp{}, ElementwiseKind::div_unsigned);
+  if (isa<arith::MaxSIOp, arith::MaximumFOp>(op))
+    return replaceOp(MaxOp{}, ElementwiseKind::max_signed);
+  if (isa<arith::MinSIOp, arith::MinimumFOp>(op))
+    return replaceOp(MinOp{}, ElementwiseKind::min_signed);
+  if (emitCategoryOp) {
+    // No named ops for unsigned maximum/minimum.
+    if (isa<arith::MaxUIOp>(op))
+      return replaceOp(nullptr, ElementwiseKind::max_unsigned);
+    if (isa<arith::MinUIOp>(op))
+      return replaceOp(nullptr, ElementwiseKind::min_unsigned);
   }
-  if (isa<math::RoundOp>(op))
-    return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
-  if (isa<math::SqrtOp>(op))
-    return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
-  if (isa<math::RsqrtOp>(op))
-    return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
-  if (auto mulOp = dyn_cast<arith::MulFOp>(op);
-      mulOp && mulOp.getLhs() == mulOp.getRhs())
-    return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
-  if (isa<math::TanhOp>(op))
-    return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
-  if (isa<math::ErfOp>(op))
-    return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
+  if (isa<math::PowFOp>(op))
+    return replaceOp(PowFOp{}, ElementwiseKind::powf);
 
-  return failure();
+  return rewriter.notifyMatchFailure(
+      genericOp,
+      "elementwise operation cannot be specialized to named or category op");
 }
 
 //===----------------------------------------------------------------------===//
@@ -580,10 +691,11 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
     RewriterBase &rewriter, GenericOp genericOp,
     const GenericOpSpecializationOptions &options) {
-  // Unary elementwise - e.g. exp
-  if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
-    return specializeLinalgUnaryElementwise(rewriter, genericOp,
-                                            options.emitCategoryOps);
+  // Elementwise - e.g. exp, add
+  if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps) ||
+      isaElemwiseSingleBinaryOpInterface(genericOp, options.emitCategoryOps)) {
+    return specializeLinalgElementwise(rewriter, genericOp,
+                                       options.emitCategoryOps);
   }
 
   // Contraction - e.g. matmul
@@ -636,28 +748,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
     return namedOp;
   }
 
-  // Elementwise Binary
-  if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
-    bool swap = areBinOpsSwapped(genericOp);
-    Operation *op = &genericOp.getBody()->front();
-    if (isa<arith::AddFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
-      return namedOp;
-    }
-    if (isa<arith::SubFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
-      return namedOp;
-    }
-    if (isa<arith::MulFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
-      return namedOp;
-    }
-    if (isa<arith::DivFOp>(op)) {
-      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
-      return namedOp;
-    }
-  }
-
   // Convolution - e.g. *conv/pooling*
   if (isaConvolutionOpInterface(genericOp))
     return specializeLinalgConvolutions(rewriter, genericOp);
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
index 7bad1b7a44d92..0a0ddbcd85a0a 100644
--- a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -1,6 +1,8 @@
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic |  FileCheck %s  --check-prefix=NAMED_TO_GENERIC
-// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic |  mlir-opt -linalg-morph-ops=generic-to-named | \
-// RUN:   FileCheck %s  --check-prefix=ROUND_TRIP
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN:   FileCheck %s  --check-prefix=ALL,NAMED_TO_GENERIC
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic -split-input-file | \
+// RUN:   mlir-opt -linalg-morph-ops=generic-to-named -split-input-file | \
+// RUN:     FileCheck %s  --check-prefix=ALL,ROUND_TRIP
 
 func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
   %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
@@ -19,6 +21,8 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16
   return %erf :  tensor<16x8xf32>
 }
 
+// ALL-LABEL: unary_ops
+
 // NAMED_TO_GENERIC-COUNT-13: linalg.generic
 // NAMED_TO_GENERIC-NOT: linalg.exp
 // NAMED_TO_GENERIC-NOT: linalg.log
@@ -48,3 +52,135 @@ func.func @unary_ops(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16
 // ROUND_TRIP: linalg.tanh
 // ROUND_TRIP: linalg.erf
 // ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_int(%A: tensor<?x?xi32>, %B: tensor<?x?xi32>,
+                          %Out: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg.add ins(%A, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %1 = linalg.sub ins(%0, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %2 = linalg.mul ins(%1, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %3 = linalg.div ins(%2, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %4 = linalg.div_unsigned ins(%3, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %5 = linalg.max ins(%4, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  %6 = linalg.min ins(%5, %B : tensor<?x?xi32>, tensor<?x?xi32>)
+                  outs(%Out : tensor<?x?xi32>) -> tensor<?x?xi32>
+  return %6 : tensor<?x?xi32>
+}
+
+// ALL-LABEL: binary_ops_int
+
+// NAMED_TO_GENERIC-COUNT-7: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.add
+// NAMED_TO_GENERIC-NOT: linalg.sub
+// NAMED_TO_GENERIC-NOT: linalg.mul
+// NAMED_TO_GENERIC-NOT: linalg.div
+// NAMED_TO_GENERIC-NOT: linalg.div_unsigned
+// NAMED_TO_GENERIC-NOT: linalg.max
+// NAMED_TO_GENERIC-NOT: linalg.min
+
+// ROUND_TRIP: linalg.add
+// ROUND_TRIP: linalg.sub
+// ROUND_TRIP: linalg.mul
+// ROUND_TRIP: linalg.div
+// ROUND_TRIP: linalg.div_unsigned
+// ROUND_TRIP: linalg.max
+// ROUND_TRIP: linalg.min
+// ROUND_TRIP-NOT: linalg.generic
+
+// -----
+
+func.func @binary_ops_float(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                            %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.sub ins(%0, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %2 = linalg.mul ins(%1, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %3 = linalg.div ins(%2, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %4 = linalg.max ins(%3, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+              outs(%Out : tensor<?x?xf32>...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/192290


More information about the Mlir-commits mailing list