[Mlir-commits] [mlir] 58ddeba - Revert "[mlir] Introduce `linalg.tiled_yield` terminator for `linalg.tiled_loop`."
Alexander Belyaev
llvmlistbot at llvm.org
Mon Jul 19 05:20:02 PDT 2021
Author: Alexander Belyaev
Date: 2021-07-19T14:19:49+02:00
New Revision: 58ddeba3e0de504039add9b5a10a4546de25c7a9
URL: https://github.com/llvm/llvm-project/commit/58ddeba3e0de504039add9b5a10a4546de25c7a9
DIFF: https://github.com/llvm/llvm-project/commit/58ddeba3e0de504039add9b5a10a4546de25c7a9.diff
LOG: Revert "[mlir] Introduce `linalg.tiled_yield` terminator for `linalg.tiled_loop`."
This reverts commit 3b03d9b874aa902f7f969e7ffdefde23c2758eeb.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile-tensors.mlir
mlir/test/Dialect/Linalg/tiled-loops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index f4c7bd03bc2c0..8d14880c148bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -492,7 +492,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
RecursiveSideEffects,
- SingleBlockImplicitTerminator<"linalg::TiledYieldOp">
+ SingleBlockImplicitTerminator<"linalg::YieldOp">
]> {
let summary = "Linalg tiled loop operation";
let description = [{
@@ -509,7 +509,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
every tensor argument of TiledLoopOp.
The body region must contain exactly one block that terminates with
- `linalg.tiled_yield`.
+ `linalg.yield` with the operands resulting from `insert_slice` operations.
Example:
@@ -528,7 +528,9 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
%result_sub = linalg.generic ...
- linalg.tiled_yield %result_sub to %out_sub : tensor<?x?xi8>
+ %result = tensor.insert_slice %result_sub into %out[%i, 0][%c4, %c64][1, 1]
+ : tensor<?x?xi8> into tensor<24x64xi8>
+ linalg.yield %result : tensor<24x64xi8>
}
```
@@ -538,7 +540,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
every memref argument of TiledLoopOp.
The body region must contain exactly one block that terminates with
- `linalg.tiled_yield` with no operands.
+ `linalg.yield` with no operands.
Example:
@@ -556,7 +558,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
: memref<24x64xi8> to memref<?x?xi8>
%result_sub = linalg.generic ...
- linalg.tiled_yield
+ linalg.yield
}
```
}];
@@ -745,28 +747,6 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let hasFolder = 1;
}
-def Linalg_TiledYieldOp : Linalg_Op<"tiled_yield",
- [NoSideEffect, ReturnLike, Terminator, SameVariadicOperandSize]>,
- Arguments<(ins Variadic<AnyType>:$tiles, Variadic<AnyType>:$outputs)> {
- let summary = "Linalg tiled yield operation";
- let description = [{
- `linalg.tiled_yield` is a special terminator operation for the block inside
- the region of `linalg.tiled_loop` op. It updates the part of the enclosing
- `linalg.tiled_loop` result specifies by the `outputs` operand with the
- values from the `tiles` operand.
-
- Example:
-
- ```mlir
- linalg.tiled_loop ... outs(%out_ = %out : tensor<?x?xf32>) {
- %output = tensor.extract_slice %out_... // or %output = %out_
- %tile = "some_computation"
- linalg.tiled_yield %tile in %output : tensor<?x?xf32>
- ```
- }];
- let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
-}
-
def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>,
Arguments<(ins Confined<I64Attr, [IntMinValue<0>]>:$dim)>,
Results<(outs Index:$result)> {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cee854593185e..93d330d8d846a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1497,6 +1497,30 @@ static LogicalResult verify(linalg::YieldOp op) {
return success();
}
+ if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
+ // Check if output args with tensor types match results types.
+ SmallVector<Value, 2> tensorOuts;
+ llvm::copy_if(
+ tiledLoopOp.outputs(), std::back_inserter(tensorOuts),
+ [&](Value out) { return out.getType().isa<RankedTensorType>(); });
+ if (tensorOuts.size() != op.values().size())
+ return op.emitOpError("expected number of tensor output args = ")
+ << tensorOuts.size() << " to match the number of yield operands = "
+ << op.values().size();
+
+ TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts));
+ for (auto &item :
+ llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) {
+ Type outType, resultType;
+ unsigned index = item.index();
+ std::tie(outType, resultType) = item.value();
+ if (outType != resultType)
+ return op.emitOpError("expected yield operand ")
+ << index << " with type = " << resultType
+ << " to match output arg type = " << outType;
+ }
+ return success();
+ }
return op.emitOpError("expected parent op with LinalgOp interface");
}
@@ -1868,11 +1892,11 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
return failure();
Block *block = tiledLoop.getBody();
- auto yieldOp = cast<linalg::TiledYieldOp>(block->getTerminator());
+ auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
// Match the pattern and collect output buffers that will replace the output
// tensors and also the ops that will be ignored when cloning the body.
- SmallVector<Value, 2> newOutputOperands, newYieldTileArgs, newYieldOutArgs;
+ SmallVector<Value, 2> newOutputOperands, newYieldArgs;
int resultId = 0;
// Store ids of the corresponding old and new output operands.
SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
@@ -1893,15 +1917,13 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
continue;
}
Value result = tiledLoop.getResult(resultId);
- Value yieldTileArg = yieldOp.tiles()[resultId];
- Value yieldOutArg = yieldOp.outputs()[resultId];
- if (yieldTileArg != outRegionArg || !result.use_empty()) {
+ Value yieldArg = yieldOp.getOperand(resultId);
+ if (yieldArg != outRegionArg || !result.use_empty()) {
oldOutputIdToNew[index] = newOutputOperands.size();
- oldResultIdToNew[resultId] = newYieldTileArgs.size();
+ oldResultIdToNew[resultId] = newYieldArgs.size();
resultReplacement[resultId] = out;
newOutputOperands.push_back(out);
- newYieldTileArgs.push_back(yieldTileArg);
- newYieldOutArgs.push_back(yieldOutArg);
+ newYieldArgs.push_back(yieldArg);
}
++resultId;
}
@@ -1930,12 +1952,9 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
for (auto &op : tiledLoop.getBody()->without_terminator())
innerBuilder.clone(op, bvm);
- innerBuilder.create<linalg::TiledYieldOp>(
- loc,
- llvm::to_vector<2>(llvm::map_range(
- newYieldTileArgs, [&](Value arg) { return bvm.lookup(arg); })),
- llvm::to_vector<2>(llvm::map_range(
- newYieldOutArgs, [&](Value arg) { return bvm.lookup(arg); })));
+ innerBuilder.create<linalg::YieldOp>(
+ loc, llvm::to_vector<2>(llvm::map_range(
+ newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
for (const auto &en : llvm::enumerate(oldResultIdToNew))
if (en.value() != kNoMatch)
@@ -1957,92 +1976,6 @@ LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
return foldMemRefCastInTiledLoopOp(*this);
}
-//===----------------------------------------------------------------------===//
-// TiledYieldOp
-//===----------------------------------------------------------------------===//
-
-static void print(OpAsmPrinter &p, TiledYieldOp op) {
- p << op.getOperationName();
-
- if (!op.tiles().empty()) {
- llvm::interleaveComma(llvm::zip(op.tiles(), op.outputs()), p, [&](auto it) {
- p << ' ' << std::get<0>(it) << " in " << std::get<1>(it) << " : "
- << std::get<1>(it).getType();
- });
- }
- p.printOptionalAttrDict(op->getAttrs());
-}
-
-static ParseResult parseTiledYieldOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 4> tiles, outputs;
- SmallVector<Type, 4> types;
-
- OpAsmParser::OperandType tile;
- while (parser.parseOptionalOperand(tile).hasValue()) {
- Type type;
- OpAsmParser::OperandType output;
- if (parser.parseKeyword("in") || parser.parseOperand(output) ||
- parser.parseColon() || parser.parseType(type))
- return failure();
- tiles.push_back(tile);
- outputs.push_back(output);
- types.push_back(type);
- parser.parseOptionalComma();
- }
- llvm::SMLoc loc = parser.getCurrentLocation();
- if (parser.resolveOperands(tiles, types, loc, result.operands) ||
- parser.resolveOperands(outputs, types, loc, result.operands))
- return failure();
-
- // Parse optional attributes.
- parser.parseOptionalAttrDict(result.attributes);
-
- return success();
-}
-
-static LogicalResult verify(TiledYieldOp op) {
- // Check if output args with tensor types match results types.
- auto loop = op->getParentOfType<TiledLoopOp>();
- SmallVector<Value, 2> loopTensorOuts;
- llvm::copy_if(
- loop.outputs(), std::back_inserter(loopTensorOuts),
- [&](Value out) { return out.getType().isa<RankedTensorType>(); });
- if (loopTensorOuts.size() != op.tiles().size())
- return op.emitOpError("expected number of tensor output args = ")
- << loopTensorOuts.size()
- << " to match the number of yield operands = " << op.tiles().size();
-
- // Check if the `tiles` args types match the `outputs` args types.
- SmallVector<Value, 2> loopTensorOutsBlockArgs;
- llvm::copy_if(
- loop.getRegionOutputArgs(), std::back_inserter(loopTensorOutsBlockArgs),
- [&](Value out) { return out.getType().isa<RankedTensorType>(); });
- for (auto en : llvm::enumerate(
- llvm::zip(op.tiles(), op.outputs(), loopTensorOutsBlockArgs))) {
- size_t index = en.index();
- Type tileType = std::get<0>(en.value()).getType();
- Value yieldOut = std::get<1>(en.value());
- Type yieldOutType = yieldOut.getType();
-
- if (tileType != yieldOutType)
- return op.emitOpError("expected tile operand with type = ")
- << tileType << " to match output type = " << yieldOutType;
-
- // Check if yieldOut is either an output bbArg or a slice of it.
- Value src = yieldOut;
- if (auto extractSlice = llvm::dyn_cast_or_null<tensor::ExtractSliceOp>(
- yieldOut.getDefiningOp()))
- src = extractSlice.source();
-
- Value loopBlockArg = std::get<2>(en.value());
- if (src != loopBlockArg)
- return op.emitOpError("expected output ")
- << index << " to be a subset of the corresponding block argument";
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// IndexOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 93b40253e7128..1ff39fc2e3e63 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -372,7 +372,6 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
ReturnOp,
TiledLoopOp,
VectorTransferOpInterface,
- linalg::TiledYieldOp,
linalg::YieldOp,
scf::YieldOp>(op)
// clang-format on
@@ -520,7 +519,7 @@ static Optional<OpResult> getAliasingOpResult(OpOperand &opOperand) {
return None;
return TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
// These terminators legitimately have no result.
- .Case<ReturnOp, linalg::TiledYieldOp, linalg::YieldOp, scf::YieldOp>(
+ .Case<ReturnOp, linalg::YieldOp, scf::YieldOp>(
[&](auto op) { return OpResult(); })
// ConstantOp is never inplaceable.
.Case([&](ConstantOp op) { return op->getResult(0); })
@@ -571,11 +570,6 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
if (auto linalgOp = dyn_cast<LinalgOp>(opOperand.getOwner()))
return linalgOp.isInputTensor(&opOperand) ||
linalgOp.isInitTensor(&opOperand);
- // This is questionable. Should we consider TiledYieldOp as an op that
- // bufferizes to "read" for the `tile` args and to "write" for the `output`
- // args?
- if (isa<TiledYieldOp>(opOperand.getOwner()))
- return false;
// All other cases are considered to bufferize to memory reads.
// In particular, terminators are often the last use and need to be considered
// as reads to return the proper value and avoid WAW clobbers.
@@ -589,8 +583,7 @@ static bool
bufferizesToMemoryWrite(OpOperand &opOperand,
InPlaceSpec inPlaceSpec = InPlaceSpec::None) {
// These terminators are not writes.
- if (isa<ReturnOp, linalg::TiledYieldOp, linalg::YieldOp, scf::YieldOp>(
- opOperand.getOwner()))
+ if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner()))
return false;
// ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses
// may.
@@ -2117,6 +2110,9 @@ static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp,
// No tensors -> success.
if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor))
return success();
+ // linalg::YieldOp nested under TiledLoop must just canonicalize.
+ if (yieldOp->getParentOfType<TiledLoopOp>())
+ return success();
llvm_unreachable("unexpected yieldOp");
}
@@ -2135,15 +2131,6 @@ static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp,
extractOp.replaceAllUsesWith(l);
return success();
}
-
-/// Bufferization for linalg::TiledYieldOp just results in later
-/// canonicalization.
-static LogicalResult bufferize(OpBuilder &b, linalg::TiledYieldOp yieldOp,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Bufferization analyses.
//===----------------------------------------------------------------------===//
@@ -2345,7 +2332,6 @@ static LogicalResult bufferizeFuncOpInternals(
TiledLoopOp,
VectorTransferOpInterface,
linalg::YieldOp,
- linalg::TiledYieldOp,
scf::YieldOp>([&](auto op) {
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index f17b381d4faa8..5418bc3e38555 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -164,47 +164,6 @@ static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
sliceOp.static_sizes(), sliceOp.static_strides());
}
-template <typename LoopTy>
-static SmallVector<Value, 4>
-collectLoopYieldArgs(OpBuilder &b, LinalgOp clonedOp,
- ArrayRef<Value> tiledOperands,
- SmallVectorImpl<Value> &tensorResults) {
-
- Location loc = clonedOp.getLoc();
- SmallVector<Value, 4> yieldArgs;
- unsigned resultIdx = 0;
- for (OpOperand *opOperand : clonedOp.getOutputTensorOperands()) {
- // TODO: use an interface/adaptor to avoid leaking position in
- // `tiledOperands`.
- Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
- // Insert a insert_slice for each output tensor.
- if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
- yieldArgs.push_back(insertSliceIntoTensor(
- b, loc, sliceOp, clonedOp->getResult(resultIdx), sliceOp.source()));
- } else {
- yieldArgs.push_back(clonedOp->getResult(resultIdx));
- }
- ++resultIdx;
- }
- tensorResults = yieldArgs;
- return yieldArgs;
-}
-
-template <>
-SmallVector<Value, 4>
-collectLoopYieldArgs<TiledLoopOp>(OpBuilder &b, LinalgOp clonedOp,
- ArrayRef<Value> tiledOperands,
- SmallVectorImpl<Value> &tensorResults) {
- auto outputTensorOperands = clonedOp.getOutputTensorOperands();
- size_t numOutputTensors = outputTensorOperands.size();
-
- SmallVector<Value, 4> yieldArgs(clonedOp->getResults());
- auto tiledOutputOperands = tiledOperands.take_back(numOutputTensors);
- yieldArgs.append(tiledOutputOperands.begin(), tiledOutputOperands.end());
-
- return yieldArgs;
-}
-
template <typename LoopTy>
static Optional<TiledLinalgOp>
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
@@ -265,7 +224,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
}
// 2. Create the tiled loops.
- LinalgOp clonedOp = op;
+ LinalgOp res = op;
SmallVector<Value, 4> ivs, tensorResults;
auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc,
ValueRange localIvs,
@@ -303,18 +262,30 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
resultTensorTypes.push_back(
tiledOperands[opOperand->getOperandNumber()].getType());
- clonedOp = op.clone(b, loc, resultTensorTypes, tiledOperands);
+ res = op.clone(b, loc, resultTensorTypes, tiledOperands);
- auto yieldArgs =
- collectLoopYieldArgs<LoopTy>(b, clonedOp, tiledOperands, tensorResults);
- return {yieldArgs.begin(), yieldArgs.end()};
+ // Insert a insert_slice for each output tensor.
+ unsigned resultIdx = 0;
+ for (OpOperand *opOperand : op.getOutputTensorOperands()) {
+ // TODO: use an interface/adaptor to avoid leaking position in
+ // `tiledOperands`.
+ Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
+ if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
+ tensorResults.push_back(insertSliceIntoTensor(
+ b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source()));
+ } else {
+ tensorResults.push_back(res->getResult(resultIdx));
+ }
+ ++resultIdx;
+ }
+ return scf::ValueVector(tensorResults.begin(), tensorResults.end());
};
GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
tiledLoopBodyBuilder, options.distribution,
options.distributionTypes);
// 3. Transform IndexOp results w.r.t. the tiling.
- transformIndexOps(b, clonedOp, ivs, loopIndexToRangeIndex);
+ transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
// 4. Gather the newly created loops and return them with the new op.
SmallVector<Operation *, 8> loops;
@@ -337,9 +308,8 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
if ((outermostLoop = loop))
break;
- return TiledLinalgOp{clonedOp, loops,
- outermostLoop ? outermostLoop->getResults()
- : tensorResults};
+ return TiledLinalgOp{
+ res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
}
template <typename LoopTy>
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index bccc1b3e7f433..1620a047390be 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -311,12 +311,9 @@ void GenerateLoopNest<TiledLoopOp>::doit(
ValueRange ivs, ValueRange inputs,
ValueRange outputs) {
SmallVector<Value> outputTensors = linalgOp.getOutputTensorOperands();
- scf::ValueVector yieldArgs =
+ scf::ValueVector results =
bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors);
- auto yieldArgsRef = llvm::makeArrayRef(yieldArgs);
- nestedBuilder.create<linalg::TiledYieldOp>(
- nestedLoc, yieldArgsRef.take_front(outputTensors.size()),
- yieldArgsRef.drop_front(outputTensors.size()));
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
};
SmallVector<Value> inputOperands = linalgOp.getInputOperands();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index ff10319bc8504..c453255d39485 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -36,7 +36,7 @@ func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>) {
%16 = memref.subview %out[%arg3] [%14] [1]
: memref<192xf32, #map> to memref<?xf32, #map>
linalg.fill(%cst, %16) : f32, memref<?xf32, #map>
- linalg.tiled_yield
+ linalg.yield
}
return
}
@@ -706,9 +706,8 @@ func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>,
%CT_ = %C_tensor: tensor<48xf32>,
%C_ = %C: memref<48xf32>) {
%result = call @foo(%A_, %B_, %C_)
- : (memref<48xf32>, tensor<48xf32>, memref<48xf32>) -> (tensor<48xf32>)
- linalg.tiled_yield %result in %B_ : tensor<48xf32>,
- %CT_ in %CT_ : tensor<48xf32>
+ : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>)
+ linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32>
}
return %useful : tensor<48xf32>
}
@@ -727,7 +726,7 @@ func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>,
// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]])
// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) {
// CHECK-NEXT: %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]])
-// CHECK-NEXT: linalg.tiled_yield %[[RES]] in %[[B_]]
+// CHECK-NEXT: linalg.yield %[[RES]] :
// CHECK: return %[[RESULT]]
@@ -744,7 +743,7 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>)
outs (%BT_ = %B_tensor: tensor<192xf32>) {
%0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32>
- linalg.tiled_yield %0 in %BT_ : tensor<192xf32>
+ linalg.yield %0 : tensor<192xf32>
}
return %result : tensor<192xf32>
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 068a8d6c24001..2c2a14ead0f32 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -507,25 +507,25 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
// of %r3 is read.
// CHECK: linalg.tiled_loop
// CHECK-NEXT: call
- // CHECK-NEXT: linalg.tiled_yield
+ // CHECK-NEXT: linalg.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
%r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
ins()
outs(%t = %B: tensor<?xf32>) {
call @some_use(%t) : (tensor<?xf32>) -> ()
- linalg.tiled_yield %t in %t : tensor<?xf32>
+ linalg.yield %t : tensor<?xf32>
}
// %r3 bufferizes inplace fine.
// CHECK: linalg.tiled_loop
// CHECK-NEXT: call
- // CHECK-NEXT: linalg.tiled_yield
+ // CHECK-NEXT: linalg.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["true"]}
%r3 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
ins()
outs(%t = %B: tensor<?xf32>) {
call @some_use(%t) : (tensor<?xf32>) -> ()
- linalg.tiled_yield %t in %t : tensor<?xf32>
+ linalg.yield %t : tensor<?xf32>
}
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 6a12a488d8915..bab271108da08 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -550,11 +550,10 @@ func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.in
// CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]]
%1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
- ins (%arg4 = %A: tensor<?xf32>,
- %use = %effecting : memref<?xf32>,
- %arg5 = %B: tensor<?xf32>)
+ ins (%arg4 = %A: tensor<?xf32>, %use = %effecting : memref<?xf32>, %arg5 = %B: tensor<?xf32>)
outs (%arg6 = %c: tensor<f32>)
- iterators["reduction"] {
+ iterators["reduction"]
+ {
// CHECK-NOT: alloc
%2 = tensor.dim %arg4, %c0 : tensor<?xf32>
@@ -574,8 +573,8 @@ func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.in
// CHECK: call @some_use(%{{.*}}) : (memref<?xf32>) -> ()
call @some_use(%use) : (memref<?xf32>) -> ()
- linalg.tiled_yield %8 in %arg6 : tensor<f32>
- // CHECK: linalg.tiled_yield
+ linalg.yield %8 : tensor<f32>
+ // CHECK: linalg.yield
// CHECK-NOT: tensor
}
diff --git a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
index b31923632177a..564db5ab4fbe7 100644
--- a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
+++ b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
@@ -14,7 +14,7 @@ func @distribute_for_gpu(%A: tensor<64x64xf32>,
distribution ["block_x", "block_y"] {
%0 = call @foo(%A_, %B_)
: (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
- linalg.tiled_yield %0 in %B_ : tensor<64x64xf32>
+ linalg.yield %0 : tensor<64x64xf32>
}
return %0 : tensor<64x64xf32>
}
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 651076142480e..de7c0b7c4820f 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file --mlir-disable-threading | FileCheck %s --check-prefix=TLOOP
+// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
module {
func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
@@ -124,7 +124,7 @@ module {
// TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B_]], %[[C1]] : [[TY]]
// TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C_]], %[[C1]] : [[TY]]
-// TLOOP: %[[ABC_SUB:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) =
+// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) =
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
// TLOOP-SAME: step (%[[C64]], %[[C16]])
// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]],
@@ -134,15 +134,18 @@ module {
// TLOOP: %[[AB_SUB_SUB:.*]] = tensor.extract_slice %[[AB_SUB_]][0, %[[IV2]]]
// TLOOP: %[[C__SUB:.*]] = tensor.extract_slice %[[C__]][%[[IV2]], %[[IV1]]]
-// TLOOP: %[[ABC_INIT_SUB_SUB:.*]] = tensor.extract_slice %[[ABC_INIT_SUB_]][0, %[[IV1]]]
+// TLOOP: %[[ABS_INIT_SUB_SUB:.*]] = tensor.extract_slice %[[ABC_INIT_SUB_]][0, %[[IV1]]]
// TLOOP: %[[ABC_SUB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]])
-// TLOOP-SAME: outs(%[[ABC_INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
+// TLOOP-SAME: outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
-// TLOOP: linalg.tiled_yield %[[ABC_SUB_SUB]] in %[[ABC_INIT_SUB_SUB]] : [[TY]]
+// TLOOP: %[[RES0:.*]] = tensor.insert_slice %[[ABC_SUB_SUB]]
+// TLOOP-SAME: into %[[ABC_INIT_SUB_]][0, %[[IV1]]]
+// TLOOP: linalg.yield %[[RES0]] : [[TY]]
// TLOOP: }
-// TLOOP: linalg.tiled_yield %[[ABC_SUB]] in %[[ABC_INIT_SUB]] : [[TY]]
+// TLOOP: %[[RES1:.*]] = tensor.insert_slice %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0]
+// TLOOP: linalg.yield %[[RES1]] : [[TY]]
// TLOOP: }
// TLOOP: return %[[ABC]] : [[TY]]
@@ -235,7 +238,10 @@ module {
// TLOOP: %[[DOUBLE_AB:.*]] = linalg.generic
// TLOOP-SAME: ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]])
-// TLOOP: linalg.tiled_yield %[[DOUBLE_AB]] in %[[INIT_SUB]] : [[TY]]
+// TLOOP: %[[RESULT_SUB:.*]] = tensor.insert_slice
+// TLOOP-SAME: %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]]
+
+// TLOOP: linalg.yield %[[RESULT_SUB]] : [[TY]]
// TLOOP: }
// TLOOP: return %[[RESULT]] : [[TY]]
@@ -298,8 +304,7 @@ module {
// TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
// TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
// TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP: %[[OUT_SUB_2:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[C0_F32_]], %[[OUT_SUB_2]])
+// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[C0_F32_]], %[[OUT_SUB]])
// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
@@ -314,9 +319,11 @@ module {
// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
-// TLOOP: linalg.tiled_yield %[[AB_SUB_SUB]] in %[[INIT_SUB_]] : [[TY]]
+// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]]
// TLOOP: }
-// TLOOP: linalg.tiled_yield %[[AB_SUB]] in %[[OUT_SUB]] : [[TY]]
+// TLOOP: %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]]
+// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]]
// TLOOP: }
// TLOOP: return %[[AB]] : [[TY]]
@@ -368,10 +375,9 @@ module {
// TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
// TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
// TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP: %[[OUT_SUB_2:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
// TLOOP: %[[INIT_SUB:.*]] = linalg.generic
// TLOOP-SAME: ins(%[[C0_F32_]]
-// TLOOP-SAME: outs(%[[OUT_SUB_2]]
+// TLOOP-SAME: outs(%[[OUT_SUB]]
// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
@@ -386,9 +392,11 @@ module {
// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
-// TLOOP: linalg.tiled_yield %[[AB_SUB_SUB]] in %[[INIT_SUB_]] : [[TY]]
+// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]]
// TLOOP: }
-// TLOOP: linalg.tiled_yield %[[AB_SUB]] in %[[OUT_SUB]] : [[TY]]
+// TLOOP: %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]]
+// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]]
// TLOOP: }
// TLOOP: return %[[AB]] : [[TY]]
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 87bf98346bc86..569b9a1b387db 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -582,6 +582,10 @@ func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2
// -----
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
%C: memref<192x192xf32>) -> ()
@@ -599,34 +603,11 @@ func @tiled_loop_incorrent_num_yield_operands(%A: memref<192x192xf32>,
call @foo(%A_, %B_, %C_)
: (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
// expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}}
- linalg.tiled_yield
+ linalg.yield
}
return
}
-// -----
-
-func @tiled_loop_incorrect_destination_for_tile(%A: tensor<4xf32>,
- %B: tensor<4xf32>) {
- %c2 = constant 2 : index
- %c4 = constant 2 : index
- %c0 = constant 0 : index
- %0 = linalg.tiled_loop (%i) = (%c0) to (%c4) step (%c2)
- ins (%A_ = %A: tensor<4xf32>)
- outs (%B_ = %B: tensor<4xf32>) {
- %A_sub = tensor.extract_slice %A_[%i][2][1]
- : tensor<4xf32> to tensor<2xf32>
- %B_sub = tensor.extract_slice %B_[%i][2][1]
- : tensor<4xf32> to tensor<2xf32>
- %c0_f32 = constant 0.0 : f32
- %tile = linalg.fill(%c0_f32, %A_sub) : f32, tensor<2xf32> -> tensor<2xf32>
- // expected-error @+1 {{expected output 0 to be a subset of the corresponding block argument}}
- linalg.tiled_yield %tile in %A_sub : tensor<2xf32>
- }
- return
-}
-
-
// -----
#map0 = affine_map<(d0) -> (24, -d0 + 192)>
@@ -649,8 +630,8 @@ func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>,
%C_ = %C: memref<192x192xf32>) {
%1 = call @foo(%A_, %B_, %C_)
: (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor<f32>
- // expected-error @+1 {{expected tile operand with type = 'tensor<f32>' to match output type = 'tensor<192x192xf32>}}
- "linalg.tiled_yield" (%1, %CT_) : (tensor<f32>, tensor<192x192xf32>) -> ()
+ // expected-error @+1 {{expected yield operand 0 with type = 'tensor<f32>' to match output arg type = 'tensor<192x192xf32>}}
+ linalg.yield %1 : tensor<f32>
}
return
}
@@ -658,7 +639,7 @@ func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>,
// -----
func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
- %C: memref<192x192xf32>) -> (tensor<192x192xf32>)
+ %C: memref<192x192xf32>) -> ()
func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
%B: memref<192x192xf32>, %C: memref<192x192xf32>,
@@ -671,10 +652,9 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>,
%B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>,
%C_: memref<192x192xf32>):
- %tile = call @foo(%A_, %B_, %C_)
- : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)
- -> (tensor<192x192xf32>)
- linalg.tiled_yield %tile in %CT_ : tensor<192x192xf32>
+ call @foo(%A_, %B_, %C_)
+ : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
+ linalg.yield %CT_ : tensor<192x192xf32>
}) {
iterator_types = ["parallel"],
operand_segment_sizes = dense<2> : vector<5xi32>
@@ -696,7 +676,7 @@ func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
"linalg.tiled_loop"(%c0, %c192, %c24, %A) ( {
^bb0(%arg4: index, %A_: memref<100xf32>):
call @foo(%A_) : (memref<100xf32>)-> ()
- linalg.tiled_yield
+ linalg.yield
}) {
iterator_types = ["parallel"],
operand_segment_sizes = dense<[1, 1, 1, 0, 1]> : vector<5xi32>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 24f5aa5652a06..e0d7ab2dfb24f 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -648,7 +648,9 @@ func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>,
linalg.yield %s : i8
} -> tensor<?x?xi8>
- linalg.tiled_yield %sum in %out_sub : tensor<?x?xi8>
+ %sum_sub = tensor.insert_slice %sum into %out_[%i, 0][%c4, %c64][1, 1]
+ : tensor<?x?xi8> into tensor<24x64xi8>
+ linalg.yield %sum_sub : tensor<24x64xi8>
}
return %prod : tensor<24x64xi8>
}
@@ -709,7 +711,9 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
linalg.yield %1 : f32
} -> tensor<4xf32>
- linalg.tiled_yield %acc in %sub_out : tensor<4xf32>
+ %sum_sub = tensor.insert_slice %acc into %o_[%j][%c4][1]
+ : tensor<4xf32> into tensor<24xf32>
+ linalg.yield %sum_sub : tensor<24xf32>
}
return %result : tensor<24xf32>
}
@@ -769,7 +773,7 @@ func @tiled_loop_on_buffers(%input_3d: memref<16x24x32xf32>,
%1 = addf %0, %i1d : f32
linalg.yield %1 : f32
}
- linalg.tiled_yield
+ linalg.yield
}
return
}
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 07eab8f9cf083..2cd04668a475c 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -58,7 +58,8 @@ func @matmul_tensors(
// TLOOP: %[[PROD:.*]] = linalg.matmul ins(%[[SUB_ARG_0]], %[[SUB_ARG_1]]
// TLOOP-SE: outs(%[[SUB_ARG_2]] : [[TY]]) -> [[TY]]
-// TLOOP: linalg.tiled_yield %[[PROD]] in %[[SUB_ARG_2]] : [[TY]]
+// TLOOP: %[[O:.*]] = tensor.insert_slice %[[PROD]] into %[[A2]][%[[I]], %[[J]]]
+// TLOOP: linalg.yield %[[O]] : [[TY]]
// -----
diff --git a/mlir/test/Dialect/Linalg/tiled-loops.mlir b/mlir/test/Dialect/Linalg/tiled-loops.mlir
index fd1809c3db04d..5798883ba2550 100644
--- a/mlir/test/Dialect/Linalg/tiled-loops.mlir
+++ b/mlir/test/Dialect/Linalg/tiled-loops.mlir
@@ -29,7 +29,7 @@ func @tiled_loop(%A: memref<192x192xf32>,
linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
memref<192x?xf32, #map1>)
outs(%4 : memref<?x?xf32, #map1>)
- linalg.tiled_yield
+ linalg.yield
}
return
}
@@ -64,7 +64,7 @@ func @tiled_loop_reduction(%A: memref<192x192xf32>,
outs (%C_ = %C: memref<f32>)
iterators["reduction", "reduction"] {
linalg.fill(%cst, %A_) : f32, memref<192x192xf32>
- linalg.tiled_yield
+ linalg.yield
}
return
}
More information about the Mlir-commits
mailing list