[Mlir-commits] [mlir] [mlir][scf] Add a scope op to the scf dialect (PR #89274)

Gil Rapaport llvmlistbot at llvm.org
Thu Apr 18 10:37:36 PDT 2024


https://github.com/aniragil created https://github.com/llvm/llvm-project/pull/89274

Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.


>From 57b873eaf7f21038e13dc8b2c64a6dc9209c56d1 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 16 Apr 2024 17:58:18 +0300
Subject: [PATCH] [mlir][scf] Add a scope op to the scf dialect

Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 37 +++++++++++++++++++++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 19 +++++++++++
 mlir/test/Dialect/SCF/invalid.mlir         | 27 +++++++++++++++-
 mlir/test/Dialect/SCF/ops.mlir             | 25 +++++++++++++++
 4 files changed, 106 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..6c04fe4cdd651d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -931,6 +931,41 @@ def ReduceReturnOp :
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+def ScopeOp : SCF_Op<"scope",
+    [AutomaticAllocationScope,
+     RecursiveMemoryEffects,
+     IsolatedFromAbove,
+     SingleBlockImplicitTerminator<"scf::YieldOp">]> {
+  let summary = "isolated code scope";
+  let description = [{
+    The 'scope' op encapsulates computations by providing an isolated-from-above,
+    executed-once single basic block. The op takes any number of operands, and
+    its return values are defined by its terminating `scf.yield`. For example:
+
+    ```mlir
+    %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+    ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+      %add = arith.addi %a, %b : i32
+      %mul = arith.mulf %c, %d : f32
+      scf.yield %add, %mul : i32, f32
+    }
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region AnyRegion:$body);
+
+  let assemblyFormat = [{
+    $operands attr-dict `:` functional-type($operands, $results) $body
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
@@ -1155,7 +1190,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
 
 def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
     ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
-                 "WhileOp"]>]> {
+                 "ScopeOp", "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..6a8d28afdbc251 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3079,6 +3079,25 @@ LogicalResult ReduceReturnOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult scf::ScopeOp::verify() {
+  Region &body = getBody();
+  Block &block = body.front();
+  Operation *terminator = block.getTerminator();
+  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects terminator operands to have the "
+                                 "same type as results of the operation";
+    diag.attachNote(terminator->getLoc()) << "terminator";
+    return diag;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 337eb9eeb8fa57..6871881a49458c 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -476,7 +476,7 @@ func.func @parallel_invalid_yield(
 
 func.func @yield_invalid_parent_op() {
   "my.op"() ({
-   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}}
+   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.scope, scf.while'}}
    scf.yield
   }) : () -> ()
   return
@@ -747,3 +747,28 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @scope_not_isolated_from_above(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-note @below {{required by region isolation constraints}}
+  %p = scf.scope : () -> (i32) {
+  ^bb0():
+    // expected-error @below {{'arith.addi' op using value defined outside the region}}
+    %add = arith.addi %arg0, %arg1 : i32
+    scf.yield %add : i32
+  }
+  return %p : i32
+}
+
+// -----
+
+func.func @scope_yield_results_mismatch(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-error @below {{'scf.scope' op expects terminator operands to have the same type as results of the operation}}
+  %p = scf.scope %arg0, %arg1 : (i32, i32) -> (i32) {
+  ^bb0(%k : i32, %t : i32):
+    %add = arith.addi %k, %t : i32
+    // expected-note @below {{terminator}}
+    scf.yield %add, %add : i32, i32
+  }
+  return %p : i32
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 7f457ef3b6ba0c..70cb72a9ec0bf6 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -441,3 +441,28 @@ func.func @switch(%arg0: index) -> i32 {
 
   return %0 : i32
 }
+
+// CHECK-LABEL: @scope
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32
+func.func @scope(%arg0 : i32, %arg1 : f32, %arg2 : f32) -> (f32) {
+  // CHECK: %[[VAL_3:.*]] = arith.constant 77 : i32
+  %c77 = arith.constant 77 : i32
+
+  // CHECK: %[[VAL_4:.*]]:2 = scf.scope %[[VAL_0]], %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : (i32, i32, f32, f32) -> (i32, f32)
+  %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+  // CHECK: ^bb0(%[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+  ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+    // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
+    %add = arith.addi %a, %b : i32
+    // CHECK: %[[VAL_10:.*]] = arith.mulf %[[VAL_7]], %[[VAL_8]] : f32
+    %mul = arith.mulf %c, %d : f32
+    // CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : i32, f32
+    scf.yield %add, %mul : i32, f32
+  }
+
+  // CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_4:.*]]#0 : i32 to f32
+  %m = arith.sitofp %p#0 : i32 to f32
+  // CHECK: %[[VAL_13:.*]] = arith.subf %[[VAL_11]], %[[VAL_4]]#1 : f32
+  %r = arith.subf %m, %p#1 : f32
+  return %r : f32
+}



More information about the Mlir-commits mailing list