[Mlir-commits] [mlir] 3cccb20 - [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (#128871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 27 21:15:42 PST 2025
Author: Arnab Dutta
Date: 2025-02-28T10:45:38+05:30
New Revision: 3cccb2017ff96d67b0e737eeddb58ff054cedc6e
URL: https://github.com/llvm/llvm-project/commit/3cccb2017ff96d67b0e737eeddb58ff054cedc6e
DIFF: https://github.com/llvm/llvm-project/commit/3cccb2017ff96d67b0e737eeddb58ff054cedc6e.diff
LOG: [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (#128871)
Instead of inferring the output shape argument of
memref.expand_shape op, use output_shape argument of tensor.expand_shape
op by adding dynamic dimension support for bufferization of
tensor.expand_shape when there are more than one dynamic dim within a
reassociation set.
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..a9ba662348a52 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -337,14 +337,12 @@ struct ExpandShapeOpInterface
if (failed(buffer))
return failure();
- // Memref result type is inferred by the builder based on reassociation
- // indices and result shape.
- // TODO: Instead of inferring the output shape argument of
- // memref.expand_shape op, use output_shape argument of tensor.expand_shape
- // op.
- replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, tensorResultType.getShape(), *buffer,
- expandShapeOp.getReassociationIndices());
+ auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
+ op->getLoc(), tensorResultType.getShape(), *buffer,
+ expandShapeOp.getReassociationIndices(),
+ expandShapeOp.getMixedOutputShape());
+ replaceOpWithBufferizedValues(rewriter, op,
+ memrefExpandShape->getResults());
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 9ea0a15f31185..c1beed95f2006 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
// -----
// CHECK-LABEL: func @tensor.expand_shape(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
%0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
: tensor<?x10xf32> into tensor<2x?x10xf32>
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
func.func @tensor.expand_shape_of_slice(
%t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
%0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
tensor<?x20xf32> to tensor<?x10xf32>
- // CHECK: %[[C7:.*]] = arith.constant 7 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
%1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
tensor<?x10xf32> into tensor<?x7x2x5xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
return %1 : tensor<?x7x2x5xf32>
}
-
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
// CHECK: return %[[r]]
return %1 : tensor<1xf32>
}
+// -----
+// CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
+func.func @tensor.expand_shape_multiple_dynamic_indices(%t1: tensor<?x256xf32>, %sz0: index, %sz1: index, %sz2: index) -> tensor<?x?x?x256xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
+ %0 = tensor.expand_shape %t1 [[0, 1, 2], [3]] output_shape [%sz0, %sz1, %sz2, 256]
+ : tensor<?x256xf32> into tensor<?x?x?x256xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<?x?x?x256xf32>
+}
// -----
// CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
// CHECK: }
return
}
+
+// -----
+
More information about the Mlir-commits
mailing list