[Mlir-commits] [mlir] b00ee46 - [mlir][bufferize][NFC] Implement BufferizableOpInterface on bufferization ops directly

Matthias Springer llvmlistbot at llvm.org
Mon Jan 24 08:23:42 PST 2022


Author: Matthias Springer
Date: 2022-01-25T01:23:26+09:00
New Revision: b00ee46b5e4bf5f0b5700373ca6302c3c50b10b9

URL: https://github.com/llvm/llvm-project/commit/b00ee46b5e4bf5f0b5700373ca6302c3c50b10b9
DIFF: https://github.com/llvm/llvm-project/commit/b00ee46b5e4bf5f0b5700373ca6302c3c50b10b9.diff

LOG: [mlir][bufferize][NFC] Implement BufferizableOpInterface on bufferization ops directly

No longer go through an external model. Also put BufferizableOpInterface into the same build target as the BufferizationDialect. This allows for some code reuse between BufferizationOps canonicalizers and BufferizableOpInterface implementations.

Differential Revision: https://reviews.llvm.org/D117987

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
    mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h
    mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 21aeb91ff2290..2cbfc901f239b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_
 
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h
deleted file mode 100644
index 7b903b59f1769..0000000000000
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h
+++ /dev/null
@@ -1,25 +0,0 @@
-//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_
-#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_
-
-namespace mlir {
-
-class DialectRegistry;
-
-namespace bufferization {
-namespace bufferization_ext {
-
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
-
-} // namespace bufferization_ext
-} // namespace bufferization
-} // namespace mlir
-
-#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9b977a7d250f5..1a76f8b3eea00 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -10,6 +10,7 @@
 #define BUFFERIZATION_OPS
 
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
@@ -64,11 +65,14 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 // ToTensorOp
 //===----------------------------------------------------------------------===//
 
-def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor",
-    [SameOperandsAndResultShape, SameOperandsAndResultElementType,
-     TypesMatchWith<"result type matches tensor equivalent of 'memref'",
-                    "memref", "result",
-                    "memref::getTensorTypeFromMemRefType($_self)">]> {
+def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
+    BufferizableOpInterface,
+    SameOperandsAndResultShape,
+    SameOperandsAndResultElementType,
+    TypesMatchWith<"result type matches tensor equivalent of 'memref'",
+                   "memref", "result",
+                   "memref::getTensorTypeFromMemRefType($_self)">
+  ]> {
   let summary = "memref to tensor operation";
   let description = [{
     Create a tensor from a memref, making an independent copy of the element
@@ -110,6 +114,35 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor",
         return resultType.cast<TensorType>();
       return {};
     }
+
+    //===------------------------------------------------------------------===//
+    // BufferizableOpInterface implementation
+    //===------------------------------------------------------------------===//
+
+    // ToTensorOp conceptually loads a tensor from a memory location. The
+    // One-Shot analysis has no information about the memref that is loaded from
+    // by ToTensorOp. We have to assume that the loaded tensor may after
+    // bufferization potentially alias with any other bufferized tensor. Since
+    // ToTensorOp and ToMemrefOp have no aliasing OpOperand/OpResult pairs, this
+    // cannot be encoded directly in the analysis. However, declaring ToTensorOp
+    // results as not writable enforces a buffer copy and has the same effect.
+
+    LogicalResult bufferize(RewriterBase &rewriter,
+                            const BufferizationState &state) const {
+      // to_tensor cannot be bufferized. However, other ops that are using
+      // to_tensor's result will eventually be bufferized. At that point, they
+      // will start using to_tensor's memref operand. Once all users of
+      // to_tensor are bufferized, the op will not have any users anymore and
+      // DCE away. In case of partial bufferization, to_memref(to_tensor(x))
+      // constructs may be left over. These are folded by the canonicalizer or
+      // FinalizingBufferize.
+      return failure();
+    }
+
+    bool isWritable(Value value, const BufferizationState &state) const {
+      // It is unknown whether the memref operand is writable or not.
+      return false;
+    }
   }];
 
   let assemblyFormat = "$memref attr-dict `:` type($memref)";
@@ -123,11 +156,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor",
 // ToMemrefOp
 //===----------------------------------------------------------------------===//
 
-def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref",
-    [SameOperandsAndResultShape, SameOperandsAndResultElementType, NoSideEffect,
-     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
-                    "memref", "tensor",
-                    "memref::getTensorTypeFromMemRefType($_self)">]> {
+def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
+    BufferizableOpInterface,
+    SameOperandsAndResultShape,
+    SameOperandsAndResultElementType,
+    NoSideEffect,
+    TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
+                   "memref", "tensor",
+                   "memref::getTensorTypeFromMemRefType($_self)">
+  ]> {
   let summary = "tensor to memref cast operation";
   let description = [{
     Casts a tensor to a memref.
@@ -150,6 +187,44 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref",
   // This op is fully verified by traits.
   let verifier = ?;
 
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // BufferizableOpInterface implementation
+    //===------------------------------------------------------------------===//
+
+    // Note: ToMemrefOp / ToTensorOp are temporary ops that are inserted at the
+    // bufferization boundary. When One-Shot bufferization is complete, there
+    // should be no such ops left over. If `allowUnknownOps` (or after running a
+    // partial bufferization pass), such ops may be part of the resulting IR,
+    // but such IR may no longer be analyzable by One-Shot analysis.
+
+    bool bufferizesToMemoryRead(OpOperand &opOperand,
+                                const BufferizationState &state) const {
+      // It is unknown whether the resulting memref will be read or not.
+      return true;
+    }
+
+    bool bufferizesToMemoryWrite(OpOperand &opOperand,
+                                 const BufferizationState &state) const {
+      // It is unknown whether the resulting MemRef will be written or not.
+      return true;
+    }
+
+    bool mustBufferizeInPlace(OpOperand &opOperand,
+                              const BufferizationState &state) const {
+      // ToMemrefOps always bufferize inplace.
+      return true;
+    }
+
+    OpResult getAliasingOpResult(OpOperand &opOperand,
+                                 const BufferizationState &state) const {
+      return OpResult();
+    }
+
+    LogicalResult bufferize(RewriterBase &rewriter,
+                            const BufferizationState &state);
+  }];
+
   let assemblyFormat = "$tensor attr-dict `:` type($memref)";
 
   let hasFolder = 1;

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp
deleted file mode 100644
index 835a153eb8548..0000000000000
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp
+++ /dev/null
@@ -1,127 +0,0 @@
-//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/Operation.h"
-
-using namespace mlir;
-using namespace mlir::bufferization;
-
-namespace mlir {
-namespace bufferization {
-namespace bufferization_ext {
-
-// TODO: These ops should implement BufferizableOpInterface.
-
-/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded
-/// to x. Other to_memref ops are ignored during bufferization.
-///
-/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
-/// location of the incoming tensor once it will be bufferized. In the anlysis,
-/// the incoming tensor is assumed to bufferize to a memory read and to an
-/// inplace memory write, since it is unknown what will happen to the resulting
-/// memref.
-///
-/// Note: ToMemrefOp / ToTensorOp are temporary ops that are inserted at the
-/// bufferization boundary. When bufferization is complete, there should be no
-/// such ops left over. If `allowUnknownOps`, such ops may be part of the
-/// resulting IR, but such IR may no longer be bufferizable by Comprehensive
-/// Bufferize.
-struct ToMemrefOpInterface
-    : public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
-                                                    bufferization::ToMemrefOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const BufferizationState &state) const {
-    // It is unknown whether the resulting memref will be read or not.
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    // It is unknown whether the resulting MemRef will be written or not.
-    return true;
-  }
-
-  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
-                            const BufferizationState &state) const {
-    // ToMemrefOps always bufferize inplace.
-    return true;
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                               const BufferizationState &state) const {
-    return OpResult();
-  }
-
-  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationState &state) const {
-    auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
-
-    // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
-    if (auto toTensorOp =
-            toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
-      Value buffer = toTensorOp.memref();
-
-      // Insert cast in case to_memref(to_tensor(x))'s type is 
diff erent from
-      // x's type.
-      if (toTensorOp.memref().getType() != toMemrefOp.getType()) {
-        assert(memref::CastOp::areCastCompatible(buffer.getType(),
-                                                 toMemrefOp.getType()) &&
-               "ToMemrefOp::bufferize : cast incompatible");
-        buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
-                                                 toMemrefOp.getType());
-      }
-      replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
-/// Bufferization of bufferization.to_tensor. Such ops cannot be bufferized.
-/// However, other ops that are using to_tensor's result will eventually be
-/// bufferized. At that point, they will start using to_tensor's memref operand.
-/// Once all users of to_tensor are bufferized, the op will not have any users
-/// anymore and DCE away.
-///
-/// ToTensorOp conceptually loads a tensor from a memory location. The analysis
-/// has no information about the memref that is loaded from by ToTensorOp. We
-/// have to assume that the loaded tensor may after bufferization potentially
-/// alias with any other bufferized tensor. Since ToTensorOp and ToMemrefOp have
-/// no aliasing OpOperand/OpResult pairs, this cannot be encoded directly in the
-/// analysis. However, declaring ToTensorOp results as not writable enforces a
-/// buffer copy and has the same effect.
-struct ToTensorOpInterface
-    : public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
-                                                    bufferization::ToTensorOp> {
-  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationState &state) const {
-    return failure();
-  }
-
-  bool isWritable(Operation *op, Value value,
-                  const BufferizationState &state) const {
-    // It is unknown whether the memref operand is writable or not.
-    return false;
-  }
-};
-
-} // namespace bufferization_ext
-} // namespace bufferization
-} // namespace mlir
-
-void bufferization_ext::registerBufferizableOpInterfaceExternalModels(
-    DialectRegistry &registry) {
-  registry.addOpInterface<ToMemrefOp, bufferization_ext::ToMemrefOpInterface>();
-  registry.addOpInterface<ToTensorOp, bufferization_ext::ToTensorOpInterface>();
-}

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f81d726108d15..93770c9da5aa6 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -182,6 +182,79 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
   }
 };
 
+/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
+/// to_memref op are 
diff erent, a memref.cast is needed.
+static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
+                                              ToMemrefOp toMemref,
+                                              bool allowSameType = true) {
+  auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
+  if (!memrefToTensor)
+    return failure();
+
+  // A memref_to_tensor + tensor_to_memref with same types can be folded without
+  // inserting a cast.
+  if (memrefToTensor.memref().getType() == toMemref.getType()) {
+    if (!allowSameType)
+      // Function can be configured to only handle cases where a cast is needed.
+      return failure();
+    rewriter.replaceOp(toMemref, memrefToTensor.memref());
+    return success();
+  }
+
+  // If types are definitely not cast-compatible, bail.
+  if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(),
+                                         toMemref.getType()))
+    return failure();
+
+  // We already know that the types are potentially cast-compatible. However
+  // in case the affine maps are 
diff erent, we may need to use a copy if we go
+  // from dynamic to static offset or stride (the canonicalization cannot know
+  // at this point that it is really cast compatible).
+  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
+    int64_t sourceOffset, targetOffset;
+    SmallVector<int64_t, 4> sourceStrides, targetStrides;
+    if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
+        failed(getStridesAndOffset(target, targetStrides, targetOffset)))
+      return false;
+    auto dynamicToStatic = [](int64_t a, int64_t b) {
+      return a == MemRefType::getDynamicStrideOrOffset() &&
+             b != MemRefType::getDynamicStrideOrOffset();
+    };
+    if (dynamicToStatic(sourceOffset, targetOffset))
+      return false;
+    for (auto it : zip(sourceStrides, targetStrides))
+      if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
+        return false;
+    return true;
+  };
+
+  auto memrefToTensorType =
+      memrefToTensor.memref().getType().dyn_cast<MemRefType>();
+  auto toMemrefType = toMemref.getType().dyn_cast<MemRefType>();
+  if (memrefToTensorType && toMemrefType &&
+      !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) {
+    MemRefType resultType = toMemrefType;
+    auto loc = toMemref.getLoc();
+    SmallVector<Value, 4> dynamicOperands;
+    for (int i = 0; i < resultType.getRank(); ++i) {
+      if (resultType.getShape()[i] != ShapedType::kDynamicSize)
+        continue;
+      auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
+      Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index);
+      dynamicOperands.push_back(size);
+    }
+    // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
+    // BufferizableOpInterface impl of ToMemrefOp.
+    auto copy =
+        rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
+    rewriter.create<memref::CopyOp>(loc, memrefToTensor.memref(), copy);
+    rewriter.replaceOp(toMemref, {copy});
+  } else
+    rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
+                                                memrefToTensor.memref());
+  return success();
+}
+
 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
@@ -189,62 +262,10 @@ struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
 
   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
                                 PatternRewriter &rewriter) const final {
-    auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
-    // Bail unless we have a memref_to_tensor + tensor_to_memref with 
diff erent
-    // types. `ToMemrefOp::fold` handles the same type case.
-    if (!memrefToTensor ||
-        memrefToTensor.memref().getType() == toMemref.getType())
-      return failure();
-    // If types are definitely not cast-compatible, bail.
-    if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(),
-                                           toMemref.getType()))
-      return failure();
-
-    // We already know that the types are potentially cast-compatible. However
-    // in case the affine maps are 
diff erent, we may need to use a copy if we go
-    // from dynamic to static offset or stride (the canonicalization cannot know
-    // at this point that it is really cast compatible).
-    auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
-      int64_t sourceOffset, targetOffset;
-      SmallVector<int64_t, 4> sourceStrides, targetStrides;
-      if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
-          failed(getStridesAndOffset(target, targetStrides, targetOffset)))
-        return false;
-      auto dynamicToStatic = [](int64_t a, int64_t b) {
-        return a == MemRefType::getDynamicStrideOrOffset() &&
-               b != MemRefType::getDynamicStrideOrOffset();
-      };
-      if (dynamicToStatic(sourceOffset, targetOffset))
-        return false;
-      for (auto it : zip(sourceStrides, targetStrides))
-        if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
-          return false;
-      return true;
-    };
-
-    auto memrefToTensorType =
-        memrefToTensor.memref().getType().dyn_cast<MemRefType>();
-    auto toMemrefType = toMemref.getType().dyn_cast<MemRefType>();
-    if (memrefToTensorType && toMemrefType &&
-        !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) {
-      MemRefType resultType = toMemrefType;
-      auto loc = toMemref.getLoc();
-      SmallVector<Value, 4> dynamicOperands;
-      for (int i = 0; i < resultType.getRank(); ++i) {
-        if (resultType.getShape()[i] != ShapedType::kDynamicSize)
-          continue;
-        auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
-        Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index);
-        dynamicOperands.push_back(size);
-      }
-      auto copy =
-          rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
-      rewriter.create<memref::CopyOp>(loc, memrefToTensor.memref(), copy);
-      rewriter.replaceOp(toMemref, {copy});
-    } else
-      rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
-                                                  memrefToTensor.memref());
-    return success();
+    // Only handle cases where a cast is needed. The other case is handled by
+    // the folder.
+    return foldToMemrefToTensorPair(rewriter, toMemref,
+                                    /*allowSameType=*/false);
   }
 };
 
@@ -288,6 +309,12 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
       context);
 }
 
+LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
+                                    const BufferizationState &state) {
+  // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
+  return foldToMemrefToTensorPair(rewriter, *this);
+}
+
 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
       .getOperation();

diff  --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index cdb6656f0f0ae..8ec23af66eac3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_dialect_library(MLIRBufferization
-  PARTIAL_SOURCES_INTENDED
   AllocationOpInterface.cpp
+  BufferizableOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
 
@@ -17,17 +17,3 @@ add_mlir_dialect_library(MLIRBufferization
   MLIRTensor
   MLIRMemRef
   )
-
-add_mlir_dialect_library(MLIRBufferizableOpInterface
-  PARTIAL_SOURCES_INTENDED
-  BufferizableOpInterface.cpp
-  BufferizationInterfaceImpl.cpp
-
-  DEPENDS
-  MLIRBufferizableOpInterfaceIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRBufferization
-  MLIRMemRef
-)

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index b3f4fb38d003a..b212ef952a05e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -10,7 +10,6 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   MLIRBufferizationPassIncGen
 
   LINK_LIBS PUBLIC
-  MLIRBufferizableOpInterface
   MLIRBufferization
   MLIRControlFlowInterfaces
   MLIRInferTypeOpInterface

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 21e22de2eee24..8b64809e4b97c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl
 
   LINK_LIBS PUBLIC
   MLIRAffine
-  MLIRBufferizableOpInterface
+  MLIRBufferization
 )
 
 add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
@@ -21,7 +21,7 @@ add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
 
   LINK_LIBS PUBLIC
   MLIRArithmetic
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRIR
   MLIRMemRef
   MLIRStandardOpsTransforms
@@ -31,7 +31,7 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
   LinalgInterfaceImpl.cpp
 
   LINK_LIBS PUBLIC
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRIR
   MLIRLinalg
@@ -42,7 +42,7 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
   SCFInterfaceImpl.cpp
 
   LINK_LIBS PUBLIC
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRIR
   MLIRSCF
@@ -52,7 +52,7 @@ add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
   StdInterfaceImpl.cpp
 
   LINK_LIBS PUBLIC
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRStandard
 )
 
@@ -60,7 +60,7 @@ add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
   VectorInterfaceImpl.cpp
 
   LINK_LIBS PUBLIC
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRIR
   MLIRVector
 )
@@ -69,7 +69,7 @@ add_mlir_dialect_library(MLIRModuleBufferization
   ModuleBufferization.cpp
 
   LINK_LIBS PUBLIC
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRIR
   MLIRMemRef

diff  --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index b88e57dcd13b9..0b9142f8d9e0f 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -14,7 +14,7 @@ add_mlir_dialect_library(MLIRLinalg
   LINK_LIBS PUBLIC
   MLIRAffine
   MLIRArithmetic
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRDialectUtils
   MLIRInferTypeOpInterface
   MLIRIR

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 6059d51260975..366600f6d33de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRAnalysis
   MLIRArithBufferizableOpInterfaceImpl
   MLIRArithmetic
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRComplex
   MLIRInferTypeOpInterface
   MLIRIR

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 90335f952d559..12d43300aacf7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -10,7 +10,6 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
@@ -54,7 +53,6 @@ struct LinalgComprehensiveModuleBufferize
                 arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
     affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
-    bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
     std_ext::registerModuleBufferizationExternalModels(registry);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 5c4ce30042d69..0c04e729e4cbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -12,7 +12,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
 
   LINK_LIBS PUBLIC
   MLIRArithmetic
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRIR
   MLIRLLVMIR
   MLIRLinalg

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 98787344c582b..d36e556fd7723 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -10,7 +10,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
 
   LINK_LIBS PUBLIC
   MLIRArithmetic
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRIR
   MLIRMemRef

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index c784461c2dc3d..b45786123f723 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -16,7 +16,7 @@ add_mlir_library(MLIRLinalgTestPasses
   MLIRAffineBufferizableOpInterfaceImpl
   MLIRArithBufferizableOpInterfaceImpl
   MLIRArithmetic
-  MLIRBufferizableOpInterface
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRGPUTransforms
   MLIRLinalg

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 3e65330addb4d..a9b5ab206d42f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
@@ -61,7 +60,6 @@ struct TestComprehensiveFunctionBufferize
                     arith::ArithmeticDialect, AffineDialect>();
     affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
-    bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
     std_ext::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 74d10d9190b5d..a3876c0b71f4f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1958,7 +1958,6 @@ cc_library(
     deps = [
         ":Affine",
         ":ArithmeticDialect",
-        ":BufferizableOpInterface",
         ":BufferizationDialect",
         ":IR",
         ":LLVMDialect",
@@ -4423,7 +4422,6 @@ cc_library(
     deps = [
         ":ArithmeticDialect",
         ":Async",
-        ":BufferizableOpInterface",
         ":BufferizationDialect",
         ":BufferizationTransforms",
         ":IR",
@@ -6573,27 +6571,6 @@ gentbl_cc_library(
     ],
 )
 
-cc_library(
-    name = "BufferizableOpInterface",
-    srcs = [
-        "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
-        "lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp",
-    ],
-    hdrs = [
-        "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
-        "include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h",
-    ],
-    includes = ["include"],
-    deps = [
-        ":BufferizableOpInterfaceIncGen",
-        ":BufferizationDialect",
-        ":IR",
-        ":MemRefDialect",
-        ":Support",
-        "//llvm:Support",
-    ],
-)
-
 cc_library(
     name = "AffineBufferizableOpInterfaceImpl",
     srcs = [
@@ -6605,7 +6582,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":Affine",
-        ":BufferizableOpInterface",
+        ":BufferizationDialect",
         "//llvm:Support",
     ],
 )
@@ -6621,7 +6598,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":ArithmeticDialect",
-        ":BufferizableOpInterface",
+        ":BufferizationDialect",
         ":IR",
         ":MemRefDialect",
         ":Support",
@@ -6640,7 +6617,6 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
-        ":BufferizableOpInterface",
         ":BufferizationDialect",
         ":BufferizationTransforms",
         ":IR",
@@ -6660,7 +6636,6 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
-        ":BufferizableOpInterface",
         ":BufferizationDialect",
         ":BufferizationTransforms",
         ":IR",
@@ -6680,7 +6655,7 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
-        ":BufferizableOpInterface",
+        ":BufferizationDialect",
         ":IR",
         ":StandardOps",
         ":Support",
@@ -6698,7 +6673,7 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
-        ":BufferizableOpInterface",
+        ":BufferizationDialect",
         ":IR",
         ":Support",
         ":VectorOps",
@@ -6827,7 +6802,7 @@ cc_library(
     deps = [
         ":Affine",
         ":ArithmeticDialect",
-        ":BufferizableOpInterface",
+        ":BufferizationDialect",
         ":CopyOpInterface",
         ":DialectUtils",
         ":IR",
@@ -6909,7 +6884,6 @@ cc_library(
         ":Analysis",
         ":ArithBufferizableOpInterfaceImpl",
         ":ArithmeticDialect",
-        ":BufferizableOpInterface",
         ":BufferizationDialect",
         ":BufferizationTransforms",
         ":ComplexDialect",
@@ -6953,7 +6927,6 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
-        ":BufferizableOpInterface",
         ":BufferizationDialect",
         ":BufferizationTransforms",
         ":DialectUtils",
@@ -7968,19 +7941,27 @@ gentbl_cc_library(
     ],
     tblgen = ":mlir-tblgen",
     td_file = "include/mlir/Dialect/Bufferization/IR/BufferizationOps.td",
-    deps = [":BufferizationOpsTdFiles"],
+    deps = [
+        ":BufferizableOpInterfaceTdFiles",
+        ":BufferizationOpsTdFiles",
+    ],
 )
 
 cc_library(
     name = "BufferizationDialect",
     srcs = [
+        "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
     ],
-    hdrs = ["include/mlir/Dialect/Bufferization/IR/Bufferization.h"],
+    hdrs = [
+        "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
+        "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
+    ],
     includes = ["include"],
     deps = [
         ":AllocationOpInterface",
+        ":BufferizableOpInterfaceIncGen",
         ":BufferizationBaseIncGen",
         ":BufferizationOpsIncGen",
         ":ControlFlowInterfaces",
@@ -7989,6 +7970,7 @@ cc_library(
         ":InferTypeOpInterface",
         ":MemRefDialect",
         ":StandardOps",
+        ":Support",
         ":TensorDialect",
         ":ViewLikeInterface",
         "//llvm:Support",
@@ -8025,7 +8007,6 @@ cc_library(
     deps = [
         ":AllocationOpInterface",
         ":Analysis",
-        ":BufferizableOpInterface",
         ":BufferizationDialect",
         ":BufferizationPassIncGen",
         ":ControlFlowInterfaces",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 8c24a43000c78..7bc92df875c46 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -388,7 +388,6 @@ cc_library(
         "//mlir:AffineBufferizableOpInterfaceImpl",
         "//mlir:ArithBufferizableOpInterfaceImpl",
         "//mlir:ArithmeticDialect",
-        "//mlir:BufferizableOpInterface",
         "//mlir:BufferizationDialect",
         "//mlir:BufferizationTransforms",
         "//mlir:GPUDialect",


        


More information about the Mlir-commits mailing list