[Mlir-commits] [mlir] efcf098 - [mlir] [EDSC] Add interface for yield-for loops.
Alex Zinenko
llvmlistbot at llvm.org
Wed Apr 15 09:39:41 PDT 2020
Author: Pierre Oechsel
Date: 2020-04-15T18:39:30+02:00
New Revision: efcf0985eef69127af0e5576f5977b0bb3f1a4a8
URL: https://github.com/llvm/llvm-project/commit/efcf0985eef69127af0e5576f5977b0bb3f1a4a8
DIFF: https://github.com/llvm/llvm-project/commit/efcf0985eef69127af0e5576f5977b0bb3f1a4a8.diff
LOG: [mlir] [EDSC] Add interface for yield-for loops.
Summary:
ModelBuilder was missing an api to easily generate yield-for-loops.
This diffs implements an interface allowing to write:
```
%2:2 = loop.for %i = %start to %end step %step iter_args(%arg0 = %init0, %arg1 = %init1) -> (f32, f32) {
%sum = addf %arg0, %arg1 : f32
loop.yield %arg1, %sum : f32, f32
}
%3 = addf %2#0, %2#1 : f32
```
as
```
auto results =
LoopNestBuilder(&i, start, end, step, {&arg0, &arg1}, {init0, init1})([&] {
auto sum = arg0 + arg1;
loop_yield(ArrayRef<ValueHandle>{arg1, sum});
});
// Add the two values accumulated by the yield-for-loop:
ValueHandle(results[0]) + ValueHandle(results[1]);
```
Differential Revision: https://reviews.llvm.org/D78093
Added:
mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h
Modified:
mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
mlir/include/mlir/EDSC/Builders.h
mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
mlir/test/EDSC/builder-api-test.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
index 41b0e1a972bf..f4b30aa54879 100644
--- a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h
@@ -33,7 +33,9 @@ LoopBuilder makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
/// variable. A ValueHandle pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable.
LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
- ValueHandle ubHandle, ValueHandle stepHandle);
+ ValueHandle ubHandle, ValueHandle stepHandle,
+ ArrayRef<ValueHandle *> iter_args_handles = {},
+ ValueRange iter_args_init_values = {});
/// Helper class to sugar building loop.parallel loop nests from lower/upper
/// bounds and step sizes.
@@ -54,9 +56,13 @@ class ParallelLoopNestBuilder {
/// loop.for.
class LoopNestBuilder {
public:
- LoopNestBuilder(ArrayRef<edsc::ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
+ LoopNestBuilder(ValueHandle *iv, ValueHandle lb, ValueHandle ub,
+ ValueHandle step,
+ ArrayRef<ValueHandle *> iter_args_handles = {},
+ ValueRange iter_args_init_values = {});
+ LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps);
- void operator()(std::function<void(void)> fun = nullptr);
+ Operation::result_range operator()(std::function<void(void)> fun = nullptr);
private:
SmallVector<LoopBuilder, 4> loops;
diff --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h
new file mode 100644
index 000000000000..21803e2bf13b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h
@@ -0,0 +1,24 @@
+//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_
+#define MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_
+
+#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace mlir {
+namespace edsc {
+namespace intrinsics {
+using loop_yield = OperationBuilder<loop::YieldOp>;
+
+} // namespace intrinsics
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_
diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index c907f2d1ea5e..9c5a8e72c66e 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -152,6 +152,8 @@ class LoopBuilder : public NestedBuilder {
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// scoped within a LoopBuilder.
void operator()(function_ref<void(void)> fun = nullptr);
+ void setOp(Operation *op) { this->op = op; }
+ Operation *getOp() { return op; }
private:
LoopBuilder() = default;
@@ -166,7 +168,10 @@ class LoopBuilder : public NestedBuilder {
ArrayRef<ValueHandle> steps);
friend LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
ValueHandle ubHandle,
- ValueHandle stepHandle);
+ ValueHandle stepHandle,
+ ArrayRef<ValueHandle *> iter_args_handles,
+ ValueRange iter_args_init_values);
+ Operation *op;
};
// This class exists solely to handle the C++ vexing parse case when
diff --git a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
index ad9e1b74ef4d..b7af6635ac5e 100644
--- a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp
@@ -45,19 +45,32 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
assert(ivs.size() == steps.size() &&
"expected size of ivs and steps to match");
loops.reserve(ivs.size());
- for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
+ for (auto it : llvm::zip(ivs, lbs, ubs, steps))
loops.emplace_back(makeLoopBuilder(std::get<0>(it), std::get<1>(it),
std::get<2>(it), std::get<3>(it)));
- }
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
}
-void mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()(
+mlir::edsc::LoopNestBuilder::LoopNestBuilder(
+ ValueHandle *iv, ValueHandle lb, ValueHandle ub, ValueHandle step,
+ ArrayRef<ValueHandle *> iter_args_handles,
+ ValueRange iter_args_init_values) {
+ assert(iter_args_init_values.size() == iter_args_handles.size() &&
+ "expected size of arguments and argument_handles to match");
+ loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, iter_args_handles,
+ iter_args_init_values));
+}
+
+Operation::result_range
+mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()(
std::function<void(void)> fun) {
if (fun)
fun();
+
for (auto &lit : reverse(loops))
lit({});
+
+ return loops[0].getOp()->getResults();
}
LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
@@ -78,15 +91,21 @@ LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
return result;
}
-mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder(ValueHandle *iv,
- ValueHandle lbHandle,
- ValueHandle ubHandle,
- ValueHandle stepHandle) {
+mlir::edsc::LoopBuilder
+mlir::edsc::makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
+ ValueHandle ubHandle, ValueHandle stepHandle,
+ ArrayRef<ValueHandle *> iter_args_handles,
+ ValueRange iter_args_init_values) {
mlir::edsc::LoopBuilder result;
- auto forOp =
- OperationHandle::createOp<loop::ForOp>(lbHandle, ubHandle, stepHandle);
+ auto forOp = OperationHandle::createOp<loop::ForOp>(
+ lbHandle, ubHandle, stepHandle, iter_args_init_values);
*iv = ValueHandle(forOp.getInductionVar());
auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
+ for (size_t i = 0, e = iter_args_handles.size(); i < e; ++i) {
+ // Skipping the induction variable.
+ *(iter_args_handles[i]) = ValueHandle(body->getArgument(i + 1));
+ }
+ result.setOp(forOp);
result.enter(body, /*prev=*/1);
return result;
}
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index dd90511a0472..594040d60ae1 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -10,7 +10,7 @@
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
-#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
+#include "mlir/Dialect/LoopOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/EDSC/Builders.h"
@@ -1074,6 +1074,44 @@ TEST_FUNC(memref_vector_matmul_test) {
f.erase();
}
+TEST_FUNC(builder_loop_for_yield) {
+ auto indexType = IndexType::get(&globalContext());
+ auto f32Type = FloatType::getF32(&globalContext());
+ auto f = makeFunction("builder_loop_for_yield", {},
+ {indexType, indexType, indexType, indexType});
+
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle init0 = std_constant_float(llvm::APFloat(1.0f), f32Type);
+ ValueHandle init1 = std_constant_float(llvm::APFloat(2.0f), f32Type);
+ ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
+ c(f.getArgument(2)), d(f.getArgument(3));
+ ValueHandle arg0(f32Type);
+ ValueHandle arg1(f32Type);
+ using namespace edsc::op;
+ auto results =
+ LoopNestBuilder(&i, a - b, c + d, a, {&arg0, &arg1}, {init0, init1})([&] {
+ auto sum = arg0 + arg1;
+ loop_yield(ArrayRef<ValueHandle>{arg1, sum});
+ });
+ ValueHandle(results[0]) + ValueHandle(results[1]);
+
+ // clang-format off
+ // CHECK-LABEL: func @builder_loop_for_yield(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
+ // CHECK: [[init0:%.*]] = constant
+ // CHECK: [[init1:%.*]] = constant
+ // CHECK-DAG: [[r0:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%{{.*}}, %{{.*}}]
+ // CHECK-DAG: [[r1:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.*}}, %{{.*}}]
+ // CHECK-NEXT: [[res:%[0-9]+]]:2 = loop.for %{{.*}} = [[r0]] to [[r1]] step {{.*}} iter_args([[arg0:%.*]] = [[init0]], [[arg1:%.*]] = [[init1]]) -> (f32, f32) {
+ // CHECK: [[sum:%[0-9]+]] = addf [[arg0]], [[arg1]] : f32
+ // CHECK: loop.yield [[arg1]], [[sum]] : f32, f32
+ // CHECK: addf [[res]]#0, [[res]]#1 : f32
+
+ // clang-format on
+ f.print(llvm::outs());
+ f.erase();
+}
+
int main() {
RUN_TESTS();
return 0;
More information about the Mlir-commits
mailing list