[Mlir-commits] [mlir] 4281946 - [mlir][Tensor] Add ReifyRankedShapedTypeOpInterface to tensor.extract_slice.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 7 17:10:59 PDT 2021
Author: MaheshRavishankar
Date: 2021-10-07T17:10:35-07:00
New Revision: 4281946390989a0392e42e3f02b9d93a80234674
URL: https://github.com/llvm/llvm-project/commit/4281946390989a0392e42e3f02b9d93a80234674
DIFF: https://github.com/llvm/llvm-project/commit/4281946390989a0392e42e3f02b9d93a80234674.diff
LOG: [mlir][Tensor] Add ReifyRankedShapedTypeOpInterface to tensor.extract_slice.
Differential Revision: https://reviews.llvm.org/D111263
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 78755c898835a..4f84bcf7270d4 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -158,8 +158,10 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
//===----------------------------------------------------------------------===//
def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
- Tensor_Dialect, "extract_slice", [NoSideEffect, AttrSizedOperandSegments,
- OffsetSizeAndStrideOpInterface]> {
+ Tensor_Dialect, "extract_slice",
+ [NoSideEffect, AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ OffsetSizeAndStrideOpInterface]> {
let summary = "extract slice operation";
let description = [{
The "extract_slice" operation extract a tensor from another tensor as
@@ -284,6 +286,11 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
/// Return the number of leading operands before the `offsets`, `sizes` and
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+ /// Return the dimensions of the source that are dropped in the
+ /// result when the result is rank-reduced.
+ llvm::SmallDenseSet<unsigned> getDroppedDims();
+
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2a55223ca9a8b..dc94c27c818c8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -277,9 +277,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
unsigned unsignedIndex = index.getValue().getZExtValue();
if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
- assert(sliceOp.isDynamicSize(unsignedIndex) &&
- "Expected dynamic slice size");
- return sliceOp.getDynamicSize(unsignedIndex);
+ // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
+ // `resolve-shaped-type-result-dims` pass.
+ if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
+ sliceOp.isDynamicSize(unsignedIndex)) {
+ return {sliceOp.getDynamicSize(unsignedIndex)};
+ }
}
// dim(cast) -> dim
@@ -895,6 +898,46 @@ getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
return resultType;
}
+llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
+ llvm::SmallDenseSet<unsigned> droppedDims;
+ ArrayRef<int64_t> resultShape = getType().getShape();
+ SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+ unsigned shapePos = 0;
+ for (auto size : enumerate(mixedSizes)) {
+ Optional<int64_t> sizeVal = getConstantIntValue(size.value());
+ // If the size is not 1, or if the current matched dimension of the result
+ // is the same static shape as the size value (which is 1), then the
+ // dimension is preserved.
+ if (!sizeVal || sizeVal.getValue() != 1 ||
+ (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
+ shapePos++;
+ continue;
+ }
+ droppedDims.insert(size.index());
+ }
+ return droppedDims;
+}
+
+LogicalResult ExtractSliceOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ reifiedReturnShapes.resize(1);
+ reifiedReturnShapes[0].reserve(getType().getRank());
+ SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+ llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
+ Location loc = getLoc();
+ for (auto size : enumerate(mixedSizes)) {
+ if (droppedDims.count(size.index()))
+ continue;
+ if (auto attr = size.value().dyn_cast<Attribute>()) {
+ reifiedReturnShapes[0].push_back(builder.create<ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt()));
+ continue;
+ }
+ reifiedReturnShapes[0].push_back(size.value().get<Value>());
+ }
+ return success();
+}
+
namespace {
/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
/// This essentially pushes memref_cast past its consuming slice when
diff --git a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir
index 7717970eb9af2..81df392481aca 100644
--- a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir
@@ -25,3 +25,120 @@ func @insert_slice(
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK: return %[[D0]], %[[D1]], %[[D2]]
+
+// -----
+
+func @extract_slice(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 : index,
+ %arg3 : index) -> (index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, %arg2, %arg3] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
+ %2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
+ %3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
+ return %1, %2, %3 : index, index, index
+}
+// CHECK-LABEL: func @extract_slice(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]], %[[ARG2]], %[[ARG3]]
+
+// -----
+
+func @extract_slice_rank_reduced_1(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c0 = constant 0 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_1(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_2(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c0 = constant 0 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x1xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x1xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_2(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_3(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c1 = constant 1 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<1x?xf32>
+ %1 = tensor.dim %0, %c1 : tensor<1x?xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_3(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_4(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c1 = constant 1 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<1x?x1xf32>
+ %1 = tensor.dim %0, %c1 : tensor<1x?x1xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_4(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_5(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_5(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]], %[[ARG2]]
+
+// -----
+
+func @extract_slice_rank_reduced_6(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x1x?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x1x?xf32>
+ %2 = tensor.dim %0, %c2 : tensor<?x1x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_6(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]], %[[ARG2]]
More information about the Mlir-commits
mailing list