[Mlir-commits] [mlir] [mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp folding (PR #118203)
Kunwar Grover
llvmlistbot at llvm.org
Sun Dec 1 02:07:30 PST 2024
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/118203
We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the tensor.dim folding directly use the available output shape.
>From dcb87b2c2cab37abdad57ad669c428bc6a6287c6 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 1 Dec 2024 10:05:58 +0000
Subject: [PATCH] [mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp
folding
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 32 ++++------------------
mlir/test/Dialect/Tensor/canonicalize.mlir | 8 ++----
2 files changed, 8 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 616d4a7d0a0ab5..a6ae728b20fa47 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
if (!dim.has_value())
return failure();
- // Skip static dims. These are folded to constant ops.
- RankedTensorType resultType = expandShapeOp.getResultType();
- if (!resultType.isDynamicDim(*dim))
- return failure();
-
- // Find reassociation group that contains this result dimension.
- int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
- // `dim` is the only dynamic dimension in `group`. (Otherwise, the
- // ExpandShapeOp would be ambiguous.)
- int64_t product = 1;
- ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
- for (int64_t d : grp) {
- if (d != dim) {
- assert(!resultType.isDynamicDim(d) && "expected static dim");
- product *= resultType.getDimSize(d);
- }
- }
-
- // result dim size = src dim size / (product(other dims in reassoc group))
- Value srcDimSz =
- rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
- AffineExpr expr;
- bindSymbols(dimOp.getContext(), expr);
- rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
- dimOp, expr.floorDiv(product), srcDimSz);
+ SmallVector<OpFoldResult> outputShape =
+ getMixedValues(expandShapeOp.getStaticOutputShape(),
+ expandShapeOp.getOutputShape(), rewriter);
+ OpFoldResult outputDim = outputShape[dim.value()];
+ rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+ rewriter, dimOp.getLoc(), outputDim));
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..3a0f8e0e073acd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2278,13 +2278,9 @@ func.func @empty_tensor_canonicalize(%i : index) {
// -----
-// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
// CHECK-LABEL: func @dim_of_expand_shape(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
-// CHECK: %[[c1:.*]] = arith.constant 1 : index
-// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
-// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
-// CHECK: return %[[apply]]
+// CHECK-SAME: %{{.*}}: tensor<?x?xf32>, %{{.*}}: index, %[[ARG2:.+]]: index
+// CHECK: return %[[ARG2]]
func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
%c2 = arith.constant 2 : index
%0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
More information about the Mlir-commits
mailing list