[Mlir-commits] [mlir] 26c8f90 - [mlir[[vector] Extend Transfer read/write ops to support tensor types.

Thomas Raoux llvmlistbot at llvm.org
Mon Dec 21 08:55:41 PST 2020


Author: Thomas Raoux
Date: 2020-12-21T08:55:04-08:00
New Revision: 26c8f9081b6b1ca9358ac2ca38e8e603fb6f7d64

URL: https://github.com/llvm/llvm-project/commit/26c8f9081b6b1ca9358ac2ca38e8e603fb6f7d64
DIFF: https://github.com/llvm/llvm-project/commit/26c8f9081b6b1ca9358ac2ca38e8e603fb6f7d64.diff

LOG: [mlir[[vector] Extend Transfer read/write ops to support tensor types.

Transfer_ops can now work on both buffers and tensor. Right now, lowering of
the tensor case is not supported yet.

Differential Revision: https://reviews.llvm.org/D93500

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 95964665ced6..5540a56a4043 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -126,7 +126,7 @@ namespace impl {
 /// Build the default minor identity map suitable for a vector transfer. This
 /// also handles the case memref<... x vector<...>> -> vector<...> in which the
 /// rank of the identity map must take the vector element type into account.
-AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
+AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
                                       VectorType vectorType);
 } // namespace impl
 } // end namespace vector

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index de77e3b03483..13aba2076ee9 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1056,7 +1056,7 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
-    Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
+    Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
                AffineMapAttr:$permutation_map, AnyType:$padding,
                OptionalAttr<BoolArrayAttr>:$masked)>,
     Results<(outs AnyVector:$vector)> {
@@ -1065,15 +1065,16 @@ def Vector_TransferReadOp :
 
   let description = [{
     The `vector.transfer_read` op performs a read from a slice within a
-    [MemRef](../LangRef.md#memref-type) supplied as its first operand
-    into a [vector](../LangRef.md#vector-type) of the same base elemental type.
+    [MemRef](../LangRef.md#memref-type) or a Ranked
+    [Tensor](../LangRef.md#tensor-type) supplied as its first operand into a
+    [vector](../LangRef.md#vector-type) of the same base elemental type.
 
-    A memref operand with vector element type, must have its vector element
-    type match a suffix (shape and element type) of the vector (e.g.
+    A memref/tensor operand with vector element type, must have its vector
+    element type match a suffix (shape and element type) of the vector (e.g.
     memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
 
-    The slice is further defined by a full-rank index within the MemRef,
-    supplied as the operands `2 .. 1 + rank(memref)`.
+    The slice is further defined by a full-rank index within the MemRef/Tensor,
+    supplied as the operands `2 .. 1 + rank(memref/tensor)`.
 
     The permutation_map [attribute](../LangRef.md#attributes) is an
     [affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1084,8 +1085,9 @@ def Vector_TransferReadOp :
     The size of the slice is specified by the size of the vector, given as the
     return type.
 
-    An `ssa-value` of the same elemental type as the MemRef is provided as the
-    last operand to specify padding in the case of out-of-bounds accesses.
+    An `ssa-value` of the same elemental type as the MemRef/Tensor is provided
+    as the last operand to specify padding in the case of out-of-bounds
+    accesses.
 
     An optional boolean array attribute is provided to specify which dimensions
     of the transfer need masking. When a dimension is specified as not requiring
@@ -1196,17 +1198,22 @@ def Vector_TransferReadOp :
     %4 = vector.transfer_read %arg1[%c3, %c3], %vf0
       {permutation_map = (d0, d1)->(d0, d1)}
         : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+
+    // Read from a tensor with vector element type.
+    %4 = vector.transfer_read %arg1[%c3, %c3], %vf0
+      {permutation_map = (d0, d1)->(d0, d1)}
+        : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
     ```
   }];
 
   let builders = [
     // Builder that sets padding to zero.
-    OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
+    OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
       "ValueRange":$indices, "AffineMap":$permutationMap,
       CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
     // Builder that sets permutation map (resp. padding) to
     // 'getMinorIdentityMap' (resp. zero).
-    OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
+    OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
       "ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>
   ];
 
@@ -1217,26 +1224,29 @@ def Vector_TransferWriteOp :
   Vector_Op<"transfer_write", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
-    ]>,
-    Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
+  ]>,
+    Arguments<(ins AnyVector:$vector, AnyShaped:$source,
                Variadic<Index>:$indices,
                AffineMapAttr:$permutation_map,
-               OptionalAttr<BoolArrayAttr>:$masked)> {
+               OptionalAttr<BoolArrayAttr>:$masked)>,
+    Results<(outs Optional<AnyRankedTensor>:$result)> {
 
   let summary = "The vector.transfer_write op writes a supervector to memory.";
 
   let description = [{
     The `vector.transfer_write` op performs a write from a
     [vector](../LangRef.md#vector-type), supplied as its first operand, into a
-    slice within a [MemRef](../LangRef.md#memref-type) of the same base
-    elemental type, supplied as its second operand.
+    slice within a [MemRef](../LangRef.md#memref-type) or a Ranked
+    [Tensor](../LangRef.md#tensor-type) of the same base elemental type,
+    supplied as its second operand.
 
-    A vector memref operand must have its vector element type match a suffix
-    (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
-    vector<1x1x4x3xf32>).
+    A vector memref/tensor operand must have its vector element type match a
+    suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
+    vector<1x1x4x3xf32>). If the operand is a tensor, the operation returns a
+    new tensor of the same type.
 
-    The slice is further defined by a full-rank index within the MemRef,
-    supplied as the operands `3 .. 2 + rank(memref)`.
+    The slice is further defined by a full-rank index within the MemRef/Tensor,
+    supplied as the operands `3 .. 2 + rank(memref/tensor)`.
 
     The permutation_map [attribute](../LangRef.md#attributes) is an
     [affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1280,15 +1290,24 @@ def Vector_TransferWriteOp :
     vector.transfer_write %4, %arg1[%c3, %c3]
       {permutation_map = (d0, d1)->(d0, d1)}
         : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+
+    // return a tensor where the vector is inserted into the source tensor.
+    %5 = vector.transfer_write %4, %arg1[%c3, %c3]
+      {permutation_map = (d0, d1)->(d0, d1)}
+        : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
     ```
   }];
 
   let builders = [
     // Builder that sets permutation map to 'getMinorIdentityMap'.
-    OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
+    OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
       CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
-    OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
+    OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
       "AffineMap":$permutationMap)>,
+    OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+      "AffineMapAttr":$permutationMap, "ArrayAttr":$masked)>,
+    OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+      "AffineMap":$permutationMap, "ArrayAttr":$masked)>,
   ];
 
   let hasFolder = 1;

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index f70fba819b66..a06bc8cf6562 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -20,9 +20,9 @@ class AffineApplyOp;
 class AffineForOp;
 class AffineMap;
 class Location;
-class MemRefType;
 class OpBuilder;
 class Operation;
+class ShapedType;
 class Value;
 class VectorType;
 class VectorTransferOpInterface;
@@ -157,7 +157,7 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
 /// Build the default minor identity map suitable for a vector transfer. This
 /// also handles the case memref<... x vector<...>> -> vector<...> in which the
 /// rank of the identity map must take the vector element type into account.
-AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
+AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
                                       VectorType vectorType);
 
 /// Return true if we can prove that the transfer operations access disjoint

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 73332afd8825..3f60de5831c9 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -47,7 +47,7 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
 
 def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
   let description = [{
-    Encodes properties of an operation on vectors that can be unrolled.
+    Encodes properties of a transfer read or write operation.
   }];
   let cppNamespace = "::mlir";
 
@@ -83,11 +83,11 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       }]
     >,
     InterfaceMethod<
-      /*desc=*/"Return the memref operand.",
+      /*desc=*/"Return the memref or ranked tensor operand.",
       /*retTy=*/"Value",
-      /*methodName=*/"memref",
+      /*methodName=*/"source",
       /*args=*/(ins),
-      /*methodBody=*/"return $_op.memref();"
+      /*methodBody=*/"return $_op.source();"
       /*defaultImplementation=*/
     >,
     InterfaceMethod<
@@ -123,13 +123,13 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*defaultImplementation=*/
     >,
     InterfaceMethod<
-      /*desc=*/"Return the MemRefType.",
-      /*retTy=*/"MemRefType",
-      /*methodName=*/"getMemRefType",
+      /*desc=*/"Return the ShapedType.",
+      /*retTy=*/"ShapedType",
+      /*methodName=*/"getShapedType",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/
-        "return $_op.memref().getType().template cast<MemRefType>();"
+        "return $_op.source().getType().template cast<ShapedType>();"
     >,
     InterfaceMethod<
       /*desc=*/"Return the VectorType.",
@@ -152,14 +152,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
         "return $_op.permutation_map().getNumResults();"
     >,
     InterfaceMethod<
-      /*desc=*/[{ Return the number of leading memref dimensions that do not
+      /*desc=*/[{ Return the number of leading shaped dimensions that do not
                   participate in the permutation map.}],
       /*retTy=*/"unsigned",
-      /*methodName=*/"getLeadingMemRefRank",
+      /*methodName=*/"getLeadingShapedRank",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/
-        "return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
+        "return $_op.getShapedType().getRank() - $_op.getTransferRank();"
     >,
     InterfaceMethod<
       /*desc=*/[{ Returns true if at least one of the dimensions is masked.}],
@@ -178,8 +178,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*desc=*/[{
       Helper function to account for the fact that `permutationMap` results and
       `op.indices` sizes may not match and may not be aligned. The first
-      `getLeadingMemRefRank()` indices may just be indexed and not transferred
-      from/into the vector.
+      `getLeadingShapedRank()` indices may just be indexed and not
+      transferred from/into the vector.
       For example:
       ```
          vector.transfer %0[%i, %j, %k, %c0] :
@@ -195,7 +195,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         for (int64_t resultIdx = 0,
-                   indicesIdx = $_op.getLeadingMemRefRank(),
+                   indicesIdx = $_op.getLeadingShapedRank(),
                    eResult = $_op.getTransferRank();
            resultIdx < eResult;
            ++resultIdx, ++indicesIdx)

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index ea483aa6abae..10d727df701b 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -22,6 +22,17 @@
 
 using namespace mlir;
 
+/// Helpers to access the memref operand for each op.
+static Value getMemRefOperand(LoadOp op) { return op.memref(); }
+
+static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
+
+static Value getMemRefOperand(StoreOp op) { return op.memref(); }
+
+static Value getMemRefOperand(vector::TransferWriteOp op) {
+  return op.source();
+}
+
 namespace {
 /// Merges subview operation with load/transferRead operation.
 template <typename OpTy>
@@ -141,7 +152,7 @@ template <typename OpTy>
 LogicalResult
 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
                                              PatternRewriter &rewriter) const {
-  auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
+  auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
   if (!subViewOp) {
     return failure();
   }
@@ -162,7 +173,8 @@ template <typename OpTy>
 LogicalResult
 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
                                               PatternRewriter &rewriter) const {
-  auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
+  auto subViewOp =
+      getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
   if (!subViewOp) {
     return failure();
   }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ebe07366f6ec..a982b90e0e93 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -141,12 +141,10 @@ static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
 }
 
-// Helper that returns data layout alignment of an operation with memref.
-template <typename T>
-LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
-                                 unsigned &align) {
-  Type elementTy =
-      typeConverter.convertType(op.getMemRefType().getElementType());
+// Helper that returns data layout alignment of a memref.
+LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
+                                 MemRefType memrefType, unsigned &align) {
+  Type elementTy = typeConverter.convertType(memrefType.getElementType());
   if (!elementTy)
     return failure();
 
@@ -222,7 +220,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  TransferReadOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
   unsigned align;
-  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(
+          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
     return failure();
   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
   return success();
@@ -243,7 +242,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
     return failure();
 
   unsigned align;
-  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(
+          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
     return failure();
 
   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
@@ -258,7 +258,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  TransferWriteOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
   unsigned align;
-  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(
+          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
     return failure();
   auto adaptor = TransferWriteOpAdaptor(operands);
   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
@@ -272,7 +273,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
                             TransferWriteOp xferOp, ArrayRef<Value> operands,
                             Value dataPtr, Value mask) {
   unsigned align;
-  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(
+          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
     return failure();
 
   auto adaptor = TransferWriteOpAdaptor(operands);
@@ -345,7 +347,8 @@ class VectorMaskedLoadOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
+                                  align)))
       return failure();
 
     auto vtype = typeConverter->convertType(load.getResultVectorType());
@@ -375,7 +378,8 @@ class VectorMaskedStoreOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
+                                  align)))
       return failure();
 
     auto vtype = typeConverter->convertType(store.getValueVectorType());
@@ -405,7 +409,8 @@ class VectorGatherOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
+                                  align)))
       return failure();
 
     // Get index ptrs.
@@ -438,7 +443,8 @@ class VectorScatterOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
+                                  align)))
       return failure();
 
     // Get index ptrs.
@@ -1182,8 +1188,11 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
                                        xferOp.getVectorType().getRank(),
                                        xferOp->getContext()))
       return failure();
+    auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
+    if (!memRefType)
+      return failure();
     // Only contiguous source tensors supported atm.
-    auto strides = computeContiguousStrides(xferOp.getMemRefType());
+    auto strides = computeContiguousStrides(memRefType);
     if (!strides)
       return failure();
 
@@ -1192,10 +1201,9 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
     };
 
     Location loc = xferOp->getLoc();
-    MemRefType memRefType = xferOp.getMemRefType();
 
     if (auto memrefVectorElementType =
-            memRefType.getElementType().dyn_cast<VectorType>()) {
+            memRefType.getElementType().template dyn_cast<VectorType>()) {
       // Memref has vector element type.
       if (memrefVectorElementType.getElementType() !=
           xferOp.getVectorType().getElementType())
@@ -1222,7 +1230,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
     //    address space 0.
     // TODO: support alignment when possible.
     Value dataPtr = this->getStridedElementPtr(
-        loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
+        loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
     auto vecTy =
         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
     Value vectorDataPtr;
@@ -1248,7 +1256,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
     unsigned vecWidth = vecTy.getVectorNumElements();
     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
     Value off = xferOp.indices()[lastIndex];
-    Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
+    Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
     Value mask = buildVectorComparison(
         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
 

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index e5474abfd3e3..973b116ef498 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -89,7 +89,9 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
       return failure();
 
     // Obtain dataPtr and elementType from the memref.
-    MemRefType memRefType = xferOp.getMemRefType();
+    auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
+    if (!memRefType)
+      return failure();
     // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
     // In case of 3(LDS): fall back to vector->llvm pass
     // In case of 5(VGPR): wrong
@@ -101,7 +103,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
     // indices, so no need to calculate offset size in bytes again in
     // the MUBUF instruction.
     Value dataPtr = this->getStridedElementPtr(
-        loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
+        loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
 
     // 1. Create and fill a <4 x i32> dwordConfig with:
     //    1st two elements holding the address of dataPtr.

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index c05e9da2c949..b0f1b46b2459 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -107,7 +107,7 @@ class NDTransferOpHelper {
     // TODO: when we go to k > 1-D vectors adapt minorRank.
     minorRank = 1;
     majorRank = vectorType.getRank() - minorRank;
-    leadingRank = xferOp.getLeadingMemRefRank();
+    leadingRank = xferOp.getLeadingShapedRank();
     majorVectorType =
         VectorType::get(vectorType.getShape().take_front(majorRank),
                         vectorType.getElementType());
@@ -115,9 +115,9 @@ class NDTransferOpHelper {
         VectorType::get(vectorType.getShape().take_back(minorRank),
                         vectorType.getElementType());
     /// Memref of minor vector type is used for individual transfers.
-    memRefMinorVectorType =
-        MemRefType::get(majorVectorType.getShape(), minorVectorType, {},
-                        xferOp.getMemRefType().getMemorySpace());
+    memRefMinorVectorType = MemRefType::get(
+        majorVectorType.getShape(), minorVectorType, {},
+        xferOp.getShapedType().template cast<MemRefType>().getMemorySpace());
   }
 
   LogicalResult doReplace();
@@ -155,7 +155,7 @@ void NDTransferOpHelper<ConcreteOp>::emitLoops(
                             const MemRefBoundsCapture &)>
         loopBodyBuilder) {
   /// Loop nest operates on the major dimensions
-  MemRefBoundsCapture memrefBoundsCapture(xferOp.memref());
+  MemRefBoundsCapture memrefBoundsCapture(xferOp.source());
 
   if (options.unroll) {
     auto shape = majorVectorType.getShape();
@@ -272,9 +272,9 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
       indexing.append(leadingOffsets.begin(), leadingOffsets.end());
       indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
       indexing.append(minorOffsets.begin(), minorOffsets.end());
-      Value memref = xferOp.memref();
+      Value memref = xferOp.source();
       auto map =
-          getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
+          getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
       ArrayAttr masked;
       if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
         OpBuilder &b = ScopedContext::getBuilderRef();
@@ -379,13 +379,13 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
       else
         result = std_load(alloc, majorIvs);
       auto map =
-          getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
+          getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
       ArrayAttr masked;
       if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
         OpBuilder &b = ScopedContext::getBuilderRef();
         masked = b.getBoolArrayAttr({false});
       }
-      vector_transfer_write(result, xferOp.memref(), indexing,
+      vector_transfer_write(result, xferOp.source(), indexing,
                             AffineMapAttr::get(map), masked);
     };
 
@@ -422,7 +422,7 @@ template <typename TransferOpTy>
 static int computeCoalescedIndex(TransferOpTy transfer) {
   // rank of the remote memory access, coalescing behavior occurs on the
   // innermost memory dimension.
-  auto remoteRank = transfer.getMemRefType().getRank();
+  auto remoteRank = transfer.getShapedType().getRank();
   // Iterate over the results expressions of the permutation map to determine
   // the loop order for creating pointwise copies between remote and local
   // memories.
@@ -536,13 +536,14 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
   using namespace mlir::edsc::op;
 
   TransferReadOp transfer = cast<TransferReadOp>(op);
-
+  auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
+  if (!memRefType)
+    return failure();
   // Fall back to a loop if the fastest varying stride is not 1 or it is
   // permuted.
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  auto successStrides =
-      getStridesAndOffset(transfer.getMemRefType(), strides, offset);
+  auto successStrides = getStridesAndOffset(memRefType, strides, offset);
   if (succeeded(successStrides) && strides.back() == 1 &&
       transfer.permutation_map().isMinorIdentity()) {
     // If > 1D, emit a bunch of loops around 1-D vector transfers.
@@ -557,8 +558,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
   // Conservative lowering to scalar load / stores.
   // 1. Setup all the captures.
   ScopedContext scope(rewriter, transfer.getLoc());
-  StdIndexedValue remote(transfer.memref());
-  MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
+  StdIndexedValue remote(transfer.source());
+  MemRefBoundsCapture memRefBoundsCapture(transfer.source());
   VectorBoundsCapture vectorBoundsCapture(transfer.vector());
   int coalescedIdx = computeCoalescedIndex(transfer);
   // Swap the vectorBoundsCapture which will reorder loop bounds.
@@ -621,13 +622,15 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
   using namespace edsc::op;
 
   TransferWriteOp transfer = cast<TransferWriteOp>(op);
+  auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
+  if (!memRefType)
+    return failure();
 
   // Fall back to a loop if the fastest varying stride is not 1 or it is
   // permuted.
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  auto successStrides =
-      getStridesAndOffset(transfer.getMemRefType(), strides, offset);
+  auto successStrides = getStridesAndOffset(memRefType, strides, offset);
   if (succeeded(successStrides) && strides.back() == 1 &&
       transfer.permutation_map().isMinorIdentity()) {
     // If > 1D, emit a bunch of loops around 1-D vector transfers.
@@ -641,8 +644,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
 
   // 1. Setup all the captures.
   ScopedContext scope(rewriter, transfer.getLoc());
-  StdIndexedValue remote(transfer.memref());
-  MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
+  StdIndexedValue remote(transfer.source());
+  MemRefBoundsCapture memRefBoundsCapture(transfer.source());
   Value vectorValue(transfer.vector());
   VectorBoundsCapture vectorBoundsCapture(transfer.vector());
   int coalescedIdx = computeCoalescedIndex(transfer);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 9e7e7efdd136..a1797fde7da6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -111,7 +111,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
       vector::TransferWriteOp transferWrite;
       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
-        if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
+        if (!candidateWrite || candidateWrite.source() != transferRead.source())
           continue;
         transferWrite = candidateWrite;
       }
@@ -142,7 +142,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
       DominanceInfo dom(loop);
       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
         return WalkResult::advance();
-      for (auto &use : transferRead.memref().getUses()) {
+      for (auto &use : transferRead.source().getUses()) {
         if (!dom.properlyDominates(loop, use.getOwner()))
           continue;
         if (use.getOwner() == transferRead.getOperation() ||

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2df1a9469eab..7165ee775e9c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -411,7 +411,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
 
   // Transfer into `view`.
-  Value viewOrAlloc = xferOp.memref();
+  Value viewOrAlloc = xferOp.source();
   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
       !viewOrAlloc.getDefiningOp<AllocOp>())
     return failure();
@@ -487,7 +487,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
   // Transfer into `viewOrAlloc`.
-  Value viewOrAlloc = xferOp.memref();
+  Value viewOrAlloc = xferOp.source();
   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
       !viewOrAlloc.getDefiningOp<AllocOp>())
     return failure();

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5c1f377e589e..a3ad355d30b2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1890,41 +1890,43 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
   return success();
 }
 
-static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
+static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
                                       VectorType vectorType,
                                       AffineMap permutationMap,
                                       ArrayAttr optionalMasked) {
-  auto memrefElementType = memrefType.getElementType();
-  if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
-    // Memref has vector element type.
-
-    unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
-                             memrefVectorElementType.getShape().back();
+  if (!shapedType.isa<MemRefType, RankedTensorType>())
+    return op->emitOpError(
+        "requires source to be a memref or ranked tensor type");
+  auto elementType = shapedType.getElementType();
+  if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
+    // Memref or tensor has vector element type.
+    unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() *
+                             vectorElementType.getShape().back();
     unsigned resultVecSize =
         vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
-    if (resultVecSize % memrefVecSize != 0)
+    if (resultVecSize % sourceVecSize != 0)
       return op->emitOpError(
           "requires the bitwidth of the minor 1-D vector to be an integral "
-          "multiple of the bitwidth of the minor 1-D vector of the memref");
+          "multiple of the bitwidth of the minor 1-D vector of the source");
 
-    unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+    unsigned sourceVecEltRank = vectorElementType.getRank();
     unsigned resultVecRank = vectorType.getRank();
-    if (memrefVecEltRank > resultVecRank)
+    if (sourceVecEltRank > resultVecRank)
       return op->emitOpError(
-          "requires memref vector element and vector result ranks to match.");
-    unsigned rankOffset = resultVecRank - memrefVecEltRank;
+          "requires source vector element and vector result ranks to match.");
+    unsigned rankOffset = resultVecRank - sourceVecEltRank;
     // Check that permutation map results match 'rankOffset' of vector type.
     if (permutationMap.getNumResults() != rankOffset)
       return op->emitOpError("requires a permutation_map with result dims of "
                              "the same rank as the vector type");
   } else {
-    // Memref has scalar element type.
+    // Memref or tensor has scalar element type.
     unsigned resultVecSize =
         vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
-    if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
+    if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0)
       return op->emitOpError(
           "requires the bitwidth of the minor 1-D vector to be an integral "
-          "multiple of the bitwidth of the memref element type");
+          "multiple of the bitwidth of the source element type");
 
     // Check that permutation map results match rank of vector type.
     if (permutationMap.getNumResults() != vectorType.getRank())
@@ -1934,9 +1936,9 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
 
   if (permutationMap.getNumSymbols() != 0)
     return op->emitOpError("requires permutation_map without symbols");
-  if (permutationMap.getNumInputs() != memrefType.getRank())
+  if (permutationMap.getNumInputs() != shapedType.getRank())
     return op->emitOpError("requires a permutation_map with input dims of the "
-                           "same rank as the memref type");
+                           "same rank as the source type");
 
   if (optionalMasked) {
     if (permutationMap.getNumResults() !=
@@ -1978,7 +1980,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   SmallVector<StringRef, 2> elidedAttrs;
   if (op.permutation_map() ==
-      getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
+      getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType()))
     elidedAttrs.push_back(op.getPermutationMapAttrName());
   bool elideMasked = true;
   if (auto maybeMasked = op.masked()) {
@@ -1995,21 +1997,21 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
 }
 
 static void print(OpAsmPrinter &p, TransferReadOp op) {
-  p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
+  p << op.getOperationName() << " " << op.source() << "[" << op.indices()
     << "], " << op.padding();
   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
-  p << " : " << op.getMemRefType() << ", " << op.getVectorType();
+  p << " : " << op.getShapedType() << ", " << op.getVectorType();
 }
 
 static ParseResult parseTransferReadOp(OpAsmParser &parser,
                                        OperationState &result) {
   llvm::SMLoc typesLoc;
-  OpAsmParser::OperandType memrefInfo;
+  OpAsmParser::OperandType sourceInfo;
   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
   OpAsmParser::OperandType paddingInfo;
   SmallVector<Type, 2> types;
   // Parsing with support for paddingValue.
-  if (parser.parseOperand(memrefInfo) ||
+  if (parser.parseOperand(sourceInfo) ||
       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser.parseComma() || parser.parseOperand(paddingInfo) ||
       parser.parseOptionalAttrDict(result.attributes) ||
@@ -2018,48 +2020,48 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
   if (types.size() != 2)
     return parser.emitError(typesLoc, "requires two types");
   auto indexType = parser.getBuilder().getIndexType();
-  MemRefType memRefType = types[0].dyn_cast<MemRefType>();
-  if (!memRefType)
-    return parser.emitError(typesLoc, "requires memref type");
+  auto shapedType = types[0].dyn_cast<ShapedType>();
+  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
+    return parser.emitError(typesLoc, "requires memref or ranked tensor type");
   VectorType vectorType = types[1].dyn_cast<VectorType>();
   if (!vectorType)
     return parser.emitError(typesLoc, "requires vector type");
   auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
   auto attr = result.attributes.get(permutationAttrName);
   if (!attr) {
-    auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
+    auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
   }
   return failure(
-      parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
+      parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
       parser.resolveOperands(indexInfo, indexType, result.operands) ||
-      parser.resolveOperand(paddingInfo, memRefType.getElementType(),
+      parser.resolveOperand(paddingInfo, shapedType.getElementType(),
                             result.operands) ||
       parser.addTypeToList(vectorType, result.types));
 }
 
 static LogicalResult verify(TransferReadOp op) {
-  // Consistency of elemental types in memref and vector.
-  MemRefType memrefType = op.getMemRefType();
+  // Consistency of elemental types in source and vector.
+  ShapedType shapedType = op.getShapedType();
   VectorType vectorType = op.getVectorType();
   auto paddingType = op.padding().getType();
   auto permutationMap = op.permutation_map();
-  auto memrefElementType = memrefType.getElementType();
+  auto sourceElementType = shapedType.getElementType();
 
-  if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
-    return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+  if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
+    return op.emitOpError("requires ") << shapedType.getRank() << " indices";
 
-  if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+  if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
                               permutationMap,
                               op.masked() ? *op.masked() : ArrayAttr())))
     return failure();
 
-  if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
-    // Memref has vector element type.
-    // Check that 'memrefVectorElementType' and 'paddingType' types match.
-    if (memrefVectorElementType != paddingType)
+  if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
+    // Source has vector element type.
+    // Check that 'sourceVectorElementType' and 'paddingType' types match.
+    if (sourceVectorElementType != paddingType)
       return op.emitOpError(
-          "requires memref element type and padding type to match.");
+          "requires source element type and padding type to match.");
 
   } else {
     // Check that 'paddingType' is valid to store in a vector type.
@@ -2067,9 +2069,9 @@ static LogicalResult verify(TransferReadOp op) {
       return op.emitOpError("requires valid padding vector elemental type");
 
     // Check that padding type and vector element types match.
-    if (paddingType != memrefElementType)
+    if (paddingType != sourceElementType)
       return op.emitOpError(
-          "requires formal padding and memref of the same elemental type");
+          "requires formal padding and source of the same elemental type");
   }
 
   return verifyPermutationMap(permutationMap,
@@ -2096,18 +2098,18 @@ static LogicalResult foldMemRefCast(Operation *op) {
 template <typename TransferOp>
 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
   // TODO: support more aggressive createOrFold on:
-  // `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
-  if (op.getMemRefType().isDynamicDim(indicesIdx))
+  // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
+  if (op.getShapedType().isDynamicDim(indicesIdx))
     return false;
   Value index = op.indices()[indicesIdx];
   auto cstOp = index.getDefiningOp<ConstantIndexOp>();
   if (!cstOp)
     return false;
 
-  int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
+  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
   int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
 
-  return cstOp.getValue() + vectorSize <= memrefSize;
+  return cstOp.getValue() + vectorSize <= sourceSize;
 }
 
 template <typename TransferOp>
@@ -2159,33 +2161,51 @@ Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
 
 /// Builder that sets permutation map to 'getMinorIdentityMap'.
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
-                            Value vector, Value memref, ValueRange indices,
+                            Value vector, Value source, ValueRange indices,
                             ArrayRef<bool> maybeMasked) {
   auto vectorType = vector.getType().cast<VectorType>();
   auto permMap = getTransferMinorIdentityMap(
-      memref.getType().cast<MemRefType>(), vectorType);
+      source.getType().cast<MemRefType>(), vectorType);
   if (maybeMasked.empty())
-    return build(builder, result, vector, memref, indices, permMap,
+    return build(builder, result, vector, source, indices, permMap,
                  ArrayAttr());
   ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
-  build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
+  build(builder, result, vector, source, indices, permMap, maskedArrayAttr);
 }
 
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
-                            Value vector, Value memref, ValueRange indices,
+                            Value vector, Value source, ValueRange indices,
                             AffineMap permutationMap) {
-  build(builder, result, vector, memref, indices, permutationMap,
+  build(builder, result, vector, source, indices, permutationMap,
         /*maybeMasked=*/ArrayAttr());
 }
 
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+                            Value vector, Value source, ValueRange indices,
+                            AffineMapAttr permutationMap,
+                            /*optional*/ ArrayAttr masked) {
+  Type resultType = source.getType().dyn_cast<RankedTensorType>();
+  build(builder, result, resultType, vector, source, indices, permutationMap,
+        masked);
+}
+
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+                            Value vector, Value source, ValueRange indices,
+                            AffineMap permutationMap,
+                            /*optional*/ ArrayAttr masked) {
+  Type resultType = source.getType().dyn_cast<RankedTensorType>();
+  build(builder, result, resultType, vector, source, indices, permutationMap,
+        masked);
+}
+
 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
                                         OperationState &result) {
   llvm::SMLoc typesLoc;
-  OpAsmParser::OperandType vectorInfo, memrefInfo;
+  OpAsmParser::OperandType vectorInfo, sourceInfo;
   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
   SmallVector<Type, 2> types;
   if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
-      parser.parseOperand(memrefInfo) ||
+      parser.parseOperand(sourceInfo) ||
       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
@@ -2196,38 +2216,40 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
   VectorType vectorType = types[0].dyn_cast<VectorType>();
   if (!vectorType)
     return parser.emitError(typesLoc, "requires vector type");
-  MemRefType memRefType = types[1].dyn_cast<MemRefType>();
-  if (!memRefType)
-    return parser.emitError(typesLoc, "requires memref type");
+  ShapedType shapedType = types[1].dyn_cast<ShapedType>();
+  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
+    return parser.emitError(typesLoc, "requires memref or ranked tensor type");
   auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
   auto attr = result.attributes.get(permutationAttrName);
   if (!attr) {
-    auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
+    auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
   }
   return failure(
       parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
-      parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
-      parser.resolveOperands(indexInfo, indexType, result.operands));
+      parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
+      parser.resolveOperands(indexInfo, indexType, result.operands) ||
+      (shapedType.isa<RankedTensorType>() &&
+       parser.addTypeToList(shapedType, result.types)));
 }
 
 static void print(OpAsmPrinter &p, TransferWriteOp op) {
-  p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
+  p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
     << op.indices() << "]";
   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
-  p << " : " << op.getVectorType() << ", " << op.getMemRefType();
+  p << " : " << op.getVectorType() << ", " << op.getShapedType();
 }
 
 static LogicalResult verify(TransferWriteOp op) {
   // Consistency of elemental types in memref and vector.
-  MemRefType memrefType = op.getMemRefType();
+  ShapedType shapedType = op.getShapedType();
   VectorType vectorType = op.getVectorType();
   auto permutationMap = op.permutation_map();
 
-  if (llvm::size(op.indices()) != memrefType.getRank())
-    return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+  if (llvm::size(op.indices()) != shapedType.getRank())
+    return op.emitOpError("requires ") << shapedType.getRank() << " indices";
 
-  if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+  if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
                               permutationMap,
                               op.masked() ? *op.masked() : ArrayAttr())))
     return failure();

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
index b7de983dd3b1..ea1189d53b31 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -94,7 +94,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
                     << "\n");
   llvm::SmallVector<Operation *, 8> reads;
   Operation *firstOverwriteCandidate = nullptr;
-  for (auto *user : write.memref().getUsers()) {
+  for (auto *user : write.source().getUsers()) {
     if (user == write.getOperation())
       continue;
     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
@@ -163,7 +163,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
                     << "\n");
   SmallVector<Operation *, 8> blockingWrites;
   vector::TransferWriteOp lastwrite = nullptr;
-  for (Operation *user : read.memref().getUsers()) {
+  for (Operation *user : read.source().getUsers()) {
     if (isa<vector::TransferReadOp>(user))
       continue;
     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 664426960beb..1e58a759d305 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -597,7 +597,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
 
   Location loc = readOp.getLoc();
   auto memrefElementType =
-      readOp.memref().getType().cast<MemRefType>().getElementType();
+      readOp.source().getType().cast<MemRefType>().getElementType();
   auto tupleType = generateExtractSlicesOpResultType(
       sourceVectorType, targetShape, strides, builder);
   int64_t numSlices = tupleType.size();
@@ -612,7 +612,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
     // `masked` attribute propagates conservatively: if the coarse op didn't
     // need masking, the fine op doesn't either.
     vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
-        loc, sliceVectorType, readOp.memref(), sliceIndices,
+        loc, sliceVectorType, readOp.source(), sliceIndices,
         readOp.permutation_map(), readOp.padding(),
         readOp.masked() ? *readOp.masked() : ArrayAttr());
   };
@@ -644,14 +644,14 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
   Value tuple = builder.create<vector::ExtractSlicesOp>(
       loc, tupleType, writeOp.vector(), targetShape, strides);
   auto memrefElementType =
-      writeOp.memref().getType().cast<MemRefType>().getElementType();
+      writeOp.source().getType().cast<MemRefType>().getElementType();
   SmallVector<Value, 4> indices(writeOp.indices().begin(),
                                 writeOp.indices().end());
   auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
     auto element = builder.create<vector::TupleGetOp>(
         loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
     builder.create<vector::TransferWriteOp>(
-        loc, element.getResult(), writeOp.memref(), sliceIndices,
+        loc, element.getResult(), writeOp.source(), sliceIndices,
         writeOp.permutation_map(),
         writeOp.masked() ? *writeOp.masked() : ArrayAttr());
   };
@@ -760,7 +760,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
 
     Location loc = xferWriteOp.getLoc();
     auto memrefElementType =
-        xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
+        xferWriteOp.source().getType().cast<MemRefType>().getElementType();
     SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
                                   xferWriteOp.indices().end());
     auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
@@ -768,7 +768,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
       // `masked` attribute propagates conservatively: if the coarse op didn't
       // need masking, the fine op doesn't either.
       rewriter.create<vector::TransferWriteOp>(
-          loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
+          loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
           xferWriteOp.permutation_map(),
           xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
     };
@@ -2142,7 +2142,7 @@ static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
     // Fold or create the check that `index + vector_size` <= `memref_size`.
     Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
     Value cond =
-        createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
+        createScopedFoldedSLE(sum, std_dim(xferOp.source(), indicesIdx));
     if (!cond)
       return;
     // Conjunction over all dims for which we are in-bounds.
@@ -2207,23 +2207,23 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
 }
 
 /// Operates under a scoped context to build the intersection between the
-/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
+/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
 // TODO: view intersection/union/
diff erences should be a proper std op.
 static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
                                              Value alloc) {
   using namespace edsc::intrinsics;
-  int64_t memrefRank = xferOp.getMemRefType().getRank();
+  int64_t memrefRank = xferOp.getShapedType().getRank();
   // TODO: relax this precondition, will require rank-reducing subviews.
   assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
          "Expected memref rank to match the alloc rank");
   Value one = std_constant_index(1);
   ValueRange leadingIndices =
-      xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
+      xferOp.indices().take_front(xferOp.getLeadingShapedRank());
   SmallVector<Value, 4> sizes;
   sizes.append(leadingIndices.begin(), leadingIndices.end());
   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-    Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
+    Value dimMemRef = std_dim(xferOp.source(), indicesIdx);
     Value dimAlloc = std_dim(alloc, resultIdx);
     Value index = xferOp.indices()[indicesIdx];
     AffineExpr i, j, k;
@@ -2235,7 +2235,7 @@ static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
                                  ValueRange{dimMemRef, index, dimAlloc});
     sizes.push_back(affineMin);
   });
-  return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
+  return std_sub_view(xferOp.source(), xferOp.indices(), sizes,
                       SmallVector<Value, 4>(memrefRank, one));
 }
 
@@ -2263,12 +2263,12 @@ static scf::IfOp createScopedFullPartialLinalgCopy(
   using namespace edsc::intrinsics;
   scf::IfOp fullPartialIfOp;
   Value zero = std_constant_index(0);
-  Value memref = xferOp.memref();
+  Value memref = xferOp.source();
   conditionBuilder(
       returnTypes, inBoundsCond,
       [&]() -> scf::ValueVector {
         Value res = memref;
-        if (compatibleMemRefType != xferOp.getMemRefType())
+        if (compatibleMemRefType != xferOp.getShapedType())
           res = std_memref_cast(memref, compatibleMemRefType);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
@@ -2317,12 +2317,12 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
   using namespace edsc::intrinsics;
   scf::IfOp fullPartialIfOp;
   Value zero = std_constant_index(0);
-  Value memref = xferOp.memref();
+  Value memref = xferOp.source();
   conditionBuilder(
       returnTypes, inBoundsCond,
       [&]() -> scf::ValueVector {
         Value res = memref;
-        if (compatibleMemRefType != xferOp.getMemRefType())
+        if (compatibleMemRefType != xferOp.getShapedType())
           res = std_memref_cast(memref, compatibleMemRefType);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
@@ -2376,7 +2376,7 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
 ///
 /// Preconditions:
 ///  1. `xferOp.permutation_map()` must be a minor identity map
-///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
 ///  must be equal. This will be relaxed in the future but requires
 ///  rank-reducing subviews.
 LogicalResult mlir::vector::splitFullAndPartialTransfer(
@@ -2404,8 +2404,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
     return failure();
 
   OpBuilder::InsertionGuard guard(b);
-  if (xferOp.memref().getDefiningOp())
-    b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
+  if (Operation *sourceOp = xferOp.source().getDefiningOp())
+    b.setInsertionPointAfter(sourceOp);
   else
     b.setInsertionPoint(xferOp);
   ScopedContext scope(b, xferOp.getLoc());
@@ -2426,8 +2426,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
                        b.getI64IntegerAttr(32));
   }
 
-  MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
-      xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
+  MemRefType compatibleMemRefType =
+      getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
+                                  alloc.getType().cast<MemRefType>());
 
   // Read case: full fill + partial copy -> unmasked vector.xfer_read.
   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
@@ -2543,7 +2544,7 @@ struct TransferReadExtractPattern
           extract.ids()[idCount++] *
               std_constant_index(extract.getResultType().getDimSize(pos));
     }
-    Value newRead = vector_transfer_read(extract.getType(), read.memref(),
+    Value newRead = vector_transfer_read(extract.getType(), read.source(),
                                          indices, read.permutation_map(),
                                          read.padding(), read.maskedAttr());
     Value dest = rewriter.create<ConstantOp>(
@@ -2579,7 +2580,7 @@ struct TransferWriteInsertPattern
           insert.ids()[idCount++] *
               std_constant_index(insert.getSourceVectorType().getDimSize(pos));
     }
-    vector_transfer_write(insert.vector(), write.memref(), indices,
+    vector_transfer_write(insert.vector(), write.source(), indices,
                           write.permutation_map(), write.maskedAttr());
     rewriter.eraseOp(write);
     return success();

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 3ab1f500f5d1..fc08d21b27a5 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -243,16 +243,16 @@ AffineMap mlir::makePermutationMap(
   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
 }
 
-AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType,
+AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
                                             VectorType vectorType) {
   int64_t elementVectorRank = 0;
   VectorType elementVectorType =
-      memRefType.getElementType().dyn_cast<VectorType>();
+      shapedType.getElementType().dyn_cast<VectorType>();
   if (elementVectorType)
     elementVectorRank += elementVectorType.getRank();
   return AffineMap::getMinorIdentityMap(
-      memRefType.getRank(), vectorType.getRank() - elementVectorRank,
-      memRefType.getContext());
+      shapedType.getRank(), vectorType.getRank() - elementVectorRank,
+      shapedType.getContext());
 }
 
 bool matcher::operatesOnSuperVectorsOf(Operation &op,
@@ -314,12 +314,12 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
 
 bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
                                  VectorTransferOpInterface transferB) {
-  if (transferA.memref() != transferB.memref())
+  if (transferA.source() != transferB.source())
     return false;
   // For simplicity only look at transfer of same type.
   if (transferA.getVectorType() != transferB.getVectorType())
     return false;
-  unsigned rankOffset = transferA.getLeadingMemRefRank();
+  unsigned rankOffset = transferA.getLeadingShapedRank();
   for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
     auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
     auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 73b1f9e1e06e..62eaa4e3a14e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -269,7 +269,7 @@ func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
   %c3 = constant 3 : index
   %f0 = constant 0.0 : f32
   %vf0 = splat %f0 : vector<4x3xf32>
-  // expected-error at +1 {{ requires memref type}}
+  // expected-error at +1 {{ requires memref or ranked tensor type}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
 }
 
@@ -297,7 +297,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
 func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
   %c3 = constant 3 : index
   %cst = constant 3.0 : f32
-  // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the memref type}}
+  // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the source type}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0)->(d0)>} : memref<?x?xf32>, vector<128xf32>
 }
 
@@ -343,7 +343,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
   %c3 = constant 3 : index
   %f0 = constant 0.0 : f32
   %vf0 = splat %f0 : vector<4x3xf32>
-  // expected-error at +1 {{requires memref vector element and vector result ranks to match}}
+  // expected-error at +1 {{requires source vector element and vector result ranks to match}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
 }
 
@@ -353,7 +353,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
   %c3 = constant 3 : index
   %f0 = constant 0.0 : f32
   %vf0 = splat %f0 : vector<6xf32>
-  // expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}}
+  // expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
   %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
 }
 
@@ -392,7 +392,7 @@ func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
   %c3 = constant 3 : index
   %f0 = constant 0.0 : f32
   %vf0 = splat %f0 : vector<4x3xf32>
-  // expected-error at +1 {{ requires memref type}}
+  // expected-error at +1 {{ requires memref or ranked tensor type}}
   vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
 }
 
@@ -419,7 +419,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
 func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
   %c3 = constant 3 : index
   %cst = constant dense<3.0> : vector<128 x f32>
-  // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the memref type}}
+  // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the source type}}
   vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0)->(d0)>} : vector<128xf32>, memref<?x?xf32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index aab6cabf759d..07e9d8de3f49 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -43,6 +43,54 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   return
 }
 
+
+// CHECK-LABEL: func @vector_transfer_ops_tensor(
+func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
+                          %arg1 : tensor<?x?xvector<4x3xf32>>,
+                          %arg2 : tensor<?x?xvector<4x3xi32>>) ->
+  (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
+   tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>){
+  // CHECK: %[[C3:.*]] = constant 3 : index
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  %f0 = constant 0.0 : f32
+  %c0 = constant 0 : i32
+  %vf0 = splat %f0 : vector<4x3xf32>
+  %v0 = splat %c0 : vector<4x3xi32>
+
+  //
+  // CHECK: vector.transfer_read
+  %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>, vector<128xf32>
+  // CHECK: vector.transfer_read
+  %1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : tensor<?x?xf32>, vector<3x7xf32>
+  // CHECK: vector.transfer_read
+  %2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>,  vector<128xf32>
+  // CHECK: vector.transfer_read
+  %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : tensor<?x?xf32>,  vector<128xf32>
+  // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+  %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+  // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+  %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+  // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
+  %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
+
+
+  // CHECK: vector.transfer_write
+  %7 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor<?x?xf32>
+  // CHECK: vector.transfer_write
+  %8 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor<?x?xf32>
+  // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+  %9 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+  // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+  %10 = vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+  // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
+  %11 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
+
+  return %7, %8, %9, %10, %11 :
+    tensor<?x?xf32>, tensor<?x?xf32>,  tensor<?x?xvector<4x3xf32>>,
+    tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>
+}
+
 // CHECK-LABEL: @vector_broadcast
 func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
   // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>


        


More information about the Mlir-commits mailing list