[Mlir-commits] [mlir] 3590840 - [mlir][scf] Canonicalize scf.for last tensor iteration result.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Mar 5 01:52:50 PST 2021


Author: Nicolas Vasilache
Date: 2021-03-05T09:42:19Z
New Revision: 35908406dc69415de392600bfb93f15865135584

URL: https://github.com/llvm/llvm-project/commit/35908406dc69415de392600bfb93f15865135584
DIFF: https://github.com/llvm/llvm-project/commit/35908406dc69415de392600bfb93f15865135584.diff

LOG: [mlir][scf] Canonicalize scf.for last tensor iteration result.

Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and
for which only the last loop iteration is actually visible outside of the
loop. The canonicalization looks for a pattern such as:
```
   %t0 = ... : tensor_type
   %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
     ...
     // %m is either tensor_to_memref(%bb00) or defined above the loop
     %m... : memref_type
     ... // uses of %m with potential inplace updates
     %new_tensor = tensor_load %m : memref_type
     ...
     scf.yield %new_tensor : tensor_type
   }
```

`%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
`%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load`
op.

If no aliasing write of `%new_tensor` occurs between tensor_load and yield
then the value %0 visible outside of the loop is the last `tensor_load`
produced in the loop.

For now, we approximate the absence of aliasing by only supporting the case
when the tensor_load is the operation immediately preceding the yield.

The canonicalization rewrites the pattern as:
```
   // %m is either a tensor_to_memref or defined above
   %m... : memref_type
   scf.for ... { // no iter_args
     ... // uses of %m with potential inplace updates
   }
   %0 = tensor_load %m : memref_type
```

Differential revision: https://reviews.llvm.org/D97953

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index d0b6d9f9fb51..57315754a910 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -560,11 +560,137 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
     return failure();
   }
 };
+
+/// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and
+/// for which only the last loop iteration is actually visible outside of the
+/// loop. The canonicalization looks for a pattern such as:
+/// ```
+///    %t0 = ... : tensor_type
+///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
+///      ...
+///      // %m is either tensor_to_memref(%bb00) or defined above the loop
+///      %m... : memref_type
+///      ... // uses of %m with potential inplace updates
+///      %new_tensor = tensor_load %m : memref_type
+///      ...
+///      scf.yield %new_tensor : tensor_type
+///    }
+/// ```
+///
+/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
+/// `%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load`
+/// op.
+///
+/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
+/// occurs between tensor_load and yield then the value %0 visible outside of
+/// the loop is the last `tensor_load` produced in the loop.
+///
+/// For now, we approximate the absence of aliasing by only supporting the case
+/// when the tensor_load is the operation immediately preceding the yield.
+///
+/// The canonicalization rewrites the pattern as:
+/// ```
+///    // %m is either a tensor_to_memref or defined above
+///    %m... : memref_type
+///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
+///      ... // uses of %m with potential inplace updates
+///      scf.yield %bb0: tensor_type
+///    }
+///    %0 = tensor_load %m : memref_type
+/// ```
+///
+/// A later bbArg canonicalization will further rewrite as:
+/// ```
+///    // %m is either a tensor_to_memref or defined above
+///    %m... : memref_type
+///    scf.for ... { // no iter_args
+///      ... // uses of %m with potential inplace updates
+///    }
+///    %0 = tensor_load %m : memref_type
+/// ```
+struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    assert(std::next(forOp.region().begin()) == forOp.region().end() &&
+           "unexpected multiple blocks");
+
+    Location loc = forOp.getLoc();
+    DenseMap<Value, Value> replacements;
+    for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
+      unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
+      auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
+      Value yieldVal = yieldOp->getOperand(idx);
+      auto tensorLoadOp = yieldVal.getDefiningOp<TensorLoadOp>();
+      bool isTensor = bbArg.getType().isa<TensorType>();
+
+      TensorToMemrefOp tensorToMemRefOp;
+      // Either bbArg has no use or it has a single tensor_to_memref use.
+      if (bbArg.hasOneUse())
+        tensorToMemRefOp =
+            dyn_cast<TensorToMemrefOp>(*bbArg.getUsers().begin());
+      if (!isTensor || !tensorLoadOp ||
+          (!bbArg.use_empty() && !tensorToMemRefOp))
+        continue;
+      // If tensorToMemRefOp is present, it must feed into the `tensorLoadOp`.
+      if (tensorToMemRefOp && tensorLoadOp.memref() != tensorToMemRefOp)
+        continue;
+      // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
+      // must be before `tensorLoadOp` in the block so that the lastWrite
+      // property is not subject to additional side-effects.
+      // For now, we only support the case when tensorLoadOp appears immediately
+      // before the terminator.
+      if (tensorLoadOp->getNextNode() != yieldOp)
+        continue;
+
+      // Clone the optional tensorToMemRefOp before forOp.
+      if (tensorToMemRefOp) {
+        rewriter.setInsertionPoint(forOp);
+        rewriter.replaceOpWithNewOp<TensorToMemrefOp>(
+            tensorToMemRefOp, tensorToMemRefOp.memref().getType(),
+            tensorToMemRefOp.tensor());
+      }
+
+      // Clone the tensorLoad after forOp.
+      rewriter.setInsertionPointAfter(forOp);
+      Value newTensorLoad =
+          rewriter.create<TensorLoadOp>(loc, tensorLoadOp.memref());
+      Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
+      replacements.insert(std::make_pair(forOpResult, newTensorLoad));
+
+      // Make the terminator just yield the bbArg, the old tensorLoadOp + the
+      // old bbArg (that is now directly yielded) will canonicalize away.
+      rewriter.startRootUpdate(yieldOp);
+      yieldOp.setOperand(idx, bbArg);
+      rewriter.finalizeRootUpdate(yieldOp);
+    }
+    if (replacements.empty())
+      return failure();
+
+    // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
+    // replaces the whole op and erase it unconditionally. This is wrong for
+    // `forOp` as it generally contains ops with side effects.
+    // Instead, use `rewriter.replaceOpWithIf`.
+    SmallVector<Value> newResults;
+    newResults.reserve(forOp.getNumResults());
+    for (Value v : forOp.getResults()) {
+      auto it = replacements.find(v);
+      newResults.push_back((it != replacements.end()) ? it->second : v);
+    }
+    unsigned idx = 0;
+    rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
+      return op.get() != newResults[idx++];
+    });
+    return success();
+  }
+};
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops>(context);
+  results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops,
+                 LastTensorLoadCanonicalization>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index f0638d16105b..0d7c4eefae25 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1,4 +1,7 @@
-// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+
+
+// -----
 
 func @single_iteration(%A: memref<?x?x?xi32>) {
   %c0 = constant 0 : index
@@ -143,6 +146,8 @@ func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
 //  CHECK-NEXT:     }
 //  CHECK-NEXT:     return %[[a]], %[[r1]], %[[b]] : i32, i32, i32
 
+// -----
+
 // CHECK-LABEL: @replace_true_if
 func @replace_true_if() {
   %true = constant true
@@ -155,6 +160,8 @@ func @replace_true_if() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @remove_false_if
 func @remove_false_if() {
   %false = constant false
@@ -167,6 +174,8 @@ func @remove_false_if() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @replace_true_if_with_values
 func @replace_true_if_with_values() {
   %true = constant true
@@ -184,6 +193,8 @@ func @replace_true_if_with_values() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @replace_false_if_with_values
 func @replace_false_if_with_values() {
   %false = constant false
@@ -201,6 +212,8 @@ func @replace_false_if_with_values() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @remove_zero_iteration_loop
 func @remove_zero_iteration_loop() {
   %c42 = constant 42 : index
@@ -217,6 +230,8 @@ func @remove_zero_iteration_loop() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @remove_zero_iteration_loop_vals
 func @remove_zero_iteration_loop_vals(%arg0: index) {
   %c2 = constant 2 : index
@@ -233,6 +248,8 @@ func @remove_zero_iteration_loop_vals(%arg0: index) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @replace_single_iteration_loop_1
 func @replace_single_iteration_loop_1() {
   // CHECK: %[[LB:.*]] = constant 42
@@ -252,6 +269,8 @@ func @replace_single_iteration_loop_1() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @replace_single_iteration_loop_2
 func @replace_single_iteration_loop_2() {
   // CHECK: %[[LB:.*]] = constant 5
@@ -271,6 +290,7 @@ func @replace_single_iteration_loop_2() {
   return
 }
 
+// -----
 
 // CHECK-LABEL: @replace_single_iteration_loop_non_unit_step
 func @replace_single_iteration_loop_non_unit_step() {
@@ -291,6 +311,8 @@ func @replace_single_iteration_loop_non_unit_step() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @remove_empty_parallel_loop
 func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
   // CHECK: %[[INIT:.*]] = "test.init"
@@ -311,3 +333,52 @@ func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
   "test.consume"(%0) : (f32) -> ()
   return
 }
+
+// -----
+func private @process(%0 : memref<128x128xf32>)
+func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>
+
+// CHECK-LABEL: last_value
+//  CHECK-SAME:   %[[T0:[0-9a-z]*]]: tensor<128x128xf32>
+//  CHECK-SAME:   %[[T1:[0-9a-z]*]]: tensor<128x128xf32>
+//  CHECK-SAME:   %[[T2:[0-9a-z]*]]: tensor<128x128xf32>
+//  CHECK-SAME:   %[[M0:[0-9a-z]*]]: memref<128x128xf32>
+func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
+                 %t2: tensor<128x128xf32>, %m0: memref<128x128xf32>,
+                 %lb : index, %ub : index, %step : index)
+  -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
+{
+  // CHECK-NEXT: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<128x128xf32>
+  // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) {
+  %0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2)
+    -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
+  {
+    %m1 = tensor_to_memref %arg2 : memref<128x128xf32>
+
+    // CHECK-NEXT:   call @process(%[[M0]]) : (memref<128x128xf32>) -> ()
+    call @process(%m0) : (memref<128x128xf32>) -> ()
+
+    // CHECK-NEXT:   call @process(%[[M1]]) : (memref<128x128xf32>) -> ()
+    call @process(%m1) : (memref<128x128xf32>) -> ()
+
+    // This does not hoist (fails the bbArg has at most a single check).
+    // CHECK-NEXT:   %[[T:.*]] = call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32>
+    // CHECK-NEXT:   %[[YIELD_T:.*]] = tensor_load %[[T:.*]]
+    %m2 = call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32>
+    %3 = tensor_load %m2 : memref<128x128xf32>
+
+    // All this stuff goes away, incrementally
+    %1 = tensor_load %m0 : memref<128x128xf32>
+    %2 = tensor_load %m1 : memref<128x128xf32>
+
+    // CHECK-NEXT:   scf.yield %[[YIELD_T]] : tensor<128x128xf32>
+    scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+
+  // CHECK-NEXT: }
+  }
+
+  // CHECK-NEXT: %[[R0:.*]] = tensor_load %[[M0]] : memref<128x128xf32>
+  // CHECK-NEXT: %[[R1:.*]] = tensor_load %[[M1]] : memref<128x128xf32>
+  // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+  return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+}


        


More information about the Mlir-commits mailing list