[Mlir-commits] [mlir] 5291a7a - [mlir] Add block arguments for input/output operands of 'linalg.tiled_loop`.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Apr 23 11:55:44 PDT 2021
Author: Alexander Belyaev
Date: 2021-04-23T20:55:20+02:00
New Revision: 5291a7a3c70c578fe3797b1116a8f74990f3750a
URL: https://github.com/llvm/llvm-project/commit/5291a7a3c70c578fe3797b1116a8f74990f3750a
DIFF: https://github.com/llvm/llvm-project/commit/5291a7a3c70c578fe3797b1116a8f74990f3750a.diff
LOG: [mlir] Add block arguments for input/output operands of 'linalg.tiled_loop`.
Differential Revision: https://reviews.llvm.org/D101186
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index ca4488d2f20a5..51c24cda736f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -579,8 +579,9 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
"ArrayAttr":$iteratorTypes,
- CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
- "nullptr">:$bodyBuilderFn)>,
+ CArg<"function_ref<void (OpBuilder &, Location, /*ivs=*/ValueRange,"
+ "/*inputs=*/ValueRange, /*outputs=*/ValueRange)>",
+ "nullptr">:$bodyBuilderFn)>,
];
let extraClassDeclaration = [{
@@ -588,7 +589,13 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
unsigned getNumControlOperands() { return 3 * getNumLoops(); }
ValueRange getInductionVars() {
- return getBody()->getArguments();
+ return getBody()->getArguments().take_front(getNumLoops());
+ }
+ ValueRange getRegionInputArgs() {
+ return getBody()->getArguments().slice(getNumLoops(), inputs().size());
+ }
+ ValueRange getRegionOutputArgs() {
+ return getBody()->getArguments().take_back(outputs().size());
}
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5250899a9b447..f41a406caa18d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -834,6 +834,22 @@ class OpAsmParser {
parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs) = 0;
+ /// Parse a list of assignments of the form
+ /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...)
+ ParseResult parseAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
+ SmallVectorImpl<OperandType> &rhs,
+ SmallVectorImpl<Type> &types) {
+ OptionalParseResult result =
+ parseOptionalAssignmentListWithTypes(lhs, rhs, types);
+ if (!result.hasValue())
+ return emitError(getCurrentLocation(), "expected '('");
+ return result.getValue();
+ }
+
+ virtual OptionalParseResult
+ parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
+ SmallVectorImpl<OperandType> &rhs,
+ SmallVectorImpl<Type> &types) = 0;
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 45e5e07a2961a..5a6d498a65b49 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1805,11 +1805,13 @@ static LogicalResult verify(linalg::YieldOp op) {
// TiledLoopOp
//===----------------------------------------------------------------------===//
-void TiledLoopOp::build(
- OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
- ValueRange upperBounds, ValueRange steps, ValueRange inputs,
- ValueRange outputs, ArrayAttr iteratorTypes,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
+ ValueRange lowerBounds, ValueRange upperBounds,
+ ValueRange steps, ValueRange inputs, ValueRange outputs,
+ ArrayAttr iteratorTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange,
+ ValueRange, ValueRange)>
+ bodyBuilderFn) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
result.addOperands(steps);
@@ -1834,25 +1836,46 @@ void TiledLoopOp::build(
OpBuilder::InsertionGuard guard(builder);
unsigned numIVs = steps.size();
SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
+ for (Type type : TypeRange(inputs))
+ argTypes.push_back(type);
+ for (Type type : TypeRange(outputs))
+ argTypes.push_back(type);
Region *bodyRegion = result.addRegion();
Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
if (bodyBuilderFn) {
builder.setInsertionPointToStart(bodyBlock);
- bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
+ bodyBuilderFn(builder, result.location,
+ bodyBlock->getArguments().take_front(numIVs),
+ bodyBlock->getArguments().slice(numIVs, inputs.size()),
+ bodyBlock->getArguments().take_back(outputs.size()));
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
}
}
static void print(OpAsmPrinter &p, TiledLoopOp op) {
- p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
+ p << op.getOperationName() << " (" << op.getInductionVars() << ") = ("
<< op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
<< ")";
- if (!op.inputs().empty())
- p << " ins (" << op.inputs() << ": " << TypeRange(op.inputs()) << ")";
- if (!op.outputs().empty())
- p << " outs (" << op.outputs() << ":" << TypeRange(op.outputs()) << ")";
+ if (!op.inputs().empty()) {
+ p << " ins (";
+ llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p,
+ [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it)
+ << ": " << std::get<1>(it).getType();
+ });
+ p << ")";
+ }
+ if (!op.outputs().empty()) {
+ p << " outs (";
+ llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p,
+ [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it)
+ << ": " << std::get<1>(it).getType();
+ });
+ p << ")";
+ }
if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
return attr.cast<StringAttr>().getValue() !=
@@ -1900,13 +1923,13 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
return failure();
// Parse input tensors.
- SmallVector<OpAsmParser::OperandType, 4> inputs;
+ SmallVector<OpAsmParser::OperandType, 4> inputs, input_region_args;
+ SmallVector<Type, 4> inputTypes;
if (succeeded(parser.parseOptionalKeyword("ins"))) {
- SmallVector<Type, 4> inputTypes;
llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
- if (parser.parseLParen() || parser.parseOperandList(inputs) ||
- parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+ if (parser.parseAssignmentListWithTypes(input_region_args, inputs,
+ inputTypes))
return failure();
if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
@@ -1915,13 +1938,13 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
}
// Parse output tensors.
- SmallVector<OpAsmParser::OperandType, 4> outputs;
+ SmallVector<OpAsmParser::OperandType, 4> outputs, output_region_args;
+ SmallVector<Type, 4> outputTypes;
if (succeeded(parser.parseOptionalKeyword("outs"))) {
- SmallVector<Type, 4> outputTypes;
llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
- if (parser.parseLParen() || parser.parseOperandList(outputs) ||
- parser.parseColonTypeList(outputTypes) || parser.parseRParen())
+ if (parser.parseAssignmentListWithTypes(output_region_args, outputs,
+ outputTypes))
return failure();
if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
@@ -1963,8 +1986,16 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
// Parse the body.
Region *body = result.addRegion();
- SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
- if (parser.parseRegion(*body, ivs, types))
+
+ SmallVector<Type, 4> region_types(ivs.size(), builder.getIndexType());
+ region_types.append(inputTypes);
+ region_types.append(outputTypes);
+
+ SmallVector<OpAsmParser::OperandType, 4> region_args(ivs);
+ region_args.append(input_region_args);
+ region_args.append(output_region_args);
+
+ if (parser.parseRegion(*body, region_args, region_types))
return failure();
// Parse optional attributes.
@@ -1991,6 +2022,33 @@ static LogicalResult verify(TiledLoopOp op) {
return op.emitOpError("expected iterator types array attribute size = ")
<< op.iterator_types().size()
<< " to match the number of loops = " << op.getNumLoops();
+
+ // Check if types of input arguments match region args types.
+ for (auto &item :
+ llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) {
+ Value input, inputRegionArg;
+ unsigned index = item.index();
+ std::tie(input, inputRegionArg) = item.value();
+ if (input.getType() != inputRegionArg.getType())
+ return op.emitOpError("expected input arg ")
+ << index << " with type = " << input.getType()
+ << " to match region arg " << index + op.getNumLoops()
+ << " type = " << inputRegionArg.getType();
+ }
+
+ // Check if types of input arguments match region args types.
+ for (auto &item :
+ llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) {
+ Value output, outputRegionArg;
+ unsigned index = item.index();
+ std::tie(output, outputRegionArg) = item.value();
+ if (output.getType() != outputRegionArg.getType())
+ return op.emitOpError("expected output arg ")
+ << index << " with type = " << output.getType()
+ << " to match region arg "
+ << index + op.getNumLoops() + op.inputs().size()
+ << " type = " << outputRegionArg.getType();
+ }
return success();
}
@@ -2002,14 +2060,15 @@ namespace {
//
// Example:
//
-// %0 = linalg.tiled_loop ... outs (%out, %out_buf:tensor<...>, memref<...>) {
+// %0 = linalg.tiled_loop ... outs (%o_ = %out: tensor<...>,
+// %obuf_ = %out_buf: memref<...>) {
// ...
-// linalg.yield %out : tensor ...
+// linalg.yield %o_ : tensor ...
// }
//
// Becomes
//
-// linalg.tiled_loop ... outs (%out_buf:memref<...>) {
+// linalg.tiled_loop ... outs (%obuf_ = %out_buf: memref<...>) {
// ...
// linalg.yield
// }
@@ -2026,16 +2085,27 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
// Match the pattern and collect output buffers that will replace the output
// tensors and also the ops that will be ignored when cloning the body.
- SmallVector<Value, 2> newOutputOperands, newYieldArgs;
+ SmallVector<Value, 2> newOutputOperands, newYieldArgs,
+ regionOutputTensorArgs;
int resultId = 0;
- for (Value out : tiledLoop.outputs()) {
+ // Store ids of the corresponding old and new output operands.
+ SmallVector<std::pair<size_t, size_t>, 2> old_out_id_to_new;
+ for (auto item : llvm::enumerate(
+ llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
+ size_t index = item.index();
+ Value out = std::get<0>(item.value());
+ Value outRegionArg = std::get<1>(item.value());
+
if (!out.getType().isa<RankedTensorType>()) {
+ old_out_id_to_new.push_back({index, newOutputOperands.size()});
newOutputOperands.push_back(out);
+ regionOutputTensorArgs.push_back(outRegionArg);
continue;
}
Value result = tiledLoop.getResult(resultId);
Value yieldArg = yieldOp.getOperand(resultId);
- if (yieldArg != out || !result.use_empty()) {
+ if (yieldArg != outRegionArg || !result.use_empty()) {
+ old_out_id_to_new.push_back({index, newOutputOperands.size()});
newOutputOperands.push_back(out);
newYieldArgs.push_back(yieldArg);
}
@@ -2053,6 +2123,10 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
// unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
BlockAndValueMapping bvm;
bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+ bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
+ for (const auto &item : old_out_id_to_new)
+ bvm.map(tiledLoop.getRegionOutputArgs()[item.first],
+ newTiledLoop.getRegionOutputArgs()[item.second]);
OpBuilder innerBuilder =
OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
for (auto &op : tiledLoop.getBody()->without_terminator())
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index df3b01d682356..7cb0d75000f01 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1694,6 +1694,29 @@ class CustomOpAsmParser : public OpAsmParser {
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
}
+ /// Parse a list of assignments of the form
+ /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
+ OptionalParseResult
+ parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
+ SmallVectorImpl<OperandType> &rhs,
+ SmallVectorImpl<Type> &types) override {
+ if (failed(parseOptionalLParen()))
+ return llvm::None;
+
+ auto parseElt = [&]() -> ParseResult {
+ OperandType regionArg, operand;
+ Type type;
+ if (parseRegionArgument(regionArg) || parseEqual() ||
+ parseOperand(operand) || parseColon() || parseType(type))
+ return failure();
+ lhs.push_back(regionArg);
+ rhs.push_back(operand);
+ types.push_back(type);
+ return success();
+ };
+ return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
+ }
+
private:
/// The source location of the operation name.
SMLoc nameLoc;
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index c086b212c5c68..afdfe6fb98a81 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -861,10 +861,12 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
%c192 = constant 192 : index
%useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
step (%c24, %c16)
- ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
- outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
- call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
- linalg.yield %C_tensor : tensor<192x192xf32>
+ ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
+ outs (%CT_ = %C_tensor: tensor<192x192xf32>,
+ %C_ = %C: memref<192x192xf32>) {
+ call @foo(%A_, %B_, %C_)
+ : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
+ linalg.yield %CT_ : tensor<192x192xf32>
}
return
}
@@ -880,9 +882,9 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
// CHECK-NOT: %{{.*}} = linalg.tiled_loop
// CHECK: linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
-// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>)
-// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
-// CHECK-NEXT: call @foo(%[[A]], %[[B]], %[[C]])
+// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: memref<192x192xf32>, %[[B_:.*]] = %[[B]]: memref<192x192xf32>)
+// CHECK-SAME: outs (%[[C_:.*]] = %[[C]]: memref<192x192xf32>) {
+// CHECK-NEXT: call @foo(%[[A_]], %[[B_]], %[[C_]])
// CHECK-NEXT: linalg.yield
// -----
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index fbcc11e900860..796e511e9db15 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -776,9 +776,10 @@ func @tiled_loop_incorrent_num_yield_operands(%A: memref<192x192xf32>,
%c192 = constant 192 : index
%0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
step (%c24, %c24)
- ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
- outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
- call @foo(%A, %B, %C)
+ ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
+ outs (%CT_ = %C_tensor: tensor<192x192xf32>,
+ %C_ = %C: memref<192x192xf32>) {
+ call @foo(%A_, %B_, %C_)
: (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
// expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}}
linalg.yield
@@ -803,9 +804,10 @@ func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>,
%c192 = constant 192 : index
%0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
step (%c24, %c24)
- ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
- outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
- %1 = call @foo(%A, %B, %C)
+ ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
+ outs (%CT_ = %C_tensor: tensor<192x192xf32>,
+ %C_ = %C: memref<192x192xf32>) {
+ %1 = call @foo(%A_, %B_, %C_)
: (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor<f32>
// expected-error @+1 {{expected yield operand 0 with type = 'tensor<f32>' to match output arg type = 'tensor<192x192xf32>}}
linalg.yield %1 : tensor<f32>
@@ -815,10 +817,6 @@ func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>,
// -----
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
-
func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
%C: memref<192x192xf32>) -> ()
@@ -830,10 +828,12 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
%c192 = constant 192 : index
// expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}}
%0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( {
- ^bb0(%arg4: index, %arg5: index): // no predecessors
- call @foo(%A, %B, %C)
+ ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>,
+ %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>,
+ %C_: memref<192x192xf32>):
+ call @foo(%A_, %B_, %C_)
: (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
- linalg.yield %C_tensor : tensor<192x192xf32>
+ linalg.yield %CT_ : tensor<192x192xf32>
}) {
iterator_types = ["parallel"],
operand_segment_sizes = dense<2> : vector<5xi32>
@@ -842,3 +842,23 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
) -> tensor<192x192xf32>
return
}
+
+// -----
+
+func private @foo(%A: memref<100xf32>) -> ()
+
+func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
+ %c0 = constant 0 : index
+ %c192 = constant 192 : index
+ %c24 = constant 24 : index
+ // expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}}
+ "linalg.tiled_loop"(%c0, %c192, %c24, %A) ( {
+ ^bb0(%arg4: index, %A_: memref<100xf32>):
+ call @foo(%A_) : (memref<100xf32>)-> ()
+ linalg.yield
+ }) {
+ iterator_types = ["parallel"],
+ operand_segment_sizes = dense<[1, 1, 1, 0, 1]> : vector<5xi32>
+ } : (index, index, index, memref<192xf32>) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index d03161541e2f6..441987a11bbdc 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -804,13 +804,13 @@ func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>,
%c24 = constant 24 : index
%c64 = constant 64 : index
%prod = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
- ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
- outs(%out : tensor<24x64xi8>) {
- %lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1]
+ ins(%lhs_ = %lhs: tensor<24x64xi8>, %rhs_ = %rhs: tensor<24x64xi8>)
+ outs(%out_ = %out: tensor<24x64xi8>) {
+ %lhs_sub = subtensor %lhs_[%i, 0] [%c4, %c64] [1, 1]
: tensor<24x64xi8> to tensor<?x?xi8>
- %rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1]
+ %rhs_sub = subtensor %rhs_[%i, 0] [%c4, %c64] [1, 1]
: tensor<24x64xi8> to tensor<?x?xi8>
- %out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1]
+ %out_sub = subtensor %out_[%i, 0] [%c4, %c64] [1, 1]
: tensor<24x64xi8> to tensor<?x?xi8>
%sum = linalg.generic #trait_4
@@ -821,7 +821,7 @@ func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>,
linalg.yield %s : i8
} -> tensor<?x?xi8>
- %sum_sub = subtensor_insert %sum into %out[%i, 0][%c4, %c64][1, 1]
+ %sum_sub = subtensor_insert %sum into %out_[%i, 0][%c4, %c64][1, 1]
: tensor<?x?xi8> into tensor<24x64xi8>
linalg.yield %sum_sub : tensor<24x64xi8>
}
@@ -860,16 +860,18 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
%Z = memref.dim %input_3d, %c2 : tensor<16x24x32xf32>
%result = linalg.tiled_loop (%i, %j, %k)
= (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8)
- ins(%input_3d, %input_2d: tensor<16x24x32xf32>, tensor<16x32xf32>)
- outs( %output: tensor<24xf32>)
+ ins(%i3d_ = %input_3d: tensor<16x24x32xf32>,
+ %i2d_ = %input_2d: tensor<16x32xf32>,
+ %i1d_ = %input_1d: tensor<24xf32>)
+ outs(%o_ = %output: tensor<24xf32>)
iterators["reduction", "parallel", "reduction"] {
- %sub_3d = subtensor %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1]
+ %sub_3d = subtensor %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
: tensor<16x24x32xf32> to tensor<2x4x8xf32>
- %sub_2d = subtensor %input_2d[%i, %k][2, 8][1, 1]
+ %sub_2d = subtensor %i2d_[%i, %k][2, 8][1, 1]
: tensor<16x32xf32> to tensor<2x8xf32>
- %sub_1d = subtensor %input_1d[%j] [4] [1]
+ %sub_1d = subtensor %i1d_[%j] [4] [1]
: tensor<24xf32> to tensor<4xf32>
- %sub_out = subtensor %output[%j] [4] [1]
+ %sub_out = subtensor %o_[%j] [4] [1]
: tensor<24xf32> to tensor<4xf32>
%acc = linalg.generic #trait_5
ins(%sub_3d, %sub_2d, %sub_1d
@@ -881,7 +883,7 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
linalg.yield %1 : f32
} -> tensor<4xf32>
- %sum_sub = subtensor_insert %acc into %output[%j][%c4][1]
+ %sum_sub = subtensor_insert %acc into %o_[%j][%c4][1]
: tensor<4xf32> into tensor<24xf32>
linalg.yield %sum_sub : tensor<24xf32>
}
@@ -919,16 +921,18 @@ func @tiled_loop_on_buffers(%input_3d: memref<16x24x32xf32>,
%Z = memref.dim %input_3d, %c2 : memref<16x24x32xf32>
linalg.tiled_loop (%i, %j, %k) = (%c0, %c0, %c0)
to (%X, %Y, %Z) step (%c2, %c4, %c8)
- ins(%input_3d, %input_2d: memref<16x24x32xf32>, memref<16x32xf32>)
- outs( %output: memref<24xf32>)
+ ins(%i3d_ = %input_3d: memref<16x24x32xf32>,
+ %i2d_ = %input_2d: memref<16x32xf32>,
+ %i1d_ = %input_1d: memref<24xf32>)
+ outs(%o_ = %output: memref<24xf32>)
iterators["reduction", "parallel", "reduction"] {
- %sub_3d = memref.subview %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1]
+ %sub_3d = memref.subview %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
: memref<16x24x32xf32> to memref<2x4x8xf32, #map_1>
- %sub_2d = memref.subview %input_2d[%i, %k][2, 8][1, 1]
+ %sub_2d = memref.subview %i2d_[%i, %k][2, 8][1, 1]
: memref<16x32xf32> to memref<2x8xf32, #map_2>
- %sub_1d = memref.subview %input_1d[%j] [4] [1]
+ %sub_1d = memref.subview %i1d_[%j] [4] [1]
: memref<24xf32> to memref<4xf32, #map_3>
- %sub_out = memref.subview %output[%j] [4] [1]
+ %sub_out = memref.subview %o_[%j] [4] [1]
: memref<24xf32> to memref<4xf32, #map_3>
linalg.generic #trait_6
ins(%sub_3d, %sub_2d, %sub_1d
More information about the Mlir-commits
mailing list