[Mlir-commits] [mlir] 092372d - [mlir][Tensor] Rework `ReifyRankedShapedTypeInterface` implementation for `tensor.expand_shape` op. (#113501)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 27 07:05:38 PST 2025
Author: MaheshRavishankar
Date: 2025-01-27T07:05:34-08:00
New Revision: 092372da15e5165be14cdbb7cac3cf4976fd82d0
URL: https://github.com/llvm/llvm-project/commit/092372da15e5165be14cdbb7cac3cf4976fd82d0
DIFF: https://github.com/llvm/llvm-project/commit/092372da15e5165be14cdbb7cac3cf4976fd82d0.diff
LOG: [mlir][Tensor] Rework `ReifyRankedShapedTypeInterface` implementation for `tensor.expand_shape` op. (#113501)
The op carries the output-shape directly. This can be used directly.
Also adds a method to get the shape as a `SmallVector<OpFoldResult>`.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
mlir/test/Dialect/Tensor/fold-empty-op.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 8ad1b23cb2bfe2..3ef7c74fd3af16 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1165,6 +1165,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);
+ // Return output shape as mixes static/dynamic shapes.
+ SmallVector<OpFoldResult> getMixedOutputShape();
+
// Infer the output shape for a tensor.expand_shape when it is possible
// to do so.
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index d1f7ab1156248f..2a3a2defb810da 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -144,6 +144,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// Return a vector of OpFoldResults with the same size a staticValues, but
/// all elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues,
+ MLIRContext *context);
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..f6fea08e2e717f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -16,24 +16,6 @@
using namespace mlir;
using namespace mlir::tensor;
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static llvm::DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
- llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
- for (const auto &map : enumerate(reassociation)) {
- unsigned startPos =
- cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
- unsigned endPos =
- cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
- for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
- expandedDimToCollapsedDim[dim] = map.index();
- }
- }
- return expandedDimToCollapsedDim;
-}
-
/// For reshape op compute the shape at dimension `dimIndex` of the output in
/// terms of shape of the `src`, when the reshape op is a collapsing
/// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,84 +58,15 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
}));
}
-/// For an expanding reshape op, compute the value for a dimension of the output
-/// from the shape of the input.
-static OpFoldResult getExpandedOutputDimFromInputShape(
- OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
- ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
- llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
- if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
- // Static dimension: return Attribute.
- return builder.getIndexAttr(dstStaticShape[dimIndex]);
- }
- unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
- unsigned startPos =
- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
- .getPosition();
- unsigned endPos =
- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
- .getPosition();
- int64_t linearizedStaticDim = 1;
- for (auto d :
- llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
- if (d.index() + startPos == static_cast<unsigned>(dimIndex))
- continue;
- assert(!ShapedType::isDynamic(d.value()) &&
- "single dimension cannot be expanded into multiple dynamic "
- "dimensions");
- linearizedStaticDim *= d.value();
- }
- OpFoldResult sourceDim =
- builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
-
- // Dynamic dimension: return Value.
- return affine::makeComposedAffineApply(
- builder, loc,
- AffineMap::get(
- 0, 1,
- builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
- sourceDim)
- ->getResult(0);
-}
-
-/// Given the `src` of an expanding reshape op, the reassociation maps and the
-/// result type, compute the shape of the result of the reshape.
-static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
- OpBuilder &builder, Location loc, Value src,
- ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
- llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
- getExpandedDimToCollapsedDimMap(reassociation);
- return llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
- return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
- dstStaticShape, reassociation,
- expandedDimToCollapsedDim);
- }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
- ArrayRef<int64_t> dstStaticShape,
- ArrayRef<AffineMap> reassocation) {
- return dstStaticShape.size() >
- static_cast<size_t>(
- llvm::cast<ShapedType>(src.getType()).getRank())
- ? getExpandedOutputShapeFromInputShape(
- builder, loc, src, dstStaticShape, reassocation)
- : getCollapsedOutputShapeFromInputShape(
- builder, loc, src, dstStaticShape, reassocation);
-}
-
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
+struct ReifyCollapseShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+ ReifyCollapseShapeOp, CollapseShapeOp> {
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
auto loc = op->getLoc();
- auto reshapeOp = cast<OpTy>(op);
- reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
+ auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
+ reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()));
return success();
@@ -162,6 +75,20 @@ struct ReifyExpandOrCollapseShapeOp
namespace {
+struct ReifyExpandShapeOp
+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp> {
+ LogicalResult
+ reifyResultShapes(Operation *op, OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifyResultShapes) const {
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+ SmallVector<OpFoldResult> resultShapes =
+ expandShapeOp.getMixedOutputShape();
+ reifyResultShapes.emplace_back(std::move(resultShapes));
+ return success();
+ }
+};
+
struct ReifyPadOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
PadOp> {
@@ -202,10 +129,8 @@ struct ReifyPadOp
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- ExpandShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
- CollapseShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
PadOp::attachInterface<ReifyPadOp>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 24a1d553153198..117908129561f2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1732,6 +1732,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
return *outputShape;
}
+SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
+ return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
+}
+
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5c8f6ded39ba4e..fcb736aa031f36 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -191,7 +191,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
- ValueRange dynamicValues, Builder &b) {
+ ValueRange dynamicValues,
+ MLIRContext *context) {
SmallVector<OpFoldResult> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
@@ -200,10 +201,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
int64_t value = staticValues[idx];
res.push_back(ShapedType::isDynamic(value)
? OpFoldResult{dynamicValues[numDynamic++]}
- : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+ : OpFoldResult{IntegerAttr::get(
+ IntegerType::get(context, 64), staticValues[idx])});
}
return res;
}
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues, Builder &b) {
+ return getMixedValues(staticValues, dynamicValues, b.getContext());
+}
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 3eb401c4499805..6b5e103cd36c2b 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
assert(shapedType.getRank() ==
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
- for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
- // reifyResultShapes must return:
- // * Attribute for static dimensions
- // * Value for dynamic dimensions
- assert(shapedType.isDynamicDim(dim) ==
- isa<Value>(reifiedReturnShapes[resultIdx][dim]) &&
- "incorrect implementation of ReifyRankedShapedTypeOpInterface");
- }
++resultIdx;
}
// Assert that every shaped value result was reified.
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..3bc1f56d816d73 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME: %[[ARG1:.+]]: index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
// -----
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..850bbcee340203 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
// CHECK-NEXT: return %[[INIT]]
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
More information about the Mlir-commits
mailing list