[Mlir-commits] [mlir] [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (PR #128871)
Arnab Dutta
llvmlistbot at llvm.org
Wed Feb 26 21:15:23 PST 2025
https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/128871
>From d9b25210f385acb7a090627383880d99ef4aa1d9 Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Thu, 27 Feb 2025 10:44:15 +0530
Subject: [PATCH] [MLIR][Tensor] Enhance bufferization of tensor.expand_shape
op
Instead of inferring the output shape argument of
memref.expand_shape op, use output_shape argument of
tensor.expand_shape op by adding dynamic dimension support
for bufferization of tensor.expand_shape when there are
more than one dynamic dim within a reassociation set.
---
.../BufferizableOpInterfaceImpl.cpp | 14 ++++-----
mlir/test/Dialect/Tensor/bufferize.mlir | 31 ++++++++++++-------
2 files changed, 26 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..a9ba662348a52 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -337,14 +337,12 @@ struct ExpandShapeOpInterface
if (failed(buffer))
return failure();
- // Memref result type is inferred by the builder based on reassociation
- // indices and result shape.
- // TODO: Instead of inferring the output shape argument of
- // memref.expand_shape op, use output_shape argument of tensor.expand_shape
- // op.
- replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, tensorResultType.getShape(), *buffer,
- expandShapeOp.getReassociationIndices());
+ auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
+ op->getLoc(), tensorResultType.getShape(), *buffer,
+ expandShapeOp.getReassociationIndices(),
+ expandShapeOp.getMixedOutputShape());
+ replaceOpWithBufferizedValues(rewriter, op,
+ memrefExpandShape->getResults());
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 9ea0a15f31185..c1beed95f2006 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
// -----
// CHECK-LABEL: func @tensor.expand_shape(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
%0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
: tensor<?x10xf32> into tensor<2x?x10xf32>
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
func.func @tensor.expand_shape_of_slice(
%t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
%0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
tensor<?x20xf32> to tensor<?x10xf32>
- // CHECK: %[[C7:.*]] = arith.constant 7 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
%1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
tensor<?x10xf32> into tensor<?x7x2x5xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
return %1 : tensor<?x7x2x5xf32>
}
-
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
// CHECK: return %[[r]]
return %1 : tensor<1xf32>
}
+// -----
+// CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
+func.func @tensor.expand_shape_multiple_dynamic_indices(%t1: tensor<?x256xf32>, %sz0: index, %sz1: index, %sz2: index) -> tensor<?x?x?x256xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
+ %0 = tensor.expand_shape %t1 [[0, 1, 2], [3]] output_shape [%sz0, %sz1, %sz2, 256]
+ : tensor<?x256xf32> into tensor<?x?x?x256xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<?x?x?x256xf32>
+}
// -----
// CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
// CHECK: }
return
}
+
+// -----
+
More information about the Mlir-commits
mailing list