[Mlir-commits] [mlir] 741f8f2 - [mlir][Tensor][NFC] Better document rank-reducing behavior of ExtractSliceOp and cleanup

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jun 29 07:39:25 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-29T07:37:58-07:00
New Revision: 741f8f2bede58573560372bc219b2dec9a1d6643

URL: https://github.com/llvm/llvm-project/commit/741f8f2bede58573560372bc219b2dec9a1d6643
DIFF: https://github.com/llvm/llvm-project/commit/741f8f2bede58573560372bc219b2dec9a1d6643.diff

LOG: [mlir][Tensor][NFC] Better document rank-reducing behavior of ExtractSliceOp and cleanup

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index bd9ab3545ccb0..a6de3e9597d72 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -210,6 +210,25 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     flexibility allows to progressively drop unit dimensions while lowering
     between 
diff erent flavors of ops on that operate on tensors.
 
+    Verification vs Inference in the rank-reduced case:
+    ===================================================
+    Note that there may be multiple ways to infer a resulting rank-reduced type.
+      e.g. 1x6x1 could potentially rank-reduce to either 1x6 or 6x1 2-D shapes.
+
+    To disambiguate, the inference helpers `inferCanonicalRankReducedResultType`
+    only drop the first unit dimensions, in order:
+      e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6.
+    
+    Verification however has access to result type and does not need to infer.
+    The verifier calls `isRankReducedType(getSource(), getResult())` to 
+    determine whether the result type is rank-reduced from the source type.
+    This computes a so-called rank-reduction mask, consisting of dropped unit 
+    dims, to map the rank-reduced type to the source type by dropping ones:
+      e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2}
+           6x1 is a rank-reduced version of 1x6x1 by mask {0}
+           1x2x1x4 is a rank-reduced version of 1x1x2x1x1x4x1 by mask {1, 4, 6}
+             (remaining common 1 dimensions are matched eagerly)
+
     Example:
 
     ```
@@ -274,26 +293,43 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
       return getResult().getType().cast<RankedTensorType>();
     }
 
-    /// An extract_slice result type can be fully inferred from the source type
-    /// and the static representation of offsets, sizes and strides. Special
-    /// sentinels encode the dynamic case.
+    /// Compute the rank-reduction mask that can be applied to map the source
+    /// tensor type to the result tensor type by dropping unit dims.
+    llvm::Optional<llvm::SmallDenseSet<unsigned>>
+    computeRankReductionMask() {
+      return ::mlir::computeRankReductionMask(getSourceType().getShape(), 
+                                              getType().getShape());
+    };
+
+    /// An extract_slice result type can be inferred, when it is not
+    /// rank-reduced, from the source type and the static representation of 
+    /// offsets, sizes and strides. Special sentinels encode the dynamic case.
     static RankedTensorType inferResultType(
-      RankedTensorType sourceRankedTensorType,
+      ShapedType sourceShapedTensorType,
       ArrayRef<int64_t> staticOffsets,
       ArrayRef<int64_t> staticSizes,
       ArrayRef<int64_t> staticStrides);
     static RankedTensorType inferResultType(
-      RankedTensorType sourceRankedTensorType,
+      ShapedType sourceShapedTensorType,
       ArrayRef<OpFoldResult> staticOffsets,
       ArrayRef<OpFoldResult> staticSizes,
       ArrayRef<OpFoldResult> staticStrides);
-    static RankedTensorType inferRankReducedResultType(
+
+    /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
+    /// number of sizes), drop as many size 1 as needed to produce an inferred type
+    /// with the desired rank.
+    ///
+    /// Note that there may be multiple ways to compute this rank-reduced type:
+    ///   e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
+    ///
+    /// To disambiguate, this function always drops the first 1 sizes occurrences.
+    static RankedTensorType inferCanonicalRankReducedResultType(
       unsigned resultRank,
       RankedTensorType sourceRankedTensorType,
       ArrayRef<int64_t> staticOffsets,
       ArrayRef<int64_t> staticSizes,
       ArrayRef<int64_t> staticStrides);
-    static RankedTensorType inferRankReducedResultType(
+    static RankedTensorType inferCanonicalRankReducedResultType(
       unsigned resultRank,
       RankedTensorType sourceRankedTensorType,
       ArrayRef<OpFoldResult> staticOffsets,

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
index 10760685301fd..6c6bcabb499d9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
@@ -228,7 +228,7 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
                 return b.create<tensor::DimOp>(loc, target, dim).getResult();
               return b.getIndexAttr(shapedType.getDimSize(dim));
             });
-        auto t = tensor::ExtractSliceOp::inferRankReducedResultType(
+        auto t = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
             insertOp.getSourceType().getRank(),
             insertOp.getDest().getType().cast<RankedTensorType>(), mixedOffsets,
             mixedSizes, mixedStrides);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 970b628a15c4b..e1e7ed76d23cc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -499,10 +499,11 @@ struct UseRankReducedExtractSliceOp
     if (!reassociation ||
         reassociation->size() == static_cast<size_t>(resultType.getRank()))
       return failure();
-    auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType(
-                               reassociation->size(), sliceOp.getSourceType(),
-                               offsets, sizes, strides)
-                               .cast<RankedTensorType>();
+    auto rankReducedType =
+        tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
+            reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
+            strides)
+            .cast<RankedTensorType>();
 
     Location loc = sliceOp.getLoc();
     Value newSlice = rewriter.create<tensor::ExtractSliceOp>(

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 897af9fcee6f3..305e8f7e42394 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -957,25 +957,24 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
 // ExtractSliceOp
 //===----------------------------------------------------------------------===//
 
-/// An extract_slice op result type can be fully inferred from the source type
-/// and the static representation of offsets, sizes and strides. Special
-/// sentinels encode the dynamic case.
+/// An extract_slice result type can be inferred, when it is not
+/// rank-reduced, from the source type and the static representation of
+/// offsets, sizes and strides. Special sentinels encode the dynamic case.
 RankedTensorType ExtractSliceOp::inferResultType(
-    RankedTensorType sourceRankedTensorType, ArrayRef<int64_t> staticOffsets,
+    ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
     ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
   // An extract_slice op may specify only a leading subset of offset/sizes/
   // strides in which case we complete with offset=0, sizes from memref type and
   // strides=1.
-  unsigned rank = sourceRankedTensorType.getRank();
-  (void)rank;
-  assert(staticSizes.size() == rank &&
+  assert(static_cast<int64_t>(staticSizes.size()) ==
+             sourceShapedTensorType.getRank() &&
          "unexpected staticSizes not equal to rank of source");
   return RankedTensorType::get(staticSizes,
-                               sourceRankedTensorType.getElementType());
+                               sourceShapedTensorType.getElementType());
 }
 
 RankedTensorType ExtractSliceOp::inferResultType(
-    RankedTensorType sourceRankedTensorType, ArrayRef<OpFoldResult> offsets,
+    ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
@@ -985,26 +984,33 @@ RankedTensorType ExtractSliceOp::inferResultType(
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
                              ShapedType::kDynamicStrideOrOffset);
-  return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
+  return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
                                          staticSizes, staticStrides);
 }
 
-/// An extract_slice op result type can be fully inferred from the source type
-/// and the static representation of offsets, sizes and strides. Special
-/// sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferRankReducedResultType(
-    unsigned resultRank, RankedTensorType sourceRankedTensorType,
+/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
+/// number of sizes), drop as many size 1 as needed to produce an inferred type
+/// with the desired rank.
+///
+/// Note that there may be multiple ways to compute this rank-reduced type:
+///   e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
+///
+/// To disambiguate, this function always drops the first 1 sizes occurrences.
+RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
+    unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
     ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
     ArrayRef<int64_t> strides) {
+  // Type inferred in the absence of rank-reducing behavior.
   auto inferredType =
       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
           .cast<RankedTensorType>();
-  int rankDiff = inferredType.getRank() - resultRank;
+  int rankDiff = inferredType.getRank() - desiredResultRank;
   if (rankDiff > 0) {
     auto shape = inferredType.getShape();
     llvm::SmallBitVector dimsToProject =
         getPositionsOfShapeOne(rankDiff, shape);
     SmallVector<int64_t> projectedShape;
+    // Best effort rank-reducing: drop 1s in order.
     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
       if (!dimsToProject.test(pos))
         projectedShape.push_back(shape[pos]);
@@ -1014,8 +1020,8 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType(
   return inferredType;
 }
 
-RankedTensorType ExtractSliceOp::inferRankReducedResultType(
-    unsigned resultRank, RankedTensorType sourceRankedTensorType,
+RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
+    unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
     ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
     ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
@@ -1026,8 +1032,8 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType(
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
                              ShapedType::kDynamicStrideOrOffset);
-  return ExtractSliceOp::inferRankReducedResultType(
-      resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+  return ExtractSliceOp::inferCanonicalRankReducedResultType(
+      desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
 }
 
@@ -1123,26 +1129,6 @@ LogicalResult ExtractSliceOp::verify() {
   return produceSliceErrorMsg(result, *this, expectedType);
 }
 
-/// Infer the canonical type of the result of an extract_slice op. Returns a
-/// type with rank `resultRank` that is either the rank of the rank-reduced
-/// type, or the non-rank-reduced type.
-static RankedTensorType
-getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
-                            ArrayRef<OpFoldResult> mixedOffsets,
-                            ArrayRef<OpFoldResult> mixedSizes,
-                            ArrayRef<OpFoldResult> mixedStrides) {
-  auto resultType =
-      ExtractSliceOp::inferRankReducedResultType(
-          resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
-          .cast<RankedTensorType>();
-  if (resultType.getRank() != resultRank) {
-    resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
-                                                 mixedSizes, mixedStrides)
-                     .cast<RankedTensorType>();
-  }
-  return resultType;
-}
-
 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
   ArrayRef<int64_t> resultShape = getType().getShape();
   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
@@ -1205,7 +1191,7 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
 
   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
                                 PatternRewriter &rewriter) const override {
-    // Any constant operand, just return to let SubViewOpConstantFolder kick in.
+    // Any constant operand, just return to let the constant folder kick in.
     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
           return matchPattern(operand, matchConstantIndex());
         }))
@@ -1219,10 +1205,11 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
       return failure();
 
     /// Deduce the type of the result to use for the canonicalized operation.
-    RankedTensorType resultType = getCanonicalSliceResultType(
-        sliceOp.getType().getRank(), sliceOp.getSourceType(),
-        sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
-        sliceOp.getMixedStrides());
+    RankedTensorType resultType =
+        ExtractSliceOp::inferCanonicalRankReducedResultType(
+            sliceOp.getType().getRank(), sliceOp.getSourceType(),
+            sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+            sliceOp.getMixedStrides());
     Value newSlice = rewriter.create<ExtractSliceOp>(
         sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
         sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
@@ -1366,9 +1353,9 @@ struct SliceReturnTypeCanonicalizer {
                               ArrayRef<OpFoldResult> mixedOffsets,
                               ArrayRef<OpFoldResult> mixedSizes,
                               ArrayRef<OpFoldResult> mixedStrides) {
-    return getCanonicalSliceResultType(op.getType().getRank(),
-                                       op.getSourceType(), mixedOffsets,
-                                       mixedSizes, mixedStrides);
+    return ExtractSliceOp::inferCanonicalRankReducedResultType(
+        op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
+        mixedStrides);
   }
 };
 
@@ -1506,9 +1493,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
                     ArrayAttr staticStrides,
                     ShapedType *expectedType = nullptr) {
   // insert_slice is the inverse of extract_slice, use the same type inference.
-  auto expected = ExtractSliceOp::inferRankReducedResultType(
-                      srcType.getRank(), dstType.cast<RankedTensorType>(),
-                      extractFromI64ArrayAttr(staticOffsets),
+  auto expected = ExtractSliceOp::inferResultType(
+                      dstType, extractFromI64ArrayAttr(staticOffsets),
                       extractFromI64ArrayAttr(staticSizes),
                       extractFromI64ArrayAttr(staticStrides))
                       .cast<ShapedType>();
@@ -1600,7 +1586,7 @@ class InsertSliceOpConstantArgumentFolder final
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    auto sourceType = ExtractSliceOp::inferRankReducedResultType(
+    auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
         insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
         mixedOffsets, mixedSizes, mixedStrides);
     Value toInsert = insertSliceOp.getSource();


        


More information about the Mlir-commits mailing list