[Mlir-commits] [mlir] e287d64 - [mlir] Add translation from tensor.reshape to memref.reshape

Matthias Springer llvmlistbot at llvm.org
Mon May 9 08:45:13 PDT 2022


Author: Ashay Rane
Date: 2022-05-09T17:45:07+02:00
New Revision: e287d647c61f5fbd054410d6ed9c37d5271f29ef

URL: https://github.com/llvm/llvm-project/commit/e287d647c61f5fbd054410d6ed9c37d5271f29ef
DIFF: https://github.com/llvm/llvm-project/commit/e287d647c61f5fbd054410d6ed9c37d5271f29ef.diff

LOG: [mlir] Add translation from tensor.reshape to memref.reshape

This patch augments the `tensor-bufferize` pass by adding a conversion
rule to translate ReshapeOp from the `tensor` dialect to the `memref`
dialect, in addition to adding a unit test to validate the translation.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D125031

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 5be820536dab6..b00d87ba54034 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -743,6 +743,54 @@ struct RankOpInterface
   }
 };
 
+/// Bufferization of tensor.reshape. Replace with memref.reshape.
+struct ReshapeOpInterface
+    : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
+                                                    tensor::ReshapeOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    if (&opOperand == &op->getOpOperand(1) /* shape */)
+      return true;
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return false;
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    return {op->getOpResult(0)};
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const AnalysisState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          BufferizationState &state) const {
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    auto &srcOperand = reshapeOp->getOpOperand(0);
+    auto srcBuffer = state.getBuffer(rewriter, srcOperand);
+    if (failed(srcBuffer))
+      return failure();
+
+    auto &shapeOperand = reshapeOp->getOpOperand(1);
+    auto shapeBuffer = state.getBuffer(rewriter, shapeOperand);
+    if (failed(shapeBuffer))
+      return failure();
+
+    auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
+    auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
+
+    replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
+        rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
+    return success();
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -761,5 +809,6 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     InsertOp::attachInterface<InsertOpInterface>(*ctx);
     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
     RankOp::attachInterface<RankOpInterface>(*ctx);
+    ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
   });
 }

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 587508c698e3f..cd88f2fd4afab 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -408,3 +408,30 @@ func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: in
   %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
   return %ret: tensor<8xf32>
 }
+
+// CHECK-LABEL: func @tensor.reshape(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
+func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
+
+  // CHECK: %[[two:.*]] = arith.constant 2 : i64
+  %two = arith.constant 2 : i64
+  // CHECK: %[[five:.*]] = arith.constant 5 : i64
+  %five = arith.constant 5 : i64
+
+  // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3xi64>
+  // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index
+  // CHECK: %[[one_idx:.*]] = arith.constant 1 : index
+  // CHECK: %[[two_idx:.*]] = arith.constant 2 : index
+  // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64>
+  // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64>
+  // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64>
+  %shape = tensor.from_elements %two, %two, %five : tensor<3xi64>
+
+  // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32>
+  %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]]
+  // CHECK: return %[[r]]
+  return %reshaped : tensor<2x2x5xf32>
+}


        


More information about the Mlir-commits mailing list