[Mlir-commits] [mlir] 24db783 - [mlir] NFC - Extend inferResultType API for SubViewOp and SubTensorOp
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Feb 10 15:04:58 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-10T22:55:28Z
New Revision: 24db78393804cc255baaf474ac405aaad458a84d
URL: https://github.com/llvm/llvm-project/commit/24db78393804cc255baaf474ac405aaad458a84d
DIFF: https://github.com/llvm/llvm-project/commit/24db78393804cc255baaf474ac405aaad458a84d.diff
LOG: [mlir] NFC - Extend inferResultType API for SubViewOp and SubTensorOp
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b9545e21d390..3d6eee4633b6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3001,6 +3001,10 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
+ static Type inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
@@ -3123,6 +3127,10 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
+ static Type inferResultType(RankedTensorType sourceRankedTensorType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index ca2e2731df03..9af00be6368e 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2831,6 +2831,23 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
sourceMemRefType.getMemorySpace());
}
+Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> leadingStaticOffsets,
+ ArrayRef<OpFoldResult> leadingStaticSizes,
+ ArrayRef<OpFoldResult> leadingStaticStrides) {
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+ staticOffsets, ShapedType::kDynamicStrideOrOffset);
+ dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+ ShapedType::kDynamicSize);
+ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+ staticStrides, ShapedType::kDynamicStrideOrOffset);
+ return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
+ staticSizes, staticStrides)
+ .cast<MemRefType>();
+}
+
// Build a SubViewOp with mixed static and dynamic entries and custom result
// type. If the type passed is nullptr, it is inferred.
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
@@ -3386,6 +3403,23 @@ Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
sourceRankedTensorType.getElementType());
}
+Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
+ ArrayRef<OpFoldResult> leadingStaticOffsets,
+ ArrayRef<OpFoldResult> leadingStaticSizes,
+ ArrayRef<OpFoldResult> leadingStaticStrides) {
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+ staticOffsets, ShapedType::kDynamicStrideOrOffset);
+ dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+ ShapedType::kDynamicSize);
+ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+ staticStrides, ShapedType::kDynamicStrideOrOffset);
+ return SubTensorOp::inferResultType(sourceRankedTensorType, staticOffsets,
+ staticSizes, staticStrides)
+ .cast<RankedTensorType>();
+}
+
// Build a SubTensorOp with mixed static and dynamic entries and custom result
// type. If the type passed is nullptr, it is inferred.
void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
More information about the Mlir-commits
mailing list