[Mlir-commits] [mlir] 7377ef9 - [mlir] Add a builder to `linalg.tiled_loop`.

Alexander Belyaev llvmlistbot at llvm.org
Wed Feb 24 05:47:47 PST 2021


Author: Alexander Belyaev
Date: 2021-02-24T14:47:27+01:00
New Revision: 7377ef9357191a2c540ba0c20375a9f92233e0f6

URL: https://github.com/llvm/llvm-project/commit/7377ef9357191a2c540ba0c20375a9f92233e0f6
DIFF: https://github.com/llvm/llvm-project/commit/7377ef9357191a2c540ba0c20375a9f92233e0f6.diff

LOG: [mlir] Add a builder to `linalg.tiled_loop`.

https://llvm.discourse.group/t/rfc-add-linalg-tileop/2833

Differential Revision: https://reviews.llvm.org/D97372

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 067cf95e1e39..a295cdc591da 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -484,6 +484,7 @@ def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
     linalg.yield %f0, %f1 : f32, f32
     ```
   }];
+  let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>];
 }
 
 def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
@@ -537,6 +538,21 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
                        ArrayAttr:$iterator_types);
   let results = (outs Variadic<AnyRankedTensor>:$results);
   let regions = (region SizedRegion<1>:$region);
+
+  let builders = [
+    OpBuilderDAG<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
+      "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
+      "ArrayRef<StringRef>":$iteratorTypes,
+      CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
+           "nullptr">:$bodyBuilderFn)>,
+  ];
+
+  let extraClassDeclaration = [{
+    ValueRange getInductionVars() {
+      return getBody()->getArguments();
+    }
+    unsigned getNumLoops() { return step().size(); }
+  }];
 }
 
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f6e2994b9718..68b5e0654a5e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1701,6 +1701,40 @@ static LogicalResult verify(linalg::YieldOp op) {
 // TiledLoopOp
 //===----------------------------------------------------------------------===//
 
+void TiledLoopOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
+    ValueRange upperBounds, ValueRange steps, ValueRange inputs,
+    ValueRange outputs, ArrayRef<StringRef> iteratorTypes,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+  result.addOperands(lowerBounds);
+  result.addOperands(upperBounds);
+  result.addOperands(steps);
+  result.addOperands(inputs);
+  result.addOperands(outputs);
+  result.addAttribute(
+      TiledLoopOp::getOperandSegmentSizeAttr(),
+      builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
+                                static_cast<int32_t>(upperBounds.size()),
+                                static_cast<int32_t>(steps.size()),
+                                static_cast<int32_t>(inputs.size()),
+                                static_cast<int32_t>(outputs.size())}));
+  result.addAttribute(getIteratorTypesAttrName(),
+                      builder.getStrArrayAttr(iteratorTypes));
+  result.addTypes(outputs.getTypes());
+
+  OpBuilder::InsertionGuard guard(builder);
+  unsigned numIVs = steps.size();
+  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
+  Region *bodyRegion = result.addRegion();
+  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
+
+  if (bodyBuilderFn) {
+    builder.setInsertionPointToStart(bodyBlock);
+    bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
+  }
+  TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
+}
+
 static void print(OpAsmPrinter &p, TiledLoopOp op) {
   p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
     << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()


        


More information about the Mlir-commits mailing list