[Mlir-commits] [mlir] c645eb0 - [mlir][memref] Bufferize memref.tensor_store op

Matthias Springer llvmlistbot at llvm.org
Wed Feb 15 06:28:45 PST 2023


Author: Matthias Springer
Date: 2023-02-15T15:26:57+01:00
New Revision: c645eb0d03bd2a04f71a01d3eac1c392539dc207

URL: https://github.com/llvm/llvm-project/commit/c645eb0d03bd2a04f71a01d3eac1c392539dc207
DIFF: https://github.com/llvm/llvm-project/commit/c645eb0d03bd2a04f71a01d3eac1c392539dc207.diff

LOG: [mlir][memref] Bufferize memref.tensor_store op

This change adds the BufferizableOpInterface implementation for memref.tensor_store.

Differential Revision: https://reviews.llvm.org/D144080

Added: 
    mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h
    mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/MemRef/bufferize.mlir

Modified: 
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000000..7e532100e4eeb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace memref {
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index adbbb847adfb7..aa7abb0e19b5d 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -45,6 +45,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -131,6 +132,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
       registry);
   linalg::registerBufferizableOpInterfaceExternalModels(registry);
   linalg::registerTilingInterfaceExternalModels(registry);
+  memref::registerBufferizableOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   scf::registerBufferizableOpInterfaceExternalModels(registry);
   shape::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..c11bce95b7ead
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,63 @@
+//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace {
+/// Bufferization of memref.tensor_store. Replace with memref.copy.
+struct TensorStoreOpInterface
+    : public BufferizableOpInterface::ExternalModel<TensorStoreOpInterface,
+                                                    memref::TensorStoreOp> {
+  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    return {};
+  }
+
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    assert(opOperand.getOperandNumber() == 0 && "expected src operand");
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    // The memref operand is written but not the tensor operand.
+    assert(opOperand.getOperandNumber() == 0 && "expected src operand");
+    return false;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto tensorStoreOp = cast<memref::TensorStoreOp>(op);
+    auto srcBuffer = getBuffer(rewriter, tensorStoreOp.getTensor(), options);
+    if (failed(srcBuffer))
+      return failure();
+    if (failed(options.createMemCpy(rewriter, op->getLoc(), *srcBuffer,
+                                    tensorStoreOp.getMemref())))
+      return failure();
+    rewriter.eraseOp(tensorStoreOp);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::memref::registerBufferizableOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, MemRefDialect *dialect) {
+    TensorStoreOp::attachInterface<TensorStoreOpInterface>(*ctx);
+  });
+}

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ceccc51f3813d..3cffe05995f0c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
+  BufferizableOpInterfaceImpl.cpp
   ComposeSubView.cpp
   ExpandOps.cpp
   ExpandStridedMetadata.cpp
@@ -20,6 +21,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MLIRAffineUtils
   MLIRArithDialect
   MLIRArithTransforms
+  MLIRBufferizationDialect
   MLIRFuncDialect
   MLIRInferTypeOpInterface
   MLIRLoopLikeInterface

diff  --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index d6a282d2e175d..be277ba6578f5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -38,6 +38,34 @@ transform.sequence failures(propagate) {
 
 // -----
 
+// CHECK-LABEL: func @tensor_pad_constant(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x10xindex>
+//       CHECK:   %[[src:.*]] = bufferization.to_memref %[[t]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc
+//       CHECK:   %[[subview:.*]] = memref.subview %[[alloc]]
+//       CHECK:   memref.copy %[[src]], %[[subview]]
+//       CHECK:   bufferization.to_tensor %[[alloc]] restrict writable
+func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
+                               %h2: index) -> tensor<?x?xindex> {
+  %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] {
+  ^bb0(%arg0: index, %arg1: index):
+    %c = arith.constant 50 : index
+    tensor.yield %c : index
+  } : tensor<?x10xindex> to tensor<?x?xindex>
+  return %0 : tensor<?x?xindex>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1 = transform.get_result %0[0] : (!pdl.operation) -> !transform.any_value
+  %2 = transform.structured.bufferize_to_allocation %1
+  // Make sure that One-Shot Bufferize can bufferize the rest.
+  transform.bufferization.one_shot_bufferize %arg1
+}
+
+// -----
+
 // CHECK-LABEL: func @materialization_of_bbarg(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x10xindex>
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
@@ -59,3 +87,26 @@ transform.sequence failures(propagate) {
   %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value
   %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
 }
+
+// -----
+
+// CHECK-LABEL: func @materialization_of_bbarg(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x10xindex>
+//       CHECK:   %[[m:.*]] = bufferization.to_memref %[[t]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?x10xindex, 4>
+//       CHECK:   memref.copy %[[m]], %[[alloc]]
+//       CHECK:   %[[r:.*]] = memref.load %[[alloc]]
+//       CHECK:   return %[[r]]
+func.func @materialization_of_bbarg(%t: tensor<?x10xindex>, %idx: index) -> index {
+  %r = tensor.extract %t[%idx, %idx] : tensor<?x10xindex>
+  return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value
+  %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
+  // Make sure that One-Shot Bufferize can bufferize the rest.
+  transform.bufferization.one_shot_bufferize %arg1
+}

diff  --git a/mlir/test/Dialect/MemRef/bufferize.mlir b/mlir/test/Dialect/MemRef/bufferize.mlir
new file mode 100644
index 0000000000000..d44cfc2f734cf
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/bufferize.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -one-shot-bufferize %s | FileCheck %s
+
+// CHECK-LABEL: func @tensor_store(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>, %[[m:.*]]: memref<?xf32>
+//       CHECK:   %[[src:.*]] = bufferization.to_memref %[[t]]
+//       CHECK:   memref.copy %[[src]], %[[m]]
+//       CHECK:   return
+func.func @tensor_store(%t: tensor<?xf32>, %m: memref<?xf32>) {
+  memref.tensor_store %t, %m : memref<?xf32>
+  return
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 94c30effadcee..8108cf96fd812 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9904,6 +9904,7 @@ cc_library(
         ":ArithDialect",
         ":ArithTransforms",
         ":ArithUtils",
+        ":BufferizationDialect",
         ":ControlFlowDialect",
         ":DialectUtils",
         ":FuncDialect",


        


More information about the Mlir-commits mailing list