[Mlir-commits] [mlir] b7528f5 - [mlir][Tensor] Use helper function for `getDroppedDims`
Matthias Springer
llvmlistbot at llvm.org
Wed Mar 29 00:17:41 PDT 2023
Author: Matthias Springer
Date: 2023-03-29T09:17:28+02:00
New Revision: b7528f52c7d8c22d06a9a386b58e52cb76cfa54c
URL: https://github.com/llvm/llvm-project/commit/b7528f52c7d8c22d06a9a386b58e52cb76cfa54c
DIFF: https://github.com/llvm/llvm-project/commit/b7528f52c7d8c22d06a9a386b58e52cb76cfa54c.diff
LOG: [mlir][Tensor] Use helper function for `getDroppedDims`
This helper function is used for both ExtractSliceOp and InsertSliceOp. Also fixes a bug in the implementation of `InsertSliceOp::getDroppedDims`.
Differential Revision: https://reviews.llvm.org/D147048
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 93db7da27abd..e7fb28794567 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -110,6 +110,48 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
return success();
}
+/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
+/// rank-extending tensor.insert_slice op.
+static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
+ ArrayRef<OpFoldResult> mixedSizes) {
+ llvm::SmallBitVector droppedDims(mixedSizes.size());
+ int64_t shapePos = 0;
+
+ for (const auto &size : enumerate(mixedSizes)) {
+ // Rank-reduced dims must have a static unit dimension.
+ bool isStaticUnitSize =
+ size.value().is<Attribute>() &&
+ size.value().get<Attribute>().cast<IntegerAttr>().getInt() == 1;
+
+ if (shapePos == static_cast<int64_t>(reducedShape.size())) {
+ // There are no more dims in the reduced shape. All remaining sizes must
+ // be rank-reduced dims.
+ assert(isStaticUnitSize && "expected unit dim");
+ droppedDims.set(size.index());
+ continue;
+ }
+
+ // Dim is preserved if the size is not a static 1.
+ if (!isStaticUnitSize) {
+ ++shapePos;
+ continue;
+ }
+
+ // Dim is preserved if the reduced shape dim is also 1.
+ if (reducedShape[shapePos] == 1) {
+ ++shapePos;
+ continue;
+ }
+
+ // Otherwise: Dim is dropped.
+ droppedDims.set(size.index());
+ }
+
+ assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
+ "dimension mismatch");
+ return droppedDims;
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
@@ -1740,23 +1782,7 @@ LogicalResult ExtractSliceOp::verify() {
}
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
- ArrayRef<int64_t> resultShape = getType().getShape();
- SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
- llvm::SmallBitVector droppedDims(mixedSizes.size());
- unsigned shapePos = 0;
- for (const auto &size : enumerate(mixedSizes)) {
- std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
- // If the size is not 1, or if the current matched dimension of the result
- // is the same static shape as the size value (which is 1), then the
- // dimension is preserved.
- if (!sizeVal || *sizeVal != 1 ||
- (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
- shapePos++;
- continue;
- }
- droppedDims.set(size.index());
- }
- return droppedDims;
+ return ::getDroppedDims(getType().getShape(), getMixedSizes());
}
FailureOr<Value>
@@ -2397,23 +2423,7 @@ struct InsertSliceOpSourceCastInserter final
} // namespace
llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
- ArrayRef<int64_t> resultShape = getType().getShape();
- SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
- llvm::SmallBitVector droppedDims(mixedSizes.size());
- unsigned shapePos = 0;
- for (const auto &size : enumerate(mixedSizes)) {
- std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
- // If the size is not 1, or if the current matched dimension of the result
- // is the same static shape as the size value (which is 1), then the
- // dimension is preserved.
- if (!sizeVal || *sizeVal != 1 ||
- (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
- shapePos++;
- continue;
- }
- droppedDims.set(size.index());
- }
- return droppedDims;
+ return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
More information about the Mlir-commits
mailing list