[Mlir-commits] [mlir] 2837991 - [mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.

Alexander Belyaev llvmlistbot at llvm.org
Thu Mar 18 08:11:21 PDT 2021


Author: Alexander Belyaev
Date: 2021-03-18T16:11:03+01:00
New Revision: 283799157e504597fc3034cc5fa02faa4e11fa58

URL: https://github.com/llvm/llvm-project/commit/283799157e504597fc3034cc5fa02faa4e11fa58
DIFF: https://github.com/llvm/llvm-project/commit/283799157e504597fc3034cc5fa02faa4e11fa58.diff

LOG: [mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.

Also use `ArrayAttr` to pass iterator pass to the TiledLoopOp builder.

Differential Revision: https://reviews.llvm.org/D98871

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index a706d67d2988..5a906ff2dafd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -15,6 +15,8 @@
 
 include "mlir/IR/OpBase.td"
 
+def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
+
 def Linalg_Dialect : Dialect {
   let name = "linalg";
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 63bee92ded7c..d54efbe37a57 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -496,21 +496,25 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
   let summary = "Linalg tiled loop operation";
   let description = [{
     This is a loop-like operation with additional properties. The arguments
-    also include the input and the output tensors and the attributes to specify
-    the iterator types. The body region of the loop contains `subtensor`
-    operations applied to every tensor argument of TiledLoopOp.
+    also include the input and the output tensors or memrefs and the attributes
+    to specify the iterator types.
+
+    Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
+    to "parallel" type, when it is absent from the custom format.
+
+    Tensor-based version:
+
+    The body region of the loop contains `subtensor` operations applied to
+    every tensor argument of TiledLoopOp.
 
     The body region must contain exactly one block that terminates with
     `linalg.yield` with the operands resulting from `subtensor_insert`
     operations.
 
-    Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
-    to "parallel" type, when it is absent from the custom format.
-
     Example:
 
     ```mlir
-    linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
+    %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
         ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
         outs(%out : tensor<24x64xi8>)
         iterators("parallel") {
@@ -528,13 +532,40 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
       linalg.yield %result : tensor<24x64xi8>
     }
     ```
+
+    MemRef-based version:
+
+    The body region of the loop contains `subview` operations applied to
+    every memref argument of TiledLoopOp.
+
+    The body region must contain exactly one block that terminates with
+    `linalg.yield` with no operands.
+
+    Example:
+
+    ```mlir
+    linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
+        ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
+        outs(%out : memref<24x64xi8>)
+        iterators("parallel") {
+      %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
+          : memref<24x64xi8> to memref<?x?xi8>
+      %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
+          : memref<24x64xi8> to memref<?x?xi8>
+      %out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1]
+          : memref<24x64xi8> to memref<?x?xi8>
+
+      %result_sub = linalg.generic ...
+      linalg.yield
+    }
+    ```
   }];
 
   let arguments = (ins Variadic<Index>:$lowerBound,
                        Variadic<Index>:$upperBound,
                        Variadic<Index>:$step,
-                       Variadic<AnyRankedTensor>:$inputs,
-                       Variadic<AnyRankedTensor>:$outputs,
+                       Variadic<LinalgOperand>:$inputs,
+                       Variadic<LinalgOperand>:$outputs,
                        ArrayAttr:$iterator_types);
   let results = (outs Variadic<AnyRankedTensor>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -542,7 +573,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
   let builders = [
     OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
       "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
-      "ArrayRef<StringRef>":$iteratorTypes,
+      "ArrayAttr":$iteratorTypes,
       CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
            "nullptr">:$bodyBuilderFn)>,
   ];

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f87a1eaeac8f..69aa7659b81c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -496,8 +496,6 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
 //===----------------------------------------------------------------------===//
 // Generic Linalg ops.
 //===----------------------------------------------------------------------===//
-def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
-
 class LinalgOperandOfRank<int rank>: Type<
   And<[
     LinalgOperand.predicate,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3b268d703a74..13cca7f19ee7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1744,7 +1744,7 @@ static LogicalResult verify(linalg::YieldOp op) {
 void TiledLoopOp::build(
     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
     ValueRange upperBounds, ValueRange steps, ValueRange inputs,
-    ValueRange outputs, ArrayRef<StringRef> iteratorTypes,
+    ValueRange outputs, ArrayAttr iteratorTypes,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
   result.addOperands(lowerBounds);
   result.addOperands(upperBounds);
@@ -1758,9 +1758,14 @@ void TiledLoopOp::build(
                                 static_cast<int32_t>(steps.size()),
                                 static_cast<int32_t>(inputs.size()),
                                 static_cast<int32_t>(outputs.size())}));
-  result.addAttribute(getIteratorTypesAttrName(),
-                      builder.getStrArrayAttr(iteratorTypes));
-  result.addTypes(outputs.getTypes());
+  result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
+
+  // Add output types for `RankedTensorType` output arguments.
+  for (Value output : outputs) {
+    Type outputType = output.getType();
+    if (outputType.isa<RankedTensorType>())
+      result.addTypes(outputType);
+  }
 
   OpBuilder::InsertionGuard guard(builder);
   unsigned numIVs = steps.size();
@@ -1771,8 +1776,8 @@ void TiledLoopOp::build(
   if (bodyBuilderFn) {
     builder.setInsertionPointToStart(bodyBlock);
     bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
+    TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
   }
-  TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
 }
 
 static void print(OpAsmPrinter &p, TiledLoopOp op) {


        


More information about the Mlir-commits mailing list