[Mlir-commits] [mlir] a4b227c - [mlir] Fix loop unrolling: properly replace the arguments of the epilogue loop.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 11 18:54:55 PDT 2022


Author: grosul1
Date: 2022-05-12T01:54:39Z
New Revision: a4b227c28aba3487ffbfde12fb59fdb69a6b5bfe

URL: https://github.com/llvm/llvm-project/commit/a4b227c28aba3487ffbfde12fb59fdb69a6b5bfe
DIFF: https://github.com/llvm/llvm-project/commit/a4b227c28aba3487ffbfde12fb59fdb69a6b5bfe.diff

LOG: [mlir] Fix loop unrolling: properly replace the arguments of the epilogue loop.

Using "replaceUsesOfWith" is incorrect because the same initializer value may appear multiple times.

For example, if the epilogue is needed when this loop is unrolled
```
%x:2 = scf.for ... iter_args(%arg1 = %c1, %arg2 = %c1) {
  ...
}
```
then both epilogue's arguments will be incorrectly renamed to use the same result index (note #1 in both cases):
```
%x_unrolled:2 = scf.for ... iter_args(%arg1 = %c1, %arg2 = %c1) {
  ...
}
%x_epilogue:2 = scf.for ... iter_args(%arg1 = %x_unrolled#1, %arg2 = %x_unrolled#1) {
  ...
}
```

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/test/Dialect/SCF/loop-unroll.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 206e61c6ac921..d4c96e51d549a 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -474,12 +474,12 @@ LogicalResult mlir::loopUnrollByFactor(
     // Update uses of loop results.
     auto results = forOp.getResults();
     auto epilogueResults = epilogueForOp.getResults();
-    auto epilogueIterOperands = epilogueForOp.getIterOperands();
 
-    for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) {
+    for (auto e : llvm::zip(results, epilogueResults)) {
       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
-      epilogueForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
     }
+    epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
+                               epilogueForOp.getNumIterOperands(), results);
     (void)promoteIfSingleIteration(epilogueForOp);
   }
 

diff  --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir
index 6a832578d581b..dc2c07f291e76 100644
--- a/mlir/test/Dialect/SCF/loop-unroll.mlir
+++ b/mlir/test/Dialect/SCF/loop-unroll.mlir
@@ -276,3 +276,41 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
 //   UNROLL-UP-TO-NEXT: affine.store %{{.*}}, %[[MEM]][%[[V1]]] : memref<?xf32>
 //   UNROLL-UP-TO-NEXT: return
 
+// Test that epilogue's arguments are correctly renamed.
+func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
+  %0 = arith.constant 7.0 : f32
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 20 : index
+  %step = arith.constant 1 : index
+  %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) {
+    %add = arith.addf %arg0, %arg1 : f32
+    %mul = arith.mulf %arg0, %arg1 : f32
+    scf.yield %add, %mul : f32, f32
+  }
+  return %result#0, %result#1 : f32, f32
+}
+// UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments
+//
+//   UNROLL-BY-3-DAG:   %[[CST:.*]] = arith.constant {{.*}} : f32
+//   UNROLL-BY-3-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   UNROLL-BY-3-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   UNROLL-BY-3-DAG:   %[[C20:.*]] = arith.constant 20 : index
+//   UNROLL-BY-3-DAG:   %[[C18:.*]] = arith.constant 18 : index
+//   UNROLL-BY-3-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//       UNROLL-BY-3:   %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
+//  UNROLL-BY-3-SAME:     iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) {
+//  UNROLL-BY-3-NEXT:     %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
+//  UNROLL-BY-3-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+//  UNROLL-BY-3-NEXT:     %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
+//  UNROLL-BY-3-NEXT:     %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32
+//  UNROLL-BY-3-NEXT:     %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32
+//  UNROLL-BY-3-NEXT:     %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32
+//  UNROLL-BY-3-NEXT:     scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
+//  UNROLL-BY-3-NEXT:   }
+//       UNROLL-BY-3:   %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
+//  UNROLL-BY-3-SAME:     iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) {
+//  UNROLL-BY-3-NEXT:     %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
+//  UNROLL-BY-3-NEXT:     %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
+//  UNROLL-BY-3-NEXT:     scf.yield %[[EADD]], %[[EMUL]] : f32, f32
+//  UNROLL-BY-3-NEXT:   }
+//  UNROLL-BY-3-NEXT:   return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32


        


More information about the Mlir-commits mailing list