[Mlir-commits] [mlir] b7528f5 - [mlir][Tensor] Use helper function for `getDroppedDims`

Matthias Springer llvmlistbot at llvm.org
Wed Mar 29 00:17:41 PDT 2023


Author: Matthias Springer
Date: 2023-03-29T09:17:28+02:00
New Revision: b7528f52c7d8c22d06a9a386b58e52cb76cfa54c

URL: https://github.com/llvm/llvm-project/commit/b7528f52c7d8c22d06a9a386b58e52cb76cfa54c
DIFF: https://github.com/llvm/llvm-project/commit/b7528f52c7d8c22d06a9a386b58e52cb76cfa54c.diff

LOG: [mlir][Tensor] Use helper function for `getDroppedDims`

This helper function is used for both ExtractSliceOp and InsertSliceOp. Also fixes a bug in the implementation of `InsertSliceOp::getDroppedDims`.

Differential Revision: https://reviews.llvm.org/D147048

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 93db7da27abd..e7fb28794567 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -110,6 +110,48 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
   return success();
 }
 
+/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
+/// rank-extending tensor.insert_slice op.
+static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
+                                           ArrayRef<OpFoldResult> mixedSizes) {
+  llvm::SmallBitVector droppedDims(mixedSizes.size());
+  int64_t shapePos = 0;
+
+  for (const auto &size : enumerate(mixedSizes)) {
+    // Rank-reduced dims must have a static unit dimension.
+    bool isStaticUnitSize =
+        size.value().is<Attribute>() &&
+        size.value().get<Attribute>().cast<IntegerAttr>().getInt() == 1;
+
+    if (shapePos == static_cast<int64_t>(reducedShape.size())) {
+      // There are no more dims in the reduced shape. All remaining sizes must
+      // be rank-reduced dims.
+      assert(isStaticUnitSize && "expected unit dim");
+      droppedDims.set(size.index());
+      continue;
+    }
+
+    // Dim is preserved if the size is not a static 1.
+    if (!isStaticUnitSize) {
+      ++shapePos;
+      continue;
+    }
+
+    // Dim is preserved if the reduced shape dim is also 1.
+    if (reducedShape[shapePos] == 1) {
+      ++shapePos;
+      continue;
+    }
+
+    // Otherwise: Dim is dropped.
+    droppedDims.set(size.index());
+  }
+
+  assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
+         "dimension mismatch");
+  return droppedDims;
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
@@ -1740,23 +1782,7 @@ LogicalResult ExtractSliceOp::verify() {
 }
 
 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
-  ArrayRef<int64_t> resultShape = getType().getShape();
-  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
-  llvm::SmallBitVector droppedDims(mixedSizes.size());
-  unsigned shapePos = 0;
-  for (const auto &size : enumerate(mixedSizes)) {
-    std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
-    // If the size is not 1, or if the current matched dimension of the result
-    // is the same static shape as the size value (which is 1), then the
-    // dimension is preserved.
-    if (!sizeVal || *sizeVal != 1 ||
-        (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
-      shapePos++;
-      continue;
-    }
-    droppedDims.set(size.index());
-  }
-  return droppedDims;
+  return ::getDroppedDims(getType().getShape(), getMixedSizes());
 }
 
 FailureOr<Value>
@@ -2397,23 +2423,7 @@ struct InsertSliceOpSourceCastInserter final
 } // namespace
 
 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
-  ArrayRef<int64_t> resultShape = getType().getShape();
-  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
-  llvm::SmallBitVector droppedDims(mixedSizes.size());
-  unsigned shapePos = 0;
-  for (const auto &size : enumerate(mixedSizes)) {
-    std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
-    // If the size is not 1, or if the current matched dimension of the result
-    // is the same static shape as the size value (which is 1), then the
-    // dimension is preserved.
-    if (!sizeVal || *sizeVal != 1 ||
-        (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
-      shapePos++;
-      continue;
-    }
-    droppedDims.set(size.index());
-  }
-  return droppedDims;
+  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
 }
 
 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,


        


More information about the Mlir-commits mailing list