[Mlir-commits] [mlir] e3de249 - [mlir] Add a subtensor operation

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 2 02:37:46 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-02T05:35:30-04:00
New Revision: e3de249a4c94d6962b36c2b4747c134d152bed37

URL: https://github.com/llvm/llvm-project/commit/e3de249a4c94d6962b36c2b4747c134d152bed37
DIFF: https://github.com/llvm/llvm-project/commit/e3de249a4c94d6962b36c2b4747c134d152bed37.diff

LOG: [mlir] Add a subtensor operation

This revision introduces a `subtensor` op, which is the counterpart of `subview` for a tensor operand. This also refactors the relevant pieces to allow reusing the `subview` implementation where appropriate.

This operation will be used to implement tiling for Linalg on tensors.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 76ce4eb30e7f..b4e5be58bad7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -185,7 +185,7 @@ struct ProcInfo {
   Value nprocs;
 };
 using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
-    OpBuilder &b, Location loc, ArrayRef<SubViewOp::Range> parallelLoopRanges)>;
+    OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
 
 /// Options that allow distribution of loops generated in Linalg transforms to
 /// processors while generating the loops.
@@ -216,7 +216,7 @@ struct GenerateLoopNest {
                                 AffineIndexedValue, StdIndexedValue>::type;
 
   static void
-  doit(ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+  doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
        ArrayRef<Attribute> iteratorTypes,
        function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
        Optional<LinalgLoopDistributionOptions> = None);

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 2500343c0af3..fbe735e31cff 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -33,6 +33,17 @@ class Builder;
 class FuncOp;
 class OpBuilder;
 
+/// Auxiliary range data structure to unpack the offset, size and stride
+/// operands of the SubViewOp / SubTensorOp into a list of triples.
+/// Such a list of triple is sometimes more convenient to manipulate.
+struct Range {
+  Value offset;
+  Value size;
+  Value stride;
+};
+
+raw_ostream &operator<<(raw_ostream &os, Range &range);
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
 
@@ -300,8 +311,6 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
                                   SmallVectorImpl<Value> &operands,
                                   unsigned &numDims);
 
-raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
-
 /// Determines whether MemRefCastOp casts to a more dynamic version of the
 /// source memref. This is useful to to fold a memref_cast into a consuming op
 /// and implement canonicalization patterns for ops in 
diff erent dialects that

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index ff1a82c26561..dbc3e4ca521b 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2706,11 +2706,214 @@ def SubIOp : IntArithmeticOp<"subi"> {
 // SubViewOp
 //===----------------------------------------------------------------------===//
 
-def SubViewOp : Std_Op<"subview", [
-    AttrSizedOperandSegments,
-    DeclareOpInterfaceMethods<ViewLikeOpInterface>,
-    NoSideEffect,
-  ]> {
+class BaseOpWithOffsetSizesAndStrides<string mnemonic, list<OpTrait> traits = []> :
+    Std_Op<mnemonic,
+           !listconcat(traits, [NoSideEffect, AttrSizedOperandSegments])> {
+  let builders = [
+    // Build a SubViewOp with mixed static and dynamic entries.
+    OpBuilder<
+      "Value source, ArrayRef<int64_t> staticOffsets, "
+      "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+      "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+      "ArrayRef<NamedAttribute> attrs = {}">,
+    // Build a SubViewOp with all dynamic entries.
+    OpBuilder<
+      "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, "
+      "ArrayRef<NamedAttribute> attrs = {}">
+  ];
+
+  code extraBaseClassDeclaration = [{
+    /// Returns the number of dynamic offset operands.
+    int64_t getNumOffsets() { return llvm::size(offsets()); }
+
+    /// Returns the number of dynamic size operands.
+    int64_t getNumSizes() { return llvm::size(sizes()); }
+
+    /// Returns the number of dynamic stride operands.
+    int64_t getNumStrides() { return llvm::size(strides()); }
+
+    /// Returns the dynamic sizes for this subview operation if specified.
+    operand_range getDynamicSizes() { return sizes(); }
+
+    /// Returns in `staticStrides` the static value of the stride
+    /// operands. Returns failure() if the static value of the stride
+    /// operands could not be retrieved.
+    LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
+      if (!strides().empty())
+        return failure();
+      staticStrides.reserve(static_strides().size());
+      for (auto s : static_strides().getAsValueRange<IntegerAttr>())
+        staticStrides.push_back(s.getZExtValue());
+      return success();
+    }
+
+    /// Return the list of Range (i.e. offset, size, stride). Each
+    /// Range entry contains either the dynamic value or a ConstantIndexOp
+    /// constructed with `b` at location `loc`.
+    SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
+
+    /// Return the offsets as Values. Each Value is either the dynamic
+    /// value specified in the op or a ConstantIndexOp constructed
+    /// with `b` at location `loc`
+    SmallVector<Value, 4> getOrCreateOffsets(OpBuilder &b, Location loc) {
+      unsigned dynamicIdx = 1;
+      return llvm::to_vector<4>(llvm::map_range(
+        static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
+          int64_t staticOffset = a.cast<IntegerAttr>().getInt();
+          if (ShapedType::isDynamicStrideOrOffset(staticOffset))
+            return getOperand(dynamicIdx++);
+          else
+            return b.create<ConstantOp>(
+              loc, b.getIndexType(), b.getIndexAttr(staticOffset));
+        }));
+    }
+
+    /// Return the sizes as Values. Each Value is either the dynamic
+    /// value specified in the op or a ConstantIndexOp constructed
+    /// with `b` at location `loc`
+    SmallVector<Value, 4> getOrCreateSizes(OpBuilder &b, Location loc) {
+      unsigned dynamicIdx = 1 + offsets().size();
+      return llvm::to_vector<4>(llvm::map_range(
+        static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
+          int64_t staticSize = a.cast<IntegerAttr>().getInt();
+          if (ShapedType::isDynamic(staticSize))
+            return getOperand(dynamicIdx++);
+          else
+            return b.create<ConstantOp>(
+              loc, b.getIndexType(), b.getIndexAttr(staticSize));
+        }));
+    }
+
+    /// Return the strides as Values. Each Value is either the dynamic
+    /// value specified in the op or a ConstantIndexOp constructed with
+    /// `b` at location `loc`
+    SmallVector<Value, 4> getOrCreateStrides(OpBuilder &b, Location loc) {
+      unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
+      return llvm::to_vector<4>(llvm::map_range(
+        static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
+          int64_t staticStride = a.cast<IntegerAttr>().getInt();
+          if (ShapedType::isDynamicStrideOrOffset(staticStride))
+            return getOperand(dynamicIdx++);
+          else
+            return b.create<ConstantOp>(
+              loc, b.getIndexType(), b.getIndexAttr(staticStride));
+        }));
+    }
+
+    /// Return the rank of the source ShapedType.
+    unsigned getSourceRank() {
+      return source().getType().cast<ShapedType>().getRank();
+    }
+
+    /// Return the rank of the result ShapedType.
+    unsigned getResultRank() { return getType().getRank(); }
+
+    /// Return true if the offset `idx` is a static constant.
+    bool isDynamicOffset(unsigned idx) {
+      APInt v = *(static_offsets().getAsValueRange<IntegerAttr>().begin() + idx);
+      return ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
+    }
+    /// Return true if the size `idx` is a static constant.
+    bool isDynamicSize(unsigned idx) {
+      APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
+      return ShapedType::isDynamic(v.getSExtValue());
+    }
+
+    /// Return true if the stride `idx` is a static constant.
+    bool isDynamicStride(unsigned idx) {
+      APInt v = *(static_strides().getAsValueRange<IntegerAttr>().begin() + idx);
+      return ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
+    }
+
+    /// Assert the offset `idx` is a static constant and return its value.
+    int64_t getStaticOffset(unsigned idx) {
+      assert(!isDynamicOffset(idx) && "expected static offset");
+      APInt v = *(static_offsets().getAsValueRange<IntegerAttr>().begin() + idx);
+      return v.getSExtValue();
+    }
+    /// Assert the size `idx` is a static constant and return its value.
+    int64_t getStaticSize(unsigned idx) {
+      assert(!isDynamicSize(idx) && "expected static size");
+      APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
+      return v.getSExtValue();
+    }
+    /// Assert the stride `idx` is a static constant and return its value.
+    int64_t getStaticStride(unsigned idx) {
+      assert(!isDynamicStride(idx) && "expected static stride");
+      APInt v = *(static_strides().getAsValueRange<IntegerAttr>().begin() + idx);
+      return v.getSExtValue();
+    }
+
+    unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr,
+        llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
+      return std::count_if(
+        attr.getValue().begin(), attr.getValue().begin() + idx,
+        [&](Attribute attr) {
+          return isDynamic(attr.cast<IntegerAttr>().getInt());
+        });
+    }
+    /// Assert the offset `idx` is dynamic and return the position of the
+    /// corresponding operand.
+    unsigned getIndexOfDynamicOffset(unsigned idx) {
+      assert(isDynamicOffset(idx) && "expected static offset");
+      auto numDynamic =
+          getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
+                                      ShapedType::isDynamicStrideOrOffset, idx);
+      return 1 + numDynamic;
+    }
+    /// Assert the size `idx` is dynamic and return the position of the
+    /// corresponding operand.
+    unsigned getIndexOfDynamicSize(unsigned idx) {
+      assert(isDynamicSize(idx) && "expected static size");
+      auto numDynamic = getNumDynamicEntriesUpToIdx(
+          static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
+      return 1 + offsets().size() + numDynamic;
+    }
+    /// Assert the stride `idx` is dynamic and return the position of the
+    /// corresponding operand.
+    unsigned getIndexOfDynamicStride(unsigned idx) {
+      assert(isDynamicStride(idx) && "expected static stride");
+      auto numDynamic =
+          getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
+                                      ShapedType::isDynamicStrideOrOffset, idx);
+      return 1 + offsets().size() + sizes().size() + numDynamic;
+    }
+
+    /// Assert the offset `idx` is dynamic and return its value.
+    Value getDynamicOffset(unsigned idx) {
+      return getOperand(getIndexOfDynamicOffset(idx));
+    }
+    /// Assert the size `idx` is dynamic and return its value.
+    Value getDynamicSize(unsigned idx) {
+      return getOperand(getIndexOfDynamicSize(idx));
+    }
+    /// Assert the stride `idx` is dynamic and return its value.
+    Value getDynamicStride(unsigned idx) {
+      return getOperand(getIndexOfDynamicStride(idx));
+    }
+
+    static StringRef getStaticOffsetsAttrName() {
+      return "static_offsets";
+    }
+    static StringRef getStaticSizesAttrName() {
+      return "static_sizes";
+    }
+    static StringRef getStaticStridesAttrName() {
+      return "static_strides";
+    }
+    static ArrayRef<StringRef> getSpecialAttrNames() {
+      static SmallVector<StringRef, 4> names{
+        getStaticOffsetsAttrName(),
+        getStaticSizesAttrName(),
+        getStaticStridesAttrName(),
+        getOperandSegmentSizeAttr()};
+      return names;
+   }
+  }];
+}
+
+def SubViewOp : BaseOpWithOffsetSizesAndStrides<
+    "subview", [DeclareOpInterfaceMethods<ViewLikeOpInterface>] >  {
   let summary = "memref subview operation";
   let description = [{
     The "subview" operation converts a memref type to another memref type
@@ -2726,8 +2929,11 @@ def SubViewOp : Std_Op<"subview", [
     * Sizes: memref-rank number of dynamic sizes or static integer attributes
              which specify the sizes of the result "view" memref type.
     * Strides: memref-rank number of dynamic strides or static integer
-               attributes multiplicatively to the base memref strides in each
-               dimension.
+               attributes that compose multiplicatively with  the base memref
+               strides in each dimension.
+
+    A subview operation may additionally reduce the rank of the resulting view
+    by removing dimensions that are statically known to be of size 1.
 
     Example 1:
 
@@ -2817,6 +3023,15 @@ def SubViewOp : Std_Op<"subview", [
     // memref is "inbounds" w.r.t to base memref. It is upto the client
     // to ensure that the subview is accessed in a manner that is
     // in-bounds.
+
+    Example 5:
+
+    ```
+    // Rank-reducing subview.
+    %1 = subview %0[0, 0, 0][1, 16, 4][1, 1, 1] :
+      memref<8x16x4xf32> to memref<16x4xf32>
+    %3 = subview %2[3, 4, 2][1, 6, 3][1, 1, 1] :
+      memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]>
     ```
     }
   }];
@@ -2859,137 +3074,97 @@ def SubViewOp : Std_Op<"subview", [
       "ArrayRef<NamedAttribute> attrs = {}">
   ];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
     /// Returns the type of the base memref operand.
-    MemRefType getBaseMemRefType() {
+    MemRefType getSourceMemRefType() {
       return source().getType().cast<MemRefType>();
     }
 
     /// The result of a subview is always a memref.
     MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
 
-    /// Returns as integer value the number of offset operands.
-    int64_t getNumOffsets() { return llvm::size(offsets()); }
+    /// A subview result type can be fully inferred from the source type and the
+    /// static representation of offsets, sizes and strides. Special sentinels
+    /// encode the dynamic case.
+    static Type inferResultType(MemRefType sourceMemRefType,
+                                ArrayRef<int64_t> staticOffsets,
+                                ArrayRef<int64_t> staticSizes,
+                                ArrayRef<int64_t> staticStrides);
+  }];
 
-    /// Returns as integer value the number of size operands.
-    int64_t getNumSizes() { return llvm::size(sizes()); }
+  let hasCanonicalizer = 1;
+}
 
-    /// Returns as integer value the number of stride operands.
-    int64_t getNumStrides() { return llvm::size(strides()); }
+//===----------------------------------------------------------------------===//
+// SubTensorOp
+//===----------------------------------------------------------------------===//
 
-    /// Returns the dynamic sizes for this subview operation if specified.
-    operand_range getDynamicSizes() { return sizes(); }
+def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
+  let summary = "subtensor operation";
+  let description = [{
+    The "subtensor" operation extract a tensor from another tensor as
+    specified by the operation's offsets, sizes and strides arguments.
 
-    /// Returns in `staticStrides` the static value of the stride
-    /// operands. Returns failure() if the static value of the stride
-    /// operands could not be retrieved.
-    LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides);
-
-    /// Auxiliary range data structure and helper function that unpacks the
-    /// offset, size and stride operands of the SubViewOp into a list of triples.
-    /// Such a list of triple is sometimes more convenient to manipulate.
-    struct Range {
-      Value offset, size, stride;
-    };
-    /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
-    /// Range entry contains either the dynamic value or a ConstantIndexOp
-    /// constructed with `b` at location `loc`.
-    SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
+    The subtensor operation supports the following arguments:
 
-    /// Return the offsets as Values. Each Value is either the dynamic
-    /// value specified in the op or a ConstantIndexOp constructed
-    /// with `b` at location `loc`
-    SmallVector<Value, 4> getOrCreateOffsets(OpBuilder &b, Location loc);
+    * tensor: the "base" tensor from which to extract a subtensor.
+    * offsets: tensor-rank number of dynamic offsets or static integer
+               attributes into the "base" tensor from which to extract the
+               subtensor.
+    * sizes: tensor-rank number of dynamic sizes or static integer attributes
+             which specify the sizes of the result tensor type.
+    * strides: tensor-rank number of dynamic strides or static integer
+               attributes specifying susampling in each dimension.
 
-    /// Return the sizes as Values. Each Value is either the dynamic
-    /// value specified in the op or a ConstantIndexOp constructed
-    /// with `b` at location `loc`
-    SmallVector<Value, 4> getOrCreateSizes(OpBuilder &b, Location loc);
+    After buffer-allocation, the "subtensor" op is expected to lower into a
+    "subview" op.
 
-    /// Return the strides as Values. Each Value is either the dynamic
-    /// value specified in the op or a ConstantIndexOp constructed with
-    /// `b` at location `loc`
-    SmallVector<Value, 4> getOrCreateStrides(OpBuilder &b, Location loc);
+    A subtensor operation may additionally reduce the rank of the resulting
+    tensor by removing dimensions that are statically known to be of size 1.
 
-    /// A subview result type can be fully inferred from the source type and the
-    /// static representation of offsets, sizes and strides. Special sentinels
-    /// encode the dynamic case.
-    static Type inferSubViewResultType(MemRefType sourceMemRefType,
-                                       ArrayRef<int64_t> staticOffsets,
-                                       ArrayRef<int64_t> staticSizes,
-                                       ArrayRef<int64_t> staticStrides);
+    Example:
 
-    /// Return the rank of the result MemRefType.
-    unsigned getRank() { return getType().getRank(); }
+    ```
+    // Rank-reducing subtensor.
+    %1 = subtensor %0[0, 0, 0][1, 16, 4][1, 1, 1] :
+      tensor<8x16x4xf32> to tensor<16x4xf32>
+    %3 = subtensor %2[3, 4, 2][1, 6, 3][1, 1, 1] :
+      tensor<8x16x4xf32> to tensor<6x3xf32>
+    ```
+  }];
 
-    /// Return true if the offset `idx` is a static constant.
-    bool isDynamicOffset(unsigned idx);
-    /// Return true if the size `idx` is a static constant.
-    bool isDynamicSize(unsigned idx);
-    /// Return true if the stride `idx` is a static constant.
-    bool isDynamicStride(unsigned idx);
+  let arguments = (ins
+    AnyRankedTensor:$source,
+    Variadic<Index>:$offsets,
+    Variadic<Index>:$sizes,
+    Variadic<Index>:$strides,
+    I64ArrayAttr:$static_offsets,
+    I64ArrayAttr:$static_sizes,
+    I64ArrayAttr:$static_strides
+  );
+  let results = (outs AnyRankedTensor:$result);
 
-    /// Assert the offset `idx` is a static constant and return its value.
-    int64_t getStaticOffset(unsigned idx) {
-      assert(!isDynamicOffset(idx) && "expected static offset");
-      return
-        static_offsets().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
-    }
-    /// Assert the size `idx` is a static constant and return its value.
-    int64_t getStaticSize(unsigned idx) {
-      assert(!isDynamicSize(idx) && "expected static size");
-      return static_sizes().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
-    }
-    /// Assert the stride `idx` is a static constant and return its value.
-    int64_t getStaticStride(unsigned idx) {
-      assert(!isDynamicStride(idx) && "expected static stride");
-      return
-        static_strides().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    /// Returns the type of the base tensor operand.
+    RankedTensorType getSourceRankedTensorType() {
+      return source().getType().cast<RankedTensorType>();
     }
 
-    /// Assert the offset `idx` is dynamic and return the position of the
-    /// corresponding operand.
-    unsigned getIndexOfDynamicOffset(unsigned idx);
-    /// Assert the size `idx` is dynamic and return the position of the
-    /// corresponding operand.
-    unsigned getIndexOfDynamicSize(unsigned idx);
-    /// Assert the stride `idx` is dynamic and return the position of the
-    /// corresponding operand.
-    unsigned getIndexOfDynamicStride(unsigned idx);
-
-    /// Assert the offset `idx` is dynamic and return its value.
-    Value getDynamicOffset(unsigned idx) {
-      return getOperand(getIndexOfDynamicOffset(idx));
-    }
-    /// Assert the size `idx` is dynamic and return its value.
-    Value getDynamicSize(unsigned idx) {
-      return getOperand(getIndexOfDynamicSize(idx));
-    }
-    /// Assert the stride `idx` is dynamic and return its value.
-    Value getDynamicStride(unsigned idx) {
-      return getOperand(getIndexOfDynamicStride(idx));
+    /// The result of a subtensor is always a tensor.
+    RankedTensorType getType() {
+      return getResult().getType().cast<RankedTensorType>();
     }
 
-    static StringRef getStaticOffsetsAttrName() {
-      return "static_offsets";
-    }
-    static StringRef getStaticSizesAttrName() {
-      return "static_sizes";
-    }
-    static StringRef getStaticStridesAttrName() {
-      return "static_strides";
-    }
-    static ArrayRef<StringRef> getSpecialAttrNames() {
-      static SmallVector<StringRef, 4> names{
-        getStaticOffsetsAttrName(),
-        getStaticSizesAttrName(),
-        getStaticStridesAttrName(),
-        getOperandSegmentSizeAttr()};
-      return names;
-   }
+    /// A subview result type can be fully inferred from the source type and the
+    /// static representation of offsets, sizes and strides. Special sentinels
+    /// encode the dynamic case.
+    static Type inferResultType(RankedTensorType sourceRankedTensorType,
+                                ArrayRef<int64_t> staticOffsets,
+                                ArrayRef<int64_t> staticSizes,
+                                ArrayRef<int64_t> staticStrides);
   }];
 
-  let hasCanonicalizer = 1;
+  // let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index c964c2466d5c..7b16a9197f11 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -60,7 +60,7 @@ using llvm::dbgs;
 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
 // to the `loopRanges` in order to obtain view ranges.
 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
-                                    ArrayRef<SubViewOp::Range> loopRanges) {
+                                    ArrayRef<Range> loopRanges) {
   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
   auto maps = op.indexing_maps();
   SmallVector<Value, 8> clonedViews;
@@ -73,7 +73,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
     auto map = maps[idx].cast<AffineMapAttr>().getValue();
     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
     Value view = en.value();
-    SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
+    SmallVector<Range, 4> viewRanges(map.getNumResults());
     for (auto en2 : llvm::enumerate(map.getResults())) {
       unsigned d = en2.index();
       // loopToOperandRangesMaps are permutations-only.
@@ -182,7 +182,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
   unsigned nPar = producer.getNumParallelLoops();
   unsigned nRed = producer.getNumReductionLoops();
   unsigned nWin = producer.getNumWindowLoops();
-  SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
+  SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
 
   // Iterate over dimensions identified by the producer map for `producerIdx`.
   // This defines a subset of the loop ranges that we need to complete later.
@@ -202,9 +202,9 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
                  << "existing LoopRange: " << loopRanges[i] << "\n");
     else {
       auto viewDim = getViewDefiningLoopRange(producer, i);
-      loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
-                                       std_dim(viewDim.view, viewDim.dimension),
-                                       folded_std_constant_index(folder, 1)};
+      loopRanges[i] = Range{folded_std_constant_index(folder, 0),
+                            std_dim(viewDim.view, viewDim.dimension),
+                            folded_std_constant_index(folder, 1)};
       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
     }
   }
@@ -300,8 +300,6 @@ static bool isSameSubView(Value a, Value b) {
     return false;
   if (sva.getType() != svb.getType())
     return false;
-  if (sva.getRank() != svb.getRank())
-    return false;
   if (sva.getNumOperands() != svb.getNumOperands())
     return false;
   if (sva.static_offsets() != svb.static_offsets())

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index eb452cc40305..a9e7a8660230 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -65,22 +65,21 @@ static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
 /// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
 /// It expects a non-inverted, concatenated map and last values in
 /// allViewSizes will be applied to the symbols in the map if it contains any.
-static SmallVector<SubViewOp::Range, 4> emitLoopRanges(OpBuilder &b,
-                                                       Location loc,
-                                                       AffineMap map,
-                                                       ValueRange viewSizes) {
+static SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc,
+                                            AffineMap map,
+                                            ValueRange viewSizes) {
   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
   unsigned numSym = map.getNumSymbols();
   assert(viewSizes.size() == numRes + numSym &&
          "viewSizes must contain sizes of all views and values for symbols");
-  SmallVector<SubViewOp::Range, 4> res(numDims);
+  SmallVector<Range, 4> res(numDims);
   for (unsigned idx = 0; idx < numRes; ++idx) {
     auto result = map.getResult(idx);
     if (auto d = result.dyn_cast<AffineDimExpr>()) {
       if (res[d.getPosition()].offset)
         continue;
-      res[d.getPosition()] = SubViewOp::Range{
-          std_constant_index(0), viewSizes[idx], std_constant_index(1)};
+      res[d.getPosition()] =
+          Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
     }
 
     // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
@@ -124,7 +123,7 @@ static SmallVector<SubViewOp::Range, 4> emitLoopRanges(OpBuilder &b,
       // Construction of the lower bound (s floordiv 2).
       Value from = applyMapToValues(b, loc, fromMap, values).front();
       Value to = applyMapToValues(b, loc, toMap, values).front();
-      res[mPos] = SubViewOp::Range{from, to, std_constant_index(1)};
+      res[mPos] = Range{from, to, std_constant_index(1)};
     }
   }
   return res;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 68d69549611c..3e8e0b74c145 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -54,7 +54,7 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
 // are tiled and for which new loops will be created. Also the function returns
 // a map from loop indices of the LinalgOp to the corresponding non-empty range
 // indices of newly created loops.
-static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
+static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
                     ArrayRef<Value> allViewSizes,
                     ArrayRef<Value> allTileSizes) {
@@ -76,10 +76,9 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
   }
 
   // Create a new range with the applied tile sizes.
-  SmallVector<SubViewOp::Range, 4> res;
+  SmallVector<Range, 4> res;
   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
-    res.push_back(SubViewOp::Range{std_constant_index(0), viewSizes[idx],
-                                   tileSizes[idx]});
+    res.push_back(Range{std_constant_index(0), viewSizes[idx], tileSizes[idx]});
   return std::make_tuple(res, loopIndexToRangeIndex);
 }
 
@@ -346,7 +345,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
   if (!viewSizesToLoopsMap)
     return llvm::None;
 
-  SmallVector<SubViewOp::Range, 4> loopRanges;
+  SmallVector<Range, 4> loopRanges;
   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
       b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes);

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 204716b40746..f9ea9092d55d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -133,11 +133,10 @@ template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
 
 /// Given a list of subview ranges, extract individual values for lower, upper
 /// bounds and steps and put them into the corresponding vectors.
-static void unpackRanges(ArrayRef<SubViewOp::Range> ranges,
-                         SmallVectorImpl<Value> &lbs,
+static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
                          SmallVectorImpl<Value> &ubs,
                          SmallVectorImpl<Value> &steps) {
-  for (SubViewOp::Range range : ranges) {
+  for (Range range : ranges) {
     lbs.emplace_back(range.offset);
     ubs.emplace_back(range.size);
     steps.emplace_back(range.stride);
@@ -194,7 +193,7 @@ getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
-    ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
     ArrayRef<Attribute> iteratorTypes,
     function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions>) {
@@ -206,7 +205,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
 /// Specialization to build affine "for" nest.
 template <>
 void GenerateLoopNest<AffineForOp>::doit(
-    ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
     ArrayRef<Attribute> iteratorTypes,
     function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions>) {
@@ -364,7 +363,7 @@ generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps,
 /// Specialization for generating a mix of parallel and sequential scf loops.
 template <>
 void GenerateLoopNest<scf::ParallelOp>::doit(
-    ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
     ArrayRef<Attribute> iteratorTypes,
     function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions> distributionOptions) {
@@ -391,7 +390,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
     Location loc = edsc::ScopedContext::getLocation();
     distributionMethod.assign(distributionOptions->distributionMethod.begin(),
                               distributionOptions->distributionMethod.end());
-    SmallVector<SubViewOp::Range, 2> parallelLoopRanges;
+    SmallVector<Range, 2> parallelLoopRanges;
     for (auto iteratorType : enumerate(iteratorTypes)) {
       if (isParallelIteratorType(iteratorType.value()))
         parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1cabf172b7fc..d684a4b98e55 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2587,10 +2587,10 @@ Wrapper operator*(Wrapper a, int64_t b) {
 /// A subview result type can be fully inferred from the source type and the
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
-Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
-                                       ArrayRef<int64_t> staticOffsets,
-                                       ArrayRef<int64_t> staticSizes,
-                                       ArrayRef<int64_t> staticStrides) {
+Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
+                                ArrayRef<int64_t> staticOffsets,
+                                ArrayRef<int64_t> staticSizes,
+                                ArrayRef<int64_t> staticStrides) {
   unsigned rank = sourceMemRefType.getRank();
   (void)rank;
   assert(staticOffsets.size() == rank &&
@@ -2638,7 +2638,8 @@ Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
 ///   subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
 ///     `:` strided-memref-type `to` strided-memref-type
 /// ```
-static void print(OpAsmPrinter &p, SubViewOp op) {
+template <typename OpType>
+static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
   p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
   p << op.getOperand(0);
@@ -2649,16 +2650,22 @@ static void print(OpAsmPrinter &p, SubViewOp op) {
   printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
                                        ShapedType::isDynamicStrideOrOffset);
   p.printOptionalAttrDict(op.getAttrs(),
-                          /*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()});
+                          /*elidedAttrs=*/{OpType::getSpecialAttrNames()});
   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
 }
 
+static void print(OpAsmPrinter &p, SubViewOp op) {
+  return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
+}
+
 /// Parse SubViewOp of the form:
 /// ```
-///   subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+///   `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
 ///     `:` strided-memref-type `to` strided-memref-type
 /// ```
-static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
+template <typename OpType>
+static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
+                                                     OperationState &result) {
   OpAsmParser::OperandType srcInfo;
   SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
   auto indexType = parser.getBuilder().getIndexType();
@@ -2666,13 +2673,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
   if (parser.parseOperand(srcInfo))
     return failure();
   if (parseListOfOperandsOrIntegers(
-          parser, result, SubViewOp::getStaticOffsetsAttrName(),
+          parser, result, OpType::getStaticOffsetsAttrName(),
           ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
       parseListOfOperandsOrIntegers(parser, result,
-                                    SubViewOp::getStaticSizesAttrName(),
+                                    OpType::getStaticSizesAttrName(),
                                     ShapedType::kDynamicSize, sizesInfo) ||
       parseListOfOperandsOrIntegers(
-          parser, result, SubViewOp::getStaticStridesAttrName(),
+          parser, result, OpType::getStaticStridesAttrName(),
           ShapedType::kDynamicStrideOrOffset, stridesInfo))
     return failure();
 
@@ -2680,7 +2687,7 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
   SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
                                    static_cast<int>(sizesInfo.size()),
                                    static_cast<int>(stridesInfo.size())};
-  result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
+  result.addAttribute(OpType::getOperandSegmentSizeAttr(),
                       b.getI32VectorAttr(segmentSizes));
 
   return failure(
@@ -2694,6 +2701,10 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
       parser.addTypeToList(dstType, result.types));
 }
 
+static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
+  return parseOpWithOffsetsSizesAndStrides<SubViewOp>(parser, result);
+}
+
 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
                             ArrayRef<int64_t> staticOffsets,
                             ArrayRef<int64_t> staticSizes,
@@ -2701,8 +2712,8 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
                             ValueRange sizes, ValueRange strides,
                             ArrayRef<NamedAttribute> attrs) {
   auto sourceMemRefType = source.getType().cast<MemRefType>();
-  auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
-                                           staticSizes, staticStrides);
+  auto resultType = inferResultType(sourceMemRefType, staticOffsets,
+                                    staticSizes, staticStrides);
   build(b, result, resultType, source, offsets, sizes, strides,
         b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
         b.getI64ArrayAttr(staticStrides));
@@ -2760,15 +2771,18 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
         staticStridesVector, offsets, sizes, strides, attrs);
 }
 
+/// For ViewLikeOpInterface.
+Value SubViewOp::getViewSource() { return source(); }
+
 /// Verify that a particular offset/size/stride static attribute is well-formed.
-static LogicalResult
-verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
-                    ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
-                    ValueRange values) {
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
+    OpType op, StringRef name, StringRef attrName, ArrayAttr attr,
+    llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
   /// Check static and dynamic offsets/sizes/strides breakdown.
-  size_t inputRank = op.source().getType().cast<MemRefType>().getRank();
-  if (attr.size() != inputRank)
-    return op.emitError("expected ") << inputRank << " " << name << " values";
+  if (attr.size() != op.getSourceRank())
+    return op.emitError("expected ")
+           << op.getSourceRank() << " " << name << " values";
   unsigned expectedNumDynamicEntries =
       llvm::count_if(attr.getValue(), [&](Attribute attr) {
         return isDynamic(attr.cast<IntegerAttr>().getInt());
@@ -2787,17 +2801,26 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
       }));
 }
 
-/// Checks if `original` MemRef type can be rank reduced to `reduced` type.
+/// Checks if `original` Type type can be rank reduced to `reduced` type.
 /// This function is slight variant of `is subsequence` algorithm where
 /// not matching dimension must be 1.
 static bool isRankReducedType(Type originalType, Type reducedType) {
   if (originalType == reducedType)
     return true;
+  if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
+    return true;
+  if (originalType.isa<RankedTensorType>() &&
+      !reducedType.isa<RankedTensorType>())
+    return true;
+  if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
+    return true;
 
-  MemRefType original = originalType.cast<MemRefType>();
-  MemRefType reduced = reducedType.cast<MemRefType>();
-  ArrayRef<int64_t> originalShape = original.getShape();
-  ArrayRef<int64_t> reducedShape = reduced.getShape();
+  ShapedType originalShapedType = originalType.cast<ShapedType>();
+  ShapedType reducedShapedType = reducedType.cast<ShapedType>();
+
+  // Rank and size logic is valid for all ShapedTypes.
+  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
+  ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
   unsigned originalRank = originalShape.size(),
            reducedRank = reducedShape.size();
   if (reducedRank > originalRank)
@@ -2819,6 +2842,13 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
   if (reducedIdx != reducedRank)
     return false;
 
+  // We are done for the tensor case.
+  if (originalType.isa<RankedTensorType>())
+    return true;
+
+  // Strided layout logic is relevant for MemRefType only.
+  MemRefType original = originalType.cast<MemRefType>();
+  MemRefType reduced = reducedType.cast<MemRefType>();
   MLIRContext *c = original.getContext();
   int64_t originalOffset, symCounter = 0, dimCounter = 0;
   SmallVector<int64_t, 4> originalStrides;
@@ -2843,10 +2873,29 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
           reducedMap == reduced.getAffineMaps().front());
 }
 
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
+  // Verify static attributes offsets/sizes/strides.
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
+          ShapedType::isDynamicStrideOrOffset, op.offsets())))
+    return failure();
+
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "size", op.getStaticSizesAttrName(), op.static_sizes(),
+          ShapedType::isDynamic, op.sizes())))
+    return failure();
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
+          ShapedType::isDynamicStrideOrOffset, op.strides())))
+    return failure();
+  return success();
+}
+
 /// Verifier for SubViewOp.
 static LogicalResult verify(SubViewOp op) {
-  auto baseType = op.getBaseMemRefType().cast<MemRefType>();
-  auto subViewType = op.getType();
+  MemRefType baseType = op.getSourceMemRefType();
+  MemRefType subViewType = op.getType();
 
   // The base memref and the view memref should be in the same memory space.
   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
@@ -2858,24 +2907,12 @@ static LogicalResult verify(SubViewOp op) {
   if (!isStrided(baseType))
     return op.emitError("base type ") << baseType << " is not strided";
 
-  // Verify static attributes offsets/sizes/strides.
-  if (failed(verifySubViewOpPart(
-          op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
-          ShapedType::isDynamicStrideOrOffset, op.offsets())))
-    return failure();
-
-  if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
-                                 op.static_sizes(), ShapedType::isDynamic,
-                                 op.sizes())))
-    return failure();
-  if (failed(verifySubViewOpPart(
-          op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
-          ShapedType::isDynamicStrideOrOffset, op.strides())))
+  if (failed(verifyOpWithOffsetSizesAndStrides(op)))
     return failure();
 
   // Verify result type against inferred type.
-  auto expectedType = SubViewOp::inferSubViewResultType(
-      op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
+  auto expectedType = SubViewOp::inferResultType(
+      baseType, extractFromI64ArrayAttr(op.static_offsets()),
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
   if (!isRankReducedType(expectedType, subViewType))
@@ -2885,123 +2922,41 @@ static LogicalResult verify(SubViewOp op) {
   return success();
 }
 
-raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
+raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
   return os << "range " << range.offset << ":" << range.size << ":"
             << range.stride;
 }
 
-static unsigned getNumDynamicEntriesUpToIdx(
-    ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
-  return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
-                       [&](Attribute attr) {
-                         return isDynamic(attr.cast<IntegerAttr>().getInt());
-                       });
-}
-
-bool SubViewOp::isDynamicOffset(unsigned idx) {
-  return ShapedType::isDynamicStrideOrOffset(
-      extractFromI64ArrayAttr(static_offsets())[idx]);
-}
-bool SubViewOp::isDynamicSize(unsigned idx) {
-  return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]);
-}
-bool SubViewOp::isDynamicStride(unsigned idx) {
-  return ShapedType::isDynamicStrideOrOffset(
-      extractFromI64ArrayAttr(static_strides())[idx]);
-}
-
-unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) {
-  assert(isDynamicOffset(idx) && "expected static offset");
-  auto numDynamic =
-      getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
-                                  ShapedType::isDynamicStrideOrOffset, idx);
-  return 1 + numDynamic;
-}
-unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) {
-  assert(isDynamicSize(idx) && "expected static size");
-  auto numDynamic = getNumDynamicEntriesUpToIdx(
-      static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
-  return 1 + offsets().size() + numDynamic;
-}
-unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
-  assert(isDynamicStride(idx) && "expected static stride");
-  auto numDynamic =
-      getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
-                                  ShapedType::isDynamicStrideOrOffset, idx);
-  return 1 + offsets().size() + sizes().size() + numDynamic;
-}
-
-/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
+/// Return the list of Range (i.e. offset, size, stride). Each Range
 /// entry contains either the dynamic value or a ConstantIndexOp constructed
 /// with `b` at location `loc`.
-SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
-                                                              Location loc) {
+template <typename OpType>
+static SmallVector<Range, 8> getOrCreateRangesImpl(OpType op, OpBuilder &b,
+                                                   Location loc) {
   SmallVector<Range, 8> res;
-  unsigned rank = getType().getRank();
+  unsigned rank = op.getSourceRank();
   res.reserve(rank);
   for (unsigned idx = 0; idx < rank; ++idx) {
-    auto offset = isDynamicOffset(idx)
-                      ? getDynamicOffset(idx)
-                      : b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
-    auto size = isDynamicSize(idx)
-                    ? getDynamicSize(idx)
-                    : b.create<ConstantIndexOp>(loc, getStaticSize(idx));
-    auto stride = isDynamicStride(idx)
-                      ? getDynamicStride(idx)
-                      : b.create<ConstantIndexOp>(loc, getStaticStride(idx));
+    Value offset =
+        op.isDynamicOffset(idx)
+            ? op.getDynamicOffset(idx)
+            : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
+    Value size = op.isDynamicSize(idx)
+                     ? op.getDynamicSize(idx)
+                     : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
+    Value stride =
+        op.isDynamicStride(idx)
+            ? op.getDynamicStride(idx)
+            : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
     res.emplace_back(Range{offset, size, stride});
   }
   return res;
 }
 
-SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b,
-                                                    Location loc) {
-  unsigned dynamicIdx = 1;
-  return llvm::to_vector<4>(llvm::map_range(
-      static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
-        int64_t staticOffset = a.cast<IntegerAttr>().getInt();
-        if (ShapedType::isDynamicStrideOrOffset(staticOffset))
-          return getOperand(dynamicIdx++);
-        else
-          return b.create<ConstantIndexOp>(loc, staticOffset);
-      }));
-}
-
-SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) {
-  unsigned dynamicIdx = 1 + offsets().size();
-  return llvm::to_vector<4>(llvm::map_range(
-      static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
-        int64_t staticSize = a.cast<IntegerAttr>().getInt();
-        if (ShapedType::isDynamic(staticSize))
-          return getOperand(dynamicIdx++);
-        else
-          return b.create<ConstantIndexOp>(loc, staticSize);
-      }));
-}
-
-SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b,
-                                                    Location loc) {
-  unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
-  return llvm::to_vector<4>(llvm::map_range(
-      static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
-        int64_t staticStride = a.cast<IntegerAttr>().getInt();
-        if (ShapedType::isDynamicStrideOrOffset(staticStride))
-          return getOperand(dynamicIdx++);
-        else
-          return b.create<ConstantIndexOp>(loc, staticStride);
-      }));
+SmallVector<Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b, Location loc) {
+  return ::getOrCreateRangesImpl(*this, b, loc);
 }
 
-LogicalResult
-SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
-  if (!strides().empty())
-    return failure();
-  staticStrides = extractFromI64ArrayAttr(static_strides());
-  return success();
-}
-
-Value SubViewOp::getViewSource() { return source(); }
-
 namespace {
 
 /// Take a list of `values` with potential new constant to extract and a list
@@ -3053,20 +3008,20 @@ class SubViewOpConstantArgumentFolder final
     SmallVector<Value, 8> newOffsets(subViewOp.offsets());
     SmallVector<int64_t, 8> newStaticOffsets =
         extractFromI64ArrayAttr(subViewOp.static_offsets());
-    assert(newStaticOffsets.size() == subViewOp.getRank());
+    assert(newStaticOffsets.size() == subViewOp.getSourceRank());
     canonicalizeSubViewPart(newOffsets, newStaticOffsets,
                             ShapedType::isDynamicStrideOrOffset);
 
     SmallVector<Value, 8> newSizes(subViewOp.sizes());
     SmallVector<int64_t, 8> newStaticSizes =
         extractFromI64ArrayAttr(subViewOp.static_sizes());
-    assert(newStaticOffsets.size() == subViewOp.getRank());
+    assert(newStaticOffsets.size() == subViewOp.getSourceRank());
     canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
 
     SmallVector<Value, 8> newStrides(subViewOp.strides());
     SmallVector<int64_t, 8> newStaticStrides =
         extractFromI64ArrayAttr(subViewOp.static_strides());
-    assert(newStaticOffsets.size() == subViewOp.getRank());
+    assert(newStaticOffsets.size() == subViewOp.getSourceRank());
     canonicalizeSubViewPart(newStrides, newStaticStrides,
                             ShapedType::isDynamicStrideOrOffset);
 
@@ -3210,7 +3165,7 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
     /// the cast source operand type and the SubViewOp static information. This
     /// is the resulting type if the MemRefCastOp were folded.
-    Type resultType = SubViewOp::inferSubViewResultType(
+    Type resultType = SubViewOp::inferResultType(
         castOp.source().getType().cast<MemRefType>(),
         extractFromI64ArrayAttr(subViewOp.static_offsets()),
         extractFromI64ArrayAttr(subViewOp.static_sizes()),
@@ -3232,6 +3187,94 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
       context);
 }
 
+//===----------------------------------------------------------------------===//
+// SubTensorOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, SubTensorOp op) {
+  return printOpWithOffsetsSizesAndStrides<SubTensorOp>(p, op);
+}
+
+static ParseResult parseSubTensorOp(OpAsmParser &parser,
+                                    OperationState &result) {
+  return parseOpWithOffsetsSizesAndStrides<SubTensorOp>(parser, result);
+}
+
+/// A subtensor result type can be fully inferred from the source type and the
+/// static representation of offsets, sizes and strides. Special sentinels
+/// encode the dynamic case.
+Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
+                                  ArrayRef<int64_t> staticOffsets,
+                                  ArrayRef<int64_t> staticSizes,
+                                  ArrayRef<int64_t> staticStrides) {
+  unsigned rank = sourceRankedTensorType.getRank();
+  (void)rank;
+  assert(staticOffsets.size() == rank &&
+         "unexpected staticOffsets size mismatch");
+  assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
+  assert(staticStrides.size() == rank &&
+         "unexpected staticStrides size mismatch");
+  return RankedTensorType::get(staticSizes,
+                               sourceRankedTensorType.getElementType());
+}
+
+void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
+                              Value source, ArrayRef<int64_t> staticOffsets,
+                              ArrayRef<int64_t> staticSizes,
+                              ArrayRef<int64_t> staticStrides,
+                              ValueRange offsets, ValueRange sizes,
+                              ValueRange strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+  auto resultType = inferResultType(sourceRankedTensorType, staticOffsets,
+                                    staticSizes, staticStrides);
+  build(b, result, resultType, source, offsets, sizes, strides,
+        b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
+        b.getI64ArrayAttr(staticStrides));
+  result.addAttributes(attrs);
+}
+
+/// Build a SubTensorOp with all dynamic entries: `staticOffsets`, `staticSizes`
+/// and `staticStrides` are automatically filled with sentinel values that
+/// encode dynamic entries.
+void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
+                              Value source, ValueRange offsets,
+                              ValueRange sizes, ValueRange strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+  unsigned rank = sourceRankedTensorType.getRank();
+  SmallVector<int64_t, 4> staticOffsetsVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+  SmallVector<int64_t, 4> staticStridesVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  build(b, result, source, staticOffsetsVector, staticSizesVector,
+        staticStridesVector, offsets, sizes, strides, attrs);
+}
+
+SmallVector<Range, 8> SubTensorOp::getOrCreateRanges(OpBuilder &b,
+                                                     Location loc) {
+  return ::getOrCreateRangesImpl(*this, b, loc);
+}
+
+/// Verifier for SubTensorOp.
+static LogicalResult verify(SubTensorOp op) {
+  if (failed(verifyOpWithOffsetSizesAndStrides(op)))
+    return failure();
+
+  // Verify result type against inferred type.
+  auto expectedType = SubTensorOp::inferResultType(
+      op.getSourceRankedTensorType(),
+      extractFromI64ArrayAttr(op.static_offsets()),
+      extractFromI64ArrayAttr(op.static_sizes()),
+      extractFromI64ArrayAttr(op.static_strides()));
+  if (!isRankReducedType(expectedType, op.getType()))
+    return op.emitError("expected result type to be ")
+           << expectedType << " or a rank-reduced version.";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TensorCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 5e3959af29dd..72a063ff9d51 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -900,3 +900,27 @@ func @assume_alignment(%0: memref<4x4xf16>) {
   assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
+
+
+// CHECK-LABEL: func @subtensor({{.*}}) {
+func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+
+  // CHECK: subtensor
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
+  %1 = subtensor %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
+    : tensor<8x16x4xf32> to tensor<?x?x?xf32>
+
+  // CHECK: subtensor
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
+  %2 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+  // CHECK: subtensor
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
+  %3 = subtensor %t[0, 2, 0][4, 1, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4xf32>
+
+  return
+}

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index ab18845bdb53..7356c07577db 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1255,3 +1255,23 @@ func @imaginary_part_from_incompatible_complex_type(%cplx: complex<f64>) {
   std.re %cplx : complex<f32>
   return
 }
+
+// -----
+
+func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
+      // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>'}}
+  %0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<?x4x4xf32>
+
+  return
+}
+
+// -----
+
+func @subtensor_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
+      // expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>'}}
+  %0 = subtensor %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index edcc66c9b6a6..ffb0f92dae99 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -301,8 +301,7 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
 
 template <typename IdOp, typename NProcsOp>
 static SmallVector<ProcInfo, 2>
-getGpuProcIds(OpBuilder &b, Location loc,
-              ArrayRef<SubViewOp::Range> parallelLoopRanges) {
+getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
   Type indexType = b.getIndexType();
   SmallVector<ProcInfo, 2> procInfo(2);
   procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),


        


More information about the Mlir-commits mailing list