[Mlir-commits] [mlir] 4281946 - [mlir][Tensor] Add ReifyRankedShapedTypeOpInterface to tensor.extract_slice.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 7 17:10:59 PDT 2021


Author: MaheshRavishankar
Date: 2021-10-07T17:10:35-07:00
New Revision: 4281946390989a0392e42e3f02b9d93a80234674

URL: https://github.com/llvm/llvm-project/commit/4281946390989a0392e42e3f02b9d93a80234674
DIFF: https://github.com/llvm/llvm-project/commit/4281946390989a0392e42e3f02b9d93a80234674.diff

LOG: [mlir][Tensor] Add ReifyRankedShapedTypeOpInterface to tensor.extract_slice.

Differential Revision: https://reviews.llvm.org/D111263

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 78755c898835a..4f84bcf7270d4 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -158,8 +158,10 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
 //===----------------------------------------------------------------------===//
 
 def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
-    Tensor_Dialect, "extract_slice", [NoSideEffect, AttrSizedOperandSegments,
-                                      OffsetSizeAndStrideOpInterface]> {
+    Tensor_Dialect, "extract_slice",
+    [NoSideEffect, AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     OffsetSizeAndStrideOpInterface]> {
   let summary = "extract slice operation";
   let description = [{
     The "extract_slice" operation extract a tensor from another tensor as
@@ -284,6 +286,11 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
     /// Return the number of leading operands before the `offsets`, `sizes` and
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+    /// Return the dimensions of the source that are dropped in the
+    /// result when the result is rank-reduced.
+    llvm::SmallDenseSet<unsigned> getDroppedDims();
+
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2a55223ca9a8b..dc94c27c818c8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -277,9 +277,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
   unsigned unsignedIndex = index.getValue().getZExtValue();
 
   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
-    assert(sliceOp.isDynamicSize(unsignedIndex) &&
-           "Expected dynamic slice size");
-    return sliceOp.getDynamicSize(unsignedIndex);
+    // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
+    // `resolve-shaped-type-result-dims` pass.
+    if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
+        sliceOp.isDynamicSize(unsignedIndex)) {
+      return {sliceOp.getDynamicSize(unsignedIndex)};
+    }
   }
 
   // dim(cast) -> dim
@@ -895,6 +898,46 @@ getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
   return resultType;
 }
 
+llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
+  llvm::SmallDenseSet<unsigned> droppedDims;
+  ArrayRef<int64_t> resultShape = getType().getShape();
+  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+  unsigned shapePos = 0;
+  for (auto size : enumerate(mixedSizes)) {
+    Optional<int64_t> sizeVal = getConstantIntValue(size.value());
+    // If the size is not 1, or if the current matched dimension of the result
+    // is the same static shape as the size value (which is 1), then the
+    // dimension is preserved.
+    if (!sizeVal || sizeVal.getValue() != 1 ||
+        (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
+      shapePos++;
+      continue;
+    }
+    droppedDims.insert(size.index());
+  }
+  return droppedDims;
+}
+
+LogicalResult ExtractSliceOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  reifiedReturnShapes.resize(1);
+  reifiedReturnShapes[0].reserve(getType().getRank());
+  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+  llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
+  Location loc = getLoc();
+  for (auto size : enumerate(mixedSizes)) {
+    if (droppedDims.count(size.index()))
+      continue;
+    if (auto attr = size.value().dyn_cast<Attribute>()) {
+      reifiedReturnShapes[0].push_back(builder.create<ConstantIndexOp>(
+          loc, attr.cast<IntegerAttr>().getInt()));
+      continue;
+    }
+    reifiedReturnShapes[0].push_back(size.value().get<Value>());
+  }
+  return success();
+}
+
 namespace {
 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
 /// This essentially pushes memref_cast past its consuming slice when

diff  --git a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir
index 7717970eb9af2..81df392481aca 100644
--- a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir
@@ -25,3 +25,120 @@ func @insert_slice(
 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //   CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
 //       CHECK:   return %[[D0]], %[[D1]], %[[D2]]
+
+// -----
+
+func @extract_slice(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 : index,
+    %arg3 : index) -> (index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, %arg2, %arg3] [1, 1, 1] :
+      tensor<?x?x?xf32> to tensor<?x?x?xf32>
+  %1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
+  %2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
+  %3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
+  return %1, %2, %3 : index, index, index
+}
+// CHECK-LABEL: func @extract_slice(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
+//       CHECK:   return %[[ARG1]], %[[ARG2]], %[[ARG3]]
+
+// -----
+
+func @extract_slice_rank_reduced_1(%arg0 : tensor<?x?x?xf32>,
+    %arg1 : index) -> index {
+  %c0 = constant 0 : index
+  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+     tensor<?x?x?xf32> to tensor<?xf32>
+  %1 = tensor.dim %0, %c0 : tensor<?xf32>
+  return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_1(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_2(%arg0 : tensor<?x?x?xf32>,
+    %arg1 : index) -> index {
+  %c0 = constant 0 : index
+  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+     tensor<?x?x?xf32> to tensor<?x1xf32>
+  %1 = tensor.dim %0, %c0 : tensor<?x1xf32>
+  return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_2(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_3(%arg0 : tensor<?x?x?xf32>,
+    %arg1 : index) -> index {
+  %c1 = constant 1 : index
+  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+     tensor<?x?x?xf32> to tensor<1x?xf32>
+  %1 = tensor.dim %0, %c1 : tensor<1x?xf32>
+  return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_3(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_4(%arg0 : tensor<?x?x?xf32>,
+    %arg1 : index) -> index {
+  %c1 = constant 1 : index
+  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+     tensor<?x?x?xf32> to tensor<1x?x1xf32>
+  %1 = tensor.dim %0, %c1 : tensor<1x?x1xf32>
+  return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_4(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_5(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
+     tensor<?x?x?xf32> to tensor<?x?xf32>
+  %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+  %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
+  return %1, %2 : index, index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_5(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+//       CHECK:   return %[[ARG1]], %[[ARG2]]
+
+// -----
+
+func @extract_slice_rank_reduced_6(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> (index, index) {
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
+     tensor<?x?x?xf32> to tensor<?x1x?xf32>
+  %1 = tensor.dim %0, %c0 : tensor<?x1x?xf32>
+  %2 = tensor.dim %0, %c2 : tensor<?x1x?xf32>
+  return %1, %2 : index, index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_6(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+//       CHECK:   return %[[ARG1]], %[[ARG2]]


        


More information about the Mlir-commits mailing list