[Mlir-commits] [mlir] [mlir][SCF] Report error when lowering to cf in single block op (PR #65305)

Matthias Springer llvmlistbot at llvm.org
Tue Sep 5 00:46:39 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/65305:

Report an error when lowering an "scf" op inside of a parent that supports only a single block (e.g., "scf.forall", which does currently not have a "cf" lowering). This indicates a problem in the pass pipeline: scf-to-cf is typically one of the last passes to run and ops that do not support unstructured control flow should have been lowered already.

Before this change, the lowering produced IR that does not verify.

>From 5dd2411406d2a6d16a3205404dc570e78ce4c526 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 5 Sep 2023 09:10:35 +0200
Subject: [PATCH] [mlir][SCF] Report error when lowering to cf in single block
 op

Report an error when lowering an "scf" op inside of a parent that supports only a single block (e.g., "scf.forall", which does currently not have a "cf" lowering). This indicates a problem in the pass pipeline: scf-to-cf is typically one of the last passes to run and ops that do not support unstructured control flow should have been lowered already.

Before this change, the lowering produced IR that does not verify.
---
 .../SCFToControlFlow/SCFToControlFlow.cpp     | 16 ++++++
 .../SCFToControlFlow/convert-to-cfg.mlir      | 57 ++++++++++++++++++-
 2 files changed, 71 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 91dbdb429f948e..d6080334fd7142 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -303,6 +303,9 @@ struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
                                            PatternRewriter &rewriter) const {
   Location loc = forOp.getLoc();
+  if (forOp->getParentOp()->hasTrait<OpTrait::SingleBlock>())
+    return forOp->emitError(
+        "cannot lower op inside parent that expects a single block");
 
   // Start by splitting the block containing the 'scf.for' into two parts.
   // The part before will get the init code, the part after will be the end
@@ -370,6 +373,9 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
                                           PatternRewriter &rewriter) const {
   auto loc = ifOp.getLoc();
+  if (ifOp->getParentOp()->hasTrait<OpTrait::SingleBlock>())
+    return ifOp->emitError(
+        "cannot lower op inside parent that expects a single block");
 
   // Start by splitting the block containing the 'scf.if' into two parts.
   // The part before will contain the condition, the part after will be the
@@ -427,6 +433,9 @@ LogicalResult
 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
                                        PatternRewriter &rewriter) const {
   auto loc = op.getLoc();
+  if (op->getParentOp()->hasTrait<OpTrait::SingleBlock>())
+    return op->emitError(
+        "cannot lower op inside parent that expects a single block");
 
   auto *condBlock = rewriter.getInsertionBlock();
   auto opPosition = rewriter.getInsertionPoint();
@@ -535,6 +544,9 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
                                              PatternRewriter &rewriter) const {
   OpBuilder::InsertionGuard guard(rewriter);
   Location loc = whileOp.getLoc();
+  if (whileOp->getParentOp()->hasTrait<OpTrait::SingleBlock>())
+    return whileOp->emitError(
+        "cannot lower op inside parent that expects a single block");
 
   // Split the current block before the WhileOp to create the inlining point.
   Block *currentBlock = rewriter.getInsertionBlock();
@@ -618,6 +630,10 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
 LogicalResult
 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
                                      PatternRewriter &rewriter) const {
+  if (op->getParentOp()->hasTrait<OpTrait::SingleBlock>())
+    return op->emitError(
+        "cannot lower op inside parent that expects a single block");
+
   // Split the block at the op.
   Block *condBlock = rewriter.getInsertionBlock();
   Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 36307a910a6cad..7278b35f554b23 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -verify-diagnostics -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
 //  CHECK-NEXT:  cf.br ^bb1(%{{.*}} : index)
@@ -18,6 +18,8 @@ func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
 //  CHECK-NEXT:    cf.br ^bb1(%{{.*}} : index)
 //  CHECK-NEXT:  ^bb1(%[[ub0:.*]]: index):    // 2 preds: ^bb0, ^bb5
@@ -48,6 +50,8 @@ func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @simple_std_if(%{{.*}}: i1) {
 //  CHECK-NEXT:   cf.cond_br %{{.*}}, ^bb1, ^bb2
 //  CHECK-NEXT:   ^bb1:   // pred: ^bb0
@@ -62,6 +66,8 @@ func.func @simple_std_if(%arg0: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @simple_std_if_else(%{{.*}}: i1) {
 //  CHECK-NEXT:   cf.cond_br %{{.*}}, ^bb1, ^bb2
 //  CHECK-NEXT:   ^bb1:   // pred: ^bb0
@@ -81,6 +87,8 @@ func.func @simple_std_if_else(%arg0: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) {
 //  CHECK-NEXT:   cf.cond_br %{{.*}}, ^bb1, ^bb5
 //  CHECK-NEXT: ^bb1:   // pred: ^bb0
@@ -108,6 +116,8 @@ func.func @simple_std_2_ifs(%arg0: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @simple_std_for_loop_with_2_ifs(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: i1) {
 //  CHECK-NEXT:   cf.br ^bb1(%{{.*}} : index)
 //  CHECK-NEXT:   ^bb1(%{{.*}}: index):    // 2 preds: ^bb0, ^bb7
@@ -148,6 +158,8 @@ func.func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 :
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @simple_if_yield
 func.func @simple_if_yield(%arg0: i1) -> (i1, i1) {
 // CHECK:   cf.cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
@@ -175,6 +187,8 @@ func.func @simple_if_yield(%arg0: i1) -> (i1, i1) {
   return %0#0, %0#1 : i1, i1
 }
 
+// -----
+
 // CHECK-LABEL: func @nested_if_yield
 func.func @nested_if_yield(%arg0: i1) -> (index) {
 // CHECK:   cf.cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
@@ -225,6 +239,8 @@ func.func @nested_if_yield(%arg0: i1) -> (index) {
   return %1 : index
 }
 
+// -----
+
 // CHECK-LABEL:   func @parallel_loop(
 // CHECK-SAME:                        [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
 // CHECK:           [[VAL_5:%.*]] = arith.constant 1 : index
@@ -258,6 +274,8 @@ func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
+// -----
+
 // CHECK-LABEL: @for_yield
 // CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
 // CHECK:        %[[INIT0:.*]] = arith.constant 0
@@ -285,6 +303,8 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32)
   return %result#0, %result#1 : f32, f32
 }
 
+// -----
+
 // CHECK-LABEL: @nested_for_yield
 // CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
 // CHECK:         %[[INIT:.*]] = arith.constant
@@ -314,7 +334,7 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
   return %r : f32
 }
 
-func.func private @generate() -> i64
+// -----
 
 // CHECK-LABEL: @simple_parallel_reduce_loop
 // CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: f32
@@ -356,6 +376,10 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
   return %0 : f32
 }
 
+// -----
+
+func.func private @generate() -> i64
+
 // CHECK-LABEL: parallel_reduce_loop
 // CHECK-SAME: %[[INIT1:[0-9A-Za-z_]*]]: f32)
 func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
@@ -399,6 +423,8 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
   return %0#0, %0#1 : f32, i64
 }
 
+// -----
+
 // Check that the conversion is not overly conservative wrt unknown ops, i.e.
 // that the presence of unknown ops does not prevent the conversion from being
 // applied.
@@ -413,6 +439,8 @@ func.func @unknown_op_inside_loop(%arg0: index, %arg1: index, %arg2: index) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @minimal_while
 func.func @minimal_while() {
   // CHECK:   %[[COND:.*]] = "test.make_condition"() : () -> i1
@@ -434,6 +462,8 @@ func.func @minimal_while() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @do_while
 func.func @do_while(%arg0: f32) {
   // CHECK:   cf.br ^[[BEFORE:.*]]({{.*}}: f32)
@@ -453,6 +483,8 @@ func.func @do_while(%arg0: f32) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @while_values
 // CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
 func.func @while_values(%arg0: i32, %arg1: f32) {
@@ -482,6 +514,8 @@ func.func @while_values(%arg0: i32, %arg1: f32) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @nested_while_ops
 func.func @nested_while_ops(%arg0: f32) -> i64 {
   // CHECK:       cf.br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
@@ -546,6 +580,8 @@ func.func @nested_while_ops(%arg0: f32) -> i64 {
   return %0 : i64
 }
 
+// -----
+
 // CHECK-LABEL: @ifs_in_parallel
 // CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1)
 func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) {
@@ -588,6 +624,8 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1,
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @func_execute_region_elim_multi_yield
 func.func @func_execute_region_elim_multi_yield() {
     "test.foo"() : () -> ()
@@ -621,6 +659,8 @@ func.func @func_execute_region_elim_multi_yield() {
 // CHECK:     "test.bar"(%[[z]])
 // CHECK:     return
 
+// -----
+
 // CHECK-LABEL: @index_switch
 func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
   // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32
@@ -648,3 +688,16 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
   // CHECK-NEXT: return %[[V]]
   return %0 : i32
 }
+
+// -----
+
+func.func @parent_has_single_block(%c: i1) {
+  test.single_no_terminator_custom_asm_op {
+    // expected-error @below{{cannot lower op inside parent that expects a single block}}
+    // expected-error @below{{failed to legalize operation 'scf.if' that was explicitly marked illegal}}
+    scf.if %c {
+      "test.foo"() : () -> ()
+    }
+  }
+  return
+}



More information about the Mlir-commits mailing list