[Mlir-commits] [mlir] 3444996 - [mlir] Add a pattern to bufferize std.index_cast.
Alexander Belyaev
llvmlistbot at llvm.org
Fri May 7 12:33:25 PDT 2021
Author: Alexander Belyaev
Date: 2021-05-07T21:32:02+02:00
New Revision: 3444996b4c45f6efdd731100e8ca6c6105407045
URL: https://github.com/llvm/llvm-project/commit/3444996b4c45f6efdd731100e8ca6c6105407045
DIFF: https://github.com/llvm/llvm-project/commit/3444996b4c45f6efdd731100e8ca6c6105407045.diff
LOG: [mlir] Add a pattern to bufferize std.index_cast.
Differential Revision: https://reviews.llvm.org/D102088
Added:
Modified:
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 040bdc81f23b4..ad5bf057204b2 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -34,9 +34,22 @@ class BufferizeDimOp : public OpConversionPattern<memref::DimOp> {
return success();
}
};
-} // namespace
-namespace {
+class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IndexCastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ IndexCastOp::Adaptor adaptor(operands);
+ auto tensorType = op.getType().cast<RankedTensorType>();
+ rewriter.replaceOpWithNewOp<IndexCastOp>(
+ op, adaptor.in(),
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
+ return success();
+ }
+};
+
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -56,8 +69,8 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
- patterns.getContext());
+ patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
+ typeConverter, patterns.getContext());
}
namespace {
@@ -68,14 +81,15 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- target.addLegalDialect<memref::MemRefDialect>();
- target.addLegalDialect<StandardOpsDialect>();
- target.addLegalDialect<scf::SCFDialect>();
+ target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
+ memref::MemRefDialect>();
populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
+ target.addDynamicallyLegalOp<IndexCastOp>(
+ [&](IndexCastOp op) { return typeConverter.isLegal(op.getType()); });
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
return typeConverter.isLegal(op.getType()) ||
!op.condition().getType().isa<IntegerType>();
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 7ed51ca9293ff..cc6725a8c1a05 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -24,3 +24,16 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0 = select %arg0, %arg1, %arg2 : tensor<f32>
return %0 : tensor<f32>
}
+
+// CHECK-LABEL: func @index_cast(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
+func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
+ %index_tensor = index_cast %tensor : tensor<i32> to tensor<index>
+ %index_scalar = index_cast %scalar : i32 to index
+ return %index_tensor, %index_scalar : tensor<index>, index
+}
+// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<i32>
+// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = index_cast %[[MEMREF]]
+// CHECK-SAME: memref<i32> to memref<index>
+// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = memref.tensor_load %[[INDEX_MEMREF]]
+// CHECK: return %[[INDEX_TENSOR]]
More information about the Mlir-commits
mailing list