[Mlir-commits] [mlir] f3f25ff - [mlir][linalg] Fix result type in FoldSourceTensorCast
Matthias Springer
llvmlistbot at llvm.org
Fri Sep 24 00:47:31 PDT 2021
Author: Matthias Springer
Date: 2021-09-24T16:47:18+09:00
New Revision: f3f25ffc04c0cbcc9a9bfc1b32b61750e8934ea8
URL: https://github.com/llvm/llvm-project/commit/f3f25ffc04c0cbcc9a9bfc1b32b61750e8934ea8
DIFF: https://github.com/llvm/llvm-project/commit/f3f25ffc04c0cbcc9a9bfc1b32b61750e8934ea8.diff
LOG: [mlir][linalg] Fix result type in FoldSourceTensorCast
* Do not discard static result type information that cannot be inferred from lower/upper padding.
* Add optional argument to `PadTensorOp::inferResultType` for specifying known result dimensions.
Differential Revision: https://reviews.llvm.org/D110380
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 4c82eafc9c973..dd568ba367067 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -226,10 +226,14 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
}
// Infer the shape of the result tensor given the type of the source tensor
- // and paddings.
- static RankedTensorType inferResultType(RankedTensorType sourceType,
+ // and paddings. Known result dimensions that cannot necessarily be inferred
+ // from low/high padding sizes can be optionally specified. Those will be
+ // considered when computing the result type.
+ static RankedTensorType inferResultType(
+ RankedTensorType sourceType,
ArrayRef<int64_t> staticLow,
- ArrayRef<int64_t> staticHigh);
+ ArrayRef<int64_t> staticHigh,
+ ArrayRef<int64_t> resultShape = {});
// Return a PadTensorOp that pads `source` to `type` size where the static
// sizes are assumed to be greater than the dynamic sizes. The op performs
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b3eeaabc780ed..75e4a1c91bcda 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1055,24 +1055,31 @@ static LogicalResult verify(PadTensorOp op) {
RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
ArrayRef<int64_t> staticLow,
- ArrayRef<int64_t> staticHigh) {
+ ArrayRef<int64_t> staticHigh,
+ ArrayRef<int64_t> resultShape) {
unsigned rank = sourceType.getRank();
assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
+ assert((resultShape.empty() || resultShape.size() == rank) &&
+ "unexpected resultShape size mismatch");
- SmallVector<int64_t, 4> resultShape;
+ SmallVector<int64_t, 4> inferredShape;
for (auto i : llvm::seq<unsigned>(0, rank)) {
if (sourceType.isDynamicDim(i) ||
staticLow[i] == ShapedType::kDynamicSize ||
staticHigh[i] == ShapedType::kDynamicSize) {
- resultShape.push_back(ShapedType::kDynamicSize);
+ inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
+ : resultShape[i]);
} else {
int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
- resultShape.push_back(size);
+ assert((resultShape.empty() || size == resultShape[i] ||
+ resultShape[i] == ShapedType::kDynamicSize) &&
+ "mismatch between inferred shape and result shape");
+ inferredShape.push_back(size);
}
}
- return RankedTensorType::get(resultShape, sourceType.getElementType());
+ return RankedTensorType::get(inferredShape, sourceType.getElementType());
}
void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
@@ -1454,7 +1461,8 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
auto newResultType = PadTensorOp::inferResultType(
castOp.source().getType().cast<RankedTensorType>(),
extractFromI64ArrayAttr(padTensorOp.static_low()),
- extractFromI64ArrayAttr(padTensorOp.static_high()));
+ extractFromI64ArrayAttr(padTensorOp.static_high()),
+ padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
rewriter.updateRootInPlace(padTensorOp, [&]() {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 3d434c2d6ebc0..fce08a1e04dca 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -629,7 +629,8 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
}
// -----
-// CHECK-LABEL: func @pad_tensor_after_cast_
diff ernt_shape(
+
+// CHECK-LABEL: func @pad_tensor_after_cast_
diff erent_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
@@ -641,7 +642,7 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
// CHECK: }
-func @pad_tensor_after_cast_
diff ernt_shape(%arg0: tensor<?x64x?x?xf32>)
+func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
-> tensor<?x?x?x?xf32> {
%cst = constant 0.000000e+00 : f32
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -653,6 +654,7 @@ func @pad_tensor_after_cast_
diff ernt_shape(%arg0: tensor<?x64x?x?xf32>)
}
// -----
+
// CHECK-LABEL: func @pad_tensor_after_cast_same_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
@@ -676,6 +678,24 @@ func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : i
}
// -----
+
+// CHECK-LABEL: func @pad_tensor_of_cast(
+// CHECK-NOT: tensor.cast
+// CHECK: linalg.pad_tensor
+// CHECK: tensor<8x?xf32> to tensor<8x32xf32>
+func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
+ %c0 = constant 0 : index
+ %cst = constant 0.000000e+00 : f32
+ %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
+ %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %s] {
+ ^bb0(%arg9: index, %arg10: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<8x32xf32>
+ return %1 : tensor<8x32xf32>
+}
+
+// -----
+
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
More information about the Mlir-commits
mailing list