[llvm] [mlir] [mlir][SCF] Retire SCF-specific `to_memref`/`to_tensor` canonicalization patterns (PR #74551)

Matthias Springer via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 5 19:46:46 PST 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/74551

>From a00b999e3b46dbb184c82c773b7455c071c2ec38 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 6 Dec 2023 12:45:51 +0900
Subject: [PATCH] [mlir][SCF] Retire SCF-specific `to_memref`/`to_tensor`
 canonicalization patterns

The partial bufferization framework has been replaced with One-Shot Bufferize. SCF-specific canonicalization patterns for `to_memref`/`to_tensor` are no longer needed.
---
 mlir/lib/Dialect/SCF/IR/CMakeLists.txt        |   3 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 132 +-----------------
 mlir/test/Dialect/SCF/canonicalize.mlir       |  50 -------
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 -
 4 files changed, 4 insertions(+), 182 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 9882b843c285e..423e1c3e1e042 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -11,12 +11,13 @@ add_mlir_dialect_library(MLIRSCFDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
-  MLIRBufferizationDialect
   MLIRControlFlowDialect
+  MLIRDialectUtils
   MLIRFunctionInterfaces
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
+  MLIRTensorDialect
   MLIRValueBoundsOpInterface
   )
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3b55704c4ea07..cf807a2adc10e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
@@ -1082,139 +1081,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
   }
 };
 
-/// Canonicalize the iter_args of an scf::ForOp that involve a
-/// `bufferization.to_tensor` and for which only the last loop iteration is
-/// actually visible outside of the loop. The canonicalization looks for a
-/// pattern such as:
-/// ```
-///    %t0 = ... : tensor_type
-///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
-///      ...
-///      // %m is either buffer_cast(%bb00) or defined above the loop
-///      %m... : memref_type
-///      ... // uses of %m with potential inplace updates
-///      %new_tensor = bufferization.to_tensor %m : memref_type
-///      ...
-///      scf.yield %new_tensor : tensor_type
-///    }
-/// ```
-///
-/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
-/// `%m = buffer_cast %bb0` op that feeds into the yielded
-/// `bufferization.to_tensor` op.
-///
-/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
-/// occurs between `bufferization.to_tensor and yield then the value %0
-/// visible outside of the loop is the last `bufferization.to_tensor`
-/// produced in the loop.
-///
-/// For now, we approximate the absence of aliasing by only supporting the case
-/// when the bufferization.to_tensor is the operation immediately preceding
-/// the yield.
-//
-/// The canonicalization rewrites the pattern as:
-/// ```
-///    // %m is either a buffer_cast or defined above
-///    %m... : memref_type
-///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
-///      ... // uses of %m with potential inplace updates
-///      scf.yield %bb0: tensor_type
-///    }
-///    %0 = bufferization.to_tensor %m : memref_type
-/// ```
-///
-/// A later bbArg canonicalization will further rewrite as:
-/// ```
-///    // %m is either a buffer_cast or defined above
-///    %m... : memref_type
-///    scf.for ... { // no iter_args
-///      ... // uses of %m with potential inplace updates
-///    }
-///    %0 = bufferization.to_tensor %m : memref_type
-/// ```
-struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ForOp forOp,
-                                PatternRewriter &rewriter) const override {
-    assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
-           "unexpected multiple blocks");
-
-    Location loc = forOp.getLoc();
-    DenseMap<Value, Value> replacements;
-    for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
-      unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
-      auto yieldOp =
-          cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
-      Value yieldVal = yieldOp->getOperand(idx);
-      auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
-      bool isTensor = llvm::isa<TensorType>(bbArg.getType());
-
-      bufferization::ToMemrefOp tensorToMemref;
-      // Either bbArg has no use or it has a single buffer_cast use.
-      if (bbArg.hasOneUse())
-        tensorToMemref =
-            dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
-      if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
-        continue;
-      // If tensorToMemref is present, it must feed into the `ToTensorOp`.
-      if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
-        continue;
-      // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
-      // must be before `ToTensorOp` in the block so that the lastWrite
-      // property is not subject to additional side-effects.
-      // For now, we only support the case when ToTensorOp appears
-      // immediately before the terminator.
-      if (tensorLoadOp->getNextNode() != yieldOp)
-        continue;
-
-      // Clone the optional tensorToMemref before forOp.
-      if (tensorToMemref) {
-        rewriter.setInsertionPoint(forOp);
-        rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
-            tensorToMemref, tensorToMemref.getMemref().getType(),
-            tensorToMemref.getTensor());
-      }
-
-      // Clone the tensorLoad after forOp.
-      rewriter.setInsertionPointAfter(forOp);
-      Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
-          loc, tensorLoadOp.getMemref());
-      Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
-      replacements.insert(std::make_pair(forOpResult, newTensorLoad));
-
-      // Make the terminator just yield the bbArg, the old tensorLoadOp + the
-      // old bbArg (that is now directly yielded) will canonicalize away.
-      rewriter.startRootUpdate(yieldOp);
-      yieldOp.setOperand(idx, bbArg);
-      rewriter.finalizeRootUpdate(yieldOp);
-    }
-    if (replacements.empty())
-      return failure();
-
-    // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
-    // replaces the whole op and erase it unconditionally. This is wrong for
-    // `forOp` as it generally contains ops with side effects.
-    // Instead, use `rewriter.replaceOpWithIf`.
-    SmallVector<Value> newResults;
-    newResults.reserve(forOp.getNumResults());
-    for (Value v : forOp.getResults()) {
-      auto it = replacements.find(v);
-      newResults.push_back((it != replacements.end()) ? it->second : v);
-    }
-    unsigned idx = 0;
-    rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
-      return op.get() != newResults[idx++];
-    });
-    return success();
-  }
-};
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
-              LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
+  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
+      context);
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 9dbf8d5dab11a..41e028028616a 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -773,56 +773,6 @@ func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
 
 // -----
 
-func.func private @process(%0 : memref<128x128xf32>)
-func.func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>
-
-// CHECK-LABEL: last_value
-//  CHECK-SAME:   %[[T0:[0-9a-z]*]]: tensor<128x128xf32>
-//  CHECK-SAME:   %[[T1:[0-9a-z]*]]: tensor<128x128xf32>
-//  CHECK-SAME:   %[[T2:[0-9a-z]*]]: tensor<128x128xf32>
-//  CHECK-SAME:   %[[M0:[0-9a-z]*]]: memref<128x128xf32>
-func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
-                 %t2: tensor<128x128xf32>, %m0: memref<128x128xf32>,
-                 %lb : index, %ub : index, %step : index)
-  -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
-{
-  // CHECK-NEXT: %[[M1:.*]] = bufferization.to_memref %[[T1]] : memref<128x128xf32>
-  // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) {
-  %0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2)
-    -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
-  {
-    %m1 = bufferization.to_memref %arg2 : memref<128x128xf32>
-
-    // CHECK-NEXT:   call @process(%[[M0]]) : (memref<128x128xf32>) -> ()
-    func.call @process(%m0) : (memref<128x128xf32>) -> ()
-
-    // CHECK-NEXT:   call @process(%[[M1]]) : (memref<128x128xf32>) -> ()
-    func.call @process(%m1) : (memref<128x128xf32>) -> ()
-
-    // This does not hoist (fails the bbArg has at most a single check).
-    // CHECK-NEXT:   %[[T:.*]] = func.call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32>
-    // CHECK-NEXT:   %[[YIELD_T:.*]] = bufferization.to_tensor %[[T:.*]]
-    %m2 = func.call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32>
-    %3 = bufferization.to_tensor %m2 : memref<128x128xf32>
-
-    // All this stuff goes away, incrementally
-    %1 = bufferization.to_tensor %m0 : memref<128x128xf32>
-    %2 = bufferization.to_tensor %m1 : memref<128x128xf32>
-
-    // CHECK-NEXT:   scf.yield %[[YIELD_T]] : tensor<128x128xf32>
-    scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
-
-  // CHECK-NEXT: }
-  }
-
-  // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
-  // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
-  // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
-  return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input
 //  CHECK-SAME:   %[[A0:[0-9a-z]*]]: i32
 func.func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4fb6a50a174c2..2a3ebbba02384 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3994,7 +3994,6 @@ cc_library(
     deps = [
         ":ArithDialect",
         ":ArithUtils",
-        ":BufferizationDialect",
         ":ControlFlowDialect",
         ":ControlFlowInterfaces",
         ":DestinationStyleOpInterface",



More information about the llvm-commits mailing list