[Mlir-commits] [mlir] 2ea6e13 - [mlir] Add an optional distributionTypes attribute to TiledLoopOp.

Alexander Belyaev llvmlistbot at llvm.org
Tue May 25 11:05:16 PDT 2021


Author: Alexander Belyaev
Date: 2021-05-25T20:04:41+02:00
New Revision: 2ea6e13bf8189efb09cec89184b21f1db3de0d1c

URL: https://github.com/llvm/llvm-project/commit/2ea6e13bf8189efb09cec89184b21f1db3de0d1c
DIFF: https://github.com/llvm/llvm-project/commit/2ea6e13bf8189efb09cec89184b21f1db3de0d1c.diff

LOG: [mlir] Add an optional distributionTypes attribute to TiledLoopOp.

Differential Revision: https://reviews.llvm.org/D103104

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index c78624b70b7b5..6c46c0abfcb46 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -521,7 +521,8 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
     %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
         ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
         outs(%out : tensor<24x64xi8>)
-        iterators("parallel") {
+        iterators("parallel")
+        distribution("block_x") {
       %lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1]
           : tensor<24x64xi8> to tensor<?x?xi8>
       %rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1]
@@ -551,7 +552,8 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
     linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
         ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
         outs(%out : memref<24x64xi8>)
-        iterators("parallel") {
+        iterators("parallel")
+        distribution("block_x") {
       %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
           : memref<24x64xi8> to memref<?x?xi8>
       %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
@@ -570,11 +572,18 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
                        Variadic<Index>:$step,
                        Variadic<LinalgOperand>:$inputs,
                        Variadic<LinalgOperand>:$outputs,
-                       ArrayAttr:$iterator_types);
+                       ArrayAttr:$iterator_types,
+                       OptionalAttr<ArrayAttr>:$distribution_types);
   let results = (outs Variadic<AnyRankedTensor>:$results);
   let regions = (region SizedRegion<1>:$region);
 
   let builders = [
+    OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
+      "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
+      "ArrayAttr":$iteratorTypes, "Optional<ArrayAttr>":$distributionTypes,
+      CArg<"function_ref<void (OpBuilder &, Location, /*ivs=*/ValueRange,"
+        "/*inputs=*/ValueRange, /*outputs=*/ValueRange)>",
+        "nullptr">:$bodyBuilderFn)>,
     OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
       "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
       "ArrayAttr":$iteratorTypes,

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index d6ccea1df7a4e..14f8b29f66892 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -50,6 +50,12 @@ constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
 /// op's iterators.
 constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
 
+/// Attribute name for the StrArrayAttr which encodes the distribution type for
+/// `linalg.tiled_loop`.
+constexpr StringRef getDistributionTypesAttrName() {
+  return "distribution_types";
+}
+
 /// Attribute name for the StringAttr which encodes an optional documentation
 /// string of the structured op.
 constexpr StringRef getDocAttrName() { return "doc"; }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f3de81801e672..034bf188d175e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2075,6 +2075,18 @@ void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
                         function_ref<void(OpBuilder &, Location, ValueRange,
                                           ValueRange, ValueRange)>
                             bodyBuilderFn) {
+  build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
+        iteratorTypes, llvm::None, bodyBuilderFn);
+}
+
+void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
+                        ValueRange lowerBounds, ValueRange upperBounds,
+                        ValueRange steps, ValueRange inputs, ValueRange outputs,
+                        ArrayAttr iteratorTypes,
+                        Optional<ArrayAttr> distributionTypes,
+                        function_ref<void(OpBuilder &, Location, ValueRange,
+                                          ValueRange, ValueRange)>
+                            bodyBuilderFn) {
   result.addOperands(lowerBounds);
   result.addOperands(upperBounds);
   result.addOperands(steps);
@@ -2089,6 +2101,10 @@ void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
                                 static_cast<int32_t>(outputs.size())}));
   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
 
+  if (distributionTypes.hasValue())
+    result.addAttribute(getDistributionTypesAttrName(),
+                        distributionTypes.getValue());
+
   // Add output types for `RankedTensorType` output arguments.
   for (Value output : outputs) {
     Type outputType = output.getType();
@@ -2143,14 +2159,17 @@ static void print(OpAsmPrinter &p, TiledLoopOp op) {
   if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
         return attr.cast<StringAttr>().getValue() !=
                getParallelIteratorTypeName();
-      })) {
+      }))
     p << " iterators" << op.iterator_types() << "";
-  }
+
+  if (op.distribution_types().hasValue())
+    p << " distribution" << op.distribution_types().getValue() << "";
 
   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
   p.printOptionalAttrDict(
       op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
-                                       getIteratorTypesAttrName()});
+                                       getIteratorTypesAttrName(),
+                                       getDistributionTypesAttrName()});
 }
 
 static ParseResult parseTiledLoopOp(OpAsmParser &parser,
@@ -2219,26 +2238,38 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
   }
 
   // Parse attributes.
-  SmallVector<Attribute, 4> iterTypes;
-  if (succeeded(parser.parseOptionalKeyword("iterators"))) {
-    StringAttr iterType;
+  SmallVector<Attribute, 4> iterTypes, distributionTypes;
+  auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
+    if (succeeded(parser.parseOptionalKeyword(keyword))) {
+      StringAttr attr;
 
-    if (parser.parseLSquare() || parser.parseAttribute(iterType))
-      return failure();
-    iterTypes.push_back(iterType);
-    for (int i = 1, e = ivs.size(); i < e; ++i) {
-      if (parser.parseComma() || parser.parseAttribute(iterType))
+      if (parser.parseLSquare() || parser.parseAttribute(attr))
+        return failure();
+      attrs->push_back(attr);
+      for (int i = 1, e = ivs.size(); i < e; ++i) {
+        if (parser.parseComma() || parser.parseAttribute(attr))
+          return failure();
+        attrs->push_back(attr);
+      }
+      if (parser.parseRSquare())
         return failure();
-      iterTypes.push_back(iterType);
     }
-    if (parser.parseRSquare())
-      return failure();
-  } else {
+    return success();
+  };
+  if (failed(parseAttr("iterators", &iterTypes)) ||
+      failed(parseAttr("distribution", &distributionTypes)))
+    return failure();
+
+  // Set all loop iterator types to "parallel" if they are not printed in IR.
+  if (iterTypes.empty()) {
     auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
     iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
   }
   result.addAttribute(getIteratorTypesAttrName(),
                       builder.getArrayAttr(iterTypes));
+  if (!distributionTypes.empty())
+    result.addAttribute(getDistributionTypesAttrName(),
+                        builder.getArrayAttr(distributionTypes));
   result.addAttribute(
       TiledLoopOp::getOperandSegmentSizeAttr(),
       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
@@ -2352,7 +2383,8 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
     Location loc = tiledLoop.getLoc();
     auto newTiledLoop = rewriter.create<TiledLoopOp>(
         loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
-        newInputs, tiledLoop.outputs(), tiledLoop.iterator_types());
+        newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(),
+        tiledLoop.distribution_types());
 
     // Clone the region.
     BlockAndValueMapping bvm;
@@ -2441,7 +2473,8 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
     Location loc = tiledLoop.getLoc();
     auto newTiledLoop = rewriter.create<TiledLoopOp>(
         loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
-        tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
+        tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(),
+        tiledLoop.distribution_types());
 
     // Clone the region.
     BlockAndValueMapping bvm;

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1435f7b6bdb1b..211e97855c662 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -827,7 +827,8 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
           %i2d_ = %input_2d: tensor<16x32xf32>,
           %i1d_ = %input_1d: tensor<24xf32>)
       outs(%o_ =  %output: tensor<24xf32>)
-      iterators["reduction", "parallel", "reduction"] {
+      iterators["reduction", "parallel", "reduction"]
+      distribution["block_x", "block_y", "none"] {
     %sub_3d = subtensor %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1]
       : tensor<16x24x32xf32> to tensor<2x4x8xf32>
     %sub_2d = subtensor %i2d_[%i, %k][2, 8][1, 1]


        


More information about the Mlir-commits mailing list