[Mlir-commits] [mlir] 3cccb20 - [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (#128871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 21:15:42 PST 2025


Author: Arnab Dutta
Date: 2025-02-28T10:45:38+05:30
New Revision: 3cccb2017ff96d67b0e737eeddb58ff054cedc6e

URL: https://github.com/llvm/llvm-project/commit/3cccb2017ff96d67b0e737eeddb58ff054cedc6e
DIFF: https://github.com/llvm/llvm-project/commit/3cccb2017ff96d67b0e737eeddb58ff054cedc6e.diff

LOG: [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (#128871)

Instead of inferring the output shape argument of
memref.expand_shape op, use output_shape argument of tensor.expand_shape
op by adding dynamic dimension support for bufferization of
tensor.expand_shape when there are more than one dynamic dim within a
reassociation set.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..a9ba662348a52 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -337,14 +337,12 @@ struct ExpandShapeOpInterface
     if (failed(buffer))
       return failure();
 
-    // Memref result type is inferred by the builder based on reassociation
-    // indices and result shape.
-    // TODO: Instead of inferring the output shape argument of
-    // memref.expand_shape op, use output_shape argument of tensor.expand_shape
-    // op.
-    replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
-        rewriter, op, tensorResultType.getShape(), *buffer,
-        expandShapeOp.getReassociationIndices());
+    auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
+        op->getLoc(), tensorResultType.getShape(), *buffer,
+        expandShapeOp.getReassociationIndices(),
+        expandShapeOp.getMixedOutputShape());
+    replaceOpWithBufferizedValues(rewriter, op,
+                                  memrefExpandShape->getResults());
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 9ea0a15f31185..c1beed95f2006 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
 // -----
 
 // CHECK-LABEL: func @tensor.expand_shape(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
 func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
-  // CHECK: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
   %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
       : tensor<?x10xf32> into tensor<2x?x10xf32>
 
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
 // -----
 
 // CHECK-LABEL: func @tensor.expand_shape_of_slice(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
 func.func @tensor.expand_shape_of_slice(
     %t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
   // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
   %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
       tensor<?x20xf32> to tensor<?x10xf32>
-  // CHECK: %[[C7:.*]] = arith.constant 7 : index
-  // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
   %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
       tensor<?x10xf32> into tensor<?x7x2x5xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
   // CHECK: return %[[r]]
   return %1 : tensor<?x7x2x5xf32>
 }
-
 // -----
 
 // CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
   // CHECK: return %[[r]]
   return %1 : tensor<1xf32>
 }
+// -----
 
+// CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
+func.func @tensor.expand_shape_multiple_dynamic_indices(%t1: tensor<?x256xf32>, %sz0: index, %sz1: index, %sz2: index) -> tensor<?x?x?x256xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
+  %0 = tensor.expand_shape %t1 [[0, 1, 2], [3]] output_shape [%sz0, %sz1, %sz2, 256]
+      : tensor<?x256xf32> into tensor<?x?x?x256xf32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+  // CHECK: return %[[r]]
+  return %0 : tensor<?x?x?x256xf32>
+}
 // -----
 
 // CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
   // CHECK: }
   return
 }
+
+// -----
+


        


More information about the Mlir-commits mailing list