[Mlir-commits] [mlir] a552deb - [mlir] Add patterns for vector.transfer_read/write to Linalg bufferization.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Aug 6 11:25:03 PDT 2021
Author: Alexander Belyaev
Date: 2021-08-06T20:24:44+02:00
New Revision: a552debdcf01d422ef8d28c0bdc153cff350710a
URL: https://github.com/llvm/llvm-project/commit/a552debdcf01d422ef8d28c0bdc153cff350710a
DIFF: https://github.com/llvm/llvm-project/commit/a552debdcf01d422ef8d28c0bdc153cff350710a.diff
LOG: [mlir] Add patterns for vector.transfer_read/write to Linalg bufferization.
Differential Revision: https://reviews.llvm.org/D107643
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/test/Dialect/Linalg/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index b918b98a76b80..cd0445f1ff87d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -309,6 +309,47 @@ class InsertSliceOpConverter
return success();
}
};
+
+class VectorTransferReadOpConverter
+ : public OpConversionPattern<vector::TransferReadOp> {
+public:
+ using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (readOp.getShapedType().isa<MemRefType>())
+ return failure();
+ vector::TransferReadOp::Adaptor adaptor(operands,
+ readOp->getAttrDictionary());
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
+ adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
+ adaptor.in_bounds());
+ return success();
+ }
+};
+
+class VectorTransferWriteOpConverter
+ : public OpConversionPattern<vector::TransferWriteOp> {
+public:
+ using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (writeOp.getShapedType().isa<MemRefType>())
+ return failure();
+ vector::TransferWriteOp::Adaptor adaptor(operands,
+ writeOp->getAttrDictionary());
+ rewriter.create<vector::TransferWriteOp>(
+ writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
+ adaptor.permutation_map(),
+ adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
+ rewriter.replaceOp(writeOp, adaptor.source());
+ return success();
+ }
+};
} // namespace
namespace {
@@ -332,10 +373,10 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
return typeConverter.isLegal(op);
};
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
- target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
+ target.addDynamicallyLegalOp<ConstantOp, vector::TransferReadOp,
+ vector::TransferWriteOp>(isLegalOperation);
RewritePatternSet patterns(&context);
- patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
populateLinalgBufferizePatterns(typeConverter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -359,7 +400,10 @@ void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeTensorReshapeOp<TensorExpandShapeOp>,
BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
ExtractSliceOpConverter,
- InsertSliceOpConverter
+ InsertSliceOpConverter,
+ VectorTransferReadOpConverter,
+ VectorTransferWriteOpConverter
>(typeConverter, patterns.getContext());
// clang-format on
+ patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index fa95deb1dbaa3..93dbf4a563675 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -299,3 +299,20 @@ func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tens
// CHECK: %[[OUT_TENSOR:.*]] = memref.tensor_load %[[OUT]] : memref<4x?x?x?xf32>
// CHECK: return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32>
// CHECK: }
+
+
+// -----
+
+// CHECK-LABEL: func @vector_transfer
+func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) {
+ %c0 = constant 0 : index
+ %cst = constant 0.000000e+00 : f32
+ %read = vector.transfer_read %in[%c0], %cst {in_bounds = [true]}
+ : tensor<4xf32>, vector<4xf32>
+ %tanh = math.tanh %read : vector<4xf32>
+ %write = vector.transfer_write %tanh, %out[%c0] {in_bounds = [true]}
+ : vector<4xf32>, tensor<4xf32>
+ return
+ // CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32>
+ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32>
+}
More information about the Mlir-commits
mailing list