[Mlir-commits] [mlir] b34ea97 - [mlir][linalg][bufferize][NFC] Remove remaining Comprehensive Bufferize code
Matthias Springer
llvmlistbot at llvm.org
Wed May 4 01:20:32 PDT 2022
Author: Matthias Springer
Date: 2022-05-04T17:19:44+09:00
New Revision: b34ea97f557165011e21ecd934d23f3f8461ffdb
URL: https://github.com/llvm/llvm-project/commit/b34ea97f557165011e21ecd934d23f3f8461ffdb
DIFF: https://github.com/llvm/llvm-project/commit/b34ea97f557165011e21ecd934d23f3f8461ffdb.diff
LOG: [mlir][linalg][bufferize][NFC] Remove remaining Comprehensive Bufferize code
This commit removes the Linalg Comprehensive Bufferize pass.
Differential Revision: https://reviews.llvm.org/D124854
Added:
mlir/lib/Dialect/Linalg/Transforms/InitTensorElimination.cpp
mlir/test/Dialect/Linalg/one-shot-bufferize-aliasing-in.mlir
mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-aliasing-in.mlir
mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Removed:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir
mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 0c7a2b86e75f4..37d4a3e7fb529 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -236,6 +236,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
Option<"allowUnknownOps", "allow-unknown-ops", "bool",
/*default=*/"false",
"Allows unknown (not bufferizable) ops in the input IR.">,
+ Option<"alwaysAliasingWithDest", "always-aliasing-with-dest", "bool",
+ /*default=*/"true",
+ "Tensor OpResult cannot bufferize inplace OpOperands other than "
+ "out/dest OpOperands (if the op has such operands; experimental)">,
Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned",
/*default=*/"0",
"Test only: Analyze ops in random order with a given seed (fuzzer)">,
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 3510b2f1f984a..93d01a738dd88 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -62,17 +62,6 @@ createConvertLinalgToParallelLoopsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertLinalgToAffineLoopsPass();
-/// This pass implements a cross-dialect bufferization approach and performs an
-/// analysis to determine which op operands and results may be bufferized in the
-/// same buffers. The analysis is performed on topologically sorted CallOp and
-/// FuncOp within a module. It provides analyses and bufferization across
-/// function boundaries. Within a function boundary, the analysis is performed
-/// on SSA use-def chains starting from function operands that are annotated
-/// with the 'inplaceable' attribute.
-std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
-std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass(
- const bufferization::OneShotBufferizationOptions &options);
-
/// Create a pass that tries to eliminate init_tensor ops that are anchored on
/// insert_slice ops.
std::unique_ptr<Pass> createLinalgInitTensorEliminationPass();
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2c0287de0fcac..da48c651597e5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -24,51 +24,6 @@ def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> {
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
}
-def LinalgComprehensiveModuleBufferize :
- Pass<"linalg-comprehensive-module-bufferize", "ModuleOp"> {
- let summary = "Bufferize (tensor into memref) for a Module.";
- let description = [{
- This pass implements a cross-dialect bufferization approach and performs an
- analysis to determine which op operands and results may be bufferized in the
- same buffers. The analysis is performed on topologically sorted CallOp and
- FuncOp within a module. It provides analyses and bufferization across
- function boundaries. Within a function boundary, the analysis is performed
- on SSA use-def chains starting from function operands that are annotated
- with the 'inplaceable' attribute.
- }];
- let options = [
- Option<"testAnalysisOnly", "test-analysis-only", "bool",
- /*default=*/"false",
- "Only runs inplaceability analysis (for testing purposes only)">,
- Option<"printConflicts", "print-conflicts", "bool",
- /*default=*/"false",
- "Annotates IR with RaW conflicts. Requires test-analysis-only.">,
- Option<"allowReturnAllocs", "allow-return-allocs", "bool",
- /*default=*/"false",
- "Allows returning/yielding new allocations from a block.">,
- Option<"allowUnknownOps", "allow-unknown-ops", "bool",
- /*default=*/"false",
- "Allows unknown (not bufferizable) ops in the input IR.">,
- Option<"alwaysAliasingWithDest", "always-aliasing-with-dest", "bool",
- /*default=*/"true",
- "Tensor OpResult cannot bufferize inplace OpOperands other than "
- "out or dest OpOperands (if the op has a notion of such operands)">,
- Option<"useAlloca", "use-alloca", "bool",
- /*default=*/"false",
- "Use stack allocations for memrefs (for testing purposes only)">,
- Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",
- /*default=*/"true",
- "Generate MemRef types with dynamic offset+strides by default.">,
- Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned",
- /*default=*/"0",
- "Analyze ops in random order with a given seed (fuzzer)">,
- Option<"createDeallocs", "create-deallocs", "bool", /*default=*/"true",
- "Specify if buffers should be deallocated. For compatibility with "
- "core bufferization passes.">,
- ];
- let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
-}
-
def LinalgInitTensorElimination : Pass<"linalg-eliminate-init-tensors"> {
let summary = "Try to eliminate all init_tensor ops.";
let description = [{
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 29026eb95b62c..a16b19148bb8b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -171,6 +171,7 @@ struct OneShotBufferizePass
// pass.
opt.allowReturnAllocs = allowReturnAllocs;
opt.allowUnknownOps = allowUnknownOps;
+ opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
opt.analysisFuzzerSeed = analysisFuzzerSeed;
opt.createDeallocs = createDeallocs;
opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 943efc5a228f5..2955391973688 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
CodegenStrategy.cpp
- ComprehensiveBufferizePass.cpp
ConstantFold.cpp
Detensorize.cpp
DropUnitDims.cpp
@@ -14,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Generalization.cpp
Hoisting.cpp
HoistPadding.cpp
+ InitTensorElimination.cpp
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
deleted file mode 100644
index bbb013d955332..0000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ /dev/null
@@ -1,161 +0,0 @@
-//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-
-#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-
-using namespace mlir;
-using namespace mlir::bufferization;
-using namespace mlir::linalg;
-
-namespace {
-struct LinalgComprehensiveModuleBufferize
- : public LinalgComprehensiveModuleBufferizeBase<
- LinalgComprehensiveModuleBufferize> {
- LinalgComprehensiveModuleBufferize() = default;
-
- LinalgComprehensiveModuleBufferize(
- const LinalgComprehensiveModuleBufferize &p) = default;
-
- explicit LinalgComprehensiveModuleBufferize(
- const OneShotBufferizationOptions &options)
- : options(options) {}
-
- void runOnOperation() override;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
- memref::MemRefDialect, tensor::TensorDialect,
- vector::VectorDialect, scf::SCFDialect,
- arith::ArithmeticDialect, func::FuncDialect, AffineDialect>();
- arith::registerBufferizableOpInterfaceExternalModels(registry);
- bufferization::registerAllocationOpInterfaceExternalModels(registry);
- linalg::registerBufferizableOpInterfaceExternalModels(registry);
- scf::registerBufferizableOpInterfaceExternalModels(registry);
- func_ext::registerBufferizableOpInterfaceExternalModels(registry);
- tensor::registerBufferizableOpInterfaceExternalModels(registry);
- vector::registerBufferizableOpInterfaceExternalModels(registry);
- }
-
-private:
- llvm::Optional<OneShotBufferizationOptions> options;
-};
-
-struct LinalgInitTensorElimination
- : public LinalgInitTensorEliminationBase<LinalgInitTensorElimination> {
- LinalgInitTensorElimination() = default;
-
- void runOnOperation() override;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
- }
-};
-} // namespace
-
-static void applyEnablingTransformations(ModuleOp moduleOp) {
- RewritePatternSet patterns(moduleOp.getContext());
- patterns.add<GeneralizePadOpPattern>(moduleOp.getContext());
- (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
-}
-
-static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
- MemRefType type,
- ValueRange dynShape,
- unsigned int bufferAlignment) {
- Value allocated = b.create<memref::AllocaOp>(
- loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
- return allocated;
-}
-
-void LinalgComprehensiveModuleBufferize::runOnOperation() {
- OneShotBufferizationOptions opt;
- if (!options) {
- // Make new bufferization options if none were provided when creating the
- // pass.
- if (useAlloca) {
- opt.allocationFn = allocationFnUsingAlloca;
- opt.deallocationFn = [](OpBuilder &b, Location loc, Value v) {
- return success();
- };
- }
- opt.allowReturnAllocs = allowReturnAllocs;
- opt.allowUnknownOps = allowUnknownOps;
- opt.analysisFuzzerSeed = analysisFuzzerSeed;
- opt.createDeallocs = createDeallocs;
- opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
- opt.printConflicts = printConflicts;
- opt.testAnalysisOnly = testAnalysisOnly;
- opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
- opt.bufferizeFunctionBoundaries = true;
- } else {
- opt = *options;
- }
-
- ModuleOp moduleOp = getOperation();
- applyEnablingTransformations(moduleOp);
-
- if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
- signalPassFailure();
- return;
- }
-
- if (opt.testAnalysisOnly)
- return;
-
- OpPassManager cleanupPipeline("builtin.module");
- cleanupPipeline.addPass(createCanonicalizerPass());
- cleanupPipeline.addPass(createCSEPass());
- cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
- (void)runPipeline(cleanupPipeline, moduleOp);
-}
-
-void LinalgInitTensorElimination::runOnOperation() {
- Operation *op = getOperation();
- OneShotBufferizationOptions options;
- OneShotAnalysisState state(op, options);
- if (failed(analyzeOp(op, state))) {
- signalPassFailure();
- return;
- }
-
- IRRewriter rewriter(op->getContext());
- if (failed(insertSliceAnchoredInitTensorEliminationStep(rewriter, op, state)))
- signalPassFailure();
-}
-
-std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
- return std::make_unique<LinalgComprehensiveModuleBufferize>();
-}
-
-std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
- const OneShotBufferizationOptions &options) {
- return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
-}
-
-std::unique_ptr<Pass> mlir::createLinalgInitTensorEliminationPass() {
- return std::make_unique<LinalgInitTensorElimination>();
-}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InitTensorElimination.cpp b/mlir/lib/Dialect/Linalg/Transforms/InitTensorElimination.cpp
new file mode 100644
index 0000000000000..f48f9c83e2cd5
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/InitTensorElimination.cpp
@@ -0,0 +1,50 @@
+//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::linalg;
+
+namespace {
+struct LinalgInitTensorElimination
+ : public LinalgInitTensorEliminationBase<LinalgInitTensorElimination> {
+ LinalgInitTensorElimination() = default;
+
+ void runOnOperation() override;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+};
+} // namespace
+
+void LinalgInitTensorElimination::runOnOperation() {
+ Operation *op = getOperation();
+ OneShotBufferizationOptions options;
+ OneShotAnalysisState state(op, options);
+ if (failed(analyzeOp(op, state))) {
+ signalPassFailure();
+ return;
+ }
+
+ IRRewriter rewriter(op->getContext());
+ if (failed(insertSliceAnchoredInitTensorEliminationStep(rewriter, op, state)))
+ signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createLinalgInitTensorEliminationPass() {
+ return std::make_unique<LinalgInitTensorElimination>();
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
deleted file mode 100644
index 88613a29b1cc8..0000000000000
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
+++ /dev/null
@@ -1,65 +0,0 @@
-// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-allocs use-alloca}" -split-input-file | FileCheck %s
-
-// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
-// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-
-// CHECK: func @init_and_dot(
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
-func.func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
- // CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32
- %v0 = arith.constant 0.0 : f32
-
- // CHECK-NEXT: linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
- %d = linalg.fill ins(%v0 : f32) outs(%c : tensor<f32>) -> tensor<f32>
-
- // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
- %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
- outs(%d: tensor<f32>) -> tensor<f32>
-
- // CHECK-NEXT: return
- return %e : tensor<f32>
-}
-
-// CHECK: func @main()
-func.func @main() {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0{{.*}} : f32
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1{{.*}} : f32
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2{{.*}} : f32
- %v0 = arith.constant 0.0 : f32
- %v1 = arith.constant 1.0 : f32
- %v2 = arith.constant 2.0 : f32
-
- // CHECK-NEXT: %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
- // CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
- // CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
- // CHECK-DAG: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
- // CHECK-DAG: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
- // CHECK-DAG: %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
- %A = linalg.init_tensor [64] : tensor<64xf32>
- %B = linalg.init_tensor [64] : tensor<64xf32>
- %C = linalg.init_tensor [] : tensor<f32>
-
- // CHECK-DAG: linalg.fill ins(%[[C1]] : f32) outs(%[[A]] : memref<64xf32>)
- // CHECK-DAG: linalg.fill ins(%[[C2]] : f32) outs(%[[B]] : memref<64xf32>)
- // CHECK-DAG: linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32>)
- %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32>
- %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32>
- %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor<f32>) -> tensor<f32>
-
- // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
- %res = call @init_and_dot(%AA, %BB, %CC) :
- (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
-
- // CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
- %res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
-
- // CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
- call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
-
- return
-}
-
-// CHECK: func private @print_memref_f32(memref<*xf32>)
-func.func private @print_memref_f32(tensor<*xf32>)
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-aliasing-in.mlir
similarity index 95%
rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir
rename to mlir/test/Dialect/Linalg/one-shot-bufferize-aliasing-in.mlir
index a2bb0700ed373..6d475bac61a08 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-aliasing-in.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
// CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input
// CHECK-SAME: %[[t1:.*]]: memref<?x?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>, %[[t3:.*]]: memref<?x?xf32, #{{.*}}>
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-aliasing-in.mlir
similarity index 94%
rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir
rename to mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-aliasing-in.mlir
index 4974f676b37f8..1d9d066f45855 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-aliasing-in.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
// This is a test case for alwaysAliasingWithDest = 0. In that case, an OpResult
// may bufferize in-place with an "in" OpOperand or any non-"out" OpOperand.
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir
similarity index 98%
rename from mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
rename to mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir
index 1a48db1c87d18..326f0dc5d6253 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="func.func(canonicalize,cse),linalg-comprehensive-module-bufferize" |\
+// RUN: mlir-opt %s -pass-pipeline="func.func(canonicalize,cse),one-shot-bufferize{bufferize-function-boundaries}" |\
// RUN: mlir-opt -pass-pipeline="func.func(buffer-deallocation,convert-vector-to-scf,lower-affine,convert-linalg-to-loops)" |\
// RUN: mlir-opt -pass-pipeline="func.func(canonicalize,convert-scf-to-cf),convert-vector-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" | \
@@ -22,7 +22,7 @@ func.func @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: ten
%9 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
%10 = tensor.cast %9 : tensor<2xf32> to tensor<?xf32>
%11 = tensor.pad %10 low[%c0] high[%c0] {
- ^bb0(%arg5: index):
+ ^bb0(%arg5: index):
tensor.yield %cst : f32
} : tensor<?xf32> to tensor<2xf32>
%12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor<?x2xf32>
More information about the Mlir-commits
mailing list