[Mlir-commits] [mlir] [mlir][linalg] Genericize MapOp (PR #162742)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Oct 11 14:23:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (srcarroll)

<details>
<summary>Changes</summary>

This PR modifies the definition of `linalg::MapOp` so that it has the same structure of `linalg::GenericOp` and all other linalg ops.  Mainly, it adds an `out` bbarg for the body of the op. Although the `out` arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it.  In fact it only makes things more complicated because it doesn't align with the `GenericOp` structure.  For example, `linalg-generalize-named-ops` avoided converting `linalg.map` purely because it didn't have the structure to do so.  If `GenericOp` can have unused bbargs, then ALL linalg ops should be allowed that as well.

---
Full diff: https://github.com/llvm/llvm-project/pull/162742.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (-4) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+24-13) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp (+2-4) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+2) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+14-8) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+5-5) 
- (modified) mlir/test/Dialect/Linalg/one-shot-bufferize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+9-9) 
- (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+1-1) 
- (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3eecfe6..ecd036d452b27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
-    SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
-      return getDpsInputOperands();
-    }
-
     bool payloadUsesValueFromOperand(OpOperand * opOperand) {
       if (isDpsInit(opOperand)) return false;
       return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..7ccba6143637e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
                                      OpAsmSetValueNameFn setNameFn) {
   for (Value v : getRegionInputArgs())
     setNameFn(v, "in");
+  for (Value v : getRegionOutputArgs())
+    setNameFn(v, "init");
 }
 
 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
 
   if (bodyBuild)
     buildGenericRegion(builder, result.location, *result.regions.front(),
-                       inputs, /*outputs=*/{}, bodyBuild);
+                       inputs, /*outputs=*/{init}, bodyBuild);
 }
 
 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
                                  const OperationName &payloadOpName,
                                  const NamedAttrList &payloadOpAttrs,
                                  ArrayRef<Value> operands,
-                                 bool initFirst = false) {
+                                 bool initFirst = false, bool mapInit = true) {
   OpBuilder b(parser.getContext());
   Region *body = result.addRegion();
   Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
   // If initFirst flag is enabled, we consider init as the first position of
   // payload operands.
   if (initFirst) {
-    payloadOpOperands.push_back(block.getArguments().back());
+    if (mapInit)
+      payloadOpOperands.push_back(block.getArguments().back());
     for (const auto &arg : block.getArguments().drop_back())
       payloadOpOperands.push_back(arg);
   } else {
     payloadOpOperands = {block.getArguments().begin(),
-                         block.getArguments().end()};
+                         block.getArguments().end() - int(!mapInit)};
   }
 
   Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   if (payloadOpName.has_value()) {
     if (!result.operands.empty())
       addBodyWithPayloadOp(parser, result, payloadOpName.value(),
-                           payloadOpAttrs,
-                           ArrayRef(result.operands).drop_back());
+                           payloadOpAttrs, ArrayRef(result.operands), false,
+                           false);
     else
       result.addRegion();
   } else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+                            bool mapInit = true) {
+  // `intFirst == true` implies that we want to map init arg
+  if (initFirst && !mapInit)
+    return false;
   // Check if the body can be printed in short form. The following 4 conditions
   // must be satisfied:
 
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
   // 2) The payload op must have the same number of operands as the number of
   //    block arguments.
   if (payload.getNumOperands() == 0 ||
-      payload.getNumOperands() != body->getNumArguments())
+      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
     return false;
 
   // 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
     }
   } else {
     for (const auto &[operand, bbArg] :
-         llvm::zip(payload.getOperands(), body->getArguments())) {
+         llvm::zip(payload.getOperands(),
+                   body->getArguments().drop_back(int(!mapInit)))) {
       if (bbArg != operand)
         return false;
     }
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
 
 void MapOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  bool useShortForm = canUseShortForm(mapper);
+  bool useShortForm =
+      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
   if (useShortForm) {
     printShortForm(p, &mapper->getOperations().front());
   }
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
   auto *bodyBlock = getBody();
   auto blockArgs = bodyBlock->getArguments();
 
-  // Checks if the number of `inputs` match the arity of the `mapper` region.
-  if (getInputs().size() != blockArgs.size())
+  // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+  // region.
+  if (getInputs().size() + 1 != blockArgs.size())
     return emitOpError() << "expects number of operands to match the arity of "
                             "mapper, but got: "
-                         << getInputs().size() << " and " << blockArgs.size();
+                         << getInputs().size() + 1 << " and "
+                         << blockArgs.size();
 
   // The parameters of mapper should all match the element type of inputs.
   for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393fd51ed..75bb1757a55f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
-  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
-  // trivially generalize a `linalg.map`, as it does not use the output as
-  // region arguments in the block.
-  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+  // Bailout if `linalgOp` is already a generic.
+  if (isa<GenericOp>(linalgOp))
     return failure();
   // Check if the operation has exactly one region.
   if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e47a3be..c607ece418dff 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
       linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
                             /*init=*/tensorDestination);
   Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+  linalgBody.addArgument(tensorType.getElementType(), loc);
 
   // Create linalg::IndexOps.
   rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
                                           /*inputs=*/ValueRange(),
                                           /*init=*/*tensorAlloc);
     Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+    linalgBody.addArgument(tensorType.getElementType(), loc);
 
     // Create linalg::IndexOps.
     rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 26d2d98572f47..f4020ede4854e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
 func.func @recursive_effect(%arg : tensor<1xf32>) {
   %init = arith.constant dense<0.0> : tensor<1xf32>
   %mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
-            (%in : f32) {
+            (%in : f32, %out: f32) {
               vector.print %in : f32
               linalg.yield %in : f32
             }
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b82228c..dcdd6c8db4b21 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
 
 // -----
 
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
   %cst = arith.constant 0.000000e+00 : f32
-  // CHECK: linalg.map
-  // CHECK-NOT: linalg.generic
-  linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
-    () {
-      linalg.yield %cst : f32
-    }
+  linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
   return
 }
 
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32
+
 // -----
 
 func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d19d6b91..fabc8e610612d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
    %add = linalg.map
           ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             // expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
             linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
 func.func @map_input_mapper_arity_mismatch(
     %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
     -> tensor<64xf32> {
-  // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+  // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f64, %rhs_elem: f64) {
+      (%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f64
         linalg.yield %0: f64
       }
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
       outs(%init:tensor<32xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 9616a3e32a064..28d7fdc041766 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 563013d4083af..74928920c695a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
 func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
    %add = linalg.map
       outs(%init:tensor<64xf32>)
-      () {
+      (%out: f32) {
         %0 = arith.constant 0.0: f32
         linalg.yield %0: f32
       }
@@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
 }
 // CHECK-LABEL: func @map_no_inputs
 //       CHECK:   linalg.map outs
-//  CHECK-NEXT:   () {
+//  CHECK-NEXT:   (%[[OUT:.*]]: f32) {
 //  CHECK-NEXT:     arith.constant
 //  CHECK-NEXT:     linalg.yield
 //  CHECK-NEXT:   }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
    linalg.map
       ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
    %abs = linalg.map
           ins(%input:tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%input_elem: f32) {
+          (%input_elem: f32, %out: f32) {
             %0 = math.absf %input_elem: f32
             linalg.yield %0: f32
           }
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
    linalg.map
       ins(%input:memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%input_elem: f32) {
+      (%input_elem: f32, %out: f32) {
         %0 = math.absf %input_elem: f32
         linalg.yield %0: f32
       }
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
             linalg.yield %0: f32
           }
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
   %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
-    (%in_1: f32, %in_2: f32) {
+    (%in_1: f32, %in_2: f32, %out: f32) {
       %1 = arith.maximumf %in_1, %in_2 : f32
       linalg.yield %in_1 : f32
     }
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
 // CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
 // CHECK:         linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>) 
 // CHECK-SAME:               outs(%[[INIT]] : tensor<1x32xf32>)
-// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
 // CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
 // CHECK-NEXT:        linalg.yield %[[IN1]] : f32
 // CHECK-NEXT:    }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 35f520a9f22a8..704ad10130fc8 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
     %arg1: memref<64xf32>, %arg2: memref<64xf32>) {
   linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
              outs(%arg2 : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 296ca02564e35..5eb2360a29b8f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
 // CHECK-DAG:     %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
 // CHECK:         %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:         %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
-// CHECK:         () {
+// CHECK:         (%[[INIT:.*]]: f32) {
 // CHECK:           linalg.yield %[[F]] : f32
 // CHECK:         }
 // CHECK:         return %[[MAPPED]] : tensor<?x3x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3cbb758b..aa8882d21698c 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @map(%lhs: memref<64xf32>,
-    %rhs: memref<64xf32>, %out: memref<64xf32>) {
+    %rhs: memref<64xf32>, %init: memref<64xf32>) {
   linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
-             outs(%out : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+             outs(%init : memref<64xf32>)
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }

``````````

</details>


https://github.com/llvm/llvm-project/pull/162742


More information about the Mlir-commits mailing list