[Mlir-commits] [mlir] b3c227a - [mlir] Better support for rank-reducing subview / subtensor type inference.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 19 00:35:10 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-19T08:30:50Z
New Revision: b3c227a25a231248e3752918c2cac1a7b9414ef1

URL: https://github.com/llvm/llvm-project/commit/b3c227a25a231248e3752918c2cac1a7b9414ef1
DIFF: https://github.com/llvm/llvm-project/commit/b3c227a25a231248e3752918c2cac1a7b9414ef1.diff

LOG: [mlir] Better support for rank-reducing subview / subtensor type inference.

Differential Revision: https://reviews.llvm.org/D96995

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 64279c8fce3c..82b4717b6bc1 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2788,6 +2788,16 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
                                 ArrayRef<OpFoldResult> staticOffsets,
                                 ArrayRef<OpFoldResult> staticSizes,
                                 ArrayRef<OpFoldResult> staticStrides);
+    static Type inferRankReducedResultType(unsigned resultRank,
+                                           MemRefType sourceMemRefType,
+                                           ArrayRef<int64_t> staticOffsets,
+                                           ArrayRef<int64_t> staticSizes,
+                                           ArrayRef<int64_t> staticStrides);
+    static Type inferRankReducedResultType(unsigned resultRank,
+                                           MemRefType sourceMemRefType,
+                                           ArrayRef<OpFoldResult> staticOffsets,
+                                           ArrayRef<OpFoldResult> staticSizes,
+                                           ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
@@ -2914,6 +2924,16 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
                                 ArrayRef<OpFoldResult> staticOffsets,
                                 ArrayRef<OpFoldResult> staticSizes,
                                 ArrayRef<OpFoldResult> staticStrides);
+    static Type inferRankReducedResultType(unsigned resultRank,
+                                           RankedTensorType sourceRankedTensorType,
+                                           ArrayRef<int64_t> staticOffsets,
+                                           ArrayRef<int64_t> staticSizes,
+                                           ArrayRef<int64_t> staticStrides);
+    static Type inferRankReducedResultType(unsigned resultRank,
+                                           RankedTensorType sourceRankedTensorType,
+                                           ArrayRef<OpFoldResult> staticOffsets,
+                                           ArrayRef<OpFoldResult> staticSizes,
+                                           ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
@@ -3027,6 +3047,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
   }];
+
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index e916de4c0658..084d3fdfb2bf 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2891,8 +2891,68 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
                              staticStrides, ShapedType::kDynamicStrideOrOffset);
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
-                                    staticSizes, staticStrides)
-      .cast<MemRefType>();
+                                    staticSizes, staticStrides);
+}
+
+static void
+getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
+                       llvm::SmallDenseSet<unsigned> &dimsToProject) {
+  dimsToProject.reserve(rank);
+  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
+    if (shape[pos] == 1) {
+      dimsToProject.insert(pos);
+      --rank;
+    }
+  }
+}
+
+Type SubViewOp::inferRankReducedResultType(
+    unsigned resultRank, MemRefType sourceRankedTensorType,
+    ArrayRef<int64_t> leadingStaticOffsets,
+    ArrayRef<int64_t> leadingStaticSizes,
+    ArrayRef<int64_t> leadingStaticStrides) {
+  auto inferredType =
+      inferResultType(sourceRankedTensorType, leadingStaticOffsets,
+                      leadingStaticSizes, leadingStaticStrides)
+          .cast<MemRefType>();
+  assert(inferredType.getRank() >= resultRank && "expected ");
+  int rankDiff = inferredType.getRank() - resultRank;
+  if (rankDiff > 0) {
+    auto shape = inferredType.getShape();
+    llvm::SmallDenseSet<unsigned> dimsToProject;
+    getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
+    SmallVector<int64_t> projectedShape;
+    for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
+      if (!dimsToProject.contains(pos))
+        projectedShape.push_back(shape[pos]);
+
+    AffineMap map;
+    auto maps = inferredType.getAffineMaps();
+    if (!maps.empty() && maps.front())
+      map = getProjectedMap(maps.front(), dimsToProject);
+    inferredType =
+        MemRefType::get(projectedShape, inferredType.getElementType(), map,
+                        inferredType.getMemorySpace());
+  }
+  return inferredType;
+}
+
+Type SubViewOp::inferRankReducedResultType(
+    unsigned resultRank, MemRefType sourceRankedTensorType,
+    ArrayRef<OpFoldResult> leadingStaticOffsets,
+    ArrayRef<OpFoldResult> leadingStaticSizes,
+    ArrayRef<OpFoldResult> leadingStaticStrides) {
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  return SubViewOp::inferRankReducedResultType(
+      resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+      staticStrides);
 }
 
 // Build a SubViewOp with mixed static and dynamic entries and custom result
@@ -3407,29 +3467,11 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
     /// the cast source operand type and the SubViewOp static information. This
     /// is the resulting type if the MemRefCastOp were folded.
-    auto resultType = SubViewOp::inferResultType(
-                          castOp.source().getType().cast<MemRefType>(),
-                          extractFromI64ArrayAttr(subViewOp.static_offsets()),
-                          extractFromI64ArrayAttr(subViewOp.static_sizes()),
-                          extractFromI64ArrayAttr(subViewOp.static_strides()))
-                          .cast<MemRefType>();
-    uint32_t rankDiff =
-        subViewOp.getSourceType().getRank() - subViewOp.getType().getRank();
-    if (rankDiff > 0) {
-      auto shape = resultType.getShape();
-      auto projectedShape = shape.drop_front(rankDiff);
-      AffineMap map;
-      auto maps = resultType.getAffineMaps();
-      if (!maps.empty() && maps.front()) {
-        auto optionalUnusedDimsMask =
-            computeRankReductionMask(shape, projectedShape);
-        llvm::SmallDenseSet<unsigned> dimsToProject =
-            optionalUnusedDimsMask.getValue();
-        map = getProjectedMap(maps.front(), dimsToProject);
-      }
-      resultType = MemRefType::get(projectedShape, resultType.getElementType(),
-                                   map, resultType.getMemorySpace());
-    }
+    auto resultType = SubViewOp::inferRankReducedResultType(
+        subViewOp.getType().getRank(),
+        castOp.source().getType().cast<MemRefType>(),
+        subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
+        subViewOp.getMixedStrides());
     Value newSubView = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
@@ -3492,8 +3534,52 @@ Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
                              staticStrides, ShapedType::kDynamicStrideOrOffset);
   return SubTensorOp::inferResultType(sourceRankedTensorType, staticOffsets,
-                                      staticSizes, staticStrides)
-      .cast<RankedTensorType>();
+                                      staticSizes, staticStrides);
+}
+
+/// A subtensor result type can be fully inferred from the source type and the
+/// static representation of offsets, sizes and strides. Special sentinels
+/// encode the dynamic case.
+Type SubTensorOp::inferRankReducedResultType(
+    unsigned resultRank, RankedTensorType sourceRankedTensorType,
+    ArrayRef<int64_t> leadingStaticOffsets,
+    ArrayRef<int64_t> leadingStaticSizes,
+    ArrayRef<int64_t> leadingStaticStrides) {
+  auto inferredType =
+      inferResultType(sourceRankedTensorType, leadingStaticOffsets,
+                      leadingStaticSizes, leadingStaticStrides)
+          .cast<RankedTensorType>();
+  int rankDiff = inferredType.getRank() - resultRank;
+  if (rankDiff > 0) {
+    auto shape = inferredType.getShape();
+    llvm::SmallDenseSet<unsigned> dimsToProject;
+    getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
+    SmallVector<int64_t> projectedShape;
+    for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
+      if (!dimsToProject.contains(pos))
+        projectedShape.push_back(shape[pos]);
+    inferredType =
+        RankedTensorType::get(projectedShape, inferredType.getElementType());
+  }
+  return inferredType;
+}
+
+Type SubTensorOp::inferRankReducedResultType(
+    unsigned resultRank, RankedTensorType sourceRankedTensorType,
+    ArrayRef<OpFoldResult> leadingStaticOffsets,
+    ArrayRef<OpFoldResult> leadingStaticSizes,
+    ArrayRef<OpFoldResult> leadingStaticStrides) {
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  return SubTensorOp::inferRankReducedResultType(
+      resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+      staticStrides);
 }
 
 // Build a SubTensorOp with mixed static and dynamic entries and custom result
@@ -3571,11 +3657,65 @@ static LogicalResult verify(SubTensorOp op) {
   return produceSubViewErrorMsg(result, op, expectedType);
 }
 
+namespace {
+/// Pattern to rewrite a subtensor op with tensor::Cast arguments.
+/// This essentially pushes memref_cast past its consuming subtensor when
+/// `canFoldIntoConsumerOp` is true.
+///
+/// Example:
+/// ```
+///   %0 = tensorcast %V : tensor<16x16xf32> to tensor<?x?xf32>
+///   %1 = subtensor %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to tensor<3x4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+///   %0 = subtensor %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to tensor<3x4xf32>
+///   %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
+/// ```
+class SubTensorOpCastFolder final : public OpRewritePattern<SubTensorOp> {
+public:
+  using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
+                                PatternRewriter &rewriter) const override {
+    // Any constant operand, just return to let SubViewOpConstantFolder kick in.
+    if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) {
+          return matchPattern(operand, m_ConstantIndex());
+        }))
+      return failure();
+
+    auto castOp = subTensorOp.source().getDefiningOp<tensor::CastOp>();
+    if (!castOp)
+      return failure();
+
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType`
+    /// on the cast source operand type and the SubTensorOp static information.
+    /// This is the resulting type if the tensor::CastOp were folded and
+    /// rank-reduced to the desired result rank.
+    auto resultType = SubTensorOp::inferRankReducedResultType(
+        subTensorOp.getType().getRank(),
+        castOp.source().getType().cast<RankedTensorType>(),
+        subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
+        subTensorOp.getMixedStrides());
+    Value newSubTensor = rewriter.create<SubTensorOp>(
+        subTensorOp.getLoc(), resultType, castOp.source(),
+        subTensorOp.offsets(), subTensorOp.sizes(), subTensorOp.strides(),
+        subTensorOp.static_offsets(), subTensorOp.static_sizes(),
+        subTensorOp.static_strides());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        subTensorOp, subTensorOp.getType(), newSubTensor);
+    return success();
+  }
+};
+} // namespace
+
 void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results
-      .insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>>(
-          context);
+  results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<SubTensorOp>,
+                 SubTensorOpCastFolder>(context);
 }
 
 //

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 9247152e8677..5c437ae3dda4 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -146,13 +146,13 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
 
 // CHECK-LABEL: func @subview_of_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-//       CHECK:   %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
 //       CHECK:   %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
 //       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
 func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
   memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
   %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
-  %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
+  %1 = subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
     memref<?x?x16x32xi8> to
     memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
   return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
@@ -176,3 +176,14 @@ func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x
   return %0 : tensor<4x6x16x32xi8>
 }
 
+// CHECK-LABEL: func @rank_reducing_tensor_of_cast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+//       CHECK:   %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
+// Tensor cast is moved after subtensor and then gets canonicalized away.
+//   CHECK-NOT:   tensor.cast
+//       CHECK:   return %[[S]] : tensor<16x32xi8>
+func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> {
+  %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+  %1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
+  return %1 : tensor<16x32xi8>
+}

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 62c07dd8a063..3bc3eeee8354 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1034,8 +1034,8 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
 
 // CHECK-LABEL: func @subtensor
 // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
-func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) 
-  -> tensor<?x?x?xf32> 
+func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
+  -> tensor<?x?x?xf32>
 {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -1045,16 +1045,18 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
 
   // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] :
   // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32>
-  // CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor<?x?x?xf32>
+  // tensor.cast gets folded away in consumer.
+  //  CHECK-NOT: tensor.cast
   %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
     : tensor<8x16x4xf32> to tensor<?x?x?xf32>
 
   // Test: subtensor with one dynamic operand can also be folded.
   // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] :
-  // CHECK-SAME: tensor<?x?x?xf32> to tensor<2x?x2xf32>
+  // CHECK-SAME: tensor<7x11x2xf32> to tensor<2x?x2xf32>
   // CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor<?x?x?xf32>
   %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1]
     : tensor<?x?x?xf32> to tensor<?x?x?xf32>
 
   return %2 : tensor<?x?x?xf32>
 }
+


        


More information about the Mlir-commits mailing list