[Mlir-commits] [mlir] 3444996 - [mlir] Add a pattern to bufferize std.index_cast.

Alexander Belyaev llvmlistbot at llvm.org
Fri May 7 12:33:25 PDT 2021


Author: Alexander Belyaev
Date: 2021-05-07T21:32:02+02:00
New Revision: 3444996b4c45f6efdd731100e8ca6c6105407045

URL: https://github.com/llvm/llvm-project/commit/3444996b4c45f6efdd731100e8ca6c6105407045
DIFF: https://github.com/llvm/llvm-project/commit/3444996b4c45f6efdd731100e8ca6c6105407045.diff

LOG: [mlir] Add a pattern to bufferize std.index_cast.

Differential Revision: https://reviews.llvm.org/D102088

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 040bdc81f23b4..ad5bf057204b2 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -34,9 +34,22 @@ class BufferizeDimOp : public OpConversionPattern<memref::DimOp> {
     return success();
   }
 };
-} // namespace
 
-namespace {
+class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(IndexCastOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    IndexCastOp::Adaptor adaptor(operands);
+    auto tensorType = op.getType().cast<RankedTensorType>();
+    rewriter.replaceOpWithNewOp<IndexCastOp>(
+        op, adaptor.in(),
+        MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
+    return success();
+  }
+};
+
 class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -56,8 +69,8 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
 
 void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
                                         RewritePatternSet &patterns) {
-  patterns.add<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
-                                                  patterns.getContext());
+  patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
+      typeConverter, patterns.getContext());
 }
 
 namespace {
@@ -68,14 +81,15 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    target.addLegalDialect<memref::MemRefDialect>();
-    target.addLegalDialect<StandardOpsDialect>();
-    target.addLegalDialect<scf::SCFDialect>();
+    target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
+                           memref::MemRefDialect>();
 
     populateStdBufferizePatterns(typeConverter, patterns);
     // We only bufferize the case of tensor selected type and scalar condition,
     // as that boils down to a select over memref descriptors (don't need to
     // touch the data).
+    target.addDynamicallyLegalOp<IndexCastOp>(
+        [&](IndexCastOp op) { return typeConverter.isLegal(op.getType()); });
     target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
       return typeConverter.isLegal(op.getType()) ||
              !op.condition().getType().isa<IntegerType>();

diff  --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 7ed51ca9293ff..cc6725a8c1a05 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -24,3 +24,16 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
   %0 = select %arg0, %arg1, %arg2 : tensor<f32>
   return %0 : tensor<f32>
 }
+
+// CHECK-LABEL:   func @index_cast(
+// CHECK-SAME:  %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
+func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
+  %index_tensor = index_cast %tensor : tensor<i32> to tensor<index>
+  %index_scalar = index_cast %scalar : i32 to index
+  return %index_tensor, %index_scalar : tensor<index>, index
+}
+// CHECK:  %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<i32>
+// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = index_cast %[[MEMREF]]
+// CHECK-SAME:   memref<i32> to memref<index>
+// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = memref.tensor_load %[[INDEX_MEMREF]]
+// CHECK: return %[[INDEX_TENSOR]]


        


More information about the Mlir-commits mailing list