[Mlir-commits] [mlir] daf1810 - [mlir][tensor] Replace tensor-bufferize with BufferizableOpInterface impl

Matthias Springer llvmlistbot at llvm.org
Thu Jan 27 02:30:58 PST 2022


Author: Matthias Springer
Date: 2022-01-27T19:30:45+09:00
New Revision: daf18108ecc959ce20c25af33618c934b470f350

URL: https://github.com/llvm/llvm-project/commit/daf18108ecc959ce20c25af33618c934b470f350
DIFF: https://github.com/llvm/llvm-project/commit/daf18108ecc959ce20c25af33618c934b470f350.diff

LOG: [mlir][tensor] Replace tensor-bufferize with BufferizableOpInterface impl

This commit switches the `tensor-bufferize` pass over to BufferizableOpInterface-based bufferization.

Differential Revision: https://reviews.llvm.org/D118246

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
    mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 5107710413ead..534f664483be2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -255,7 +255,7 @@ class BufferizationState {
   const BufferizationOptions &getOptions() const { return options; }
 
 protected:
-  BufferizationState(const BufferizationOptions &options);
+  explicit BufferizationState(const BufferizationOptions &options);
 
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
@@ -270,6 +270,24 @@ class BufferizationState {
   const BufferizationOptions &options;
 };
 
+/// This a "no analysis, always copy" BufferizationState. In the absence of an
+/// analysis, a buffer must be copied each time it is written to. Therefore, all
+/// OpOperands that bufferize to a memory write must bufferize out-of-place.
+class AlwaysCopyBufferizationState : public BufferizationState {
+public:
+  explicit AlwaysCopyBufferizationState(const BufferizationOptions &options);
+
+  AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
+
+  virtual ~AlwaysCopyBufferizationState() = default;
+
+  /// Return `true` if the given OpResult has been decided to bufferize inplace.
+  bool isInPlace(OpOperand &opOperand) const override;
+
+  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+  bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
+};
+
 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
 /// must be replaced with memref values.
 void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index b587955d65b13..a56287995aa96 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -69,6 +69,21 @@ void populateEliminateBufferizeMaterializationsPatterns(
 // TODO: Extract `options` from `state` and pass as separate argument.
 LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
 
+/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
+/// Buffers are duplicated and copied before any tensor use that bufferizes to
+/// a memory write.
+///
+/// Note: This function bufferizes ops without utilizing analysis results. It
+/// can be used to implement partial bufferization passes.
+LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
+
+/// Populate the pattern set with a pattern that bufferizes ops that implement
+/// `BufferizableOpInterface`.
+void populateBufferizationPattern(const BufferizationState &state,
+                                  RewritePatternSet &patterns);
+
+std::unique_ptr<BufferizationOptions> getPartialBufferizationOptions();
+
 } // namespace bufferization
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index f90d02bda22d3..a346577e2f569 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -12,16 +12,6 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
-namespace bufferization {
-class BufferizeTypeConverter;
-} // namespace bufferization
-
-class RewritePatternSet;
-
-void populateTensorBufferizePatterns(
-    bufferization::BufferizeTypeConverter &typeConverter,
-    RewritePatternSet &patterns);
-
 /// Creates an instance of `tensor` dialect bufferization pass.
 std::unique_ptr<Pass> createTensorBufferizePass();
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
index 57c754c84c321..77743134325f5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
@@ -14,11 +14,6 @@ include "mlir/Pass/PassBase.td"
 def TensorBufferize : Pass<"tensor-bufferize", "FuncOp"> {
   let summary = "Bufferize the `tensor` dialect";
   let constructor = "mlir::createTensorBufferizePass()";
-  let dependentDialects = [
-    "bufferization::BufferizationDialect",
-    "memref::MemRefDialect",
-    "scf::SCFDialect"
-  ];
 }
 
 #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 276950d4ef19e..3af1c37594f96 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -318,6 +318,25 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
   rewriter.eraseOp(op);
 }
 
+AlwaysCopyBufferizationState::AlwaysCopyBufferizationState(
+    const BufferizationOptions &options)
+    : BufferizationState(options) {}
+
+/// Return `true` if the given OpResult has been decided to bufferize inplace.
+bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const {
+  // OpOperands that bufferize to a memory write are out-of-place, i.e., an
+  // alloc and copy is inserted.
+  return !bufferizesToMemoryWrite(opOperand);
+}
+
+/// Return true if `v1` and `v2` bufferize to equivalent buffers.
+bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues(
+    Value v1, Value v2) const {
+  // There is no analysis, so we do not know if the values are equivalent. The
+  // conservative answer is "false".
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific scoped alloc/dealloc insertion support.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index f202a7a5bbd9c..d31c3532509e7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -207,9 +207,59 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationState &state) {
   // Bufferize the op and its nested ops.
   RewritePatternSet patterns(op->getContext());
-  patterns.add<BufferizationPattern>(op->getContext(), state);
+  populateBufferizationPattern(state, patterns);
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     return failure();
 
   return checkBufferizationResult(op, state.getOptions());
 }
+
+namespace {
+/// This a "no analysis, always copy" BufferizationState. In the absence of an
+/// analysis, a buffer must be copied each time it is written to. Therefore, all
+/// OpOperands that bufferize to a memory write must bufferize out-of-place.
+class AlwaysCopyBufferizationState : public BufferizationState {
+public:
+  AlwaysCopyBufferizationState(const BufferizationOptions &options)
+      : BufferizationState(options) {}
+
+  AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
+
+  virtual ~AlwaysCopyBufferizationState() = default;
+
+  /// Return `true` if the given OpResult has been decided to bufferize inplace.
+  bool isInPlace(OpOperand &opOperand) const override {
+    // OpOperands that bufferize to a memory write are out-of-place, i.e., an
+    // alloc and copy is inserted.
+    return !bufferizesToMemoryWrite(opOperand);
+  }
+
+  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+  bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
+    // There is no analysis, so we do not know if the values are equivalent. The
+    // conservative answer is "false".
+    return false;
+  }
+};
+} // namespace
+
+LogicalResult bufferization::bufferizeOp(Operation *op,
+                                         const BufferizationOptions &options) {
+  AlwaysCopyBufferizationState state(options);
+  return bufferizeOp(op, state);
+}
+
+void bufferization::populateBufferizationPattern(
+    const BufferizationState &state, RewritePatternSet &patterns) {
+  patterns.add<BufferizationPattern>(patterns.getContext(), state);
+}
+
+std::unique_ptr<BufferizationOptions>
+bufferization::getPartialBufferizationOptions() {
+  auto options = std::make_unique<BufferizationOptions>();
+  options->allowReturnMemref = true;
+  options->allowUnknownOps = true;
+  options->createDeallocs = false;
+  options->fullyDynamicLayoutMaps = false;
+  return options;
+}

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 0bd01dabdfef5..1d435c64d9287 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -13,223 +13,40 @@
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "PassDetail.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
+using namespace bufferization;
 
 namespace {
-struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto resultType = getTypeConverter()->convertType(op.getType());
-    rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
-                                                adaptor.getOperands()[0]);
-    return success();
-  }
-};
-
-struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
-                                               adaptor.index());
-    return success();
-  }
-};
-
-struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
-                                                adaptor.indices());
-    return success();
-  }
-};
-
-struct BufferizeFromElementsOp
-    : public OpConversionPattern<tensor::FromElementsOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    auto tensorType = op.getType().cast<RankedTensorType>();
-    auto shape = tensorType.getShape();
-
-    // Allocate a buffer for the result.
-    auto resultType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-    Value buffer = rewriter.create<memref::AllocOp>(loc, resultType);
-
-    // Case: tensor<0xelem_type>.
-    if (op.elements().empty()) {
-      rewriter.replaceOp(op, {buffer});
-      return success();
-    }
-
-    // Case: tensor<elem_type>.
-    if (shape.empty()) {
-      rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer);
-      rewriter.replaceOp(op, {buffer});
-      return success();
-    }
-
-    // Create constants for the range of possible indices [0, max{shape_i}).
-    auto maxDim = *std::max_element(shape.begin(), shape.end());
-    SmallVector<Value, 2> constants;
-    constants.reserve(maxDim);
-    for (int i = 0; i < maxDim; ++i)
-      constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
-
-    // Traverse all `elements` and create `memref.store` ops.
-    ImplicitLocOpBuilder b(loc, rewriter);
-    auto elementIt = adaptor.elements().begin();
-    SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
-    createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b);
-
-    rewriter.replaceOp(op, {buffer});
-    return success();
-  }
-
-private:
-  // Implements backtracking to traverse indices of the output buffer while
-  // iterating over op.elements().
-  void createStores(int dim, Value buffer, ArrayRef<int64_t> shape,
-                    ArrayRef<Value> constants, ValueRange::iterator &elementIt,
-                    SmallVectorImpl<Value> &indices,
-                    ImplicitLocOpBuilder b) const {
-    if (dim == static_cast<int>(shape.size()) - 1) {
-      for (int i = 0; i < shape.back(); ++i) {
-        indices.back() = constants[i];
-        b.create<memref::StoreOp>(*elementIt, buffer, indices);
-        ++elementIt;
-      }
-      return;
-    }
-    for (int i = 0; i < shape[dim]; ++i) {
-      indices[dim] = constants[i];
-      createStores(dim + 1, buffer, shape, constants, elementIt, indices, b);
-    }
-  }
-};
-
-struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    // Allocate memory.
-    Location loc = op.getLoc();
-    RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
-    MemRefType memrefType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-    Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
-                                                    adaptor.dynamicExtents());
-
-    // Collect loop bounds.
-    int64_t rank = tensorType.getRank();
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    SmallVector<Value, 4> lowerBounds(rank, zero);
-    SmallVector<Value, 4> steps(rank, one);
-    SmallVector<Value, 4> upperBounds;
-    int nextDynamicIndex = 0;
-    for (int i = 0; i < rank; i++) {
-      Value upperBound = tensorType.isDynamicDim(i)
-                             ? adaptor.dynamicExtents()[nextDynamicIndex++]
-                             : rewriter.create<arith::ConstantIndexOp>(
-                                   loc, memrefType.getDimSize(i));
-      upperBounds.push_back(upperBound);
-    }
-
-    // Generate tensor elements with a parallel loop that stores into
-    // each element of the resulting memref.
-    //
-    // This is a bit tricky. We cannot simply clone the ops because when an op
-    // is cloned, it must be legalized. However, we want to allow arbitrary ops
-    // in the body that we don't necessarily have legalization patterns for as
-    // part of this dialect conversion invocation.
-    //
-    // To accomplish this, we use mergeBlockBefore to "move" this op's body
-    // into the scf.parallel's body.
-    auto parallel =
-        rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
-    Block *parallelBody = parallel.getBody();
-    rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
-                              parallelBody->getArguments());
-    // Replace the inlined yield op with a store op. The scf.parallel's builder
-    // already populated an scf.yield at the end, so we don't need to worry
-    // about creating that.
-    Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
-    rewriter.setInsertionPointAfter(elementYield);
-    rewriter.replaceOpWithNewOp<memref::StoreOp>(
-        elementYield, elementYield->getOperands()[0], result,
-        parallelBody->getArguments());
-
-    rewriter.replaceOp(op, {result});
-    return success();
-  }
-};
-
-struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
-                                                adaptor.tensor());
-    return success();
-  }
-};
-
 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
   void runOnOperation() override {
-    auto *context = &getContext();
-    bufferization::BufferizeTypeConverter typeConverter;
+    std::unique_ptr<BufferizationOptions> options =
+        getPartialBufferizationOptions();
+    options->addToDialectFilter<tensor::TensorDialect>();
 
-    ConversionTarget target(*context);
-    target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
-    target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
-                                      StandardOpsDialect>(
-        [&](Operation *op) { return typeConverter.isLegal(op); });
-    target.addLegalOp<CallOp, ReturnOp>();
-    target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
-                        tensor::FromElementsOp, tensor::GenerateOp>();
-    bufferization::populateBufferizeMaterializationLegality(target);
-
-    RewritePatternSet patterns(context);
-    populateTensorBufferizePatterns(typeConverter, patterns);
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(bufferizeOp(getOperation(), *options)))
       signalPassFailure();
   }
-};
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    tensor::TensorDialect, scf::SCFDialect,
+                    arith::ArithmeticDialect>();
+    tensor::registerBufferizableOpInterfaceExternalModels(registry);
+  }
+};
 } // namespace
 
-void mlir::populateTensorBufferizePatterns(
-    bufferization::BufferizeTypeConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
-               BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
-      typeConverter, patterns.getContext());
-}
-
 std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
   return std::make_unique<TensorBufferizePass>();
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 42734ae1e9ad5..c8dcf748df11d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1355,55 +1355,3 @@ func @write_after_select_read_one(
   // CHECK: return %[[f]], %[[select]]
   return %f, %w : f32, tensor<?xf32>
 }
-
-// -----
-
-// CHECK-LABEL: func @tensor_rank(
-//  CHECK-SAME:     %[[arg0:.*]]: memref<*xf32>
-func @tensor_rank(%arg0: tensor<*xf32>) -> index {
-  // CHECK: %[[r:.*]] = memref.rank %[[arg0]]
-  %0 = tensor.rank %arg0 : tensor<*xf32>
-  // CHECK: return %[[r]] : index
-  return %0 : index
-}
-
-// -----
-
-// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
-//  CHECK-SAME:     %[[arg0:.*]]: index
-func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
-  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
-  // CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
-  // CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
-  %result = tensor.generate %arg0 {
-  ^bb0(%i: index, %j: index):
-    %sum = arith.addi %i, %j : index
-    // CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
-    // CHECK: scf.yield
-    tensor.yield %sum : index
-  } : tensor<16x?xindex>
-  // CHECK: }
-  return %result : tensor<16x?xindex>
-}
-
-// -----
-
-// CHECK-LABEL: func @tensor_from_elements_2d(
-//  CHECK-SAME:     %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
-func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
-  //     CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
-  //     CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
-  //     CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
-  //     CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
-  //     CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
-  //     CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
-  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
-         : tensor<3x2xindex>
-  //     CHECK: return %[[MEMREF]]
-  return %0 : tensor<3x2xindex>
-}

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c6dd6b9310d92..b0415ce1464ce 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
 
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+
 // CHECK-LABEL:   func @dim(
 // CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
 // CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
@@ -66,8 +68,7 @@ func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
 }
 
 // CHECK-LABEL:   func @tensor.from_elements_no_elements() -> tensor<0xindex> {
-// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<0xindex>
-// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK:           %[[RET:.*]] = arith.constant dense<> : tensor<0xindex>
 // CHECK:           return %[[RET]] : tensor<0xindex>
 func @tensor.from_elements_no_elements() -> tensor<0xindex> {
   %0 = tensor.from_elements : tensor<0xindex>
@@ -76,7 +77,7 @@ func @tensor.from_elements_no_elements() -> tensor<0xindex> {
 
 // CHECK-LABEL:   func @tensor.from_elements_0d(
 // CHECK-SAME:        %[[ELEM0:.*]]: index) -> tensor<index> {
-// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<index>
+// CHECK:           %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
 // CHECK:           store %[[ELEM0]], %[[MEMREF]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
 // CHECK:           return %[[RET]] : tensor<index>
@@ -88,9 +89,9 @@ func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
 // CHECK-LABEL:   func @tensor.from_elements_1d(
 // CHECK-SAME:                               %[[ELEM0:.*]]: index,
 // CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
-// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<2xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
 // CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
 // CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
@@ -103,10 +104,10 @@ func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
 // CHECK-LABEL: func @tensor.from_elements_2d(
 // CHECK-SAME:      %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
 // CHECK-SAME:      -> tensor<3x2xindex> {
-// CHECK:         %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex>
-// CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:         %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
 // CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
 // CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
 // CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
@@ -121,9 +122,9 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
   return %0 : tensor<3x2xindex>
 }
 
-// CHECK-LABEL: func @tensor.from_elements_3d()
+// CHECK-LABEL: func @tensor.from_elements_3d(
+//  CHECK-SAME:     %[[F0:.*]]: f32
 
-// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
 // CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
 // CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
 // CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
@@ -136,11 +137,11 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
 // CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
 // CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
 
-// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
 
 // CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
 // CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
@@ -157,8 +158,7 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
 
 // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
 // CHECK: return %[[RET]] : tensor<3x2x2xf32>
-func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
-  %f0 = arith.constant 0.0 : f32
+func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
   %f1 = arith.constant 1.0 : f32
   %f2 = arith.constant 2.0 : f32
   %f3 = arith.constant 3.0 : f32
@@ -179,9 +179,9 @@ func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
 // CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
 // CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
 // CHECK:           %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
-// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
 // CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
 // CHECK:             %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
 // CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
@@ -203,11 +203,11 @@ func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xi
 // extents.
 //
 // CHECK-LABEL:   func @tensor.generate_static_and_dynamic(
-// CHECK-SAME:                                                          %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
-// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex>
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
-// CHECK:           %[[C16:.*]] = arith.constant 16 : index
+// CHECK-SAME:        %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
 // CHECK:           scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
 // CHECK:             %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
 // CHECK:             store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
@@ -225,12 +225,6 @@ func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
   return %result : tensor<16x?xindex>
 }
 
-// The tensor.generate op needs to put its body into the
-// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
-// the body because that would require the cloned ops to be legalized
-// immediately, which is usually not possible since they might be from various
-// other dialects.
-//
 // CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
 func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
   // CHECK-NOT: tensor.generate
@@ -242,3 +236,68 @@ func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
   } : tensor<?xindex>
   return %tensor : tensor<?xindex>
 }
+
+// CHECK-LABEL: func @tensor.extract_slice(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index
+func @tensor.extract_slice(
+    %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
+  // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
+  // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP]]>
+  %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
+      : tensor<?x?xf32> to tensor<?x10xf32>
+  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+  // CHECK: return %[[r_tensor]]
+  return %0 : tensor<?x10xf32>
+}
+
+// CHECK-LABEL: func @tensor.extract_slice_rank_reducing(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index,
+//  CHECK-SAME:     %[[idx2:.*]]: index
+func @tensor.extract_slice_rank_reducing(
+    %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
+  // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP]]>
+  %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
+      : tensor<?x10x?xf32> to tensor<?x15xf32>
+  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+  // CHECK: return %[[r_tensor]]
+  return %0 : tensor<?x15xf32>
+}
+
+// CHECK-LABEL: func @tensor.insert_slice(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
+//  CHECK-SAME:     %[[idx1:.*]]: index, %[[idx2:.*]]: index
+func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
+                          %idx1: index, %idx2: index) -> tensor<?x?xf32> {
+  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
+  // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
+  //     CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
+  //     CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
+  //     CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
+  //     CHECK: memref.copy %[[m1]], %[[alloc]]
+  //     CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
+  //     CHECK: memref.copy %[[m2]], %[[subview]]
+  %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1]
+      : tensor<?x10xf32> into tensor<?x?xf32>
+
+  //     CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
+  //     CHECK: return %[[r]]
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @tensor.insert(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
+//  CHECK-SAME:     %[[f:.*]]: f32
+func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
+  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
+  // CHECK: memref.copy %[[m1]], %[[alloc]]
+  // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
+  %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
+  // CHECK: return %[[r]]
+  return %0 : tensor<5xf32>
+}


        


More information about the Mlir-commits mailing list