[Mlir-commits] [mlir] [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (PR #128871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 26 04:36:50 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Arnab Dutta (arnab-polymage)
<details>
<summary>Changes</summary>
Instead of inferring the output shape argument of
memref.expand_shape op, use output_shape argument of tensor.expand_shape op by adding dd support for bufferization of tensor.expand_shape when there are more than one dynamic dim within a reassociation set.
---
Full diff: https://github.com/llvm/llvm-project/pull/128871.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+21-8)
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+20-11)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..efbe09f4d2419 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -337,14 +337,27 @@ struct ExpandShapeOpInterface
if (failed(buffer))
return failure();
- // Memref result type is inferred by the builder based on reassociation
- // indices and result shape.
- // TODO: Instead of inferring the output shape argument of
- // memref.expand_shape op, use output_shape argument of tensor.expand_shape
- // op.
- replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, tensorResultType.getShape(), *buffer,
- expandShapeOp.getReassociationIndices());
+ // Use output_shape argument of tensor.expand_shape op to get the result
+ // shapes of the memref.expand_shape op to be created.
+ SmallVector<OpFoldResult> outShape;
+ unsigned dynDimCount = 0;
+ for (unsigned i = 0, e = tensorResultType.getRank(); i < e; i++) {
+ if (tensorResultType.isDynamicDim(i))
+ outShape.push_back(expandShapeOp.getOutputShape()[dynDimCount++]);
+ }
+ auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
+ op->getLoc(), tensorResultType.getShape(), *buffer,
+ expandShapeOp.getReassociationIndices(), outShape);
+ SmallVector<int64_t> staticShape;
+ for (unsigned i = 0, e = tensorResultType.getRank(); i < e; i++) {
+ if (tensorResultType.isDynamicDim(i))
+ staticShape.push_back(ShapedType::kDynamic);
+ else
+ staticShape.push_back(tensorResultType.getDimSize(i));
+ }
+ memrefExpandShape.setStaticOutputShape(staticShape);
+ replaceOpWithBufferizedValues(rewriter, op,
+ memrefExpandShape->getResults());
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 9ea0a15f31185..c1beed95f2006 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
// -----
// CHECK-LABEL: func @tensor.expand_shape(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
%0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
: tensor<?x10xf32> into tensor<2x?x10xf32>
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
func.func @tensor.expand_shape_of_slice(
%t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
%0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
tensor<?x20xf32> to tensor<?x10xf32>
- // CHECK: %[[C7:.*]] = arith.constant 7 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
%1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
tensor<?x10xf32> into tensor<?x7x2x5xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
return %1 : tensor<?x7x2x5xf32>
}
-
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
// CHECK: return %[[r]]
return %1 : tensor<1xf32>
}
+// -----
+// CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
+func.func @tensor.expand_shape_multiple_dynamic_indices(%t1: tensor<?x256xf32>, %sz0: index, %sz1: index, %sz2: index) -> tensor<?x?x?x256xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
+ %0 = tensor.expand_shape %t1 [[0, 1, 2], [3]] output_shape [%sz0, %sz1, %sz2, 256]
+ : tensor<?x256xf32> into tensor<?x?x?x256xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<?x?x?x256xf32>
+}
// -----
// CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
// CHECK: }
return
}
+
+// -----
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/128871
More information about the Mlir-commits
mailing list