[Mlir-commits] [mlir] 51df623 - [mlir][tensor] Fix bufferization of CollapseShapeOp / ExpandShapeOp
Matthias Springer
llvmlistbot at llvm.org
Thu Mar 31 01:11:55 PDT 2022
Author: Matthias Springer
Date: 2022-03-31T17:11:45+09:00
New Revision: 51df62388e83a406e4a946ff8aae1f7299a2d92b
URL: https://github.com/llvm/llvm-project/commit/51df62388e83a406e4a946ff8aae1f7299a2d92b
DIFF: https://github.com/llvm/llvm-project/commit/51df62388e83a406e4a946ff8aae1f7299a2d92b.diff
LOG: [mlir][tensor] Fix bufferization of CollapseShapeOp / ExpandShapeOp
Infer a tighter MemRef type instead of always falling back to the most dynamic MemRef type. This is inefficient and caused op verification errors.
Differential Revision: https://reviews.llvm.org/D122649
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a9519b98803cc..6f86270e5adda 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -108,12 +108,27 @@ struct CollapseShapeOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
+ RankedTensorType tensorResultType = collapseShapeOp.getResultType();
Value buffer =
*state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
- Type resultType =
- getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
+
+ if (tensorResultType.getRank() == 0) {
+ // 0-d collapses must go through a
diff erent op builder.
+ auto bufferType = buffer.getType().cast<MemRefType>();
+ // Assume identity layout: No offset.
+ assert(bufferType.getLayout().isIdentity() &&
+ "non-zero offset for 0-d collapse not supported");
+ MemRefLayoutAttrInterface layout;
+ auto resultType = MemRefType::get({}, tensorResultType.getElementType(),
+ layout, bufferType.getMemorySpace());
+ replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
+ rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+ return success();
+ }
+
+ // Result type is inferred by the builder.
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
- rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+ rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
return success();
}
};
@@ -175,12 +190,15 @@ struct ExpandShapeOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+ auto tensorResultType = expandShapeOp.getResultType();
Value buffer =
*state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
- Type resultType =
- getMemRefType(expandShapeOp.getResultType(), state.getOptions());
+
+ // Memref result type is inferred by the builder based on reassociation
+ // indices and result shape.
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, resultType, buffer, expandShapeOp.reassociation());
+ rewriter, op, tensorResultType.getShape(), buffer,
+ expandShapeOp.getReassociationIndices());
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index cbb05473807b0..5fa8e3f8a2a46 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
@@ -242,7 +244,7 @@ func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
func @tensor.extract_slice(
%t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
- // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP]]>
+ // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP0]]>
%0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
: tensor<?x?xf32> to tensor<?x10xf32>
// CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
@@ -256,7 +258,7 @@ func @tensor.extract_slice(
func @tensor.extract_slice_rank_reducing(
%t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
- // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP]]>
+ // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP0]]>
%0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
: tensor<?x10x?xf32> to tensor<?x15xf32>
// CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
@@ -316,6 +318,23 @@ func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
return %0 : tensor<2x?x10xf32>
}
+// CHECK-LABEL: func @tensor.expand_shape_of_slice(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
+func @tensor.expand_shape_of_slice(
+ %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
+ // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, #[[$MAP1]]>
+ %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
+ tensor<?x20xf32> to tensor<?x10xf32>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
+ // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, #[[$MAP1]]> into memref<?x7x2x5xf32, #[[$MAP2]]>
+ %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
+ tensor<?x10xf32> into tensor<?x7x2x5xf32>
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+ // CHECK: return %[[r]]
+ return %1 : tensor<?x7x2x5xf32>
+}
+
// CHECK-LABEL: func @tensor.collapse_shape(
// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32>
func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
@@ -329,3 +348,16 @@ func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
// CHECK: return %[[r]]
return %0 : tensor<?x?xf32>
}
+
+// CHECK-LABEL: func @tensor.collapse_shape_to_scalar(
+// CHECK-SAME: %[[t1:.*]]: tensor<1x1x1xf32>
+func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<1x1x1xf32>
+ // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref<f32>
+ %0 = tensor.collapse_shape %t1 []
+ : tensor<1x1x1xf32> into tensor<f32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<f32>
+}
More information about the Mlir-commits
mailing list