[Mlir-commits] [mlir] fc08d1c - [mlir][tensor][bufferize] Support tensor.rank in BufferizableOpInterfaceImpl

Matthias Springer llvmlistbot at llvm.org
Mon Jan 24 07:31:54 PST 2022


Author: Matthias Springer
Date: 2022-01-25T00:31:20+09:00
New Revision: fc08d1c2940609d26a534d7a12e6c6a528891830

URL: https://github.com/llvm/llvm-project/commit/fc08d1c2940609d26a534d7a12e6c6a528891830
DIFF: https://github.com/llvm/llvm-project/commit/fc08d1c2940609d26a534d7a12e6c6a528891830.diff

LOG: [mlir][tensor][bufferize] Support tensor.rank in BufferizableOpInterfaceImpl

This is the only op that is not supported via BufferizableOpInterfaceImpl bufferization. Once this op is supported we can switch `tensor-bufferize` over to the new unified bufferization.

Differential Revision: https://reviews.llvm.org/D117985

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index f4a2a5d692152..0fe79862a69d0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -457,16 +457,22 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
     Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc());
     OpBuilder b(funcOp->getContext());
     b.setInsertionPointToStart(&frontBlock);
-    // Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
+    // Replace all uses of bbArg through a ToMemRefOp.
     for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
       if (auto toMemrefOp =
               dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
-        assert(memref::CastOp::areCastCompatible(
-                   memref.getType(), toMemrefOp.memref().getType()) &&
-               "bufferizeFuncOpBoundary: cast incompatible");
-        auto castOp = b.create<memref::CastOp>(
-            funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
-        toMemrefOp.memref().replaceAllUsesWith(castOp);
+        if (memref.getType() != toMemrefOp.memref().getType()) {
+          // Type has changed, insert a cast.
+          assert(memref::CastOp::areCastCompatible(
+                     memref.getType(), toMemrefOp.memref().getType()) &&
+                 "bufferizeFuncOpBoundary: cast incompatible");
+          auto castOp = b.create<memref::CastOp>(
+              funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
+          toMemrefOp.memref().replaceAllUsesWith(castOp);
+        } else {
+          // Type did not change, replace directly.
+          toMemrefOp.memref().replaceAllUsesWith(memref);
+        }
       }
     }
     // Replace all remaining uses by a to_tensor.

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 03fa45f04c68f..ea9d885736f90 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -463,6 +463,35 @@ struct InsertSliceOpInterface
   }
 };
 
+/// Bufferization of tensor.rank. Replace with memref.rank.
+struct RankOpInterface
+    : public BufferizableOpInterface::ExternalModel<RankOpInterface,
+                                                    tensor::RankOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const BufferizationState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return false;
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return OpResult();
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto rankOp = cast<tensor::RankOp>(op);
+    Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
+    replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
+                                                 v);
+    return success();
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -475,4 +504,5 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
   registry.addOpInterface<InsertOp, InsertOpInterface>();
   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
+  registry.addOpInterface<RankOp, RankOpInterface>();
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index a739fc4645ed0..1f301a14c11e1 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1348,3 +1348,14 @@ func @write_after_select_read_one(
   // CHECK: return %[[f]], %[[select]]
   return %f, %w : f32, tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor_rank(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<*xf32>
+func @tensor_rank(%arg0: tensor<*xf32>) -> index {
+  // CHECK: %[[r:.*]] = memref.rank %[[arg0]]
+  %0 = tensor.rank %arg0 : tensor<*xf32>
+  // CHECK: return %[[r]] : index
+  return %0 : index
+}


        


More information about the Mlir-commits mailing list