[Mlir-commits] [mlir] ee491ac - [mlir] Add std.tensor_to_memref op and teach the infra about it

Sean Silva llvmlistbot at llvm.org
Thu Oct 15 12:20:48 PDT 2020


Author: Sean Silva
Date: 2020-10-15T12:19:20-07:00
New Revision: ee491ac91e123b90eeec3cce7e494936ea8cb85d

URL: https://github.com/llvm/llvm-project/commit/ee491ac91e123b90eeec3cce7e494936ea8cb85d
DIFF: https://github.com/llvm/llvm-project/commit/ee491ac91e123b90eeec3cce7e494936ea8cb85d.diff

LOG: [mlir] Add std.tensor_to_memref op and teach the infra about it

The opposite of tensor_to_memref is tensor_load.

- Add some basic tensor_load/tensor_to_memref folding.
- Add source/target materializations to BufferizeTypeConverter.
- Add an example std bufferization pattern/pass that shows how the
  materialiations work together (more std bufferization patterns to come
  in subsequent commits).
  - In coming commits, I'll document how to write composable
  bufferization passes/patterns and update the other in-tree
  bufferization passes to match this convention. The populate* functions
  will of course continue to be exposed for power users.

The naming on tensor_load/tensor_to_memref and their pretty forms are
not very intuitive. I'm open to any suggestions here. One key
observation is that the memref type must always be the one specified in
the pretty form, since the tensor type can be inferred from the memref
type but not vice-versa.

With this, I've been able to replace all my custom bufferization type
converters in npcomp with BufferizeTypeConverter!

Part of the plan discussed in:
https://llvm.discourse.group/t/what-is-the-strategy-for-tensor-memref-conversion-bufferization/1938/17

Differential Revision: https://reviews.llvm.org/D89437

Added: 
    mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/bufferize.mlir
    mlir/test/Dialect/Standard/canonicalize.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/include/mlir/Transforms/Bufferize.h
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index c024e19b5009..97ab739890be 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3363,6 +3363,10 @@ def TensorLoadOp : Std_Op<"tensor_load",
     data. The result value is a tensor whose shape and element type match the
     memref operand.
 
+    The opposite of this op is tensor_to_memref. Together, these two ops are
+    useful for source/target materializations when doing type conversions
+    involving tensors and memrefs.
+
     Example:
 
     ```mlir
@@ -3394,6 +3398,8 @@ def TensorLoadOp : Std_Op<"tensor_load",
   }];
 
   let assemblyFormat = "$memref attr-dict `:` type($memref)";
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -3428,6 +3434,47 @@ def TensorStoreOp : Std_Op<"tensor_store",
   let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
 }
 
+//===----------------------------------------------------------------------===//
+// TensorToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def TensorToMemrefOp : Std_Op<"tensor_to_memref",
+    [SameOperandsAndResultShape, SameOperandsAndResultElementType,
+     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
+                    "memref", "tensor",
+                    "getTensorTypeFromMemRefType($_self)">]> {
+  let summary = "tensor to memref operation";
+  let description = [{
+    Create a memref from a tensor. This is equivalent to allocating a new
+    memref of the appropriate (possibly dynamic) shape, and then copying the
+    elements (as if by a tensor_store op) into the newly allocated memref.
+
+    The opposite of this op is tensor_load. Together, these two ops are useful
+    for source/target materializations when doing type conversions involving
+    tensors and memrefs.
+
+    Note: This op takes the memref type in its pretty form because the tensor
+    type can always be inferred from the memref type, but the reverse is not
+    true. For example, the memref might have a layout map or memory space which
+    cannot be inferred from the tensor type.
+
+    ```mlir
+    // Result type is tensor<4x?xf32>
+    %12 = tensor_to_memref %10 :  memref<4x?xf32, #map0, 42>
+    ```
+  }];
+
+  let arguments = (ins AnyTensor:$tensor);
+  let results = (outs Res<AnyRankedOrUnrankedMemRef,
+                      "the memref to create", [MemAlloc]>:$memref);
+  // This op is fully verified by traits.
+  let verifier = ?;
+
+  let assemblyFormat = "$tensor attr-dict `:` type($memref)";
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index fba5f4b32043..65a39dc9ad9b 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -16,6 +16,7 @@
 #define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Bufferize.h"
 
 namespace mlir {
 
@@ -27,6 +28,13 @@ std::unique_ptr<Pass> createExpandAtomicPass();
 void populateExpandTanhPattern(OwningRewritePatternList &patterns,
                                MLIRContext *ctx);
 
+void populateStdBufferizePatterns(MLIRContext *context,
+                                  BufferizeTypeConverter &typeConverter,
+                                  OwningRewritePatternList &patterns);
+
+/// Creates an instance of std bufferization pass.
+std::unique_ptr<Pass> createStdBufferizePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index b65c03d33fc1..ff5a5a63b24b 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -16,4 +16,9 @@ def ExpandAtomic : FunctionPass<"expand-atomic"> {
   let constructor = "mlir::createExpandAtomicPass()";
 }
 
+def StdBufferize : FunctionPass<"std-bufferize"> {
+  let summary = "Bufferize the std dialect";
+  let constructor = "mlir::createStdBufferizePass()";
+}
+
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index 128d4e7ebb07..eebebaae54e3 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -13,6 +13,16 @@
 // pattern needs to be written. The infrastructure in this file assists in
 // defining these conversion patterns in a composable way.
 //
+// Bufferization conversion patterns should generally use the ordinary
+// conversion pattern classes (e.g. OpConversionPattern). A TypeConverter
+// (accessible with getTypeConverter()) available on such patterns is sufficient
+// for most cases (if needed at all).
+//
+// But some patterns require access to the extra functions on
+// BufferizeTypeConverter that don't exist on the base TypeConverter class. For
+// those cases, BufferizeConversionPattern and its related classes should be
+// used, which provide access to a BufferizeTypeConverter directly.
+//
 //===----------------------------------------------------------------------===//
 
 #ifndef MLIR_TRANSFORMS_BUFFERIZE_H

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 8fe45cbb1a13..b5d1429829e5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3601,7 +3601,7 @@ void TensorCastOp::getCanonicalizationPatterns(
 }
 
 //===----------------------------------------------------------------------===//
-// Helpers for Tensor[Load|Store]Op
+// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp
 //===----------------------------------------------------------------------===//
 
 static Type getTensorTypeFromMemRefType(Type type) {
@@ -3612,6 +3612,27 @@ static Type getTensorTypeFromMemRefType(Type type) {
   return NoneType::get(type.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// TensorLoadOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
+  if (auto tensorToMemref = memref().getDefiningOp<TensorToMemrefOp>())
+    return tensorToMemref.tensor();
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// TensorToMemrefOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
+  if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
+    if (tensorLoad.memref().getType() == getType())
+      return tensorLoad.memref();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
new file mode 100644
index 000000000000..95a8b75e2c2b
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -0,0 +1,62 @@
+//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of std ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Bufferize.h"
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resultType = getTypeConverter()->convertType(op.getType());
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateStdBufferizePatterns(MLIRContext *context,
+                                        BufferizeTypeConverter &typeConverter,
+                                        OwningRewritePatternList &patterns) {
+  patterns.insert<BufferizeTensorCastOp>(typeConverter, context);
+}
+
+namespace {
+struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
+  void runOnFunction() override {
+    auto *context = &getContext();
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    target.addLegalDialect<StandardOpsDialect>();
+
+    populateStdBufferizePatterns(context, typeConverter, patterns);
+    target.addIllegalOp<TensorCastOp>();
+
+    if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createStdBufferizePass() {
+  return std::make_unique<StdBufferizePass>();
+}

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index d1204df2de76..182f03e8dd55 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
+  Bufferize.cpp
   ExpandAtomic.cpp
   ExpandTanh.cpp
   FuncConversions.cpp

diff  --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 9f6473d2df20..682fd9ff6719 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -27,6 +27,18 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   addConversion([](UnrankedTensorType type) -> Type {
     return UnrankedMemRefType::get(type.getElementType(), 0);
   });
+  addSourceMaterialization([](OpBuilder &builder, RankedTensorType type,
+                              ValueRange inputs, Location loc) -> Value {
+    assert(inputs.size() == 1);
+    assert(inputs[0].getType().isa<BaseMemRefType>());
+    return builder.create<TensorLoadOp>(loc, type, inputs[0]);
+  });
+  addTargetMaterialization([](OpBuilder &builder, MemRefType type,
+                              ValueRange inputs, Location loc) -> Value {
+    assert(inputs.size() == 1);
+    assert(inputs[0].getType().isa<TensorType>());
+    return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
+  });
 }
 
 /// This method tries to decompose a value of a certain type using provided

diff  --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
new file mode 100644
index 000000000000..981237d78cdd
--- /dev/null
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -std-bufferize | FileCheck %s
+
+// CHECK-LABEL:   func @tensor_cast(
+// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
+// CHECK:           %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
+// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED]]
+// CHECK:           return %[[RET]] : tensor<2xindex>
+func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
+  %0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
+  return %0 : tensor<2xindex>
+}

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
new file mode 100644
index 000000000000..cd22014e0de0
--- /dev/null
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// Test case: Basic folding of tensor_load(tensor_to_memref(t)) -> t
+// CHECK-LABEL:   func @tensor_load_of_tensor_to_memref(
+// CHECK-SAME:                                          %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:           return %[[TENSOR]]
+func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+    %0 = tensor_to_memref %arg0 : memref<?xf32>
+    %1 = tensor_load %0 : memref<?xf32>
+    return %1 : tensor<?xf32>
+}
+
+// Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m
+// CHECK-LABEL:   func @tensor_to_memref_of_tensor_load(
+// CHECK-SAME:                                          %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
+// CHECK:           return %[[MEMREF]]
+func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
+    %0 = tensor_load %arg0 : memref<?xf32>
+    %1 = tensor_to_memref %0 : memref<?xf32>
+    return %1 : memref<?xf32>
+}
+
+// Test case: If the memrefs are not the same type, don't fold them.
+// CHECK-LABEL:   func @no_fold_tensor_to_memref_of_tensor_load(
+// CHECK-SAME:                                                  %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>) -> memref<?xf32, 7> {
+// CHECK:           %[[TENSOR:.*]] = tensor_load %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2>
+// CHECK:           %[[MEMREF_ADDRSPACE7:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32, 7>
+// CHECK:           return %[[MEMREF_ADDRSPACE7]]
+func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf32, 7> {
+    %0 = tensor_load %arg0 : memref<?xf32, 2>
+    %1 = tensor_to_memref %0 : memref<?xf32, 7>
+    return %1 : memref<?xf32, 7>
+}

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 64474e391b81..b11c9534cc2d 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -19,6 +19,13 @@ func @test_index_cast_tensor_reverse(%arg0 : tensor<i64>) -> tensor<index> {
   return %0 : tensor<index>
 }
 
+// CHECK-LABEL: test_tensor_to_memref
+func @test_tensor_to_memref(%arg0: tensor<?xi64>, %arg1: tensor<*xi64>) -> (memref<?xi64, affine_map<(d0) -> (d0 + 7)>>, memref<*xi64, 1>) {
+  %0 = tensor_to_memref %arg0 : memref<?xi64, affine_map<(d0) -> (d0 + 7)>>
+  %1 = tensor_to_memref %arg1 : memref<*xi64, 1>
+  return %0, %1 : memref<?xi64, affine_map<(d0) -> (d0 + 7)>>, memref<*xi64, 1>
+}
+
 // CHECK-LABEL: @assert
 func @assert(%arg : i1) {
   assert %arg, "Some message in case this assertion fails."


        


More information about the Mlir-commits mailing list