[Mlir-commits] [mlir] daf1810 - [mlir][tensor] Replace tensor-bufferize with BufferizableOpInterface impl
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 27 02:30:58 PST 2022
Author: Matthias Springer
Date: 2022-01-27T19:30:45+09:00
New Revision: daf18108ecc959ce20c25af33618c934b470f350
URL: https://github.com/llvm/llvm-project/commit/daf18108ecc959ce20c25af33618c934b470f350
DIFF: https://github.com/llvm/llvm-project/commit/daf18108ecc959ce20c25af33618c934b470f350.diff
LOG: [mlir][tensor] Replace tensor-bufferize with BufferizableOpInterface impl
This commit switches the `tensor-bufferize` pass over to BufferizableOpInterface-based bufferization.
Differential Revision: https://reviews.llvm.org/D118246
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 5107710413ead..534f664483be2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -255,7 +255,7 @@ class BufferizationState {
const BufferizationOptions &getOptions() const { return options; }
protected:
- BufferizationState(const BufferizationOptions &options);
+ explicit BufferizationState(const BufferizationOptions &options);
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@@ -270,6 +270,24 @@ class BufferizationState {
const BufferizationOptions &options;
};
+/// This a "no analysis, always copy" BufferizationState. In the absence of an
+/// analysis, a buffer must be copied each time it is written to. Therefore, all
+/// OpOperands that bufferize to a memory write must bufferize out-of-place.
+class AlwaysCopyBufferizationState : public BufferizationState {
+public:
+ explicit AlwaysCopyBufferizationState(const BufferizationOptions &options);
+
+ AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
+
+ virtual ~AlwaysCopyBufferizationState() = default;
+
+ /// Return `true` if the given OpResult has been decided to bufferize inplace.
+ bool isInPlace(OpOperand &opOperand) const override;
+
+ /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+ bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
+};
+
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index b587955d65b13..a56287995aa96 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -69,6 +69,21 @@ void populateEliminateBufferizeMaterializationsPatterns(
// TODO: Extract `options` from `state` and pass as separate argument.
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
+/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
+/// Buffers are duplicated and copied before any tensor use that bufferizes to
+/// a memory write.
+///
+/// Note: This function bufferizes ops without utilizing analysis results. It
+/// can be used to implement partial bufferization passes.
+LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
+
+/// Populate the pattern set with a pattern that bufferizes ops that implement
+/// `BufferizableOpInterface`.
+void populateBufferizationPattern(const BufferizationState &state,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<BufferizationOptions> getPartialBufferizationOptions();
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index f90d02bda22d3..a346577e2f569 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -12,16 +12,6 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
-namespace bufferization {
-class BufferizeTypeConverter;
-} // namespace bufferization
-
-class RewritePatternSet;
-
-void populateTensorBufferizePatterns(
- bufferization::BufferizeTypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
/// Creates an instance of `tensor` dialect bufferization pass.
std::unique_ptr<Pass> createTensorBufferizePass();
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
index 57c754c84c321..77743134325f5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
@@ -14,11 +14,6 @@ include "mlir/Pass/PassBase.td"
def TensorBufferize : Pass<"tensor-bufferize", "FuncOp"> {
let summary = "Bufferize the `tensor` dialect";
let constructor = "mlir::createTensorBufferizePass()";
- let dependentDialects = [
- "bufferization::BufferizationDialect",
- "memref::MemRefDialect",
- "scf::SCFDialect"
- ];
}
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 276950d4ef19e..3af1c37594f96 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -318,6 +318,25 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
rewriter.eraseOp(op);
}
+AlwaysCopyBufferizationState::AlwaysCopyBufferizationState(
+ const BufferizationOptions &options)
+ : BufferizationState(options) {}
+
+/// Return `true` if the given OpResult has been decided to bufferize inplace.
+bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const {
+ // OpOperands that bufferize to a memory write are out-of-place, i.e., an
+ // alloc and copy is inserted.
+ return !bufferizesToMemoryWrite(opOperand);
+}
+
+/// Return true if `v1` and `v2` bufferize to equivalent buffers.
+bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues(
+ Value v1, Value v2) const {
+ // There is no analysis, so we do not know if the values are equivalent. The
+ // conservative answer is "false".
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index f202a7a5bbd9c..d31c3532509e7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -207,9 +207,59 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationState &state) {
// Bufferize the op and its nested ops.
RewritePatternSet patterns(op->getContext());
- patterns.add<BufferizationPattern>(op->getContext(), state);
+ populateBufferizationPattern(state, patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return failure();
return checkBufferizationResult(op, state.getOptions());
}
+
+namespace {
+/// This a "no analysis, always copy" BufferizationState. In the absence of an
+/// analysis, a buffer must be copied each time it is written to. Therefore, all
+/// OpOperands that bufferize to a memory write must bufferize out-of-place.
+class AlwaysCopyBufferizationState : public BufferizationState {
+public:
+ AlwaysCopyBufferizationState(const BufferizationOptions &options)
+ : BufferizationState(options) {}
+
+ AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
+
+ virtual ~AlwaysCopyBufferizationState() = default;
+
+ /// Return `true` if the given OpResult has been decided to bufferize inplace.
+ bool isInPlace(OpOperand &opOperand) const override {
+ // OpOperands that bufferize to a memory write are out-of-place, i.e., an
+ // alloc and copy is inserted.
+ return !bufferizesToMemoryWrite(opOperand);
+ }
+
+ /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+ bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
+ // There is no analysis, so we do not know if the values are equivalent. The
+ // conservative answer is "false".
+ return false;
+ }
+};
+} // namespace
+
+LogicalResult bufferization::bufferizeOp(Operation *op,
+ const BufferizationOptions &options) {
+ AlwaysCopyBufferizationState state(options);
+ return bufferizeOp(op, state);
+}
+
+void bufferization::populateBufferizationPattern(
+ const BufferizationState &state, RewritePatternSet &patterns) {
+ patterns.add<BufferizationPattern>(patterns.getContext(), state);
+}
+
+std::unique_ptr<BufferizationOptions>
+bufferization::getPartialBufferizationOptions() {
+ auto options = std::make_unique<BufferizationOptions>();
+ options->allowReturnMemref = true;
+ options->allowUnknownOps = true;
+ options->createDeallocs = false;
+ options->fullyDynamicLayoutMaps = false;
+ return options;
+}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 0bd01dabdfef5..1d435c64d9287 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -13,223 +13,40 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
+using namespace bufferization;
namespace {
-struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto resultType = getTypeConverter()->convertType(op.getType());
- rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
- adaptor.getOperands()[0]);
- return success();
- }
-};
-
-struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
- adaptor.index());
- return success();
- }
-};
-
-struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
- adaptor.indices());
- return success();
- }
-};
-
-struct BufferizeFromElementsOp
- : public OpConversionPattern<tensor::FromElementsOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- auto tensorType = op.getType().cast<RankedTensorType>();
- auto shape = tensorType.getShape();
-
- // Allocate a buffer for the result.
- auto resultType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- Value buffer = rewriter.create<memref::AllocOp>(loc, resultType);
-
- // Case: tensor<0xelem_type>.
- if (op.elements().empty()) {
- rewriter.replaceOp(op, {buffer});
- return success();
- }
-
- // Case: tensor<elem_type>.
- if (shape.empty()) {
- rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer);
- rewriter.replaceOp(op, {buffer});
- return success();
- }
-
- // Create constants for the range of possible indices [0, max{shape_i}).
- auto maxDim = *std::max_element(shape.begin(), shape.end());
- SmallVector<Value, 2> constants;
- constants.reserve(maxDim);
- for (int i = 0; i < maxDim; ++i)
- constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
-
- // Traverse all `elements` and create `memref.store` ops.
- ImplicitLocOpBuilder b(loc, rewriter);
- auto elementIt = adaptor.elements().begin();
- SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
- createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b);
-
- rewriter.replaceOp(op, {buffer});
- return success();
- }
-
-private:
- // Implements backtracking to traverse indices of the output buffer while
- // iterating over op.elements().
- void createStores(int dim, Value buffer, ArrayRef<int64_t> shape,
- ArrayRef<Value> constants, ValueRange::iterator &elementIt,
- SmallVectorImpl<Value> &indices,
- ImplicitLocOpBuilder b) const {
- if (dim == static_cast<int>(shape.size()) - 1) {
- for (int i = 0; i < shape.back(); ++i) {
- indices.back() = constants[i];
- b.create<memref::StoreOp>(*elementIt, buffer, indices);
- ++elementIt;
- }
- return;
- }
- for (int i = 0; i < shape[dim]; ++i) {
- indices[dim] = constants[i];
- createStores(dim + 1, buffer, shape, constants, elementIt, indices, b);
- }
- }
-};
-
-struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- // Allocate memory.
- Location loc = op.getLoc();
- RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
- MemRefType memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
- adaptor.dynamicExtents());
-
- // Collect loop bounds.
- int64_t rank = tensorType.getRank();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- SmallVector<Value, 4> lowerBounds(rank, zero);
- SmallVector<Value, 4> steps(rank, one);
- SmallVector<Value, 4> upperBounds;
- int nextDynamicIndex = 0;
- for (int i = 0; i < rank; i++) {
- Value upperBound = tensorType.isDynamicDim(i)
- ? adaptor.dynamicExtents()[nextDynamicIndex++]
- : rewriter.create<arith::ConstantIndexOp>(
- loc, memrefType.getDimSize(i));
- upperBounds.push_back(upperBound);
- }
-
- // Generate tensor elements with a parallel loop that stores into
- // each element of the resulting memref.
- //
- // This is a bit tricky. We cannot simply clone the ops because when an op
- // is cloned, it must be legalized. However, we want to allow arbitrary ops
- // in the body that we don't necessarily have legalization patterns for as
- // part of this dialect conversion invocation.
- //
- // To accomplish this, we use mergeBlockBefore to "move" this op's body
- // into the scf.parallel's body.
- auto parallel =
- rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
- Block *parallelBody = parallel.getBody();
- rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
- parallelBody->getArguments());
- // Replace the inlined yield op with a store op. The scf.parallel's builder
- // already populated an scf.yield at the end, so we don't need to worry
- // about creating that.
- Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
- rewriter.setInsertionPointAfter(elementYield);
- rewriter.replaceOpWithNewOp<memref::StoreOp>(
- elementYield, elementYield->getOperands()[0], result,
- parallelBody->getArguments());
-
- rewriter.replaceOp(op, {result});
- return success();
- }
-};
-
-struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
- adaptor.tensor());
- return success();
- }
-};
-
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnOperation() override {
- auto *context = &getContext();
- bufferization::BufferizeTypeConverter typeConverter;
+ std::unique_ptr<BufferizationOptions> options =
+ getPartialBufferizationOptions();
+ options->addToDialectFilter<tensor::TensorDialect>();
- ConversionTarget target(*context);
- target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
- target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
- StandardOpsDialect>(
- [&](Operation *op) { return typeConverter.isLegal(op); });
- target.addLegalOp<CallOp, ReturnOp>();
- target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
- tensor::FromElementsOp, tensor::GenerateOp>();
- bufferization::populateBufferizeMaterializationLegality(target);
-
- RewritePatternSet patterns(context);
- populateTensorBufferizePatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(bufferizeOp(getOperation(), *options)))
signalPassFailure();
}
-};
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+ tensor::TensorDialect, scf::SCFDialect,
+ arith::ArithmeticDialect>();
+ tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ }
+};
} // namespace
-void mlir::populateTensorBufferizePatterns(
- bufferization::BufferizeTypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
- BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
- typeConverter, patterns.getContext());
-}
-
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 42734ae1e9ad5..c8dcf748df11d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1355,55 +1355,3 @@ func @write_after_select_read_one(
// CHECK: return %[[f]], %[[select]]
return %f, %w : f32, tensor<?xf32>
}
-
-// -----
-
-// CHECK-LABEL: func @tensor_rank(
-// CHECK-SAME: %[[arg0:.*]]: memref<*xf32>
-func @tensor_rank(%arg0: tensor<*xf32>) -> index {
- // CHECK: %[[r:.*]] = memref.rank %[[arg0]]
- %0 = tensor.rank %arg0 : tensor<*xf32>
- // CHECK: return %[[r]] : index
- return %0 : index
-}
-
-// -----
-
-// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
-// CHECK-SAME: %[[arg0:.*]]: index
-func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
- // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
- // CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
- // CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
- %result = tensor.generate %arg0 {
- ^bb0(%i: index, %j: index):
- %sum = arith.addi %i, %j : index
- // CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
- // CHECK: scf.yield
- tensor.yield %sum : index
- } : tensor<16x?xindex>
- // CHECK: }
- return %result : tensor<16x?xindex>
-}
-
-// -----
-
-// CHECK-LABEL: func @tensor_from_elements_2d(
-// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
-func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
- // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
- // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
- // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
- // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
- // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
- // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
- %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
- : tensor<3x2xindex>
- // CHECK: return %[[MEMREF]]
- return %0 : tensor<3x2xindex>
-}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c6dd6b9310d92..b0415ce1464ce 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
@@ -66,8 +68,7 @@ func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
}
// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> {
-// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<0xindex>
-// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK: %[[RET:.*]] = arith.constant dense<> : tensor<0xindex>
// CHECK: return %[[RET]] : tensor<0xindex>
func @tensor.from_elements_no_elements() -> tensor<0xindex> {
%0 = tensor.from_elements : tensor<0xindex>
@@ -76,7 +77,7 @@ func @tensor.from_elements_no_elements() -> tensor<0xindex> {
// CHECK-LABEL: func @tensor.from_elements_0d(
// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> {
-// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<index>
+// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
// CHECK: store %[[ELEM0]], %[[MEMREF]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
// CHECK: return %[[RET]] : tensor<index>
@@ -88,9 +89,9 @@ func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
// CHECK-LABEL: func @tensor.from_elements_1d(
// CHECK-SAME: %[[ELEM0:.*]]: index,
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
-// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<2xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
@@ -103,10 +104,10 @@ func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
// CHECK-LABEL: func @tensor.from_elements_2d(
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
// CHECK-SAME: -> tensor<3x2xindex> {
-// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex>
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
@@ -121,9 +122,9 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
return %0 : tensor<3x2xindex>
}
-// CHECK-LABEL: func @tensor.from_elements_3d()
+// CHECK-LABEL: func @tensor.from_elements_3d(
+// CHECK-SAME: %[[F0:.*]]: f32
-// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
@@ -136,11 +137,11 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
-// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
@@ -157,8 +158,7 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
// CHECK: return %[[RET]] : tensor<3x2x2xf32>
-func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
- %f0 = arith.constant 0.0 : f32
+func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
%f1 = arith.constant 1.0 : f32
%f2 = arith.constant 2.0 : f32
%f3 = arith.constant 3.0 : f32
@@ -179,9 +179,9 @@ func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
-// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
@@ -203,11 +203,11 @@ func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xi
// extents.
//
// CHECK-LABEL: func @tensor.generate_static_and_dynamic(
-// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
-// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex>
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
@@ -225,12 +225,6 @@ func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
return %result : tensor<16x?xindex>
}
-// The tensor.generate op needs to put its body into the
-// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
-// the body because that would require the cloned ops to be legalized
-// immediately, which is usually not possible since they might be from various
-// other dialects.
-//
// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
// CHECK-NOT: tensor.generate
@@ -242,3 +236,68 @@ func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
} : tensor<?xindex>
return %tensor : tensor<?xindex>
}
+
+// CHECK-LABEL: func @tensor.extract_slice(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index
+func @tensor.extract_slice(
+ %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
+ // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
+ // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP]]>
+ %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
+ : tensor<?x?xf32> to tensor<?x10xf32>
+ // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+ // CHECK: return %[[r_tensor]]
+ return %0 : tensor<?x10xf32>
+}
+
+// CHECK-LABEL: func @tensor.extract_slice_rank_reducing(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index,
+// CHECK-SAME: %[[idx2:.*]]: index
+func @tensor.extract_slice_rank_reducing(
+ %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
+ // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP]]>
+ %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
+ : tensor<?x10x?xf32> to tensor<?x15xf32>
+ // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+ // CHECK: return %[[r_tensor]]
+ return %0 : tensor<?x15xf32>
+}
+
+// CHECK-LABEL: func @tensor.insert_slice(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
+// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index
+func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
+ %idx1: index, %idx2: index) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
+ // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
+ // CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
+ // CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
+ // CHECK: memref.copy %[[m1]], %[[alloc]]
+ // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
+ // CHECK: memref.copy %[[m2]], %[[subview]]
+ %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1]
+ : tensor<?x10xf32> into tensor<?x?xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @tensor.insert(
+// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
+// CHECK-SAME: %[[f:.*]]: f32
+func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
+ // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
+ // CHECK: memref.copy %[[m1]], %[[alloc]]
+ // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
+ %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<5xf32>
+}
More information about the Mlir-commits
mailing list