[Mlir-commits] [mlir] 0a9f6b8 - [mlir][tensor/linalg] Fix bug in reifyResultShapes
Matthias Springer
llvmlistbot at llvm.org
Fri Mar 10 02:38:08 PST 2023
Author: Matthias Springer
Date: 2023-03-10T11:37:54+01:00
New Revision: 0a9f6b8ca3fa9449cc0accba1bc11e98d6dbc6b6
URL: https://github.com/llvm/llvm-project/commit/0a9f6b8ca3fa9449cc0accba1bc11e98d6dbc6b6
DIFF: https://github.com/llvm/llvm-project/commit/0a9f6b8ca3fa9449cc0accba1bc11e98d6dbc6b6.diff
LOG: [mlir][tensor/linalg] Fix bug in reifyResultShapes
`reifyResultShapes` should return an IntegerAttr if and only if the corresponding dimension is static.
Differential Revision: https://reviews.llvm.org/D145702
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6844b685f6fe2..22a3c04ff3745 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -644,10 +644,18 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
for (OpOperand *opOperand : getDpsInitOperands()) {
SmallVector<OpFoldResult> shapes;
for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
- if (checkDimExpr.visit(shapeExprs[pos]))
- shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
- else
- shapes.push_back(allResultDimValues[pos]);
+ auto shapedType = opOperand->get().getType().cast<ShapedType>();
+ if (!shapedType.isDynamicDim(dim)) {
+ // Static dim: Return IntegerAttr.
+ shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
+ } else {
+ // Dynamic dim: Return Value.
+ OpFoldResult ofr =
+ checkDimExpr.visit(shapeExprs[pos])
+ ? createOrFoldDimOp(b, loc, opOperand->get(), dim)
+ : allResultDimValues[pos];
+ shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
+ }
pos++;
}
reifiedReturnShapes.emplace_back(std::move(shapes));
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index baf213006a6dc..e1bf889088a9e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2205,8 +2205,13 @@ LogicalResult InsertSliceOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
- reifiedReturnShapes[0][dim] =
- builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
+ if (getType().isDynamicDim(dim)) {
+ reifiedReturnShapes[0][dim] =
+ builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
+ } else {
+ reifiedReturnShapes[0][dim] =
+ builder.getIndexAttr(getType().getDimSize(dim));
+ }
}
return success();
}
@@ -3154,9 +3159,15 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
"applies to only pack or unpack operations");
int64_t destRank = op.getDestRank();
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
+ ShapedType resultType = op.getResult().getType().template cast<ShapedType>();
for (auto dim : llvm::seq<int64_t>(0, destRank)) {
- reifiedReturnShapes[0][dim] =
- builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
+ if (resultType.isDynamicDim(dim)) {
+ reifiedReturnShapes[0][dim] =
+ builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
+ } else {
+ reifiedReturnShapes[0][dim] =
+ builder.getIndexAttr(resultType.getDimSize(dim));
+ }
}
return success();
}
More information about the Mlir-commits
mailing list