[Mlir-commits] [mlir] [MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (PR #128871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 26 10:54:46 PST 2025


================
@@ -337,14 +337,27 @@ struct ExpandShapeOpInterface
     if (failed(buffer))
       return failure();
 
-    // Memref result type is inferred by the builder based on reassociation
-    // indices and result shape.
-    // TODO: Instead of inferring the output shape argument of
-    // memref.expand_shape op, use output_shape argument of tensor.expand_shape
-    // op.
-    replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
-        rewriter, op, tensorResultType.getShape(), *buffer,
-        expandShapeOp.getReassociationIndices());
+    // Use output_shape argument of tensor.expand_shape op to get the result
+    // shapes of the memref.expand_shape op to be created.
+    SmallVector<OpFoldResult> outShape;
+    unsigned dynDimCount = 0;
+    for (unsigned i = 0, e = tensorResultType.getRank(); i < e; i++) {
+      if (tensorResultType.isDynamicDim(i))
+        outShape.push_back(expandShapeOp.getOutputShape()[dynDimCount++]);
+    }
+    auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
+        op->getLoc(), tensorResultType.getShape(), *buffer,
+        expandShapeOp.getReassociationIndices(), outShape);
+    SmallVector<int64_t> staticShape;
----------------
MaheshRavishankar wrote:

You can use `dispatchIndexOpFoldResult`.

https://github.com/llvm/llvm-project/pull/128871


More information about the Mlir-commits mailing list