[Mlir-commits] [mlir] e6f6916 - [mlir][bufferize] Support tensor.expand_shape and tensor.collapse_shape

Matthias Springer llvmlistbot at llvm.org
Tue Feb 15 02:57:53 PST 2022


Author: Matthias Springer
Date: 2022-02-15T19:53:49+09:00
New Revision: e6f691615e481d74cad8c6369bc0116192630ea1

URL: https://github.com/llvm/llvm-project/commit/e6f691615e481d74cad8c6369bc0116192630ea1
DIFF: https://github.com/llvm/llvm-project/commit/e6f691615e481d74cad8c6369bc0116192630ea1.diff

LOG: [mlir][bufferize] Support tensor.expand_shape and tensor.collapse_shape

Differential Revision: https://reviews.llvm.org/D112512

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 91e10916124ab..ad5485de63317 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -80,6 +80,46 @@ struct CastOpInterface
   }
 };
 
+/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
+struct CollapseShapeOpInterface
+    : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
+                                                    tensor::CollapseShapeOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const BufferizationState &state) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return false;
+  }
+
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    if (&opOperand == &op->getOpOperand(0) /*src*/)
+      return {op->getOpResult(0)};
+    return {};
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
+    Value buffer =
+        *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
+    Type resultType =
+        getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
+    replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
+        rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+    return success();
+  }
+};
+
 /// Bufferization of tensor.dim. Replace with memref.dim.
 struct DimOpInterface
     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
@@ -109,6 +149,46 @@ struct DimOpInterface
   }
 };
 
+/// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
+struct ExpandShapeOpInterface
+    : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
+                                                    tensor::ExpandShapeOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const BufferizationState &state) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return false;
+  }
+
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    if (&opOperand == &op->getOpOperand(0) /*src*/)
+      return {op->getOpResult(0)};
+    return {};
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+    Value buffer =
+        *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
+    Type resultType =
+        getMemRefType(expandShapeOp.getResultType(), state.getOptions());
+    replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
+        rewriter, op, resultType, buffer, expandShapeOp.reassociation());
+    return success();
+  }
+};
+
 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
 struct ExtractSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
@@ -635,7 +715,9 @@ struct RankOpInterface
 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addOpInterface<CastOp, CastOpInterface>();
+  registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
   registry.addOpInterface<DimOp, DimOpInterface>();
+  registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index b0415ce1464ce..7d3084d9d024c 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -301,3 +301,31 @@ func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32>
   // CHECK: return %[[r]]
   return %0 : tensor<5xf32>
 }
+
+// CHECK-LABEL: func @tensor.expand_shape(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
+func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
+  // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
+  %0 = tensor.expand_shape %t1 [[0, 1], [2]]
+      : tensor<?x10xf32> into tensor<2x?x10xf32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+  // CHECK: return %[[r]]
+  return %0 : tensor<2x?x10xf32>
+}
+
+// CHECK-LABEL: func @tensor.collapse_shape(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<2x?x?xf32>
+func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32>
+  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [
+  // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32>
+  %0 = tensor.collapse_shape %t1 [[0, 1], [2]]
+      : tensor<2x?x?xf32> into tensor<?x?xf32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
+  // CHECK: return %[[r]]
+  return %0 : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list