[Mlir-commits] [mlir] Revert "[mlir] Remove dialect specific bufferization passes" (PR #93528)

Kunwar Grover llvmlistbot at llvm.org
Tue May 28 03:21:28 PDT 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/93528

Reverts llvm/llvm-project#93488

Buildbot failure: https://lab.llvm.org/buildbot/#/builders/220/builds/39911

>From ff4df79c7525ff673ebccdd2f5c828196c464f31 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 28 May 2024 11:20:44 +0100
Subject: [PATCH] Revert "[mlir] Remove dialect specific bufferization passes
 (#93488)"

This reverts commit 2fc510643747dc70abdf8f2f7efcc7763d1392cb.
---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |  3 +
 .../mlir/Dialect/Arith/Transforms/Passes.td   | 16 +++++
 .../Dialect/Bufferization/Transforms/Passes.h |  3 +
 .../Bufferization/Transforms/Passes.td        |  5 ++
 mlir/include/mlir/Dialect/Linalg/Passes.h     |  4 ++
 mlir/include/mlir/Dialect/Linalg/Passes.td    | 10 +++
 .../mlir/Dialect/Shape/Transforms/Passes.h    |  7 ++
 .../mlir/Dialect/Shape/Transforms/Passes.td   |  7 ++
 .../mlir/Dialect/Tensor/Transforms/Passes.h   |  3 +
 .../mlir/Dialect/Tensor/Transforms/Passes.td  |  5 ++
 .../mlir/Dialect/Vector/Transforms/Passes.h   |  3 +
 .../mlir/Dialect/Vector/Transforms/Passes.td  |  5 ++
 .../Dialect/Arith/Transforms/Bufferize.cpp    | 67 +++++++++++++++++++
 .../Dialect/Arith/Transforms/CMakeLists.txt   |  1 +
 .../Bufferization/Transforms/Bufferize.cpp    | 23 +++++++
 .../Dialect/Linalg/Transforms/Bufferize.cpp   | 52 ++++++++++++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |  1 +
 .../Dialect/Shape/Transforms/Bufferize.cpp    | 49 ++++++++++++++
 .../Dialect/Shape/Transforms/CMakeLists.txt   |  1 +
 .../Dialect/Tensor/Transforms/Bufferize.cpp   | 58 ++++++++++++++++
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |  1 +
 .../Dialect/Vector/Transforms/Bufferize.cpp   | 55 +++++++++++++++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 mlir/test/Dialect/Arith/bufferize.mlir        |  8 ++-
 mlir/test/Dialect/Linalg/bufferize.mlir       | 30 ++++++++-
 mlir/test/Dialect/Shape/bufferize.mlir        |  2 +-
 .../Dialect/SparseTensor/sparse_lower.mlir    |  3 +-
 .../SparseTensor/sparse_lower_col.mlir        |  3 +-
 .../SparseTensor/sparse_lower_inplace.mlir    |  3 +-
 mlir/test/Dialect/Tensor/bufferize.mlir       |  2 +-
 .../Dialect/Vector/bufferize-invalid.mlir     |  3 +-
 mlir/test/Dialect/Vector/bufferize.mlir       |  2 +-
 32 files changed, 426 insertions(+), 10 deletions(-)
 create mode 100644 mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
 create mode 100644 mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 9dc262cc72ed0..cbc6147cb81e2 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -24,6 +24,9 @@ namespace arith {
 class WideIntEmulationConverter;
 class NarrowTypeEmulationConverter;
 
+/// Create a pass to bufferize arith.constant ops.
+std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
+
 /// Adds patterns to emulate wide Arith and Function ops over integer
 /// types into supported ones. This is done by splitting original power-of-two
 /// i2N integer types into two iN halves.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..4096e309199e9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -11,6 +11,22 @@
 
 include "mlir/Pass/PassBase.td"
 
+def ArithBufferizePass : Pass<"arith-bufferize", "ModuleOp"> {
+  let summary = "Bufferize Arith dialect ops.";
+  let description = [{
+    This pass bufferizes arith dialect ops.
+
+    This pass needs to be a module pass because it inserts memref.global
+    ops into the module, which cannot be done safely from a function pass due to
+    multi-threading. Most other bufferization passes can run in parallel at
+    function granularity.
+  }];
+  let options = [
+    Option<"alignment", "alignment", "unsigned", /*default=*/"0",
+           "Create global memrefs with a specified alignment">,
+  ];
+}
+
 def ArithExpandOpsPass : Pass<"arith-expand"> {
   let summary = "Legalize Arith ops to be convertible to LLVM.";
   let dependentDialects = ["vector::VectorDialect"];
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c97e143..459c252b70712 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -221,6 +221,9 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
 /// insert_slice ops.
 std::unique_ptr<Pass> createEmptyTensorEliminationPass();
 
+/// Create a pass that bufferizes ops from the bufferization dialect.
+std::unique_ptr<Pass> createBufferizationBufferizePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 8f8826b9ad56b..75ce85c9128c9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -350,6 +350,11 @@ def FinalizingBufferize : Pass<"finalizing-bufferize", "func::FuncOp"> {
   let constructor = "mlir::bufferization::createFinalizingBufferizePass()";
 }
 
+def BufferizationBufferize : Pass<"bufferization-bufferize", "func::FuncOp"> {
+  let summary = "Bufferize the `bufferization` dialect";
+  let constructor = "mlir::bufferization::createBufferizationBufferizePass()";
+}
+
 def DropEquivalentBufferResults : Pass<"drop-equivalent-buffer-results", "ModuleOp">  {
   let summary = "Remove MemRef return values that are equivalent to a bbArg";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index f2955d55e59ec..d36d1e70f0b14 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -22,6 +22,10 @@ namespace func {
 class FuncOp;
 } // namespace func
 
+namespace bufferization {
+struct OneShotBufferizationOptions;
+} // namespace bufferization
+
 #define GEN_PASS_DECL
 #include "mlir/Dialect/Linalg/Passes.h.inc" // IWYU pragma: keep
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0621a9f33ba1e..0a4ce8953136d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -89,6 +89,16 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
   ];
 }
 
+def LinalgBufferizePass : Pass<"linalg-bufferize"> {
+  let summary = "Bufferize the linalg dialect";
+  let dependentDialects = [
+    "affine::AffineDialect",
+    "bufferization::BufferizationDialect",
+    "linalg::LinalgDialect",
+    "memref::MemRefDialect",
+  ];
+}
+
 def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
   let summary = "Convert named ops into generic ops";
   let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 28e17459ff962..cfb637f133f54 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -47,6 +47,13 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
 void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
 std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass();
 
+// Bufferizes shape dialect ops.
+//
+// Note that most shape dialect ops must be converted to std before
+// bufferization happens, as they are intended to be bufferized at the std
+// level.
+std::unique_ptr<OperationPass<func::FuncOp>> createShapeBufferizePass();
+
 /// Outline the shape computation part by adding shape.func and populate
 /// conrresponding mapping infomation into ShapeMappingAnalysis.
 std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass();
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 83834509b4a35..9dfda9ea33615 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -103,4 +103,11 @@ def ShapeToShapeLowering : Pass<"shape-to-shape-lowering", "func::FuncOp"> {
   let constructor = "mlir::createShapeToShapeLowering()";
 }
 
+// TODO: Generalize this to allow any type conversions desired.
+def ShapeBufferize : Pass<"shape-bufferize", "func::FuncOp"> {
+  let summary = "Bufferize the shape dialect.";
+  let constructor = "mlir::createShapeBufferizePass()";
+  let dependentDialects = ["bufferization::BufferizationDialect",
+                           "memref::MemRefDialect"];
+}
 #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index 964c35b3f15b8..48f9066934a25 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -21,6 +21,9 @@ namespace tensor {
 /// Creates an instance of the `tensor` subset folding pass.
 std::unique_ptr<Pass> createFoldTensorSubsetOpsPass();
 
+/// Creates an instance of the `tensor` dialect bufferization pass.
+std::unique_ptr<Pass> createTensorBufferizePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
index be4c333836ec0..4cc3844f29120 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
@@ -27,4 +27,9 @@ def FoldTensorSubsetOps : Pass<"fold-tensor-subset-ops"> {
   ];
 }
 
+def TensorBufferize : Pass<"tensor-bufferize", "func::FuncOp"> {
+  let summary = "Bufferize the `tensor` dialect";
+  let constructor = "mlir::tensor::createTensorBufferizePass()";
+}
+
 #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..911402551e14d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -17,6 +17,9 @@ namespace vector {
 #define GEN_PASS_DECL
 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
 
+/// Creates an instance of the `vector` dialect bufferization pass.
+std::unique_ptr<Pass> createVectorBufferizePass();
+
 /// Creates an instance of the `vector.mask` lowering pass.
 std::unique_ptr<Pass> createLowerVectorMaskPass();
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..31a0b3b2f0c53 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -11,6 +11,11 @@
 
 include "mlir/Pass/PassBase.td"
 
+def VectorBufferize : Pass<"vector-bufferize", "func::FuncOp"> {
+  let summary = "Bufferize Vector dialect ops";
+  let constructor = "mlir::vector::createVectorBufferizePass()";
+}
+
 def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let summary = "Lower 'vector.mask' operations";
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
diff --git a/mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..9a066756f429c
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp
@@ -0,0 +1,67 @@
+//===- Bufferize.cpp - Bufferization for Arith ops ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+namespace mlir {
+namespace arith {
+#define GEN_PASS_DEF_ARITHBUFFERIZEPASS
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace arith
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+/// Pass to bufferize Arith ops.
+struct ArithBufferizePass
+    : public arith::impl::ArithBufferizePassBase<ArithBufferizePass> {
+  using ArithBufferizePassBase::ArithBufferizePassBase;
+
+  ArithBufferizePass(uint64_t alignment = 0, bool constantOpOnly = false)
+      : constantOpOnly(constantOpOnly) {
+    this->alignment = alignment;
+  }
+
+  void runOnOperation() override {
+    BufferizationOptions options = getPartialBufferizationOptions();
+    if (constantOpOnly) {
+      options.opFilter.allowOperation<arith::ConstantOp>();
+    } else {
+      options.opFilter.allowDialect<arith::ArithDialect>();
+    }
+    options.bufferAlignment = alignment;
+
+    if (failed(bufferizeOp(getOperation(), options)))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    arith::ArithDialect>();
+    arith::registerBufferizableOpInterfaceExternalModels(registry);
+  }
+
+private:
+  bool constantOpOnly;
+};
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::arith::createConstantBufferizePass(uint64_t alignment) {
+  return std::make_unique<ArithBufferizePass>(alignment,
+                                              /*constantOpOnly=*/true);
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6b8bde8dc2aaf..12659eaba1fa5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArithTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
+  Bufferize.cpp
   BufferViewFlowOpInterfaceImpl.cpp
   EmulateUnsupportedFloats.cpp
   EmulateWideInt.cpp
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 0fddd60eb8140..7ba347a1f15e4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -320,6 +320,29 @@ struct OneShotBufferizePass
 };
 } // namespace
 
+namespace {
+struct BufferizationBufferizePass
+    : public bufferization::impl::BufferizationBufferizeBase<
+          BufferizationBufferizePass> {
+  void runOnOperation() override {
+    BufferizationOptions options = getPartialBufferizationOptions();
+    options.opFilter.allowDialect<BufferizationDialect>();
+
+    if (failed(bufferizeOp(getOperation(), options)))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
+  return std::make_unique<BufferizationBufferizePass>();
+}
+
 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
   return std::make_unique<OneShotBufferizePass>();
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..8812ca14ba610
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -0,0 +1,52 @@
+//===- Bufferize.cpp - Bufferization of linalg ops ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGBUFFERIZEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+/// Converts Linalg operations that work on tensor-type operands or results to
+/// work on buffers.
+struct LinalgBufferizePass
+    : public impl::LinalgBufferizePassBase<LinalgBufferizePass> {
+  using impl::LinalgBufferizePassBase<
+      LinalgBufferizePass>::LinalgBufferizePassBase;
+  void runOnOperation() override {
+    BufferizationOptions options = getPartialBufferizationOptions();
+    options.opFilter.allowDialect<linalg::LinalgDialect>();
+
+    if (failed(bufferizeOp(getOperation(), options)))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    tensor::TensorDialect, linalg::LinalgDialect>();
+    linalg::registerBufferizableOpInterfaceExternalModels(registry);
+  }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..ed9f40089282a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   AllInterfaces.cpp
   BubbleUpExtractSlice.cpp
   BufferizableOpInterfaceImpl.cpp
+  Bufferize.cpp
   ConstantFold.cpp
   ConvertToDestinationStyle.cpp
   ConvertConv2DToImg2Col.cpp
diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..9dadbdbc91eca
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -0,0 +1,49 @@
+//====----- Bufferize.cpp - Bufferization of shape ops  ---------*- C++-*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SHAPEBUFFERIZE
+#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+struct ShapeBufferizePass
+    : public impl::ShapeBufferizeBase<ShapeBufferizePass> {
+  void runOnOperation() override {
+    BufferizationOptions options = getPartialBufferizationOptions();
+    options.opFilter.allowDialect<shape::ShapeDialect>();
+
+    if (failed(bufferizeOp(getOperation(), options)))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    shape::ShapeDialect>();
+    shape::registerBufferizableOpInterfaceExternalModels(registry);
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>> mlir::createShapeBufferizePass() {
+  return std::make_unique<ShapeBufferizePass>();
+}
diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index a51c6780c2866..7c9b0d2e5e3a8 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRShapeOpsTransforms
   BufferizableOpInterfaceImpl.cpp
+  Bufferize.cpp
   OutlineShapeComputation.cpp
   RemoveShapeConstraints.cpp
   ShapeToShapeLowering.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..d27c4576a8b7a
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -0,0 +1,58 @@
+//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of `tensor` dialect ops
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace tensor {
+#define GEN_PASS_DEF_TENSORBUFFERIZE
+#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
+} // namespace tensor
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+struct TensorBufferizePass
+    : public tensor::impl::TensorBufferizeBase<TensorBufferizePass> {
+  void runOnOperation() override {
+    BufferizationOptions options = getPartialBufferizationOptions();
+    options.opFilter.allowDialect<tensor::TensorDialect>();
+
+    if (failed(bufferizeOp(getOperation(), options)))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                tensor::TensorDialect, scf::SCFDialect, arith::ArithDialect>();
+    tensor::registerBufferizableOpInterfaceExternalModels(registry);
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tensor::createTensorBufferizePass() {
+  return std::make_unique<TensorBufferizePass>();
+}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index ce32dea09bb0b..0aabdaf667b9d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
+  Bufferize.cpp
   ConcatOpPatterns.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..ee99a99b56109
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
@@ -0,0 +1,55 @@
+//===- Bufferize.cpp - Bufferization for `vector` dialect ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of `vector` dialect ops
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_VECTORBUFFERIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+struct VectorBufferizePass
+    : public vector::impl::VectorBufferizeBase<VectorBufferizePass> {
+  void runOnOperation() override {
+    BufferizationOptions options = getPartialBufferizationOptions();
+    options.opFilter.allowDialect<vector::VectorDialect>();
+
+    if (failed(bufferizeOp(getOperation(), options)))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    tensor::TensorDialect, vector::VectorDialect>();
+    vector::registerBufferizableOpInterfaceExternalModels(registry);
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::vector::createVectorBufferizePass() {
+  return std::make_unique<VectorBufferizePass>();
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4dbefdd376a8b..c4b6abd3e2361 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorTransforms
   BufferizableOpInterfaceImpl.cpp
+  Bufferize.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
   LowerVectorGather.cpp
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index a3b1454fb68f6..944954e9e4edd 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --one-shot-bufferize="dialect-filter=arith,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -arith-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -arith-bufferize=alignment=64 -split-input-file -verify-diagnostics | FileCheck --check-prefix=ALIGNED %s
 
 // CHECK-LABEL:   func @index_cast(
 // CHECK-SAME:  %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
@@ -21,7 +22,10 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
 // The name isn't load-bearing though.
 
 // CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
-// CHECK-SAME: {alignment = 64 : i64}
+// CHECK-NOT: alignment
+
+// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// ALIGNED-SAME: {alignment = 64 : i64}
 
 // CHECK: @basic
 func.func @basic() -> tensor<3x4xf32> {
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..29f27e6838e66 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --one-shot-bufferize="dialect-filter=linalg,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -canonicalize -cse -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s
 
 #map0 = affine_map<(d0) -> (d0)>
 
@@ -189,3 +189,31 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
   // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
   // CHECK: return %[[OUT_TENSOR]]
 }
+
+// -----
+
+// This is a regression test. The linalg-bufferize pass should ignore all func
+// dialect ops.
+
+// CHECK-LABEL: func private @csum(tensor<6xi64>) -> tensor<6xi64>
+func.func private @csum(%arg0: tensor<6xi64>) -> tensor<6xi64>
+
+// CHECK: func public @main(%[[arg0:.*]]: tensor<2x3xi1>)
+// CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[arg0]]
+// CHECK:   %[[collapse_m:.*]] = bufferization.to_memref %[[collapse]]
+// CHECK:   %[[alloc:.*]] = memref.alloc()
+// CHECK:   linalg.generic {{.*}} ins(%[[collapse_m]] : memref<6xi1>) outs(%[[alloc]] : memref<6xi64>)
+// CHECK:   %[[generic_t:.*]] = bufferization.to_tensor %[[alloc]]
+// CHECK:   %[[call:.*]] = call @csum(%[[generic_t]])
+// CHECK:   return %[[call]]
+func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xi1> into tensor<6xi1>
+  %1 = tensor.empty() : tensor<6xi64>
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0 : tensor<6xi1>) outs(%1 : tensor<6xi64>) {
+  ^bb0(%arg1: i1, %arg2: i64):
+    %4 = arith.extui %arg1 : i1 to i64
+    linalg.yield %4 : i64
+  } -> tensor<6xi64>
+  %3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
+  return %3 : tensor<6xi64>
+}
diff --git a/mlir/test/Dialect/Shape/bufferize.mlir b/mlir/test/Dialect/Shape/bufferize.mlir
index 9f30a052208f0..963a5e8bcf578 100644
--- a/mlir/test/Dialect/Shape/bufferize.mlir
+++ b/mlir/test/Dialect/Shape/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file --one-shot-bufferize="dialect-filter=shape,bufferization copy-before-write unknown-type-conversion=identity-layout-map allow-unknown-ops" <%s | FileCheck %s
+// RUN: mlir-opt -split-input-file -shape-bufferize <%s | FileCheck %s
 
 // -----
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
index c27df00785522..6112856fbf293 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -4,7 +4,8 @@
 // RUN: FileCheck %s --check-prefix=CHECK-MIR
 //
 // RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
-// RUN: --one-shot-bufferize="copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" | \
+// RUN: --func-bufferize --arith-bufferize           \
+// RUN: --tensor-bufferize --finalizing-bufferize |  \
 // RUN: FileCheck %s --check-prefix=CHECK-LIR
 
 #CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
index 9fbb9dd0a26d1..401da152a8bdb 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
@@ -4,7 +4,8 @@
 // RUN: FileCheck %s --check-prefix=CHECK-MIR
 //
 // RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
-// RUN: --one-shot-bufferize="copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" | \
+// RUN: --func-bufferize --arith-bufferize           \
+// RUN: --tensor-bufferize --finalizing-bufferize |  \
 // RUN: FileCheck %s --check-prefix=CHECK-LIR
 
 #CSC = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
index a827360abb426..d769876d8ee8e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
@@ -4,7 +4,8 @@
 // RUN: FileCheck %s --check-prefix=CHECK-MIR
 //
 // RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
-// RUN: --one-shot-bufferize="copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" | \
+// RUN: --func-bufferize --arith-bufferize           \
+// RUN: --tensor-bufferize --finalizing-bufferize |  \
 // RUN: FileCheck %s --check-prefix=CHECK-LIR
 
 #CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index e85d9e740adf4..4f553adcc500f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --one-shot-bufferize="dialect-filter=tensor,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -cse -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file | FileCheck %s
 
 // CHECK-LABEL:   func @dim(
 // CHECK-SAME:              %[[TENSOR:.*]]: tensor<*xf32>,
diff --git a/mlir/test/Dialect/Vector/bufferize-invalid.mlir b/mlir/test/Dialect/Vector/bufferize-invalid.mlir
index bcca50a0fe79a..1ae3e312c868f 100644
--- a/mlir/test/Dialect/Vector/bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Vector/bufferize-invalid.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --one-shot-bufferize="dialect-filter=vector,bufferization copy-before-write unknown-type-conversion=identity-layout-map allow-unknown-ops" -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -vector-bufferize -split-input-file -verify-diagnostics
+// | FileCheck %s
 
 // CHECK-LABEL: func @mask(
 func.func @mask(%t0: tensor<?xf32>, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?xf32> {
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 3399f60a2c3bf..6a6a8fa8938bc 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --one-shot-bufferize="dialect-filter=vector,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -vector-bufferize -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @transfer_read(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[pad:.*]]: f32)



More information about the Mlir-commits mailing list