[Mlir-commits] [mlir] f5e175f - [mlir][linalg] Genericize MapOp (#162742)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 30 08:20:23 PDT 2025
Author: srcarroll
Date: 2025-10-30T10:20:19-05:00
New Revision: f5e175f06dacf6751f27f7a4bd9e9a3489e95a5e
URL: https://github.com/llvm/llvm-project/commit/f5e175f06dacf6751f27f7a4bd9e9a3489e95a5e
DIFF: https://github.com/llvm/llvm-project/commit/f5e175f06dacf6751f27f7a4bd9e9a3489e95a5e.diff
LOG: [mlir][linalg] Genericize MapOp (#162742)
This PR modifies the definition of `linalg::MapOp` so that it has the
same structure of `linalg::GenericOp` and all other linalg ops. Mainly,
it adds an `out` bbarg for the body of the op. Although the `out` arg is
never used in the body, there doesn't seem to be much benefit in
specializing the op to exclude it. In fact it only makes things more
complicated because it doesn't align with the `GenericOp` structure. For
example, `linalg-generalize-named-ops` avoided converting `linalg.map`
purely because it didn't have the structure to do so. Moreover, although
some fusion patterns are applied explicitly to `GenericOp`, we can
change them to be applied to the base `LinalgOp` which will enable
fusion for any fusion-compatible linalg op, but that requires the op
having a generic structure. So these changes will enable us to use
existing generic transformation patterns on `MapOp` that weren't
possible before. They can either be applied to `MapOp` directly or
applied after converting to `GenericOp`.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3eecfe6..ecd036d452b27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
- SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
- return getDpsInputOperands();
- }
-
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
if (isDpsInit(opOperand)) return false;
return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b0c8cbd..3dc45edf4a23f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
}
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, /*outputs=*/{}, bodyBuild);
+ inputs, /*outputs=*/{init}, bodyBuild);
}
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
- bool initFirst = false) {
+ bool initFirst = false, bool mapInit = true) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
- payloadOpOperands.push_back(block.getArguments().back());
+ if (mapInit)
+ payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
- block.getArguments().end()};
+ block.getArguments().end() - int(!mapInit)};
}
Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
- payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ payloadOpAttrs, ArrayRef(result.operands), false,
+ false);
else
result.addRegion();
} else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+ bool mapInit = true) {
+ // `intFirst == true` implies that we want to map init arg
+ if (initFirst && !mapInit)
+ return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
+ llvm::zip(payload.getOperands(),
+ body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- bool useShortForm = canUseShortForm(mapper);
+ bool useShortForm =
+ canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();
- // Checks if the number of `inputs` match the arity of the `mapper` region.
- if (getInputs().size() != blockArgs.size())
+ // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+ // region.
+ if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
- << getInputs().size() << " and " << blockArgs.size();
+ << getInputs().size() + 1 << " and "
+ << blockArgs.size();
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393fd51ed..75bb1757a55f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
- // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
- // trivially generalize a `linalg.map`, as it does not use the output as
- // region arguments in the block.
- if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+ // Bailout if `linalgOp` is already a generic.
+ if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e47a3be..c607ece418dff 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/tensorDestination);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
/*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 26d2d98572f47..f4020ede4854e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
func.func @recursive_effect(%arg : tensor<1xf32>) {
%init = arith.constant dense<0.0> : tensor<1xf32>
%mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
- (%in : f32) {
+ (%in : f32, %out: f32) {
vector.print %in : f32
linalg.yield %in : f32
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b82228c..dcdd6c8db4b21 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
// -----
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
%cst = arith.constant 0.000000e+00 : f32
- // CHECK: linalg.map
- // CHECK-NOT: linalg.generic
- linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
- () {
- linalg.yield %cst : f32
- }
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
return
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
// -----
func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d19d6b91..fabc8e610612d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
// expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
func.func @map_input_mapper_arity_mismatch(
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-> tensor<64xf32> {
- // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+ // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f64, %rhs_elem: f64) {
+ (%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f64
linalg.yield %0: f64
}
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%init:tensor<32xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 1df15e85bac17..85cc1ffc2029e 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 563013d4083af..74928920c695a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
%add = linalg.map
outs(%init:tensor<64xf32>)
- () {
+ (%out: f32) {
%0 = arith.constant 0.0: f32
linalg.yield %0: f32
}
@@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
}
// CHECK-LABEL: func @map_no_inputs
// CHECK: linalg.map outs
-// CHECK-NEXT: () {
+// CHECK-NEXT: (%[[OUT:.*]]: f32) {
// CHECK-NEXT: arith.constant
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
linalg.map
ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
outs(%init:memref<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
%abs = linalg.map
ins(%input:tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%input_elem: f32) {
+ (%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
linalg.map
ins(%input:memref<64xf32>)
outs(%init:memref<64xf32>)
- (%input_elem: f32) {
+ (%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
linalg.yield %0: f32
}
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
- (%in_1: f32, %in_2: f32) {
+ (%in_1: f32, %in_2: f32, %out: f32) {
%1 = arith.maximumf %in_1, %in_2 : f32
linalg.yield %in_1 : f32
}
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
-// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 93a03369be239..aa2c1da4b6274 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -356,7 +356,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
outs(%arg2 : memref<64xf32>)
- (%in: f32, %in_0: f32) {
+ (%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 296ca02564e35..5eb2360a29b8f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
-// CHECK: () {
+// CHECK: (%[[INIT:.*]]: f32) {
// CHECK: linalg.yield %[[F]] : f32
// CHECK: }
// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3cbb758b..aa8882d21698c 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @map(%lhs: memref<64xf32>,
- %rhs: memref<64xf32>, %out: memref<64xf32>) {
+ %rhs: memref<64xf32>, %init: memref<64xf32>) {
linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
- outs(%out : memref<64xf32>)
- (%in: f32, %in_0: f32) {
+ outs(%init : memref<64xf32>)
+ (%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
More information about the Mlir-commits
mailing list