[Mlir-commits] [mlir] [mlir][Tensor] Rework `ReifyRankedShapedTypeInterface` implementation for `tensor.expand_shape` op. (PR #113501)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 23 15:45:02 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: None (MaheshRavishankar)
<details>
<summary>Changes</summary>
The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a `SmallVector<OpFoldResult>`.
---
Full diff: https://github.com/llvm/llvm-project/pull/113501.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+3)
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3)
- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+1)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+19-8)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+4)
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+8-2)
- (modified) mlir/lib/Interfaces/InferTypeOpInterface.cpp (-8)
- (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5)
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+3-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8203b9c0fab437 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1160,6 +1160,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);
+ // Return output shape as mixes static/dynamic shapes.
+ SmallVector<OpFoldResult> getMixedOutputShape();
+
// Infer the output shape for a tensor.expand_shape when it is possible
// to do so.
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 4d7aa1ae17fdb1..9f0e01f1d8ca00 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,6 +125,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// Return a vector of OpFoldResults with the same size a staticValues, but
/// all elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues,
+ MLIRContext *context);
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 792e7229183064..416aac7d64aad5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebc458170337d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
builder, loc, src, dstStaticShape, reassocation);
}
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
+struct ReifyCollapseShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+ ReifyCollapseShapeOp, CollapseShapeOp> {
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
auto loc = op->getLoc();
- auto reshapeOp = cast<OpTy>(op);
+ auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
namespace {
+struct ReifyExpandShapeOp
+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp> {
+ LogicalResult
+ reifyResultShapes(Operation *op, OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifyResultShapes) const {
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+ SmallVector<OpFoldResult> resultShapes =
+ expandShapeOp.getMixedOutputShape();
+ reifyResultShapes.emplace_back(std::move(resultShapes));
+ return success();
+ }
+};
+
struct ReifyPadOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- ExpandShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
- CollapseShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
PadOp::attachInterface<ReifyPadOp>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..f1f33bd940f7d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1675,6 +1675,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
return *outputShape;
}
+SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
+ return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
+}
+
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3eb6215a7a0b9b..f1166269f0a400 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -177,7 +177,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
- ValueRange dynamicValues, Builder &b) {
+ ValueRange dynamicValues,
+ MLIRContext *context) {
SmallVector<OpFoldResult> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
@@ -186,10 +187,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
int64_t value = staticValues[idx];
res.push_back(ShapedType::isDynamic(value)
? OpFoldResult{dynamicValues[numDynamic++]}
- : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+ : OpFoldResult{IntegerAttr::get(
+ IntegerType::get(context, 64), staticValues[idx])});
}
return res;
}
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues, Builder &b) {
+ return getMixedValues(staticValues, dynamicValues, b.getContext());
+}
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..c7f5fcb1d21fc8 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
assert(shapedType.getRank() ==
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
- for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
- // reifyResultShapes must return:
- // * Attribute for static dimensions
- // * Value for dynamic dimensions
- assert(shapedType.isDynamicDim(dim) ==
- reifiedReturnShapes[resultIdx][dim].is<Value>() &&
- "incorrect implementation of ReifyRankedShapedTypeOpInterface");
- }
++resultIdx;
}
// Assert that every shaped value result was reified.
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..3bc1f56d816d73 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME: %[[ARG1:.+]]: index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
// -----
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..850bbcee340203 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
// CHECK-NEXT: return %[[INIT]]
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/113501
More information about the Mlir-commits
mailing list