[Mlir-commits] [mlir] [mlir][linalg] Genericize MapOp (PR #162742)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 30 07:19:16 PDT 2025


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/162742

>From 7d4273a64897f54f67a3b55f086fe6613dba659e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 9 Oct 2025 16:41:03 -0500
Subject: [PATCH 1/4] [mlir][linalg] Genericize MapOp

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  4 ----
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 21 +++++++++---------
 .../Linalg/Transforms/Generalization.cpp      |  6 ++---
 .../Dialect/Linalg/generalize-named-ops.mlir  | 22 ++++++++++++-------
 .../lower-to-loops-using-interface.mlir       |  6 ++---
 5 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3eecfe6..ecd036d452b27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
-    SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
-      return getDpsInputOperands();
-    }
-
     bool payloadUsesValueFromOperand(OpOperand * opOperand) {
       if (isDpsInit(opOperand)) return false;
       return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..6134a7535808d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
                                      OpAsmSetValueNameFn setNameFn) {
   for (Value v : getRegionInputArgs())
     setNameFn(v, "in");
+  for (Value v : getRegionOutputArgs())
+    setNameFn(v, "init");
 }
 
 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
 
   if (bodyBuild)
     buildGenericRegion(builder, result.location, *result.regions.front(),
-                       inputs, /*outputs=*/{}, bodyBuild);
+                       inputs, /*outputs=*/{init}, bodyBuild);
 }
 
 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
                                  const OperationName &payloadOpName,
                                  const NamedAttrList &payloadOpAttrs,
                                  ArrayRef<Value> operands,
-                                 bool initFirst = false) {
+                                 bool initFirst = false, bool mapInit = true) {
   OpBuilder b(parser.getContext());
   Region *body = result.addRegion();
   Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
   // If initFirst flag is enabled, we consider init as the first position of
   // payload operands.
   if (initFirst) {
-    payloadOpOperands.push_back(block.getArguments().back());
+    if (mapInit)
+      payloadOpOperands.push_back(block.getArguments().back());
     for (const auto &arg : block.getArguments().drop_back())
       payloadOpOperands.push_back(arg);
   } else {
     payloadOpOperands = {block.getArguments().begin(),
-                         block.getArguments().end()};
+                         block.getArguments().end() - int(!mapInit)};
   }
 
   Operation *payloadOp = b.create(
@@ -1551,12 +1554,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   if (payloadOpName.has_value()) {
-    if (!result.operands.empty())
-      addBodyWithPayloadOp(parser, result, payloadOpName.value(),
-                           payloadOpAttrs,
-                           ArrayRef(result.operands).drop_back());
-    else
-      result.addRegion();
+    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
+                         ArrayRef(result.operands), false, false);
   } else {
     SmallVector<OpAsmParser::Argument> regionArgs;
     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1659,7 +1658,7 @@ LogicalResult MapOp::verify() {
   auto blockArgs = bodyBlock->getArguments();
 
   // Checks if the number of `inputs` match the arity of the `mapper` region.
-  if (getInputs().size() != blockArgs.size())
+  if (getInputs().size() + 1 != blockArgs.size())
     return emitOpError() << "expects number of operands to match the arity of "
                             "mapper, but got: "
                          << getInputs().size() << " and " << blockArgs.size();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393fd51ed..75bb1757a55f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
-  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
-  // trivially generalize a `linalg.map`, as it does not use the output as
-  // region arguments in the block.
-  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+  // Bailout if `linalgOp` is already a generic.
+  if (isa<GenericOp>(linalgOp))
     return failure();
   // Check if the operation has exactly one region.
   if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b82228c..dcdd6c8db4b21 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
 
 // -----
 
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
   %cst = arith.constant 0.000000e+00 : f32
-  // CHECK: linalg.map
-  // CHECK-NOT: linalg.generic
-  linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
-    () {
-      linalg.yield %cst : f32
-    }
+  linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
   return
 }
 
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32
+
 // -----
 
 func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3cbb758b..aa8882d21698c 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @map(%lhs: memref<64xf32>,
-    %rhs: memref<64xf32>, %out: memref<64xf32>) {
+    %rhs: memref<64xf32>, %init: memref<64xf32>) {
   linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
-             outs(%out : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+             outs(%init : memref<64xf32>)
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }

>From 7a1d5f3cfdb7714e425da84e97a0855982b34831 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 10 Oct 2025 09:27:56 -0500
Subject: [PATCH 2/4] revert unnecessary change

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6134a7535808d..4a3b7ad49eb59 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1554,8 +1554,12 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   if (payloadOpName.has_value()) {
-    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
-                         ArrayRef(result.operands), false, false);
+    if (!result.operands.empty())
+      addBodyWithPayloadOp(parser, result, payloadOpName.value(),
+                           payloadOpAttrs, ArrayRef(result.operands), false,
+                           false);
+    else
+      result.addRegion();
   } else {
     SmallVector<OpAsmParser::Argument> regionArgs;
     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1657,11 +1661,13 @@ LogicalResult MapOp::verify() {
   auto *bodyBlock = getBody();
   auto blockArgs = bodyBlock->getArguments();
 
-  // Checks if the number of `inputs` match the arity of the `mapper` region.
+  // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+  // region.
   if (getInputs().size() + 1 != blockArgs.size())
     return emitOpError() << "expects number of operands to match the arity of "
                             "mapper, but got: "
-                         << getInputs().size() << " and " << blockArgs.size();
+                         << getInputs().size() + 1 << " and "
+                         << blockArgs.size();
 
   // The parameters of mapper should all match the element type of inputs.
   for (const auto &[bbArgType, inputArg] :

>From 5a1622dc3d790792bf2b0ba2b3828b74c9e81bf5 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 10 Oct 2025 09:47:38 -0500
Subject: [PATCH 3/4] fix bufferizations that generate MapOp

---
 .../Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp   | 2 ++
 mlir/test/Dialect/Tensor/bufferize.mlir                         | 2 +-
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e47a3be..c607ece418dff 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
       linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
                             /*init=*/tensorDestination);
   Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+  linalgBody.addArgument(tensorType.getElementType(), loc);
 
   // Create linalg::IndexOps.
   rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
                                           /*inputs=*/ValueRange(),
                                           /*init=*/*tensorAlloc);
     Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+    linalgBody.addArgument(tensorType.getElementType(), loc);
 
     // Create linalg::IndexOps.
     rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 296ca02564e35..5eb2360a29b8f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
 // CHECK-DAG:     %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
 // CHECK:         %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:         %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
-// CHECK:         () {
+// CHECK:         (%[[INIT:.*]]: f32) {
 // CHECK:           linalg.yield %[[F]] : f32
 // CHECK:         }
 // CHECK:         return %[[MAPPED]] : tensor<?x3x?xf32>

>From b9335070844ca327c1e2a57618d7df2f6f394b88 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 11 Oct 2025 16:18:49 -0500
Subject: [PATCH 4/4] fix short form print and update remaining tests

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp       | 14 ++++++++++----
 mlir/test/Dialect/Linalg/canonicalize.mlir     |  2 +-
 mlir/test/Dialect/Linalg/invalid.mlir          | 10 +++++-----
 .../Dialect/Linalg/one-shot-bufferize.mlir     |  2 +-
 mlir/test/Dialect/Linalg/roundtrip.mlir        | 18 +++++++++---------
 .../linalg-ops-with-patterns.mlir              |  2 +-
 6 files changed, 27 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4a3b7ad49eb59..7ccba6143637e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1573,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+                            bool mapInit = true) {
+  // `intFirst == true` implies that we want to map init arg
+  if (initFirst && !mapInit)
+    return false;
   // Check if the body can be printed in short form. The following 4 conditions
   // must be satisfied:
 
@@ -1585,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
   // 2) The payload op must have the same number of operands as the number of
   //    block arguments.
   if (payload.getNumOperands() == 0 ||
-      payload.getNumOperands() != body->getNumArguments())
+      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
     return false;
 
   // 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1603,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
     }
   } else {
     for (const auto &[operand, bbArg] :
-         llvm::zip(payload.getOperands(), body->getArguments())) {
+         llvm::zip(payload.getOperands(),
+                   body->getArguments().drop_back(int(!mapInit)))) {
       if (bbArg != operand)
         return false;
     }
@@ -1635,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
 
 void MapOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  bool useShortForm = canUseShortForm(mapper);
+  bool useShortForm =
+      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
   if (useShortForm) {
     printShortForm(p, &mapper->getOperations().front());
   }
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 26d2d98572f47..f4020ede4854e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
 func.func @recursive_effect(%arg : tensor<1xf32>) {
   %init = arith.constant dense<0.0> : tensor<1xf32>
   %mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
-            (%in : f32) {
+            (%in : f32, %out: f32) {
               vector.print %in : f32
               linalg.yield %in : f32
             }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d19d6b91..fabc8e610612d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
    %add = linalg.map
           ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             // expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
             linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
 func.func @map_input_mapper_arity_mismatch(
     %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
     -> tensor<64xf32> {
-  // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+  // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f64, %rhs_elem: f64) {
+      (%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f64
         linalg.yield %0: f64
       }
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
       outs(%init:tensor<32xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 9616a3e32a064..28d7fdc041766 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 563013d4083af..74928920c695a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
 func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
    %add = linalg.map
       outs(%init:tensor<64xf32>)
-      () {
+      (%out: f32) {
         %0 = arith.constant 0.0: f32
         linalg.yield %0: f32
       }
@@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
 }
 // CHECK-LABEL: func @map_no_inputs
 //       CHECK:   linalg.map outs
-//  CHECK-NEXT:   () {
+//  CHECK-NEXT:   (%[[OUT:.*]]: f32) {
 //  CHECK-NEXT:     arith.constant
 //  CHECK-NEXT:     linalg.yield
 //  CHECK-NEXT:   }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
    linalg.map
       ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
    %abs = linalg.map
           ins(%input:tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%input_elem: f32) {
+          (%input_elem: f32, %out: f32) {
             %0 = math.absf %input_elem: f32
             linalg.yield %0: f32
           }
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
    linalg.map
       ins(%input:memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%input_elem: f32) {
+      (%input_elem: f32, %out: f32) {
         %0 = math.absf %input_elem: f32
         linalg.yield %0: f32
       }
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
             linalg.yield %0: f32
           }
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
   %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
-    (%in_1: f32, %in_2: f32) {
+    (%in_1: f32, %in_2: f32, %out: f32) {
       %1 = arith.maximumf %in_1, %in_2 : f32
       linalg.yield %in_1 : f32
     }
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
 // CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
 // CHECK:         linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>) 
 // CHECK-SAME:               outs(%[[INIT]] : tensor<1x32xf32>)
-// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
 // CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
 // CHECK-NEXT:        linalg.yield %[[IN1]] : f32
 // CHECK-NEXT:    }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 35f520a9f22a8..704ad10130fc8 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
     %arg1: memref<64xf32>, %arg2: memref<64xf32>) {
   linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
              outs(%arg2 : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }



More information about the Mlir-commits mailing list