[Mlir-commits] [mlir] 5a31403 - [MLIR][SCF] Create selects from if yield results which are not defined in the body

William S. Moses llvmlistbot at llvm.org
Fri Mar 18 08:38:34 PDT 2022


Author: William S. Moses
Date: 2022-03-18T11:38:29-04:00
New Revision: 5a31403b513d7449c4f045347740056800996357

URL: https://github.com/llvm/llvm-project/commit/5a31403b513d7449c4f045347740056800996357
DIFF: https://github.com/llvm/llvm-project/commit/5a31403b513d7449c4f045347740056800996357.diff

LOG: [MLIR][SCF] Create selects from if yield results which are not defined in the body

Previously, the canonicalizer to create ifs from selects would only work
if the if did not have a body other than yielding. This patch upgrade the functionality
to be able to create selects from any if result whose operands are not defined
within the body.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D121943

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index cd2b8a0b45055..d888a341f0781 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1326,6 +1326,8 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
   }
 };
 
+/// Hoist any yielded results whose operands are defined outside
+/// the if, to a select instruction.
 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
@@ -1334,31 +1336,58 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
     if (op->getNumResults() == 0)
       return failure();
 
-    if (!llvm::hasSingleElement(op.getThenRegion().front()) ||
-        !llvm::hasSingleElement(op.getElseRegion().front()))
+    auto cond = op.getCondition();
+    auto thenYieldArgs = op.thenYield().getOperands();
+    auto elseYieldArgs = op.elseYield().getOperands();
+
+    SmallVector<Type> nonHoistable;
+    for (const auto &it :
+         llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
+      Value trueVal = std::get<0>(it.value());
+      Value falseVal = std::get<1>(it.value());
+      if (&op.getThenRegion() == trueVal.getParentRegion() ||
+          &op.getElseRegion() == falseVal.getParentRegion())
+        nonHoistable.push_back(trueVal.getType());
+    }
+    // Early exit if there aren't any yielded values we can
+    // hoist outside the if.
+    if (nonHoistable.size() == op->getNumResults())
       return failure();
 
-    auto cond = op.getCondition();
-    auto thenYieldArgs =
-        cast<scf::YieldOp>(op.getThenRegion().front().getTerminator())
-            .getOperands();
-    auto elseYieldArgs =
-        cast<scf::YieldOp>(op.getElseRegion().front().getTerminator())
-            .getOperands();
+    IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
+    if (replacement.thenBlock())
+      rewriter.eraseBlock(replacement.thenBlock());
+    replacement.getThenRegion().takeBody(op.getThenRegion());
+    replacement.getElseRegion().takeBody(op.getElseRegion());
+
     SmallVector<Value> results(op->getNumResults());
     assert(thenYieldArgs.size() == results.size());
     assert(elseYieldArgs.size() == results.size());
+
+    SmallVector<Value> trueYields;
+    SmallVector<Value> falseYields;
     for (const auto &it :
          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
       Value trueVal = std::get<0>(it.value());
       Value falseVal = std::get<1>(it.value());
-      if (trueVal == falseVal)
+      if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
+          &replacement.getElseRegion() == falseVal.getParentRegion()) {
+        results[it.index()] = replacement.getResult(trueYields.size());
+        trueYields.push_back(trueVal);
+        falseYields.push_back(falseVal);
+      } else if (trueVal == falseVal)
         results[it.index()] = trueVal;
       else
         results[it.index()] = rewriter.create<arith::SelectOp>(
             op.getLoc(), cond, trueVal, falseVal);
     }
 
+    rewriter.setInsertionPointToEnd(replacement.thenBlock());
+    rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
+
+    rewriter.setInsertionPointToEnd(replacement.elseBlock());
+    rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
+
     rewriter.replaceOp(op, results);
     return success();
   }

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 6edf9ff9ddff7..f06e9833d1211 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -136,26 +136,26 @@ func @nested_parallel(%0: memref<?x?x?xf64>) -> memref<?x?x?xf64> {
 
 func private @side_effect()
 func @one_unused(%cond: i1) -> (index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
   %0, %1 = scf.if %cond -> (index, index) {
     call @side_effect() : () -> ()
+    %c0 = "test.value0"() : () -> (index)
+    %c1 = "test.value1"() : () -> (index)
     scf.yield %c0, %c1 : index, index
   } else {
+    %c2 = "test.value2"() : () -> (index)
+    %c3 = "test.value3"() : () -> (index)
     scf.yield %c2, %c3 : index, index
   }
   return %1 : index
 }
 
 // CHECK-LABEL:   func @one_unused
-// CHECK-DAG:       [[C0:%.*]] = arith.constant 1 : index
-// CHECK-DAG:       [[C3:%.*]] = arith.constant 3 : index
 // CHECK:           [[V0:%.*]] = scf.if %{{.*}} -> (index) {
 // CHECK:             call @side_effect() : () -> ()
-// CHECK:             scf.yield [[C0]] : index
+// CHECK:             [[C1:%.*]] = "test.value1"
+// CHECK:             scf.yield [[C1]] : index
 // CHECK:           } else
+// CHECK:             [[C3:%.*]] = "test.value3"
 // CHECK:             scf.yield [[C3]] : index
 // CHECK:           }
 // CHECK:           return [[V0]] : index
@@ -164,37 +164,40 @@ func @one_unused(%cond: i1) -> (index) {
 
 func private @side_effect()
 func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
   %0, %1 = scf.if %cond1 -> (index, index) {
     %2, %3 = scf.if %cond2 -> (index, index) {
       call @side_effect() : () -> ()
+      %c0 = "test.value0"() : () -> (index)
+      %c1 = "test.value1"() : () -> (index)
       scf.yield %c0, %c1 : index, index
     } else {
+      %c2 = "test.value2"() : () -> (index)
+      %c3 = "test.value3"() : () -> (index)
       scf.yield %c2, %c3 : index, index
     }
     scf.yield %2, %3 : index, index
   } else {
+    %c0 = "test.value0_2"() : () -> (index)
+    %c1 = "test.value1_2"() : () -> (index)
     scf.yield %c0, %c1 : index, index
   }
   return %1 : index
 }
 
 // CHECK-LABEL:   func @nested_unused
-// CHECK-DAG:       [[C0:%.*]] = arith.constant 1 : index
-// CHECK-DAG:       [[C3:%.*]] = arith.constant 3 : index
 // CHECK:           [[V0:%.*]] = scf.if {{.*}} -> (index) {
 // CHECK:             [[V1:%.*]] = scf.if {{.*}} -> (index) {
 // CHECK:               call @side_effect() : () -> ()
-// CHECK:               scf.yield [[C0]] : index
+// CHECK:               [[C1:%.*]] = "test.value1"
+// CHECK:               scf.yield [[C1]] : index
 // CHECK:             } else
+// CHECK:               [[C3:%.*]] = "test.value3"
 // CHECK:               scf.yield [[C3]] : index
 // CHECK:             }
 // CHECK:             scf.yield [[V1]] : index
 // CHECK:           } else
-// CHECK:             scf.yield [[C0]] : index
+// CHECK:             [[C1_2:%.*]] = "test.value1_2"
+// CHECK:             scf.yield [[C1_2]] : index
 // CHECK:           }
 // CHECK:           return [[V0]] : index
 
@@ -302,6 +305,27 @@ func @to_select_same_val(%cond: i1) -> (index, index) {
 // CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
 // CHECK:           return [[V0]], [[C1]] : index, index
 
+
+func @to_select_with_body(%cond: i1) -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = scf.if %cond -> index {
+    "test.op"() : () -> ()
+    scf.yield %c0 : index
+  } else {
+    scf.yield %c1 : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL:   func @to_select_with_body
+// CHECK-DAG:       [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:       [[C1:%.*]] = arith.constant 1 : index
+// CHECK:           scf.if {{.*}} {
+// CHECK:             "test.op"() : () -> ()
+// CHECK:           } 
+// CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
+// CHECK:           return [[V0]] : index
 // -----
 
 func @to_select2(%cond: i1) -> (index, index) {
@@ -731,38 +755,32 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
 
 // CHECK-LABEL: @cond_prop
 func @cond_prop(%arg0 : i1) -> index {
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
   %res = scf.if %arg0 -> index {
     %res1 = scf.if %arg0 -> index {
-      %v1 = "test.get_some_value"() : () -> i32
-      scf.yield %c1 : index
+      %v1 = "test.get_some_value1"() : () -> index
+      scf.yield %v1 : index
     } else {
-      %v2 = "test.get_some_value"() : () -> i32
-      scf.yield %c2 : index
+      %v2 = "test.get_some_value2"() : () -> index
+      scf.yield %v2 : index
     }
     scf.yield %res1 : index
   } else {
     %res2 = scf.if %arg0 -> index {
-      %v3 = "test.get_some_value"() : () -> i32
-      scf.yield %c3 : index
+      %v3 = "test.get_some_value3"() : () -> index
+      scf.yield %v3 : index
     } else {
-      %v4 = "test.get_some_value"() : () -> i32
-      scf.yield %c4 : index
+      %v4 = "test.get_some_value4"() : () -> index
+      scf.yield %v4 : index
     }
     scf.yield %res2 : index
   }
   return %res : index
 }
-// CHECK-DAG:  %[[c1:.+]] = arith.constant 1 : index
-// CHECK-DAG:  %[[c4:.+]] = arith.constant 4 : index
 // CHECK-NEXT:  %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT:    %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
 // CHECK-NEXT:    scf.yield %[[c1]] : index
 // CHECK-NEXT:  } else {
-// CHECK-NEXT:    %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
 // CHECK-NEXT:    scf.yield %[[c4]] : index
 // CHECK-NEXT:  }
 // CHECK-NEXT:  return %[[if]] : index
@@ -808,7 +826,6 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
   return %res#0, %res#1 : i32, i1
 }
 // CHECK-NEXT:     %true = arith.constant true
-// CHECK-NEXT:     %[[toret:.+]] = arith.xori %arg0, %true : i1
 // CHECK-NEXT:     %[[if:.+]] = scf.if %arg0 -> (i32) {
 // CHECK-NEXT:       %[[sv1:.+]] = "test.get_some_value"() : () -> i32
 // CHECK-NEXT:       scf.yield %[[sv1]] : i32
@@ -816,6 +833,7 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
 // CHECK-NEXT:       %[[sv2:.+]] = "test.get_some_value"() : () -> i32
 // CHECK-NEXT:       scf.yield %[[sv2]] : i32
 // CHECK-NEXT:     }
+// CHECK-NEXT:     %[[toret:.+]] = arith.xori %arg0, %true : i1
 // CHECK-NEXT:     return %[[if]], %[[toret]] : i32, i1
 
 // -----


        


More information about the Mlir-commits mailing list