[Mlir-commits] [mlir] [mlir][scf] Add a scope op to the scf dialect (PR #89274)
Gil Rapaport
llvmlistbot at llvm.org
Thu Apr 18 10:37:36 PDT 2024
https://github.com/aniragil created https://github.com/llvm/llvm-project/pull/89274
Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.
>From 57b873eaf7f21038e13dc8b2c64a6dc9209c56d1 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 16 Apr 2024 17:58:18 +0300
Subject: [PATCH] [mlir][scf] Add a scope op to the scf dialect
Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 37 +++++++++++++++++++++-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 19 +++++++++++
mlir/test/Dialect/SCF/invalid.mlir | 27 +++++++++++++++-
mlir/test/Dialect/SCF/ops.mlir | 25 +++++++++++++++
4 files changed, 106 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..6c04fe4cdd651d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -931,6 +931,41 @@ def ReduceReturnOp :
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+def ScopeOp : SCF_Op<"scope",
+ [AutomaticAllocationScope,
+ RecursiveMemoryEffects,
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"scf::YieldOp">]> {
+ let summary = "isolated code scope";
+ let description = [{
+ The 'scope' op encapsulates computations by providing an isolated-from-above,
+ executed-once single basic block. The op takes any number of operands, and
+ its return values are defined by its terminating `scf.yield`. For example:
+
+ ```mlir
+ %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+ ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+ %add = arith.addi %a, %b : i32
+ %mul = arith.mulf %c, %d : f32
+ scf.yield %add, %mul : i32, f32
+ }
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ $operands attr-dict `:` functional-type($operands, $results) $body
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
@@ -1155,7 +1190,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
- "WhileOp"]>]> {
+ "ScopeOp", "WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..6a8d28afdbc251 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3079,6 +3079,25 @@ LogicalResult ReduceReturnOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult scf::ScopeOp::verify() {
+ Region &body = getBody();
+ Block &block = body.front();
+ Operation *terminator = block.getTerminator();
+ if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+ InFlightDiagnostic diag = emitOpError()
+ << "expects terminator operands to have the "
+ "same type as results of the operation";
+ diag.attachNote(terminator->getLoc()) << "terminator";
+ return diag;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 337eb9eeb8fa57..6871881a49458c 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -476,7 +476,7 @@ func.func @parallel_invalid_yield(
func.func @yield_invalid_parent_op() {
"my.op"() ({
- // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}}
+ // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.scope, scf.while'}}
scf.yield
}) : () -> ()
return
@@ -747,3 +747,28 @@ func.func @parallel_missing_terminator(%0 : index) {
return
}
+// -----
+
+func.func @scope_not_isolated_from_above(%arg0 : i32, %arg1 : i32) -> (i32) {
+ // expected-note @below {{required by region isolation constraints}}
+ %p = scf.scope : () -> (i32) {
+ ^bb0():
+ // expected-error @below {{'arith.addi' op using value defined outside the region}}
+ %add = arith.addi %arg0, %arg1 : i32
+ scf.yield %add : i32
+ }
+ return %p : i32
+}
+
+// -----
+
+func.func @scope_yield_results_mismatch(%arg0 : i32, %arg1 : i32) -> (i32) {
+ // expected-error @below {{'scf.scope' op expects terminator operands to have the same type as results of the operation}}
+ %p = scf.scope %arg0, %arg1 : (i32, i32) -> (i32) {
+ ^bb0(%k : i32, %t : i32):
+ %add = arith.addi %k, %t : i32
+ // expected-note @below {{terminator}}
+ scf.yield %add, %add : i32, i32
+ }
+ return %p : i32
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 7f457ef3b6ba0c..70cb72a9ec0bf6 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -441,3 +441,28 @@ func.func @switch(%arg0: index) -> i32 {
return %0 : i32
}
+
+// CHECK-LABEL: @scope
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32
+func.func @scope(%arg0 : i32, %arg1 : f32, %arg2 : f32) -> (f32) {
+ // CHECK: %[[VAL_3:.*]] = arith.constant 77 : i32
+ %c77 = arith.constant 77 : i32
+
+ // CHECK: %[[VAL_4:.*]]:2 = scf.scope %[[VAL_0]], %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : (i32, i32, f32, f32) -> (i32, f32)
+ %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+ // CHECK: ^bb0(%[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+ ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+ // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
+ %add = arith.addi %a, %b : i32
+ // CHECK: %[[VAL_10:.*]] = arith.mulf %[[VAL_7]], %[[VAL_8]] : f32
+ %mul = arith.mulf %c, %d : f32
+ // CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : i32, f32
+ scf.yield %add, %mul : i32, f32
+ }
+
+ // CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_4:.*]]#0 : i32 to f32
+ %m = arith.sitofp %p#0 : i32 to f32
+ // CHECK: %[[VAL_13:.*]] = arith.subf %[[VAL_11]], %[[VAL_4]]#1 : f32
+ %r = arith.subf %m, %p#1 : f32
+ return %r : f32
+}
More information about the Mlir-commits
mailing list