[Mlir-commits] [mlir] b2b3569 - [mlir] Fix scf.for single iteration canonicalization check

Lei Zhang llvmlistbot at llvm.org
Tue Feb 2 08:12:07 PST 2021


Author: Lei Zhang
Date: 2021-02-02T11:08:56-05:00
New Revision: b2b35697dc5172ab1e815e08c0a2714f2a1a9330

URL: https://github.com/llvm/llvm-project/commit/b2b35697dc5172ab1e815e08c0a2714f2a1a9330
DIFF: https://github.com/llvm/llvm-project/commit/b2b35697dc5172ab1e815e08c0a2714f2a1a9330.diff

LOG: [mlir] Fix scf.for single iteration canonicalization check

We should be check whether lb + step >= ub to determine
whether this is a single iteration. Previously we were
checking lb + lb >= ub.

Differential Revision: https://reviews.llvm.org/D95440

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/Linalg/fusion.mlir
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 1ea0571cf690..eaaf90108f61 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -547,7 +547,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 
     // If the loop is known to have 1 iteration, inline its body and remove the
     // loop.
-    llvm::APInt stepValue = lb.getValue().cast<IntegerAttr>().getValue();
+    llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
     if ((lbValue + stepValue).sge(ubValue)) {
       SmallVector<Value, 4> blockArgs;
       blockArgs.reserve(op.getNumIterOperands() + 1);

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 9a4a1c5f3f6f..977782ed5dbd 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -665,9 +665,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
 #map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
 #map2 = affine_map<()[s0] -> (s0 + 3)>
 
-func @fill_and_conv(%arg0: memref<1x4x5x1xf32>, %arg1: memref<2x3x1x1xf32>, %arg2: memref<1x4x5x1xf32>) {
+func @fill_and_conv(%arg0: memref<?x?x?x?xf32>, %arg1: memref<2x3x1x1xf32>, %arg2: memref<?x?x?x?xf32>) {
   %cst = constant 0.000000e+00 : f32
-  linalg.fill(%arg2, %cst) : memref<1x4x5x1xf32>, f32
+  linalg.fill(%arg2, %cst) : memref<?x?x?x?xf32>, f32
 
   %c4 = constant 4 : index
   %c1 = constant 1 : index
@@ -676,13 +676,13 @@ func @fill_and_conv(%arg0: memref<1x4x5x1xf32>, %arg1: memref<2x3x1x1xf32>, %arg
   %c3 = constant 3 : index
   %4 = dim %arg1, %c0 : memref<2x3x1x1xf32>
   %5 = dim %arg1, %c1 : memref<2x3x1x1xf32>
-  %6 = dim %arg0, %c0 : memref<1x4x5x1xf32>
-  %7 = dim %arg0, %c1 : memref<1x4x5x1xf32>
-  %8 = dim %arg0, %c3 : memref<1x4x5x1xf32>
-  %9 = dim %arg2, %c0 : memref<1x4x5x1xf32>
-  %10 = dim %arg2, %c1 : memref<1x4x5x1xf32>
-  %11 = dim %arg2, %c2 : memref<1x4x5x1xf32>
-  %12 = dim %arg2, %c3 : memref<1x4x5x1xf32>
+  %6 = dim %arg0, %c0 : memref<?x?x?x?xf32>
+  %7 = dim %arg0, %c1 : memref<?x?x?x?xf32>
+  %8 = dim %arg0, %c3 : memref<?x?x?x?xf32>
+  %9 = dim %arg2, %c0 : memref<?x?x?x?xf32>
+  %10 = dim %arg2, %c1 : memref<?x?x?x?xf32>
+  %11 = dim %arg2, %c2 : memref<?x?x?x?xf32>
+  %12 = dim %arg2, %c3 : memref<?x?x?x?xf32>
   %13 = linalg.range %c0 : %6 : %c2 : !linalg.range
   %14 = linalg.range %c0 : %10 : %c3 : !linalg.range
   scf.for %arg3 = %c0 to %6 step %c2 {
@@ -690,14 +690,14 @@ func @fill_and_conv(%arg0: memref<1x4x5x1xf32>, %arg1: memref<2x3x1x1xf32>, %arg
       %15 = affine.min #map0(%c2, %c1, %arg3)
       %16 = affine.apply #map2()[%7]
       %17 = affine.min #map0(%16, %c4, %arg4)
-      %18 = dim %arg0, %c2 : memref<1x4x5x1xf32>
-      %19 = dim %arg0, %c3 : memref<1x4x5x1xf32>
-      %20 = subview %arg0[%arg3, %arg4, %c0, %c0] [%15, %17, %18, %19] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref<?x?x?x?xf32, #map1>
+      %18 = dim %arg0, %c2 : memref<?x?x?x?xf32>
+      %19 = dim %arg0, %c3 : memref<?x?x?x?xf32>
+      %20 = subview %arg0[%arg3, %arg4, %c0, %c0] [%15, %17, %18, %19] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map1>
       %21 = affine.min #map0(%c2, %c1, %arg3)
       %22 = affine.min #map0(%c3, %c4, %arg4)
-      %23 = dim %arg2, %c2 : memref<1x4x5x1xf32>
-      %24 = dim %arg2, %c3 : memref<1x4x5x1xf32>
-      %25 = subview %arg2[%arg3, %arg4, %c0, %c0] [%21, %22, %23, %24] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref<?x?x?x?xf32, #map1>
+      %23 = dim %arg2, %c2 : memref<?x?x?x?xf32>
+      %24 = dim %arg2, %c3 : memref<?x?x?x?xf32>
+      %25 = subview %arg2[%arg3, %arg4, %c0, %c0] [%21, %22, %23, %24] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map1>
       linalg.conv(%arg1, %20, %25) {dilations = [1, 1], strides = [1, 1]} : memref<2x3x1x1xf32>, memref<?x?x?x?xf32, #map1>, memref<?x?x?x?xf32, #map1>
     }
   }

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 05b75a55da8f..f0638d16105b 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -233,8 +233,8 @@ func @remove_zero_iteration_loop_vals(%arg0: index) {
   return
 }
 
-// CHECK-LABEL: @replace_single_iteration_loop
-func @replace_single_iteration_loop() {
+// CHECK-LABEL: @replace_single_iteration_loop_1
+func @replace_single_iteration_loop_1() {
   // CHECK: %[[LB:.*]] = constant 42
   %c42 = constant 42 : index
   %c43 = constant 43 : index
@@ -252,6 +252,26 @@ func @replace_single_iteration_loop() {
   return
 }
 
+// CHECK-LABEL: @replace_single_iteration_loop_2
+func @replace_single_iteration_loop_2() {
+  // CHECK: %[[LB:.*]] = constant 5
+	%c5 = constant 5 : index
+	%c6 = constant 6 : index
+	%c11 = constant 11 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
+  %0 = scf.for %i = %c5 to %c11 step %c6 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[VAL]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
+
 // CHECK-LABEL: @replace_single_iteration_loop_non_unit_step
 func @replace_single_iteration_loop_non_unit_step() {
   // CHECK: %[[LB:.*]] = constant 42


        


More information about the Mlir-commits mailing list