[Mlir-commits] [mlir] [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (PR #128871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 26 04:36:50 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Arnab Dutta  (arnab-polymage)

<details>
<summary>Changes</summary>

Instead of inferring the output shape argument of
memref.expand_shape op, use output_shape argument of tensor.expand_shape op by adding dd support for bufferization of tensor.expand_shape when there are more than one dynamic dim within a reassociation set.

---
Full diff: https://github.com/llvm/llvm-project/pull/128871.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+21-8) 
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+20-11) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..efbe09f4d2419 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -337,14 +337,27 @@ struct ExpandShapeOpInterface
     if (failed(buffer))
       return failure();
 
-    // Memref result type is inferred by the builder based on reassociation
-    // indices and result shape.
-    // TODO: Instead of inferring the output shape argument of
-    // memref.expand_shape op, use output_shape argument of tensor.expand_shape
-    // op.
-    replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
-        rewriter, op, tensorResultType.getShape(), *buffer,
-        expandShapeOp.getReassociationIndices());
+    // Use output_shape argument of tensor.expand_shape op to get the result
+    // shapes of the memref.expand_shape op to be created.
+    SmallVector<OpFoldResult> outShape;
+    unsigned dynDimCount = 0;
+    for (unsigned i = 0, e = tensorResultType.getRank(); i < e; i++) {
+      if (tensorResultType.isDynamicDim(i))
+        outShape.push_back(expandShapeOp.getOutputShape()[dynDimCount++]);
+    }
+    auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
+        op->getLoc(), tensorResultType.getShape(), *buffer,
+        expandShapeOp.getReassociationIndices(), outShape);
+    SmallVector<int64_t> staticShape;
+    for (unsigned i = 0, e = tensorResultType.getRank(); i < e; i++) {
+      if (tensorResultType.isDynamicDim(i))
+        staticShape.push_back(ShapedType::kDynamic);
+      else
+        staticShape.push_back(tensorResultType.getDimSize(i));
+    }
+    memrefExpandShape.setStaticOutputShape(staticShape);
+    replaceOpWithBufferizedValues(rewriter, op,
+                                  memrefExpandShape->getResults());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 9ea0a15f31185..c1beed95f2006 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
 // -----
 
 // CHECK-LABEL: func @tensor.expand_shape(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
 func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
-  // CHECK: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
   %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
       : tensor<?x10xf32> into tensor<2x?x10xf32>
 
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
 // -----
 
 // CHECK-LABEL: func @tensor.expand_shape_of_slice(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
 func.func @tensor.expand_shape_of_slice(
     %t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
   // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
   %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
       tensor<?x20xf32> to tensor<?x10xf32>
-  // CHECK: %[[C7:.*]] = arith.constant 7 : index
-  // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
   %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
       tensor<?x10xf32> into tensor<?x7x2x5xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
   // CHECK: return %[[r]]
   return %1 : tensor<?x7x2x5xf32>
 }
-
 // -----
 
 // CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
   // CHECK: return %[[r]]
   return %1 : tensor<1xf32>
 }
+// -----
 
+// CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
+func.func @tensor.expand_shape_multiple_dynamic_indices(%t1: tensor<?x256xf32>, %sz0: index, %sz1: index, %sz2: index) -> tensor<?x?x?x256xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
+  %0 = tensor.expand_shape %t1 [[0, 1, 2], [3]] output_shape [%sz0, %sz1, %sz2, 256]
+      : tensor<?x256xf32> into tensor<?x?x?x256xf32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+  // CHECK: return %[[r]]
+  return %0 : tensor<?x?x?x256xf32>
+}
 // -----
 
 // CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
   // CHECK: }
   return
 }
+
+// -----
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/128871


More information about the Mlir-commits mailing list