[Mlir-commits] [mlir] [mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp folding (PR #118203)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Dec 1 02:08:02 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tensor
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the tensor.dim folding directly use the available output shape.
---
Full diff: https://github.com/llvm/llvm-project/pull/118203.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+6-26)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+2-6)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 616d4a7d0a0ab5..a6ae728b20fa47 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
if (!dim.has_value())
return failure();
- // Skip static dims. These are folded to constant ops.
- RankedTensorType resultType = expandShapeOp.getResultType();
- if (!resultType.isDynamicDim(*dim))
- return failure();
-
- // Find reassociation group that contains this result dimension.
- int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
- // `dim` is the only dynamic dimension in `group`. (Otherwise, the
- // ExpandShapeOp would be ambiguous.)
- int64_t product = 1;
- ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
- for (int64_t d : grp) {
- if (d != dim) {
- assert(!resultType.isDynamicDim(d) && "expected static dim");
- product *= resultType.getDimSize(d);
- }
- }
-
- // result dim size = src dim size / (product(other dims in reassoc group))
- Value srcDimSz =
- rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
- AffineExpr expr;
- bindSymbols(dimOp.getContext(), expr);
- rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
- dimOp, expr.floorDiv(product), srcDimSz);
+ SmallVector<OpFoldResult> outputShape =
+ getMixedValues(expandShapeOp.getStaticOutputShape(),
+ expandShapeOp.getOutputShape(), rewriter);
+ OpFoldResult outputDim = outputShape[dim.value()];
+ rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+ rewriter, dimOp.getLoc(), outputDim));
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..3a0f8e0e073acd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2278,13 +2278,9 @@ func.func @empty_tensor_canonicalize(%i : index) {
// -----
-// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
// CHECK-LABEL: func @dim_of_expand_shape(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
-// CHECK: %[[c1:.*]] = arith.constant 1 : index
-// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
-// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
-// CHECK: return %[[apply]]
+// CHECK-SAME: %{{.*}}: tensor<?x?xf32>, %{{.*}}: index, %[[ARG2:.+]]: index
+// CHECK: return %[[ARG2]]
func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
%c2 = arith.constant 2 : index
%0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
``````````
</details>
https://github.com/llvm/llvm-project/pull/118203
More information about the Mlir-commits
mailing list