[Mlir-commits] [mlir] 39ec46b - [mlir][bufferize] Extract buffer hoisting into separate function

Matthias Springer llvmlistbot at llvm.org
Tue Mar 15 05:31:53 PDT 2022


Author: Matthias Springer
Date: 2022-03-15T21:25:03+09:00
New Revision: 39ec46bd83703364d6d4da1e2ca3b09fa12d7a6b

URL: https://github.com/llvm/llvm-project/commit/39ec46bd83703364d6d4da1e2ca3b09fa12d7a6b
DIFF: https://github.com/llvm/llvm-project/commit/39ec46bd83703364d6d4da1e2ca3b09fa12d7a6b.diff

LOG: [mlir][bufferize] Extract buffer hoisting into separate function

This improves the modularity of the bufferization.

>From now on, all ops that do not implement BufferizableOpInterface are considered hoisting barriers. Previously, all ops that do not implement the interface were not considered barriers and such ops had to be marked as barriers explicitly. This was unsafe because we could've hoisted across unknown ops where it was not safe to hoist.

As a side effect, this allows for cleaning up AffineBufferizableOpInterfaceImpl. This build unit no longer needed and can be deleted.

Differential Revision: https://reviews.llvm.org/D121519

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
    mlir/test/Dialect/Linalg/bufferize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 500f7e89b1758..6860bec2386ab 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -220,6 +220,14 @@ struct BufferizationOptions {
   /// computation. Whether this pays off or not can be very input IR-specific.
   bool alwaysAliasingWithDest = true;
 
+  /// If set to `true`, try to hoist allocations out of blocks as much as
+  /// possible. An allocation is not hoisted across allocation hoisting barriers
+  /// as indicated by `BufferizableOpInterface::isAllocationHoistingBarrier`.
+  ///
+  /// Examples of allocation hoisting barriers are parallel loops or ops where
+  /// SSA values cannot be captured from the outside.
+  bool hoistAllocations = true;
+
   /// Buffer alignment for new memory allocations.
   unsigned int bufferAlignment = 128;
 
@@ -495,8 +503,9 @@ LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
 LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
                            const BufferizationOptions &options);
 
-/// Finalize all buffer allocations, i.e., create alloc ops as specified in the
-/// bufferization options and deallocate all buffers.
+/// Finalize all buffer allocations.
+/// * Hoist buffer allocations as much as possible.
+/// * Create alloc/dealloc ops as specified by the bufferization options.
 LogicalResult finalizeBuffers(Operation *op,
                               const BufferizationOptions &options);
 } // namespace bufferization
@@ -504,57 +513,4 @@ LogicalResult finalizeBuffers(Operation *op,
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
 
-namespace mlir {
-namespace bufferization {
-
-/// AllocationHoistingBarrierOnly is an external implementation of
-/// BufferizableOpInterface for ops that are (not yet) bufferizable, but are
-/// known to be allocation hoisting barriers. All interface methods (except for
-/// `isAllocationHoistingBarrier`) are implemented conservatively.
-template <typename OpTy>
-struct AllocationHoistingBarrierOnly
-    : public BufferizableOpInterface::ExternalModel<
-          AllocationHoistingBarrierOnly<OpTy>, OpTy> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const AnalysisState &state) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const AnalysisState &state) const {
-    return true;
-  }
-
-  SmallVector<OpOperand *>
-  getAliasingOpOperand(Operation *op, OpResult opResult,
-                       const AnalysisState &state) const {
-    return {};
-  }
-
-  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                                            const AnalysisState &state) const {
-    return {};
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const AnalysisState &state) const {
-    return BufferRelation::None;
-  }
-
-  bool isWritable(Operation *op, Value value,
-                  const AnalysisState &state) const {
-    return false;
-  }
-
-  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
-    return failure();
-  }
-
-  bool isAllocationHoistingBarrier(Operation *op) const { return true; }
-};
-
-} // namespace bufferization
-} // namespace mlir
-
 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
deleted file mode 100644
index 877234b60b5a1..0000000000000
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- AffineInterfaceImpl.h - Affine Impl. of BufferizableOpInterface ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_AFFINEINTERFACEIMPL_H
-#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_AFFINEINTERFACEIMPL_H
-
-namespace mlir {
-
-class DialectRegistry;
-
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace affine_ext {
-
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
-
-} // namespace affine_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
-} // namespace mlir
-
-#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_AFFINEINTERFACEIMPL_H

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8a3dbfc960e9b..cc697487b07fb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -387,40 +387,9 @@ bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
   return success();
 }
 
-/// Move the insertion point of the given builder to the beginning of a
-/// surrounding block as much as possible, while not crossing any allocation
-/// hoisting barriers.
-static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
-  Operation *op = b.getInsertionBlock()->getParentOp();
-  while (op) {
-    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
-      if (bufferizableOp.isAllocationHoistingBarrier())
-        break;
-    op = op->getParentOp();
-  }
-
-  if (!op) {
-    // No allocation hoisting barrier found. Hoist to FuncOp.
-    op = b.getInsertionBlock()->getParentOp();
-    if (!isa<FuncOp>(op))
-      op = op->getParentOfType<FuncOp>();
-    assert(op && "could not find enclosing FuncOp");
-  }
-
-  // TODO: Handle cases where allocation hoisting barrier has more than one
-  // region or block.
-  assert(op->getNumRegions() == 1 &&
-         "allocation hoisting barriers with >1 regions not supported");
-  assert(op->getRegion(0).getBlocks().size() == 1 &&
-         "allocation hoisting barriers with >1 blocks not supported");
-  b.setInsertionPointToStart(&(op->getRegion(0).front()));
-}
-
 /// Compute the type of the `memref` to use for allocating the buffer for
 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
-/// dynamic dimensions in the returned `memref` type. The function may also set
-/// the insertion point to an earlier location, where the allocation should
-/// happen ("allocation hoisting").
+/// dynamic dimensions in the returned `memref` type.
 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
                                             Value shapedValue,
                                             SmallVectorImpl<Value> &dynShape) {
@@ -453,15 +422,6 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
       }
   }
 
-  // If the buffer is statically shaped, try to hoist it to the first enclosing
-  // parallel region.
-  // TODO: also hoist in the dynamic case. For now this relies on subsequent
-  // calls to LICM and buffer hoisting which will most likely not succeed.
-  // TODO: when packing, allocate a static bounding box which will enable more
-  // hoisting.
-  if (dynShape.empty())
-    moveInsertionPointToAllocationHoistingBarrier(b);
-
   return allocMemRefType;
 }
 
@@ -481,7 +441,6 @@ FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
   assert(shapedValue.getType().isa<ShapedType>());
   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
   SmallVector<Value> dynShape;
-  // Note: getAllocationTypeAndShape also sets the insertion point.
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
   Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
@@ -511,9 +470,8 @@ LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
   return success();
 }
 
-LogicalResult
-bufferization::finalizeBuffers(Operation *op,
-                               const BufferizationOptions &options) {
+static LogicalResult
+createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
   IRRewriter rewriter(op->getContext());
 
   // Bufferization creates memref.alloca ops. After bufferization, these must be
@@ -546,6 +504,73 @@ bufferization::finalizeBuffers(Operation *op,
   return success(!status.wasInterrupted());
 }
 
+/// Try to hoist all new buffer allocations until the next hoisting barrier.
+// TODO: Consolidate this function with the existing buffer hoisting pass.
+static LogicalResult
+hoistBufferAllocations(Operation *op, const BufferizationOptions &options) {
+  // Nothing to do if allocation hoisting is deactivated.
+  if (!options.hoistAllocations)
+    return success();
+
+  // Gather all buffer allocations that were created by the bufferization.
+  SmallVector<Operation *> allocaOps;
+  op->walk([&](memref::AllocaOp allocaOp) {
+    if (allocaOp->hasAttr(kBufferAllocationAttr))
+      allocaOps.push_back(allocaOp);
+  });
+
+  for (Operation *allocaOp : allocaOps) {
+    // TODO: Hoisting of allocs with dynamic shape not implemented.
+    if (!allocaOp->getOpOperands().empty())
+      continue;
+
+    Operation *op = allocaOp->getParentOp();
+    while (op) {
+      if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) {
+        if (bufferizableOp.isAllocationHoistingBarrier()) {
+          break;
+        }
+      } else {
+        // Op is not bufferizable: It may not be safe to hoist across this op.
+        break;
+      }
+      op = op->getParentOp();
+    }
+
+    // FuncOp is an allocation hoisting barrier, so this should never happen.
+    assert(op && "allocation hoisting barrier not found");
+
+    // Nothing to do if the insertion point is in the same block.
+    if (op == allocaOp->getParentOp())
+      continue;
+
+    // `op` may have multiple blocks. Make sure that we insert in the right one.
+    SmallVector<Block *> blocks;
+    for (Region &r : op->getRegions())
+      for (Block &b : r.getBlocks())
+        blocks.push_back(&b);
+    auto *insertionBlock = llvm::find_if(
+        blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); });
+    assert(insertionBlock != blocks.end() && "owning block not found");
+
+    // Move to the beginning of the block.
+    allocaOp->moveBefore(&(*insertionBlock)->front());
+  }
+
+  return success();
+}
+
+LogicalResult
+bufferization::finalizeBuffers(Operation *op,
+                               const BufferizationOptions &options) {
+  if (failed(hoistBufferAllocations(op, options)))
+    return failure();
+  if (failed(createAllocDeallocOps(op, options)))
+    return failure();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
deleted file mode 100644
index c1fad16126e40..0000000000000
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- AffineInterfaceImpl.cpp - Affine Impl. of BufferizableOpInterface --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-
-using namespace mlir::bufferization;
-
-void mlir::linalg::comprehensive_bufferize::affine_ext::
-    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
-  // AffineParallelOp bufferization not implemented yet. However, never hoist
-  // memref allocations across AffineParallelOp boundaries.
-  registry.addOpInterface<AffineParallelOp,
-                          AllocationHoistingBarrierOnly<AffineParallelOp>>();
-}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 066204d8b65d2..92b473bd382ec 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -1,16 +1,3 @@
-set(LLVM_OPTIONAL_SOURCES
-  AffineInterfaceImpl.cpp
-  ModuleBufferization.cpp
-)
-
-add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl
-  AffineInterfaceImpl.cpp
-
-  LINK_LIBS PUBLIC
-  MLIRAffine
-  MLIRBufferization
-)
-
 add_mlir_dialect_library(MLIRModuleBufferization
   ModuleBufferization.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index f530fe2b64676..7048a414aa829 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -32,7 +32,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffine
-  MLIRAffineBufferizableOpInterfaceImpl
   MLIRAffineUtils
   MLIRAnalysis
   MLIRArithmetic

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index a7e86e11fbc3a..495ab974f1d16 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
@@ -51,7 +50,6 @@ struct LinalgComprehensiveModuleBufferize
                 memref::MemRefDialect, tensor::TensorDialect,
                 vector::VectorDialect, scf::SCFDialect,
                 arith::ArithmeticDialect, func::FuncDialect, AffineDialect>();
-    affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith::registerBufferizableOpInterfaceExternalModels(registry);
     linalg::registerBufferizableOpInterfaceExternalModels(registry);
     scf::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b15d3460fa105..d885216a4391f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -469,6 +469,4 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
   registry.addOpInterface<ForOp, ForOpInterface>();
   registry.addOpInterface<IfOp, IfOpInterface>();
   registry.addOpInterface<YieldOp, YieldOpInterface>();
-  registry
-      .addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>();
 }

diff  --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index ca102a2dbddd9..8796f4cc378e7 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -118,6 +118,11 @@ struct AssumingOpInterface
                                 const AnalysisState &state) const {
     return BufferRelation::Equivalent;
   }
+
+  bool isAllocationHoistingBarrier(Operation *op) const {
+    // Allocations should not be hoisted out of AssumingOps.
+    return true;
+  }
 };
 
 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
index b49ef6edcfca2..100172836319d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
@@ -140,8 +140,8 @@ func @unknown_op_may_read(%v: vector<5xf32>)
 
   // One alloc for the init_tensor, another one because the transfer_write
   // bufferizes out-of-place.
-  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10xf32>
   // CHECK: %[[m1:.*]] = memref.alloc() {{.*}} : memref<10xf32>
+  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10xf32>
   %t1 = linalg.init_tensor [10] : tensor<10xf32>
 
   // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[m1]]

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index a027112ee05c0..6776197a91fe3 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -140,9 +140,9 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
 // CHECK-LABEL:   func @generic_with_init_tensor(
 // CHECK-SAME:                                   %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>,
 // CHECK-SAME:                                   %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> {
-// CHECK:           %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<3x2xf32>
-// CHECK-DAG:           %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>>
-// CHECK-DAG:           %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<3x2xf32>
+// CHECK-DAG:       %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<3x2xf32>
+// CHECK-DAG:       %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>>
+// CHECK-DAG:       %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<3x2xf32>
 // CHECK:           memref.copy %[[ARG1_MEMREF]], %[[INIT_BUFFER]] : memref<3x2xf32> to memref<3x2xf32>
 // CHECK:           linalg.generic
 // CHECK-SAME:      ins(%[[ARG0_MEMREF]] : memref<2x3x4xvector<3x4xi4>>)

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
index 5aab59311305b..0520d579cc562 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
@@ -31,9 +31,9 @@ func @main() {
   %v1 = arith.constant 1.0 : f32
   %v2 = arith.constant 2.0 : f32
 
-  // CHECK-NEXT:   %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
-  // CHECK-NEXT:   %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
   // CHECK-NEXT:   %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
+  // CHECK-NEXT:   %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
+  // CHECK-NEXT:   %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
   //  CHECK-DAG:   %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
   //  CHECK-DAG:   %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
   //  CHECK-DAG:   %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 5d0087ff29910..53c3a603ca03d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -601,9 +601,9 @@ func @main() {
   %v1 = arith.constant 1.0 : f32
   %v2 = arith.constant 2.0 : f32
 
-  // CHECK-NEXT:   %[[C:.*]] = memref.alloc() {alignment = 128 : i64} : memref<f32>
-  // CHECK-NEXT:   %[[B:.*]] = memref.alloc() {alignment = 128 : i64} : memref<64xf32>
   // CHECK-NEXT:   %[[A:.*]] = memref.alloc() {alignment = 128 : i64} : memref<64xf32>
+  // CHECK-NEXT:   %[[B:.*]] = memref.alloc() {alignment = 128 : i64} : memref<64xf32>
+  // CHECK-NEXT:   %[[C:.*]] = memref.alloc() {alignment = 128 : i64} : memref<f32>
   //  CHECK-DAG:   %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
   //  CHECK-DAG:   %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
   //  CHECK-DAG:   %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 7d3084d9d024c..cbb05473807b0 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -89,9 +89,9 @@ func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
 // CHECK-LABEL:   func @tensor.from_elements_1d(
 // CHECK-SAME:                               %[[ELEM0:.*]]: index,
 // CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
-// CHECK:           %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
 // CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
 // CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
@@ -107,7 +107,7 @@ func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
-// CHECK:         %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
+// CHECK-DAG:     %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
 // CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
 // CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
 // CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
@@ -141,7 +141,7 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 
-// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
+// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
 
 // CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
 // CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
@@ -291,8 +291,8 @@ func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
 //  CHECK-SAME:     %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
 //  CHECK-SAME:     %[[f:.*]]: f32
 func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
-  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
-  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
+  // CHECK-DAG: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
   // CHECK: memref.copy %[[m1]], %[[alloc]]
   // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
   %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 1f2c63f402548..45ebc78c361c2 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_library(MLIRLinalgTestPasses
 
   LINK_LIBS PUBLIC
   MLIRAffine
-  MLIRAffineBufferizableOpInterfaceImpl
   MLIRArithmetic
   MLIRArithmeticTransforms
   MLIRBufferization

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 32b5cce6fbafb..621a54b9e64d4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6823,22 +6823,6 @@ gentbl_cc_library(
     ],
 )
 
-cc_library(
-    name = "AffineBufferizableOpInterfaceImpl",
-    srcs = [
-        "lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp",
-    ],
-    hdrs = [
-        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h",
-    ],
-    includes = ["include"],
-    deps = [
-        ":Affine",
-        ":BufferizationDialect",
-        "//llvm:Support",
-    ],
-)
-
 td_library(
     name = "LinalgDocTdFiles",
     srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -7040,7 +7024,6 @@ cc_library(
     deps = [
         ":Affine",
         ":AffineAnalysis",
-        ":AffineBufferizableOpInterfaceImpl",
         ":AffineUtils",
         ":Analysis",
         ":ArithmeticDialect",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 6f2cfee731df8..9d8ca48fe36a2 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -390,7 +390,6 @@ cc_library(
     deps = [
         "//llvm:Support",
         "//mlir:Affine",
-        "//mlir:AffineBufferizableOpInterfaceImpl",
         "//mlir:ArithmeticDialect",
         "//mlir:ArithmeticTransforms",
         "//mlir:BufferizationDialect",


        


More information about the Mlir-commits mailing list