[Mlir-commits] [mlir] efcf098 - [mlir] [EDSC] Add interface for yield-for loops.

Alex Zinenko llvmlistbot at llvm.org
Wed Apr 15 09:39:41 PDT 2020


Author: Pierre Oechsel
Date: 2020-04-15T18:39:30+02:00
New Revision: efcf0985eef69127af0e5576f5977b0bb3f1a4a8

URL: https://github.com/llvm/llvm-project/commit/efcf0985eef69127af0e5576f5977b0bb3f1a4a8
DIFF: https://github.com/llvm/llvm-project/commit/efcf0985eef69127af0e5576f5977b0bb3f1a4a8.diff

LOG: [mlir] [EDSC] Add interface for yield-for loops.

Summary:
ModelBuilder was missing an api to easily generate yield-for-loops.
This diffs implements an interface allowing to write:
```
%2:2 = loop.for %i = %start to %end step %step iter_args(%arg0 = %init0, %arg1 = %init1) -> (f32, f32) {
  %sum = addf %arg0, %arg1 : f32
  loop.yield %arg1, %sum : f32, f32
}
%3 = addf %2#0, %2#1 : f32
```

as

```
auto results =
    LoopNestBuilder(&i, start, end, step, {&arg0, &arg1},  {init0, init1})([&] {
      auto sum = arg0 + arg1;
      loop_yield(ArrayRef<ValueHandle>{arg1, sum});
    });

// Add the two values accumulated by the yield-for-loop:
ValueHandle(results[0]) + ValueHandle(results[1]);
```

Differential Revision: https://reviews.llvm.org/D78093

Added: 
    mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h

Modified: 
    mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
    mlir/include/mlir/EDSC/Builders.h
    mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
    mlir/test/EDSC/builder-api-test.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
index 41b0e1a972bf..f4b30aa54879 100644
--- a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
@@ -33,7 +33,9 @@ LoopBuilder makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
 /// variable. A ValueHandle pointer is passed as the first argument and is the
 /// *only* way to capture the loop induction variable.
 LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
-                            ValueHandle ubHandle, ValueHandle stepHandle);
+                            ValueHandle ubHandle, ValueHandle stepHandle,
+                            ArrayRef<ValueHandle *> iter_args_handles = {},
+                            ValueRange iter_args_init_values = {});
 
 /// Helper class to sugar building loop.parallel loop nests from lower/upper
 /// bounds and step sizes.
@@ -54,9 +56,13 @@ class ParallelLoopNestBuilder {
 /// loop.for.
 class LoopNestBuilder {
 public:
-  LoopNestBuilder(ArrayRef<edsc::ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
+  LoopNestBuilder(ValueHandle *iv, ValueHandle lb, ValueHandle ub,
+                  ValueHandle step,
+                  ArrayRef<ValueHandle *> iter_args_handles = {},
+                  ValueRange iter_args_init_values = {});
+  LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
                   ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps);
-  void operator()(std::function<void(void)> fun = nullptr);
+  Operation::result_range operator()(std::function<void(void)> fun = nullptr);
 
 private:
   SmallVector<LoopBuilder, 4> loops;

diff  --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h
new file mode 100644
index 000000000000..21803e2bf13b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h
@@ -0,0 +1,24 @@
+//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_
+#define MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_
+
+#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace mlir {
+namespace edsc {
+namespace intrinsics {
+using loop_yield = OperationBuilder<loop::YieldOp>;
+
+} // namespace intrinsics
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_

diff  --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index c907f2d1ea5e..9c5a8e72c66e 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -152,6 +152,8 @@ class LoopBuilder : public NestedBuilder {
   /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
   /// scoped within a LoopBuilder.
   void operator()(function_ref<void(void)> fun = nullptr);
+  void setOp(Operation *op) { this->op = op; }
+  Operation *getOp() { return op; }
 
 private:
   LoopBuilder() = default;
@@ -166,7 +168,10 @@ class LoopBuilder : public NestedBuilder {
                                              ArrayRef<ValueHandle> steps);
   friend LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
                                      ValueHandle ubHandle,
-                                     ValueHandle stepHandle);
+                                     ValueHandle stepHandle,
+                                     ArrayRef<ValueHandle *> iter_args_handles,
+                                     ValueRange iter_args_init_values);
+  Operation *op;
 };
 
 // This class exists solely to handle the C++ vexing parse case when

diff  --git a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
index ad9e1b74ef4d..b7af6635ac5e 100644
--- a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
@@ -45,19 +45,32 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
   assert(ivs.size() == steps.size() &&
          "expected size of ivs and steps to match");
   loops.reserve(ivs.size());
-  for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
+  for (auto it : llvm::zip(ivs, lbs, ubs, steps))
     loops.emplace_back(makeLoopBuilder(std::get<0>(it), std::get<1>(it),
                                        std::get<2>(it), std::get<3>(it)));
-  }
   assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
 }
 
-void mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()(
+mlir::edsc::LoopNestBuilder::LoopNestBuilder(
+    ValueHandle *iv, ValueHandle lb, ValueHandle ub, ValueHandle step,
+    ArrayRef<ValueHandle *> iter_args_handles,
+    ValueRange iter_args_init_values) {
+  assert(iter_args_init_values.size() == iter_args_handles.size() &&
+         "expected size of arguments and argument_handles to match");
+  loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, iter_args_handles,
+                                     iter_args_init_values));
+}
+
+Operation::result_range
+mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()(
     std::function<void(void)> fun) {
   if (fun)
     fun();
+
   for (auto &lit : reverse(loops))
     lit({});
+
+  return loops[0].getOp()->getResults();
 }
 
 LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
@@ -78,15 +91,21 @@ LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
   return result;
 }
 
-mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder(ValueHandle *iv,
-                                                    ValueHandle lbHandle,
-                                                    ValueHandle ubHandle,
-                                                    ValueHandle stepHandle) {
+mlir::edsc::LoopBuilder
+mlir::edsc::makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
+                            ValueHandle ubHandle, ValueHandle stepHandle,
+                            ArrayRef<ValueHandle *> iter_args_handles,
+                            ValueRange iter_args_init_values) {
   mlir::edsc::LoopBuilder result;
-  auto forOp =
-      OperationHandle::createOp<loop::ForOp>(lbHandle, ubHandle, stepHandle);
+  auto forOp = OperationHandle::createOp<loop::ForOp>(
+      lbHandle, ubHandle, stepHandle, iter_args_init_values);
   *iv = ValueHandle(forOp.getInductionVar());
   auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
+  for (size_t i = 0, e = iter_args_handles.size(); i < e; ++i) {
+    // Skipping the induction variable.
+    *(iter_args_handles[i]) = ValueHandle(body->getArgument(i + 1));
+  }
+  result.setOp(forOp);
   result.enter(body, /*prev=*/1);
   return result;
 }

diff  --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index dd90511a0472..594040d60ae1 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -10,7 +10,7 @@
 
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
-#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
+#include "mlir/Dialect/LoopOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
 #include "mlir/EDSC/Builders.h"
@@ -1074,6 +1074,44 @@ TEST_FUNC(memref_vector_matmul_test) {
   f.erase();
 }
 
+TEST_FUNC(builder_loop_for_yield) {
+  auto indexType = IndexType::get(&globalContext());
+  auto f32Type = FloatType::getF32(&globalContext());
+  auto f = makeFunction("builder_loop_for_yield", {},
+                        {indexType, indexType, indexType, indexType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+  ValueHandle init0 = std_constant_float(llvm::APFloat(1.0f), f32Type);
+  ValueHandle init1 = std_constant_float(llvm::APFloat(2.0f), f32Type);
+  ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
+      c(f.getArgument(2)), d(f.getArgument(3));
+  ValueHandle arg0(f32Type);
+  ValueHandle arg1(f32Type);
+  using namespace edsc::op;
+  auto results =
+      LoopNestBuilder(&i, a - b, c + d, a, {&arg0, &arg1}, {init0, init1})([&] {
+        auto sum = arg0 + arg1;
+        loop_yield(ArrayRef<ValueHandle>{arg1, sum});
+      });
+  ValueHandle(results[0]) + ValueHandle(results[1]);
+
+  // clang-format off
+  // CHECK-LABEL: func @builder_loop_for_yield(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
+  // CHECK:     [[init0:%.*]] = constant
+  // CHECK:     [[init1:%.*]] = constant
+  // CHECK-DAG:    [[r0:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%{{.*}}, %{{.*}}]
+  // CHECK-DAG:    [[r1:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.*}}, %{{.*}}]
+  // CHECK-NEXT: [[res:%[0-9]+]]:2 = loop.for %{{.*}} = [[r0]] to [[r1]] step {{.*}} iter_args([[arg0:%.*]] = [[init0]], [[arg1:%.*]] = [[init1]]) -> (f32, f32) {
+  // CHECK:     [[sum:%[0-9]+]] = addf [[arg0]], [[arg1]] : f32
+  // CHECK:     loop.yield [[arg1]], [[sum]] : f32, f32
+  // CHECK:     addf [[res]]#0, [[res]]#1 : f32
+
+  // clang-format on
+  f.print(llvm::outs());
+  f.erase();
+}
+
 int main() {
   RUN_TESTS();
   return 0;


        


More information about the Mlir-commits mailing list