[Mlir-commits] [mlir] f895e95 - [mlir][linalg] Make padding work for rank-reducing slice ops.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 13 04:48:02 PST 2021
Author: gysit
Date: 2021-12-13T12:34:20Z
New Revision: f895e9513860243fdbceaa7d323447c3720ac6df
URL: https://github.com/llvm/llvm-project/commit/f895e9513860243fdbceaa7d323447c3720ac6df
DIFF: https://github.com/llvm/llvm-project/commit/f895e9513860243fdbceaa7d323447c3720ac6df.diff
LOG: [mlir][linalg] Make padding work for rank-reducing slice ops.
Adapt the computation of a static bounding box to take rank-reducing slice operations into account by filtering out reduced size one dimensions. The revision is needed to make padding work for decomposed convolution operations. The decomposition introduces rank reducing extract slice operations that previously let padding fail.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D115336
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/pad.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4dbca8a308a1..87626221b29e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -186,26 +186,34 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
if (!sliceOp)
return failure(hasDynamicShape);
+ // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
+ llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims();
+
// Upper bound the `sliceOp` sizes to obtain a static bounding box.
SmallVector<int64_t> staticSizes;
- staticSizes.reserve(opToPad.getRank(opOperand));
+ staticSizes.reserve(shape.size());
auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
- for (auto size : shapedOp.getMixedSizes()) {
+ for (auto en : enumerate(shapedOp.getMixedSizes())) {
+ // Skip dropped dimensions.
+ if (droppedDims.contains(en.index()))
+ continue;
// If the size is an attribute add it directly to `staticSizes`.
- if (size.is<Attribute>()) {
+ if (en.value().is<Attribute>()) {
staticSizes.push_back(
- size.get<Attribute>().dyn_cast<IntegerAttr>().getInt());
+ en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
continue;
}
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
- getConstantUpperBoundForIndex(size.get<Value>());
+ getConstantUpperBoundForIndex(en.value().get<Value>());
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
return failure();
}
staticSizes.push_back(upperBound.getValue());
}
+ assert(staticSizes.size() == shape.size() &&
+ "expect the dynamic and static ranks to match");
// Pad the operand to the bounding box defined by `staticSizes`.
auto staticTensorType = RankedTensorType::get(
diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir
index f2154a8f3584..d67635c8b26a 100644
--- a/mlir/test/Dialect/Linalg/pad.mlir
+++ b/mlir/test/Dialect/Linalg/pad.mlir
@@ -426,3 +426,24 @@ func @dynamic_input_padding_only(%arg0: tensor<24x12xf32>,
%5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, %0] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32>
return %5 : tensor<24x25xf32>
}
+
+// -----
+
+#map0 = affine_map<()[s0] -> (64, s0)>
+
+// FILL: rank_reducing
+// FILL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<1x64x1x64xf32>
+func @rank_reducing(%arg0: tensor<1x64x1x64xf32>,
+ %iv0 : index) -> tensor<1x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %size = affine.min #map0()[%iv0]
+ %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [1, %size, 1, %size] [1, 1, 1, 1] : tensor<1x64x1x64xf32> to tensor<1x?x?xf32>
+
+ // Check the fill is padded despite the rank-reducing slice operation.
+ // FILL: %[[T0:.*]] = linalg.pad_tensor
+ // FILL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // FILL-SAME: tensor<1x64x64xf32>
+ // FILL: = tensor.extract_slice %[[T1]]
+ %1 = linalg.fill(%cst, %0) : f32, tensor<1x?x?xf32> -> tensor<1x?x?xf32>
+ return %1 : tensor<1x?x?xf32>
+}
More information about the Mlir-commits
mailing list