[Mlir-commits] [mlir] 24db783 - [mlir] NFC - Extend inferResultType API for SubViewOp and SubTensorOp

Nicolas Vasilache llvmlistbot at llvm.org
Wed Feb 10 15:04:58 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-10T22:55:28Z
New Revision: 24db78393804cc255baaf474ac405aaad458a84d

URL: https://github.com/llvm/llvm-project/commit/24db78393804cc255baaf474ac405aaad458a84d
DIFF: https://github.com/llvm/llvm-project/commit/24db78393804cc255baaf474ac405aaad458a84d.diff

LOG: [mlir] NFC - Extend inferResultType API for SubViewOp and SubTensorOp

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b9545e21d390..3d6eee4633b6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3001,6 +3001,10 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
                                 ArrayRef<int64_t> staticOffsets,
                                 ArrayRef<int64_t> staticSizes,
                                 ArrayRef<int64_t> staticStrides);
+    static Type inferResultType(MemRefType sourceMemRefType,
+                                ArrayRef<OpFoldResult> staticOffsets,
+                                ArrayRef<OpFoldResult> staticSizes,
+                                ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
@@ -3123,6 +3127,10 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
                                 ArrayRef<int64_t> staticOffsets,
                                 ArrayRef<int64_t> staticSizes,
                                 ArrayRef<int64_t> staticStrides);
+    static Type inferResultType(RankedTensorType sourceRankedTensorType,
+                                ArrayRef<OpFoldResult> staticOffsets,
+                                ArrayRef<OpFoldResult> staticSizes,
+                                ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index ca2e2731df03..9af00be6368e 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2831,6 +2831,23 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
       sourceMemRefType.getMemorySpace());
 }
 
+Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
+                                ArrayRef<OpFoldResult> leadingStaticOffsets,
+                                ArrayRef<OpFoldResult> leadingStaticSizes,
+                                ArrayRef<OpFoldResult> leadingStaticStrides) {
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
+                                    staticSizes, staticStrides)
+      .cast<MemRefType>();
+}
+
 // Build a SubViewOp with mixed static and dynamic entries and custom result
 // type. If the type passed is nullptr, it is inferred.
 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
@@ -3386,6 +3403,23 @@ Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
                                sourceRankedTensorType.getElementType());
 }
 
+Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
+                                  ArrayRef<OpFoldResult> leadingStaticOffsets,
+                                  ArrayRef<OpFoldResult> leadingStaticSizes,
+                                  ArrayRef<OpFoldResult> leadingStaticStrides) {
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  return SubTensorOp::inferResultType(sourceRankedTensorType, staticOffsets,
+                                      staticSizes, staticStrides)
+      .cast<RankedTensorType>();
+}
+
 // Build a SubTensorOp with mixed static and dynamic entries and custom result
 // type. If the type passed is nullptr, it is inferred.
 void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,


        


More information about the Mlir-commits mailing list