[Mlir-commits] [mlir] a82a19c - [mlir] Add a missing pattern to bufferize tensor.rank.

Alexander Belyaev llvmlistbot at llvm.org
Tue Dec 14 11:05:22 PST 2021


Author: Alexander Belyaev
Date: 2021-12-14T20:04:57+01:00
New Revision: a82a19c137ad0b966847241c40546b3e145a17b5

URL: https://github.com/llvm/llvm-project/commit/a82a19c137ad0b966847241c40546b3e145a17b5
DIFF: https://github.com/llvm/llvm-project/commit/a82a19c137ad0b966847241c40546b3e145a17b5.diff

LOG: [mlir] Add a missing pattern to bufferize tensor.rank.

Differential Revision: https://reviews.llvm.org/D115745

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index d02328e4230db..0fd5b2d75d677 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -24,8 +24,7 @@
 using namespace mlir;
 
 namespace {
-class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
-public:
+struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
@@ -36,11 +35,8 @@ class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
-public:
+struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
@@ -50,11 +46,8 @@ class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
-public:
+struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
@@ -64,10 +57,8 @@ class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeFromElementsOp
+struct BufferizeFromElementsOp
     : public OpConversionPattern<tensor::FromElementsOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -88,11 +79,8 @@ class BufferizeFromElementsOp
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
-public:
+struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -150,44 +138,51 @@ class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
     return success();
   }
 };
-} // namespace
 
-void mlir::populateTensorBufferizePatterns(
-    bufferization::BufferizeTypeConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
-               BufferizeFromElementsOp, BufferizeGenerateOp>(
-      typeConverter, patterns.getContext());
-}
+struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
+                                                adaptor.tensor());
+    return success();
+  }
+};
 
-namespace {
 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
   void runOnFunction() override {
     auto *context = &getContext();
     bufferization::BufferizeTypeConverter typeConverter;
-    RewritePatternSet patterns(context);
-    ConversionTarget target(*context);
-
-    bufferization::populateBufferizeMaterializationLegality(target);
 
-    populateTensorBufferizePatterns(typeConverter, patterns);
-    target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
-                        tensor::FromElementsOp, tensor::GenerateOp>();
-    target.addLegalDialect<memref::MemRefDialect>();
+    ConversionTarget target(*context);
+    target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
     target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
                                       StandardOpsDialect>(
         [&](Operation *op) { return typeConverter.isLegal(op); });
-    target.addLegalOp<CallOp>();
-    target.addLegalOp<ReturnOp>();
-    target.addLegalDialect<scf::SCFDialect>();
+    target.addLegalOp<CallOp, ReturnOp>();
+    target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
+                        tensor::FromElementsOp, tensor::GenerateOp>();
+    bufferization::populateBufferizeMaterializationLegality(target);
 
+    RewritePatternSet patterns(context);
+    populateTensorBufferizePatterns(typeConverter, patterns);
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
   }
 };
+
 } // namespace
 
+void mlir::populateTensorBufferizePatterns(
+    bufferization::BufferizeTypeConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
+               BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
+      typeConverter, patterns.getContext());
+}
+
 std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
   return std::make_unique<TensorBufferizePass>();
 }

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 91642f06d0f26..5b3bb149d6180 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -11,6 +11,15 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: func @rank(
+// CHECK-SAME:    %[[TENSOR:.*]]: tensor<*xf32>) -> index {
+// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
+// CHECK:           %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
+func @rank(%arg0: tensor<*xf32>) -> index {
+  %0 = tensor.rank %arg0 : tensor<*xf32>
+  return %0 : index
+}
+
 // CHECK-LABEL:   func @tensor.cast(
 // CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
 // CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]


        


More information about the Mlir-commits mailing list