[Mlir-commits] [mlir] 5291a7a - [mlir] Add block arguments for input/output operands of 'linalg.tiled_loop`.

Alexander Belyaev llvmlistbot at llvm.org
Fri Apr 23 11:55:44 PDT 2021


Author: Alexander Belyaev
Date: 2021-04-23T20:55:20+02:00
New Revision: 5291a7a3c70c578fe3797b1116a8f74990f3750a

URL: https://github.com/llvm/llvm-project/commit/5291a7a3c70c578fe3797b1116a8f74990f3750a
DIFF: https://github.com/llvm/llvm-project/commit/5291a7a3c70c578fe3797b1116a8f74990f3750a.diff

LOG: [mlir] Add block arguments for input/output operands of 'linalg.tiled_loop`.

Differential Revision: https://reviews.llvm.org/D101186

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index ca4488d2f20a5..51c24cda736f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -579,8 +579,9 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
     OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
       "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
       "ArrayAttr":$iteratorTypes,
-      CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
-           "nullptr">:$bodyBuilderFn)>,
+      CArg<"function_ref<void (OpBuilder &, Location, /*ivs=*/ValueRange,"
+        "/*inputs=*/ValueRange, /*outputs=*/ValueRange)>",
+        "nullptr">:$bodyBuilderFn)>,
   ];
 
   let extraClassDeclaration = [{
@@ -588,7 +589,13 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
     unsigned getNumControlOperands() { return 3 * getNumLoops(); }
 
     ValueRange getInductionVars() {
-      return getBody()->getArguments();
+      return getBody()->getArguments().take_front(getNumLoops());
+    }
+    ValueRange getRegionInputArgs() {
+      return getBody()->getArguments().slice(getNumLoops(), inputs().size());
+    }
+    ValueRange getRegionOutputArgs() {
+      return getBody()->getArguments().take_back(outputs().size());
     }
 
 

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5250899a9b447..f41a406caa18d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -834,6 +834,22 @@ class OpAsmParser {
   parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
                               SmallVectorImpl<OperandType> &rhs) = 0;
 
+  /// Parse a list of assignments of the form
+  ///   (%x1 = %y1 : type1, %x2 = %y2 : type2, ...)
+  ParseResult parseAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
+                                           SmallVectorImpl<OperandType> &rhs,
+                                           SmallVectorImpl<Type> &types) {
+    OptionalParseResult result =
+        parseOptionalAssignmentListWithTypes(lhs, rhs, types);
+    if (!result.hasValue())
+      return emitError(getCurrentLocation(), "expected '('");
+    return result.getValue();
+  }
+
+  virtual OptionalParseResult
+  parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
+                                       SmallVectorImpl<OperandType> &rhs,
+                                       SmallVectorImpl<Type> &types) = 0;
   /// Parse a keyword followed by a type.
   ParseResult parseKeywordType(const char *keyword, Type &result) {
     return failure(parseKeyword(keyword) || parseType(result));

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 45e5e07a2961a..5a6d498a65b49 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1805,11 +1805,13 @@ static LogicalResult verify(linalg::YieldOp op) {
 // TiledLoopOp
 //===----------------------------------------------------------------------===//
 
-void TiledLoopOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
-    ValueRange upperBounds, ValueRange steps, ValueRange inputs,
-    ValueRange outputs, ArrayAttr iteratorTypes,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
+                        ValueRange lowerBounds, ValueRange upperBounds,
+                        ValueRange steps, ValueRange inputs, ValueRange outputs,
+                        ArrayAttr iteratorTypes,
+                        function_ref<void(OpBuilder &, Location, ValueRange,
+                                          ValueRange, ValueRange)>
+                            bodyBuilderFn) {
   result.addOperands(lowerBounds);
   result.addOperands(upperBounds);
   result.addOperands(steps);
@@ -1834,25 +1836,46 @@ void TiledLoopOp::build(
   OpBuilder::InsertionGuard guard(builder);
   unsigned numIVs = steps.size();
   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
+  for (Type type : TypeRange(inputs))
+    argTypes.push_back(type);
+  for (Type type : TypeRange(outputs))
+    argTypes.push_back(type);
   Region *bodyRegion = result.addRegion();
   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
 
   if (bodyBuilderFn) {
     builder.setInsertionPointToStart(bodyBlock);
-    bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
+    bodyBuilderFn(builder, result.location,
+                  bodyBlock->getArguments().take_front(numIVs),
+                  bodyBlock->getArguments().slice(numIVs, inputs.size()),
+                  bodyBlock->getArguments().take_back(outputs.size()));
     TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
   }
 }
 
 static void print(OpAsmPrinter &p, TiledLoopOp op) {
-  p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
+  p << op.getOperationName() << " (" << op.getInductionVars() << ") = ("
     << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
     << ")";
 
-  if (!op.inputs().empty())
-    p << " ins (" << op.inputs() << ": " << TypeRange(op.inputs()) << ")";
-  if (!op.outputs().empty())
-    p << " outs (" << op.outputs() << ":" << TypeRange(op.outputs()) << ")";
+  if (!op.inputs().empty()) {
+    p << " ins (";
+    llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p,
+                          [&](auto it) {
+                            p << std::get<0>(it) << " = " << std::get<1>(it)
+                              << ": " << std::get<1>(it).getType();
+                          });
+    p << ")";
+  }
+  if (!op.outputs().empty()) {
+    p << " outs (";
+    llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p,
+                          [&](auto it) {
+                            p << std::get<0>(it) << " = " << std::get<1>(it)
+                              << ": " << std::get<1>(it).getType();
+                          });
+    p << ")";
+  }
 
   if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
         return attr.cast<StringAttr>().getValue() !=
@@ -1900,13 +1923,13 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
     return failure();
 
   // Parse input tensors.
-  SmallVector<OpAsmParser::OperandType, 4> inputs;
+  SmallVector<OpAsmParser::OperandType, 4> inputs, input_region_args;
+  SmallVector<Type, 4> inputTypes;
   if (succeeded(parser.parseOptionalKeyword("ins"))) {
-    SmallVector<Type, 4> inputTypes;
     llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
 
-    if (parser.parseLParen() || parser.parseOperandList(inputs) ||
-        parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+    if (parser.parseAssignmentListWithTypes(input_region_args, inputs,
+                                            inputTypes))
       return failure();
 
     if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
@@ -1915,13 +1938,13 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
   }
 
   // Parse output tensors.
-  SmallVector<OpAsmParser::OperandType, 4> outputs;
+  SmallVector<OpAsmParser::OperandType, 4> outputs, output_region_args;
+  SmallVector<Type, 4> outputTypes;
   if (succeeded(parser.parseOptionalKeyword("outs"))) {
-    SmallVector<Type, 4> outputTypes;
     llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
 
-    if (parser.parseLParen() || parser.parseOperandList(outputs) ||
-        parser.parseColonTypeList(outputTypes) || parser.parseRParen())
+    if (parser.parseAssignmentListWithTypes(output_region_args, outputs,
+                                            outputTypes))
       return failure();
 
     if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
@@ -1963,8 +1986,16 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
 
   // Parse the body.
   Region *body = result.addRegion();
-  SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
-  if (parser.parseRegion(*body, ivs, types))
+
+  SmallVector<Type, 4> region_types(ivs.size(), builder.getIndexType());
+  region_types.append(inputTypes);
+  region_types.append(outputTypes);
+
+  SmallVector<OpAsmParser::OperandType, 4> region_args(ivs);
+  region_args.append(input_region_args);
+  region_args.append(output_region_args);
+
+  if (parser.parseRegion(*body, region_args, region_types))
     return failure();
 
   // Parse optional attributes.
@@ -1991,6 +2022,33 @@ static LogicalResult verify(TiledLoopOp op) {
     return op.emitOpError("expected iterator types array attribute size = ")
            << op.iterator_types().size()
            << " to match the number of loops = " << op.getNumLoops();
+
+  // Check if types of input arguments match region args types.
+  for (auto &item :
+       llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) {
+    Value input, inputRegionArg;
+    unsigned index = item.index();
+    std::tie(input, inputRegionArg) = item.value();
+    if (input.getType() != inputRegionArg.getType())
+      return op.emitOpError("expected input arg ")
+             << index << " with type = " << input.getType()
+             << " to match region arg " << index + op.getNumLoops()
+             << " type = " << inputRegionArg.getType();
+  }
+
+  // Check if types of input arguments match region args types.
+  for (auto &item :
+       llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) {
+    Value output, outputRegionArg;
+    unsigned index = item.index();
+    std::tie(output, outputRegionArg) = item.value();
+    if (output.getType() != outputRegionArg.getType())
+      return op.emitOpError("expected output arg ")
+             << index << " with type = " << output.getType()
+             << " to match region arg "
+             << index + op.getNumLoops() + op.inputs().size()
+             << " type = " << outputRegionArg.getType();
+  }
   return success();
 }
 
@@ -2002,14 +2060,15 @@ namespace {
 //
 // Example:
 //
-// %0 = linalg.tiled_loop ...  outs (%out, %out_buf:tensor<...>, memref<...>) {
+// %0 = linalg.tiled_loop ...  outs (%o_ = %out: tensor<...>,
+//                                   %obuf_ = %out_buf: memref<...>) {
 //   ...
-//   linalg.yield %out : tensor ...
+//   linalg.yield %o_ : tensor ...
 // }
 //
 // Becomes
 //
-// linalg.tiled_loop ...  outs (%out_buf:memref<...>) {
+// linalg.tiled_loop ...  outs (%obuf_ = %out_buf: memref<...>) {
 //   ...
 //   linalg.yield
 // }
@@ -2026,16 +2085,27 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
 
     // Match the pattern and collect output buffers that will replace the output
     // tensors and also the ops that will be ignored when cloning the body.
-    SmallVector<Value, 2> newOutputOperands, newYieldArgs;
+    SmallVector<Value, 2> newOutputOperands, newYieldArgs,
+        regionOutputTensorArgs;
     int resultId = 0;
-    for (Value out : tiledLoop.outputs()) {
+    // Store ids of the corresponding old and new output operands.
+    SmallVector<std::pair<size_t, size_t>, 2> old_out_id_to_new;
+    for (auto item : llvm::enumerate(
+             llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
+      size_t index = item.index();
+      Value out = std::get<0>(item.value());
+      Value outRegionArg = std::get<1>(item.value());
+
       if (!out.getType().isa<RankedTensorType>()) {
+        old_out_id_to_new.push_back({index, newOutputOperands.size()});
         newOutputOperands.push_back(out);
+        regionOutputTensorArgs.push_back(outRegionArg);
         continue;
       }
       Value result = tiledLoop.getResult(resultId);
       Value yieldArg = yieldOp.getOperand(resultId);
-      if (yieldArg != out || !result.use_empty()) {
+      if (yieldArg != outRegionArg || !result.use_empty()) {
+        old_out_id_to_new.push_back({index, newOutputOperands.size()});
         newOutputOperands.push_back(out);
         newYieldArgs.push_back(yieldArg);
       }
@@ -2053,6 +2123,10 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
     // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
     BlockAndValueMapping bvm;
     bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+    bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
+    for (const auto &item : old_out_id_to_new)
+      bvm.map(tiledLoop.getRegionOutputArgs()[item.first],
+              newTiledLoop.getRegionOutputArgs()[item.second]);
     OpBuilder innerBuilder =
         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
     for (auto &op : tiledLoop.getBody()->without_terminator())

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index df3b01d682356..7cb0d75000f01 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1694,6 +1694,29 @@ class CustomOpAsmParser : public OpAsmParser {
     return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
   }
 
+  /// Parse a list of assignments of the form
+  ///   (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
+  OptionalParseResult
+  parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
+                                       SmallVectorImpl<OperandType> &rhs,
+                                       SmallVectorImpl<Type> &types) override {
+    if (failed(parseOptionalLParen()))
+      return llvm::None;
+
+    auto parseElt = [&]() -> ParseResult {
+      OperandType regionArg, operand;
+      Type type;
+      if (parseRegionArgument(regionArg) || parseEqual() ||
+          parseOperand(operand) || parseColon() || parseType(type))
+        return failure();
+      lhs.push_back(regionArg);
+      rhs.push_back(operand);
+      types.push_back(type);
+      return success();
+    };
+    return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
+  }
+
 private:
   /// The source location of the operation name.
   SMLoc nameLoc;

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index c086b212c5c68..afdfe6fb98a81 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -861,10 +861,12 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
   %c192 = constant 192 : index
   %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
       step (%c24, %c16)
-      ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
-      outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
-        call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
-    linalg.yield %C_tensor : tensor<192x192xf32>
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
+      outs (%CT_ = %C_tensor: tensor<192x192xf32>,
+            %C_ = %C: memref<192x192xf32>) {
+        call @foo(%A_, %B_, %C_)
+          : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
+    linalg.yield %CT_ : tensor<192x192xf32>
   }
   return
 }
@@ -880,9 +882,9 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
 // CHECK-NOT: %{{.*}} = linalg.tiled_loop
 // CHECK:  linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
 // CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
-// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>)
-// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
-// CHECK-NEXT:   call @foo(%[[A]], %[[B]], %[[C]])
+// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: memref<192x192xf32>, %[[B_:.*]] = %[[B]]: memref<192x192xf32>)
+// CHECK-SAME: outs (%[[C_:.*]] = %[[C]]: memref<192x192xf32>) {
+// CHECK-NEXT:   call @foo(%[[A_]], %[[B_]], %[[C_]])
 // CHECK-NEXT:   linalg.yield
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index fbcc11e900860..796e511e9db15 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -776,9 +776,10 @@ func @tiled_loop_incorrent_num_yield_operands(%A: memref<192x192xf32>,
   %c192 = constant 192 : index
   %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
       step (%c24, %c24)
-      ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
-      outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
-        call @foo(%A, %B, %C)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
+      outs (%CT_ = %C_tensor: tensor<192x192xf32>,
+            %C_ = %C: memref<192x192xf32>) {
+        call @foo(%A_, %B_, %C_)
           : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
     // expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}}
     linalg.yield
@@ -803,9 +804,10 @@ func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>,
   %c192 = constant 192 : index
   %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
       step (%c24, %c24)
-      ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
-      outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
-        %1 = call @foo(%A, %B, %C)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
+      outs (%CT_ = %C_tensor: tensor<192x192xf32>,
+            %C_ = %C: memref<192x192xf32>) {
+        %1 = call @foo(%A_, %B_, %C_)
           : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor<f32>
     // expected-error @+1 {{expected yield operand 0 with type = 'tensor<f32>' to match output arg type = 'tensor<192x192xf32>}}
     linalg.yield %1 : tensor<f32>
@@ -815,10 +817,6 @@ func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>,
 
 // -----
 
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
-
 func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
                   %C: memref<192x192xf32>) -> ()
 
@@ -830,10 +828,12 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
   %c192 = constant 192 : index
   // expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}}
   %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( {
-    ^bb0(%arg4: index, %arg5: index):  // no predecessors
-      call @foo(%A, %B, %C)
+    ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>,
+         %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>,
+         %C_: memref<192x192xf32>):
+      call @foo(%A_, %B_, %C_)
           : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
-      linalg.yield %C_tensor : tensor<192x192xf32>
+      linalg.yield %CT_ : tensor<192x192xf32>
     }) {
       iterator_types = ["parallel"],
       operand_segment_sizes = dense<2> : vector<5xi32>
@@ -842,3 +842,23 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
     ) -> tensor<192x192xf32>
   return
 }
+
+// -----
+
+func private @foo(%A: memref<100xf32>) -> ()
+
+func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
+  %c0 = constant 0 : index
+  %c192 = constant 192 : index
+  %c24 = constant 24 : index
+  // expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}}
+  "linalg.tiled_loop"(%c0, %c192, %c24, %A) ( {
+    ^bb0(%arg4: index, %A_: memref<100xf32>):
+      call @foo(%A_) : (memref<100xf32>)-> ()
+      linalg.yield
+    }) {
+      iterator_types = ["parallel"],
+      operand_segment_sizes = dense<[1, 1, 1, 0, 1]> : vector<5xi32>
+    } : (index, index, index, memref<192xf32>) -> ()
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index d03161541e2f6..441987a11bbdc 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -804,13 +804,13 @@ func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>,
  %c24 = constant 24 : index
  %c64 = constant 64 : index
  %prod = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
-      ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
-      outs(%out : tensor<24x64xi8>) {
-    %lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1]
+      ins(%lhs_ = %lhs: tensor<24x64xi8>, %rhs_ = %rhs: tensor<24x64xi8>)
+      outs(%out_ = %out: tensor<24x64xi8>) {
+    %lhs_sub = subtensor %lhs_[%i, 0] [%c4, %c64] [1, 1]
         : tensor<24x64xi8> to tensor<?x?xi8>
-    %rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1]
+    %rhs_sub = subtensor %rhs_[%i, 0] [%c4, %c64] [1, 1]
         : tensor<24x64xi8> to tensor<?x?xi8>
-    %out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1]
+    %out_sub = subtensor %out_[%i, 0] [%c4, %c64] [1, 1]
         : tensor<24x64xi8> to tensor<?x?xi8>
 
     %sum = linalg.generic #trait_4
@@ -821,7 +821,7 @@ func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>,
         linalg.yield %s : i8
       } -> tensor<?x?xi8>
 
-    %sum_sub = subtensor_insert %sum into %out[%i, 0][%c4, %c64][1, 1]
+    %sum_sub = subtensor_insert %sum into %out_[%i, 0][%c4, %c64][1, 1]
       : tensor<?x?xi8> into tensor<24x64xi8>
     linalg.yield %sum_sub : tensor<24x64xi8>
   }
@@ -860,16 +860,18 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
   %Z = memref.dim %input_3d, %c2 : tensor<16x24x32xf32>
   %result = linalg.tiled_loop (%i, %j, %k)
       = (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8)
-      ins(%input_3d, %input_2d: tensor<16x24x32xf32>, tensor<16x32xf32>)
-      outs( %output: tensor<24xf32>)
+      ins(%i3d_ = %input_3d: tensor<16x24x32xf32>,
+          %i2d_ = %input_2d: tensor<16x32xf32>,
+          %i1d_ = %input_1d: tensor<24xf32>)
+      outs(%o_ =  %output: tensor<24xf32>)
       iterators["reduction", "parallel", "reduction"] {
-    %sub_3d = subtensor %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1]
+    %sub_3d = subtensor %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
       : tensor<16x24x32xf32> to tensor<2x4x8xf32>
-    %sub_2d = subtensor %input_2d[%i, %k][2, 8][1, 1]
+    %sub_2d = subtensor %i2d_[%i, %k][2, 8][1, 1]
       : tensor<16x32xf32> to tensor<2x8xf32>
-    %sub_1d = subtensor %input_1d[%j] [4] [1]
+    %sub_1d = subtensor %i1d_[%j] [4] [1]
       : tensor<24xf32> to tensor<4xf32>
-    %sub_out = subtensor %output[%j] [4] [1]
+    %sub_out = subtensor %o_[%j] [4] [1]
       : tensor<24xf32> to tensor<4xf32>
     %acc = linalg.generic #trait_5
       ins(%sub_3d, %sub_2d, %sub_1d
@@ -881,7 +883,7 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
       linalg.yield %1 : f32
     } -> tensor<4xf32>
 
-    %sum_sub = subtensor_insert %acc into %output[%j][%c4][1]
+    %sum_sub = subtensor_insert %acc into %o_[%j][%c4][1]
       : tensor<4xf32> into tensor<24xf32>
     linalg.yield %sum_sub : tensor<24xf32>
   }
@@ -919,16 +921,18 @@ func @tiled_loop_on_buffers(%input_3d: memref<16x24x32xf32>,
   %Z = memref.dim %input_3d, %c2 : memref<16x24x32xf32>
   linalg.tiled_loop (%i, %j, %k) = (%c0, %c0, %c0)
       to (%X, %Y, %Z) step (%c2, %c4, %c8)
-      ins(%input_3d, %input_2d: memref<16x24x32xf32>, memref<16x32xf32>)
-      outs( %output: memref<24xf32>)
+      ins(%i3d_ = %input_3d: memref<16x24x32xf32>,
+          %i2d_ = %input_2d: memref<16x32xf32>,
+          %i1d_ = %input_1d: memref<24xf32>)
+      outs(%o_ =  %output: memref<24xf32>)
       iterators["reduction", "parallel", "reduction"] {
-    %sub_3d = memref.subview %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1]
+    %sub_3d = memref.subview %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
       : memref<16x24x32xf32> to memref<2x4x8xf32, #map_1>
-    %sub_2d = memref.subview %input_2d[%i, %k][2, 8][1, 1]
+    %sub_2d = memref.subview %i2d_[%i, %k][2, 8][1, 1]
       : memref<16x32xf32> to memref<2x8xf32, #map_2>
-    %sub_1d = memref.subview %input_1d[%j] [4] [1]
+    %sub_1d = memref.subview %i1d_[%j] [4] [1]
       : memref<24xf32> to memref<4xf32, #map_3>
-    %sub_out = memref.subview %output[%j] [4] [1]
+    %sub_out = memref.subview %o_[%j] [4] [1]
       : memref<24xf32> to memref<4xf32, #map_3>
     linalg.generic #trait_6
       ins(%sub_3d, %sub_2d, %sub_1d


        


More information about the Mlir-commits mailing list