[Mlir-commits] [mlir] c537a94 - [mlir][Vector] Thread 0-d vectors through vector.transfer ops

Nicolas Vasilache llvmlistbot at llvm.org
Wed Dec 1 08:49:48 PST 2021


Author: Nicolas Vasilache
Date: 2021-12-01T16:49:43Z
New Revision: c537a943342be66d0876c6440a2df317b572c092

URL: https://github.com/llvm/llvm-project/commit/c537a943342be66d0876c6440a2df317b572c092
DIFF: https://github.com/llvm/llvm-project/commit/c537a943342be66d0876c6440a2df317b572c092.diff

LOG: [mlir][Vector] Thread 0-d vectors through vector.transfer ops

This revision adds 0-d vector support to vector.transfer ops.
In the process, numerous cleanups are applied, in particular around normalizing
and reducing the number of builders.

Reviewed By: ThomasRaoux, springerm

Differential Revision: https://reviews.llvm.org/D114803

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Interfaces/VectorInterfaces.cpp
    mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 1bab07e77325c..8eaf785319578 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1133,26 +1133,28 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       AttrSizedOperandSegments
     ]>,
-    Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
-               AffineMapAttr:$permutation_map, AnyType:$padding,
-               Optional<VectorOf<[I1]>>:$mask,
-               OptionalAttr<BoolArrayAttr>:$in_bounds)>,
-    Results<(outs AnyVector:$vector)> {
+    Arguments<(ins AnyShaped:$source,
+                   Variadic<Index>:$indices,
+                   AffineMapAttr:$permutation_map,
+                   AnyType:$padding,
+                   Optional<VectorOf<[I1]>>:$mask,
+                   OptionalAttr<BoolArrayAttr>:$in_bounds)>,
+    Results<(outs AnyVectorOfAnyRank:$vector)> {
 
   let summary = "Reads a supervector from memory into an SSA vector value.";
 
   let description = [{
     The `vector.transfer_read` op performs a read from a slice within a
     [MemRef](../LangRef.md#memref-type) or a Ranked
-    [Tensor](../LangRef.md#tensor-type) supplied as its first operand into a
-    [vector](../LangRef.md#vector-type) of the same base elemental type.
+    [Tensor](../LangRef.md#tensor-type) supplied as its first operand
+    into a [vector](../LangRef.md#vector-type) of the same base elemental type.
 
     A memref/tensor operand with vector element type, must have its vector
     element type match a suffix (shape and element type) of the vector (e.g.
     memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
 
     The slice is further defined by a full-rank index within the MemRef/Tensor,
-    supplied as the operands `2 .. 1 + rank(memref/tensor)`.
+    supplied as the operands `[1 .. 1 + rank(memref/tensor))`.
 
     The permutation_map [attribute](../LangRef.md#attributes) is an
     [affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1301,39 +1303,31 @@ def Vector_TransferReadOp :
   }];
 
   let builders = [
-    // Builder that sets padding to zero.
-    OpBuilder<(ins "VectorType":$vector, "Value":$source,
-      "ValueRange":$indices, "AffineMap":$permutationMap,
-      CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
-    // Builder that sets permutation map to 'getMinorIdentityMap'.
-    OpBuilder<(ins "VectorType":$vector, "Value":$source,
-      "ValueRange":$indices, "Value":$padding,
-      CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
-    // Builder that sets permutation map (resp. padding) to
-    // 'getMinorIdentityMap' (resp. zero).
-    OpBuilder<(ins "VectorType":$vector, "Value":$source,
-      "ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
-    // Builder that does not set mask.
-    OpBuilder<(ins "Type":$vector, "Value":$source,
-      "ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding,
-      "ArrayAttr":$inBounds)>,
-    // Builder that does not set mask.
-    OpBuilder<(ins "Type":$vector, "Value":$source,
-      "ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding,
-      "ArrayAttr":$inBounds)>
+    /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   "AffineMapAttr":$permutationMapAttr,
+                   "ArrayAttr":$inBoundsAttr)>,
+    /// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   "AffineMap":$permutationMap,
+                   CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
+    /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   "Value":$padding,
+                   CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
+    /// 4. Builder that sets padding to zero and permutation map to
+    /// 'getMinorIdentityMap'.
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
   ];
-
-  let extraClassDeclaration = [{
-    /// Temporary convenience builders to account for the fact that we do not
-    /// have 0-d vectors atm. These create a constant `vector<1xt>` and
-    /// insert/extract into it.
-    // Builder that sets permutation map (resp. padding) to
-    // 'getMinorIdentityMap' (resp. zero).
-    static Value createScalarOp(OpBuilder &builder, Location loc, Value source,
-                                ValueRange indices,
-                                ArrayRef<bool> inBounds = ArrayRef<bool>{});
-  }];
-
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
@@ -1345,11 +1339,12 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       AttrSizedOperandSegments
   ]>,
-    Arguments<(ins AnyVector:$vector, AnyShaped:$source,
-               Variadic<Index>:$indices,
-               AffineMapAttr:$permutation_map,
-               Optional<VectorOf<[I1]>>:$mask,
-               OptionalAttr<BoolArrayAttr>:$in_bounds)>,
+    Arguments<(ins AnyVectorOfAnyRank:$vector,
+                   AnyShaped:$source,
+                   Variadic<Index>:$indices,
+                   AffineMapAttr:$permutation_map,
+                   Optional<VectorOf<[I1]>>:$mask,
+                   OptionalAttr<BoolArrayAttr>:$in_bounds)>,
     Results<(outs Optional<AnyRankedTensor>:$result)> {
 
   let summary = "The vector.transfer_write op writes a supervector to memory.";
@@ -1367,7 +1362,7 @@ def Vector_TransferWriteOp :
     new tensor of the same type.
 
     The slice is further defined by a full-rank index within the MemRef/Tensor,
-    supplied as the operands `3 .. 2 + rank(memref/tensor)`.
+    supplied as the operands `[2 .. 2 + rank(memref/tensor))`.
 
     The permutation_map [attribute](../LangRef.md#attributes) is an
     [affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1444,32 +1439,32 @@ def Vector_TransferWriteOp :
   }];
 
   let builders = [
-    // Builder that sets an empty mask.
-    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
-      "AffineMap":$permutationMap, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
-    // Builder that sets permutation map to 'getMinorIdentityMap'.
-    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
-      CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
-    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
-      "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
-    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
-      "AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>,
-    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
-      "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
+    /// 1. Builder with type inference.
+    OpBuilder<(ins "Value":$vector,
+                   "Value":$dest,
+                   "ValueRange":$indices,
+                   "AffineMapAttr":$permutationMapAttr,
+                   "Value":$mask,
+                   "ArrayAttr":$inBoundsAttr)>,
+    /// 2. Builder with type inference that sets an empty mask (variant with attrs).
+    OpBuilder<(ins "Value":$vector,
+                   "Value":$dest,
+                   "ValueRange":$indices,
+                   "AffineMapAttr":$permutationMapAttr,
+                   "ArrayAttr":$inBoundsAttr)>,
+    /// 3. Builder with type inference that sets an empty mask (variant without attrs).
+    OpBuilder<(ins "Value":$vector,
+                   "Value":$dest,
+                   "ValueRange":$indices,
+                   "AffineMap":$permutationMap,
+                   CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
+    /// 4. Builder with type inference that sets an empty mask and sets permutation
+    /// map to 'getMinorIdentityMap'.
+    OpBuilder<(ins "Value":$vector,
+                   "Value":$dest,
+                   "ValueRange":$indices,
+                   CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
   ];
-
-  let extraClassDeclaration = [{
-    /// Temporary convenience builders to account for the fact that we do not
-    /// have 0-d vectors atm. These create a constant `vector<1xt>` and
-    /// insert/extract into it.
-    // Builder that sets permutation map (resp. padding) to
-    // 'getMinorIdentityMap' (resp. zero).
-    static Operation *createScalarOp(
-      OpBuilder &builder, Location loc, Value value,
-      Value dest, ValueRange indices,
-      ArrayRef<bool> inBounds = ArrayRef<bool>{});
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index c713f1806d1f1..68b88860b2ff3 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -114,29 +114,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodBody=*/"return $_op.permutation_map();"
       /*defaultImplementation=*/
     >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Returns true if op involves a 0-d tensor/memref and a vector
-        of shape {1}. This is temporary until we have 0-d vectors.
-        // TODO: turn this into 0-d vectors + empty permutation_map.
-      }],
-      /*retTy=*/"bool",
-      /*methodName=*/"isZeroD",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        if (getShapedType().getRank() > 0)
-          return false;
-        if (getVectorType().getShape() != ArrayRef<int64_t>{1})
-          return false;
-        AffineMap map = AffineMap::get(
-          /*numDims=*/0, /*numSymbols=*/0,
-          getAffineConstantExpr(0, $_op->getContext()));
-        if ($_op.permutation_map() != map)
-          return false;
-        return true;
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{ Returns true if the specified dimension is a broadcast. }],
       /*retTy=*/"bool",
@@ -157,10 +134,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        // 0-d transfers are not considered broadcasts but they need to be 
-        // represented with a vector<1xt> until we have 0-d vectors.
-        if ($_op.isZeroD()) return false;
-        for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) {
+        for (unsigned i = 0, rank = getTransferRank(); i < rank; ++i) {
           if ($_op.isBroadcastDim(i))
             return true;
         }

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 84e9cae77dd11..4c50b4f699365 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -92,6 +92,10 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
 // Return true if the transfer op can be converted to a MMA matrix store.
 static bool
 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
+  // TODO: support 0-d corner case.
+  if (writeOp.getTransferRank() == 0)
+    return false;
+
   if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
       writeOp.getVectorType().getRank() != 2)
     return false;
@@ -295,6 +299,11 @@ struct CombineTransferReadOpTranspose final
     auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
     if (!transferReadOp)
       return failure();
+
+    // TODO: support 0-d corner case.
+    if (transferReadOp.getTransferRank() == 0)
+      return failure();
+
     if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
       return failure();
     SmallVector<int64_t, 2> perm;
@@ -307,8 +316,8 @@ struct CombineTransferReadOpTranspose final
     AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
         op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
-        newMap, transferReadOp.padding(), transferReadOp.mask(),
-        transferReadOp.in_boundsAttr());
+        AffineMapAttr::get(newMap), transferReadOp.padding(),
+        transferReadOp.mask(), transferReadOp.in_boundsAttr());
     return success();
   }
 };
@@ -335,6 +344,7 @@ static const char *inferFragType(OpTy op) {
 
 static void convertTransferReadOp(vector::TransferReadOp op,
                                   llvm::DenseMap<Value, Value> &valueMapping) {
+  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
   assert(transferReadSupportsMMAMatrixType(op));
   Optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index cc54b7f8bd2ed..50b4c3a8dd2e7 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -64,6 +64,10 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
   LogicalResult
   matchAndRewrite(ConcreteOp xferOp, typename ConcreteOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
+
     if (xferOp.getVectorType().getRank() > 1 ||
         llvm::size(xferOp.indices()) == 0)
       return failure();

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 6d2c91f19bb68..4709bed076377 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -52,6 +52,8 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
 /// A return value of None indicates a broadcast.
 template <typename OpTy>
 static Optional<int64_t> unpackedDim(OpTy xferOp) {
+  // TODO: support 0-d corner case.
+  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
   auto map = xferOp.permutation_map();
   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
     return expr.getPosition();
@@ -66,6 +68,8 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
 /// omitted.
 template <typename OpTy>
 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
+  // TODO: support 0-d corner case.
+  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
   auto map = xferOp.permutation_map();
   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
                         b.getContext());
@@ -1081,6 +1085,7 @@ get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
                    SmallVector<Value, 8> &memrefIndices) {
   auto indices = xferOp.indices();
   auto map = xferOp.permutation_map();
+  assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
 
   memrefIndices.append(indices.begin(), indices.end());
   assert(map.getNumResults() == 1 &&
@@ -1206,6 +1211,9 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
 
   LogicalResult matchAndRewrite(OpTy xferOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
     auto map = xferOp.permutation_map();
     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index b10857e13c9c8..a8736bd3ca1c2 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -101,8 +101,7 @@ struct TransferWriteOpInterface
       return failure();
     b.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
-        writeOp.permutation_map(),
-        writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
+        writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
     state.mapBuffer(op->getResult(0), resultBuffer);
 
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 206daf9c81650..42811da8f59ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -115,8 +115,6 @@ struct VectorizationResult {
 /// ShapedType of `v`.
 static VectorType extractVectorTypeFromShapedValue(Value v) {
   auto st = v.getType().cast<ShapedType>();
-  if (st.getShape().empty())
-    return VectorType();
   return VectorType::get(st.getShape(), st.getElementType());
 }
 
@@ -179,21 +177,6 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
   return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
 }
 
-/// Build a vector.transfer_read from `source` at indices set to all `0`.
-/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
-/// Return the produced value.
-static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
-                             AffineMap map) {
-  Location loc = source.getLoc();
-  auto shapedType = source.getType().cast<ShapedType>();
-  SmallVector<Value> indices(shapedType.getRank(),
-                             b.create<arith::ConstantIndexOp>(loc, 0));
-  if (auto vectorType = readType.dyn_cast<VectorType>())
-    return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
-                                            map);
-  return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
-}
-
 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
 /// assumes that `reductionOp` has two operands and one of them is the reduction
 /// initial value.
@@ -226,8 +209,11 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
   Operation *write;
   Location loc = value.getLoc();
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
-  if (VectorType vectorType =
-          extractVectorTypeFromShapedValue(outputOperand->get())) {
+  ArrayRef<int64_t> shape = linalgOp.getShape(outputOperand);
+  auto vectorType = VectorType::get(
+      shape, getElementTypeOrSelf(outputOperand->get().getType()));
+  if (vectorType.getRank() > 0) {
+    // 0-d case is still special: do not invert the reindexing map.
     AffineMap map =
         reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
     SmallVector<int64_t> transposeShape =
@@ -240,8 +226,11 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
                                               indices, map);
   } else {
-    write = vector::TransferWriteOp::createScalarOp(
-        b, loc, value, outputOperand->get(), ValueRange{});
+    if (!value.getType().isa<VectorType>())
+      value = b.create<vector::BroadcastOp>(loc, vectorType, value);
+    assert(value.getType() == vectorType && "incorrect type");
+    write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
+                                              ValueRange{});
   }
   LDBG("vectorized op: " << *write);
   if (!write->getResults().empty())
@@ -515,32 +504,42 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
   SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
 
   // 3. Turn all BBArgs into vector.transfer_read / load.
-  SmallVector<AffineMap> indexings;
+  Location loc = linalgOp.getLoc();
+  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
     BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber());
     if (linalgOp.isScalar(opOperand)) {
       bvm.map(bbarg, opOperand->get());
       continue;
     }
-    // TODO: 0-d vectors.
-    Type readType;
+    VectorType readType;
     AffineMap map;
-    if (linalgOp.getShape(opOperand).empty()) {
-      readType = bbarg.getType();
+    // TODO: can we keep this simplification?
+    // if (linalgOp.getShape(opOperand).empty()) {
+    //   readType = VectorType::get({}, bbarg.getType());
+    // } else {
+    if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
+      map = inverseAndBroadcastProjectedPermuation(
+          linalgOp.getTiedIndexingMap(opOperand));
+      readType = VectorType::get(commonVectorShape,
+                                 getElementTypeOrSelf(opOperand->get()));
     } else {
-      if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
-        map = inverseAndBroadcastProjectedPermuation(
-            linalgOp.getTiedIndexingMap(opOperand));
-        readType = VectorType::get(commonVectorShape,
-                                   getElementTypeOrSelf(opOperand->get()));
-      } else {
-        map = inversePermutation(
-            reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
-        readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
-                                   getElementTypeOrSelf(opOperand->get()));
-      }
+      map = inversePermutation(
+          reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
+      readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+                                 getElementTypeOrSelf(opOperand->get()));
     }
-    Value readValue = buildVectorRead(b, opOperand->get(), readType, map);
+    // }
+
+    auto shape = linalgOp.getShape(opOperand);
+    SmallVector<Value> indices(shape.size(), zero);
+    Value readValue = b.create<vector::TransferReadOp>(
+        loc, readType, opOperand->get(), indices, map);
+    // Not all ops support 0-d vectors, extract the scalar for now.
+    // TODO: remove this.
+    if (readValue.getType().cast<VectorType>().getRank() == 0)
+      readValue = b.create<vector::ExtractElementOp>(loc, readValue);
+
     LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
     bvm.map(bbarg, readValue);
     bvm.map(opOperand->get(), readValue);
@@ -752,7 +751,7 @@ struct GenericPadTensorOpVectorizationPattern
         rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
     auto read = rewriter.create<vector::TransferReadOp>(
         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
-        readInBounds);
+        ArrayRef<bool>{readInBounds});
 
     // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
     // tensor, write directly to the FillOp's operand.
@@ -765,7 +764,7 @@ struct GenericPadTensorOpVectorizationPattern
     auto writeIndices =
         ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        padOp, read, dest, writeIndices, writeInBounds);
+        padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
 
     return success();
   }
@@ -878,6 +877,10 @@ struct PadTensorOpVectorizationWithTransferWritePattern
 
   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
                             vector::TransferWriteOp xferOp) const override {
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
+
     // Low padding must be static 0.
     if (!padOp.hasZeroLowPad())
       return failure();
@@ -1072,7 +1075,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
         ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
     SmallVector<bool> inBounds(vecRank, true);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        insertOp, read, insertOp.dest(), writeIndices, inBounds);
+        insertOp, read, insertOp.dest(), writeIndices,
+        ArrayRef<bool>{inBounds});
 
     return success();
   }
@@ -1266,6 +1270,10 @@ static memref::SubViewOp getSubViewUseIfUnique(Value v) {
 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
 
+  // TODO: support mask.
+  if (xferOp.mask())
+    return failure();
+
   // Transfer into `view`.
   Value viewOrAlloc = xferOp.source();
   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
@@ -1328,7 +1336,9 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   // conservatively.
   Value res = rewriter.create<vector::TransferReadOp>(
       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
-      xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
+      xferOp.permutation_mapAttr(), xferOp.padding(), xferOp.mask(),
+      // in_bounds is explicitly reset
+      /*inBoundsAttr=*/ArrayAttr());
 
   if (maybeFillOp)
     rewriter.eraseOp(maybeFillOp);
@@ -1342,6 +1352,10 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
 /// when available.
 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
+  // TODO: support mask.
+  if (xferOp.mask())
+    return failure();
+
   // Transfer into `viewOrAlloc`.
   Value viewOrAlloc = xferOp.source();
   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
@@ -1380,7 +1394,9 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
   // conservatively.
   rewriter.create<vector::TransferWriteOp>(
       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
-      xferOp.permutation_map(), ArrayAttr());
+      xferOp.permutation_mapAttr(), xferOp.mask(),
+      // in_bounds is explicitly reset
+      /*inBoundsAttr=*/ArrayAttr());
 
   rewriter.eraseOp(copyOp);
   rewriter.eraseOp(xferOp);

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 95bab64e63ed2..1feda57d8de03 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -103,9 +103,9 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
 /// Given the permutation map of the original
 /// `vector.transfer_read`/`vector.transfer_write` operations compute the
 /// permutation map to use after the subview is folded with it.
-static AffineMap getPermutationMap(MLIRContext *context,
-                                   memref::SubViewOp subViewOp,
-                                   AffineMap currPermutationMap) {
+static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
+                                           memref::SubViewOp subViewOp,
+                                           AffineMap currPermutationMap) {
   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
   SmallVector<AffineExpr> exprs;
   int64_t sourceRank = subViewOp.getSourceType().getRank();
@@ -115,7 +115,8 @@ static AffineMap getPermutationMap(MLIRContext *context,
     exprs.push_back(getAffineDimExpr(dim, context));
   }
   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
-  return currPermutationMap.compose(resultDimToSourceDimMap);
+  return AffineMapAttr::get(
+      currPermutationMap.compose(resultDimToSourceDimMap));
 }
 
 //===----------------------------------------------------------------------===//
@@ -163,13 +164,18 @@ void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
 
 template <>
 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
-    vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
+    vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
+  // TODO: support 0-d corner case.
+  if (transferReadOp.getTransferRank() == 0)
+    return;
   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-      loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
-      getPermutationMap(rewriter.getContext(), subViewOp,
-                        loadOp.permutation_map()),
-      loadOp.padding(), loadOp.in_boundsAttr());
+      transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
+      sourceIndices,
+      getPermutationMapAttr(rewriter.getContext(), subViewOp,
+                            transferReadOp.permutation_map()),
+      transferReadOp.padding(),
+      /*mask=*/Value(), transferReadOp.in_boundsAttr());
 }
 
 template <>
@@ -184,11 +190,14 @@ template <>
 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
+  // TODO: support 0-d corner case.
+  if (transferWriteOp.getTransferRank() == 0)
+    return;
   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
       sourceIndices,
-      getPermutationMap(rewriter.getContext(), subViewOp,
-                        transferWriteOp.permutation_map()),
+      getPermutationMapAttr(rewriter.getContext(), subViewOp,
+                            transferWriteOp.permutation_map()),
       transferWriteOp.in_boundsAttr());
 }
 } // namespace

diff  --git a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
index f00a4f808c695..9739c7c792e63 100644
--- a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
@@ -133,6 +133,10 @@ struct CastAwayTransferReadLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (read.getTransferRank() == 0)
+      return failure();
+
     if (read.mask())
       return failure();
 
@@ -153,14 +157,15 @@ struct CastAwayTransferReadLeadingOneDim
         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
                        rewriter.getContext());
 
-    ArrayAttr inBounds;
+    ArrayAttr inBoundsAttr;
     if (read.in_bounds())
-      inBounds = rewriter.getArrayAttr(
+      inBoundsAttr = rewriter.getArrayAttr(
           read.in_boundsAttr().getValue().take_back(newType.getRank()));
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
-        read.getLoc(), newType, read.source(), read.indices(), newMap,
-        read.padding(), inBounds);
+        read.getLoc(), newType, read.source(), read.indices(),
+        AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(),
+        inBoundsAttr);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
 
     return success();
@@ -176,6 +181,10 @@ struct CastAwayTransferWriteLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (write.getTransferRank() == 0)
+      return failure();
+
     if (write.mask())
       return failure();
 
@@ -196,15 +205,16 @@ struct CastAwayTransferWriteLeadingOneDim
         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
                        rewriter.getContext());
 
-    ArrayAttr inBounds;
+    ArrayAttr inBoundsAttr;
     if (write.in_bounds())
-      inBounds = rewriter.getArrayAttr(
+      inBoundsAttr = rewriter.getArrayAttr(
           write.in_boundsAttr().getValue().take_back(newType.getRank()));
 
     auto newVector = rewriter.create<vector::ExtractOp>(
         write.getLoc(), write.vector(), splatZero(dropDim));
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        write, newVector, write.source(), write.indices(), newMap, inBounds);
+        write, newVector, write.source(), write.indices(),
+        AffineMapAttr::get(newMap), inBoundsAttr);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 474b08276933b..859067b2bffe8 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1613,8 +1613,8 @@ static LogicalResult verify(InsertOp op) {
        static_cast<unsigned>(destVectorType.getRank())))
     return op.emitOpError("expected position attribute rank + source rank to "
                           "match dest vector rank");
-  if (!srcVectorType && (positionAttr.size() !=
-                              static_cast<unsigned>(destVectorType.getRank())))
+  if (!srcVectorType &&
+      (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
     return op.emitOpError(
         "expected position attribute rank to match the dest vector rank");
   for (auto en : llvm::enumerate(positionAttr)) {
@@ -2314,6 +2314,59 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
 // TransferReadOp
 //===----------------------------------------------------------------------===//
 
+/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, AffineMapAttr permutationMapAttr,
+                           /*optional*/ ArrayAttr inBoundsAttr) {
+  Type elemType = source.getType().cast<ShapedType>().getElementType();
+  Value padding = builder.create<arith::ConstantOp>(
+      result.location, elemType, builder.getZeroAttr(elemType));
+  build(builder, result, vectorType, source, indices, permutationMapAttr,
+        padding, /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, AffineMap permutationMap,
+                           Optional<ArrayRef<bool>> inBounds) {
+  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+  auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
+                          ? builder.getBoolArrayAttr(inBounds.getValue())
+                          : ArrayAttr();
+  build(builder, result, vectorType, source, indices, permutationMapAttr,
+        inBoundsAttr);
+}
+
+/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, Value padding,
+                           Optional<ArrayRef<bool>> inBounds) {
+  AffineMap permutationMap = getTransferMinorIdentityMap(
+      source.getType().cast<ShapedType>(), vectorType);
+  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+  auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
+                          ? builder.getBoolArrayAttr(inBounds.getValue())
+                          : ArrayAttr();
+  build(builder, result, vectorType, source, indices, permutationMapAttr,
+        padding,
+        /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 4. Builder that sets padding to zero and permutation map to
+/// 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices,
+                           Optional<ArrayRef<bool>> inBounds) {
+  Type elemType = source.getType().cast<ShapedType>().getElementType();
+  Value padding = builder.create<arith::ConstantOp>(
+      result.location, elemType, builder.getZeroAttr(elemType));
+  build(builder, result, vectorType, source, indices, padding, inBounds);
+}
+
 template <typename EmitFun>
 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
                                           EmitFun emitOpError) {
@@ -2347,10 +2400,6 @@ static LogicalResult
 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
                  VectorType vectorType, VectorType maskType,
                  AffineMap permutationMap, ArrayAttr inBounds) {
-  if (shapedType.getRank() == 0 && !op.isZeroD())
-    return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> "
-                           "(0) permutation_map");
-
   if (op->hasAttr("masked")) {
     return op->emitOpError("masked attribute has been removed. "
                            "Use in_bounds instead.");
@@ -2359,6 +2408,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
   if (!shapedType.isa<MemRefType, RankedTensorType>())
     return op->emitOpError(
         "requires source to be a memref or ranked tensor type");
+
   auto elementType = shapedType.getElementType();
   DataLayout dataLayout = DataLayout::closest(op);
   if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
@@ -2389,9 +2439,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
       return op->emitOpError("does not support masks with vector element type");
   } else {
     // Memref or tensor has scalar element type.
+    unsigned minorSize =
+        vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
     unsigned resultVecSize =
-        dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
-        vectorType.getShape().back();
+        dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
     if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
       return op->emitOpError(
           "requires the bitwidth of the minor 1-D vector to be an integral "
@@ -2412,8 +2463,8 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
 
   if (permutationMap.getNumSymbols() != 0)
     return op->emitOpError("requires permutation_map without symbols");
-  // TODO: implement 0-d vector corner cases.
-  if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank())
+
+  if (permutationMap.getNumInputs() != shapedType.getRank())
     return op->emitOpError("requires a permutation_map with input dims of the "
                            "same rank as the source type");
 
@@ -2421,7 +2472,8 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
     if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
       return op->emitOpError("expects the optional in_bounds attr of same rank "
                              "as permutation_map results: ")
-             << AffineMapAttr::get(permutationMap);
+             << AffineMapAttr::get(permutationMap)
+             << " vs inBounds of size: " << inBounds.size();
     for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
       if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
           !inBounds.getValue()[i].cast<BoolAttr>().getValue())
@@ -2431,77 +2483,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
   return success();
 }
 
-/// Builder that sets padding to zero.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value source,
-                           ValueRange indices, AffineMap permutationMap,
-                           ArrayRef<bool> inBounds) {
-  Type elemType = source.getType().cast<ShapedType>().getElementType();
-  Value padding = builder.create<arith::ConstantOp>(
-      result.location, elemType, builder.getZeroAttr(elemType));
-  if (inBounds.empty())
-    return build(builder, result, vectorType, source, indices, permutationMap,
-                 padding, ArrayAttr());
-  ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
-  build(builder, result, vectorType, source, indices, permutationMap, padding,
-        inBoundsArrayAttr);
-}
-
-/// Builder that sets permutation map to 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value source,
-                           ValueRange indices, Value padding,
-                           ArrayRef<bool> inBounds) {
-  auto permMap = getTransferMinorIdentityMap(
-      source.getType().cast<ShapedType>(), vectorType);
-  if (inBounds.empty())
-    return build(builder, result, vectorType, source, indices, permMap, padding,
-                 ArrayAttr());
-  ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
-  build(builder, result, vectorType, source, indices, permMap, padding,
-        inBoundsArrayAttr);
-}
-
-/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
-/// (resp. zero).
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value source,
-                           ValueRange indices, ArrayRef<bool> inBounds) {
-  auto permMap = getTransferMinorIdentityMap(
-      source.getType().cast<ShapedType>(), vectorType);
-  build(builder, result, vectorType, source, indices, permMap, inBounds);
-}
-
-/// Builder that does not provide a mask.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           Type vectorType, Value source, ValueRange indices,
-                           AffineMap permutationMap, Value padding,
-                           ArrayAttr inBounds) {
-  build(builder, result, vectorType, source, indices, permutationMap, padding,
-        /*mask=*/Value(), inBounds);
-}
-
-/// Builder that does not provide a mask.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           Type vectorType, Value source, ValueRange indices,
-                           AffineMapAttr permutationMap, Value padding,
-                           ArrayAttr inBounds) {
-  build(builder, result, vectorType, source, indices, permutationMap, padding,
-        /*mask=*/Value(), inBounds);
-}
-
-Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
-                                     Value source, ValueRange indices,
-                                     ArrayRef<bool> inBounds) {
-  Type elemType = source.getType().cast<ShapedType>().getElementType();
-  auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
-  AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
-                                 getAffineConstantExpr(0, loc.getContext()));
-  Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
-                                                      indices, map, inBounds);
-  return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
-}
-
 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   SmallVector<StringRef, 3> elidedAttrs;
   elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
@@ -2563,6 +2544,7 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
   Attribute mapAttr = result.attributes.get(permutationAttrName);
   if (!mapAttr) {
     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+    // Update `mapAttr` that is used later to determine mask type.
     mapAttr = AffineMapAttr::get(permMap);
     result.attributes.set(permutationAttrName, mapAttr);
   }
@@ -2677,8 +2659,9 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
 
 template <typename TransferOp>
 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
-  // TODO: Be less conservative once we have 0-d vectors.
-  if (op.isZeroD())
+  // TODO: support 0-d corner case.
+  // TODO: Be less conservative.
+  if (op.getTransferRank() == 0)
     return failure();
   AffineMap permutationMap = op.permutation_map();
   bool changed = false;
@@ -2783,6 +2766,9 @@ struct FoldExtractSliceIntoTransferRead
 
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
     if (xferOp.hasOutOfBoundsDim())
       return failure();
     if (!xferOp.permutation_map().isIdentity())
@@ -2814,9 +2800,9 @@ struct FoldExtractSliceIntoTransferRead
                                           offset)));
     }
     SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<TransferReadOp>(xferOp, xferOp.getVectorType(),
-                                                extractOp.source(), newIndices,
-                                                xferOp.padding(), inBounds);
+    rewriter.replaceOpWithNewOp<TransferReadOp>(
+        xferOp, xferOp.getVectorType(), extractOp.source(), newIndices,
+        xferOp.padding(), ArrayRef<bool>{inBounds});
 
     return success();
   }
@@ -2832,69 +2818,49 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
 
+/// 1. Builder with type inference.
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             Value vector, Value dest, ValueRange indices,
-                            AffineMap permutationMap, ArrayRef<bool> inBounds) {
-  if (inBounds.empty())
-    return build(builder, result, vector, dest, indices, permutationMap,
-                 /*mask=*/Value(), ArrayAttr());
-  build(builder, result, vector, dest, indices, permutationMap,
-        /*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
-}
-
-/// Builder that sets permutation map to 'getMinorIdentityMap'.
-void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
-                            Value vector, Value source, ValueRange indices,
-                            ArrayRef<bool> inBounds) {
-  auto vectorType = vector.getType().cast<VectorType>();
-  auto permMap = getTransferMinorIdentityMap(
-      source.getType().cast<ShapedType>(), vectorType);
-  if (inBounds.empty())
-    return build(builder, result, vector, source, indices, permMap,
-                 ArrayAttr());
-  ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
-  build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
+                            AffineMapAttr permutationMapAttr,
+                            /*optional*/ Value mask,
+                            /*optional*/ ArrayAttr inBoundsAttr) {
+  Type resultType = dest.getType().dyn_cast<RankedTensorType>();
+  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
+        mask, inBoundsAttr);
 }
 
+/// 2. Builder with type inference that sets an empty mask (variant with attrs).
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
-                            Value vector, Value source, ValueRange indices,
-                            AffineMapAttr permutationMap,
-                            /*optional*/ ArrayAttr inBounds) {
-  Type resultType = source.getType().dyn_cast<RankedTensorType>();
-  build(builder, result, resultType, vector, source, indices, permutationMap,
-        /*mask=*/Value(), inBounds);
+                            Value vector, Value dest, ValueRange indices,
+                            AffineMapAttr permutationMapAttr,
+                            /*optional*/ ArrayAttr inBoundsAttr) {
+  build(builder, result, vector, dest, indices, permutationMapAttr,
+        /*mask=*/Value(), inBoundsAttr);
 }
 
+/// 3. Builder with type inference that sets an empty mask (variant without
+/// attrs)
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
-                            Value vector, Value source, ValueRange indices,
+                            Value vector, Value dest, ValueRange indices,
                             AffineMap permutationMap,
-                            /*optional*/ ArrayAttr inBounds) {
-  Type resultType = source.getType().dyn_cast<RankedTensorType>();
-  build(builder, result, resultType, vector, source, indices, permutationMap,
-        /*mask=*/Value(), inBounds);
+                            Optional<ArrayRef<bool>> inBounds) {
+  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+  auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
+                          ? builder.getBoolArrayAttr(inBounds.getValue())
+                          : ArrayAttr();
+  build(builder, result, vector, dest, indices, permutationMapAttr,
+        /*mask=*/Value(), inBoundsAttr);
 }
 
+/// 4. Builder with type inference that sets an empty mask and sets permutation
+///    map to 'getMinorIdentityMap'.
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
-                            Value vector, Value source, ValueRange indices,
-                            AffineMap permutationMap, /*optional*/ Value mask,
-                            /*optional*/ ArrayAttr inBounds) {
-  Type resultType = source.getType().dyn_cast<RankedTensorType>();
-  build(builder, result, resultType, vector, source, indices, permutationMap,
-        mask, inBounds);
-}
-
-Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
-                                           Value value, Value dest,
-                                           ValueRange indices,
-                                           ArrayRef<bool> inBounds) {
-  Value vectorOfAScalar = value;
-  if (!value.getType().isa<VectorType>())
-    vectorOfAScalar = builder.create<vector::BroadcastOp>(
-        loc, VectorType::get({1}, value.getType()), value);
-  AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
-                                 getAffineConstantExpr(0, loc.getContext()));
-  return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
-                                                 indices, map, inBounds);
+                            Value vector, Value dest, ValueRange indices,
+                            Optional<ArrayRef<bool>> inBounds) {
+  auto vectorType = vector.getType().cast<VectorType>();
+  AffineMap permutationMap = getTransferMinorIdentityMap(
+      dest.getType().cast<ShapedType>(), vectorType);
+  build(builder, result, vector, dest, indices, permutationMap, inBounds);
 }
 
 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
@@ -3003,6 +2969,9 @@ static LogicalResult verify(TransferWriteOp op) {
 static LogicalResult foldReadInitWrite(TransferWriteOp write,
                                        ArrayRef<Attribute>,
                                        SmallVectorImpl<OpFoldResult> &results) {
+  // TODO: support 0-d corner case.
+  if (write.getTransferRank() == 0)
+    return failure();
   auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
   // If not operating on tensors, bail.
   if (!rankedTensorType)
@@ -3011,6 +2980,9 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
   auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
   if (!read)
     return failure();
+  // TODO: support 0-d corner case.
+  if (read.getTransferRank() == 0)
+    return failure();
   // For now, only accept minor identity. Future: composition is minor identity.
   if (!read.permutation_map().isMinorIdentity() ||
       !write.permutation_map().isMinorIdentity())
@@ -3179,9 +3151,14 @@ struct FoldInsertSliceIntoTransferWrite
                                 PatternRewriter &rewriter) const override {
     if (!insertOp.hasUnitStride())
       return failure();
+
     auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
     if (!xferOp)
       return failure();
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
+
     if (xferOp.hasOutOfBoundsDim())
       return failure();
     if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
@@ -3200,8 +3177,9 @@ struct FoldInsertSliceIntoTransferWrite
     SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
         rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
     SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<TransferWriteOp>(
-        insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds);
+    rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(),
+                                                 insertOp.dest(), indices,
+                                                 ArrayRef<bool>{inBounds});
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
index e8817e71ce2ac..36725e03ae09e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
@@ -31,6 +31,7 @@ transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
         attr.getValue()[pos].cast<BoolAttr>().getValue());
   return builder.getBoolArrayAttr(newInBoundsValues);
 }
+
 /// Lower transfer_read op with permutation into a transfer_read with a
 /// permutation map composed of leading zeros followed by a minor identiy +
 /// vector.transpose op.
@@ -56,6 +57,10 @@ struct TransferReadPermutationLowering
 
   LogicalResult matchAndRewrite(vector::TransferReadOp op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (op.getTransferRank() == 0)
+      return failure();
+
     SmallVector<unsigned> permutation;
     AffineMap map = op.permutation_map();
     if (map.getNumResults() == 0)
@@ -99,7 +104,7 @@ struct TransferReadPermutationLowering
     }
 
     // Transpose in_bounds attribute.
-    ArrayAttr newInBounds =
+    ArrayAttr newInBoundsAttr =
         op.in_bounds() ? transposeInBoundsAttr(
                              rewriter, op.in_bounds().getValue(), permutation)
                        : ArrayAttr();
@@ -108,8 +113,8 @@ struct TransferReadPermutationLowering
     VectorType newReadType =
         VectorType::get(newVectorShape, op.getVectorType().getElementType());
     Value newRead = rewriter.create<vector::TransferReadOp>(
-        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
-        op.padding(), newMask, newInBounds);
+        op.getLoc(), newReadType, op.source(), op.indices(),
+        AffineMapAttr::get(newMap), op.padding(), newMask, newInBoundsAttr);
 
     // Transpose result of transfer_read.
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
@@ -141,7 +146,8 @@ struct TransferWritePermutationLowering
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.isZeroD())
+    // TODO: support 0-d corner case.
+    if (op.getTransferRank() == 0)
       return failure();
 
     SmallVector<unsigned> permutation;
@@ -168,7 +174,7 @@ struct TransferWritePermutationLowering
                               : Value();
 
     // Transpose in_bounds attribute.
-    ArrayAttr newInBounds =
+    ArrayAttr newInBoundsAttr =
         op.in_bounds() ? transposeInBoundsAttr(
                              rewriter, op.in_bounds().getValue(), permutation)
                        : ArrayAttr();
@@ -179,8 +185,8 @@ struct TransferWritePermutationLowering
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
-        newInBounds);
+        op, Type(), newVec, op.source(), op.indices(),
+        AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
 
     return success();
   }
@@ -199,6 +205,10 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (op.getTransferRank() == 0)
+      return failure();
+
     AffineMap map = op.permutation_map();
     unsigned numLeadingBroadcast = 0;
     for (auto expr : map.getResults()) {
@@ -245,14 +255,14 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
       return failure();
     VectorType newReadType =
         VectorType::get(newShape, originalVecType.getElementType());
-    ArrayAttr newInBounds =
+    ArrayAttr newInBoundsAttr =
         op.in_bounds()
             ? rewriter.getArrayAttr(
                   op.in_boundsAttr().getValue().take_back(reducedShapeRank))
             : ArrayAttr();
     Value newRead = rewriter.create<vector::TransferReadOp>(
-        op.getLoc(), newReadType, op.source(), op.indices(), newMap,
-        op.padding(), op.mask(), newInBounds);
+        op.getLoc(), newReadType, op.source(), op.indices(),
+        AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBoundsAttr);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
                                                      newRead);
     return success();

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6bdbeb1a550b5..876f8aeb219cb 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -229,7 +229,9 @@ struct UnrollTransferReadPattern
         options(options) {}
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-
+    // TODO: support 0-d corner case.
+    if (readOp.getTransferRank() == 0)
+      return failure();
     if (readOp.mask())
       return failure();
     auto targetShape = getTargetShape(options, readOp);
@@ -254,9 +256,9 @@ struct UnrollTransferReadPattern
           sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
                                readOp.permutation_map(), loc, rewriter);
       auto slicedRead = rewriter.create<vector::TransferReadOp>(
-          loc, targetType, readOp.source(), indices, readOp.permutation_map(),
-          readOp.padding(),
-          readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr());
+          loc, targetType, readOp.source(), indices,
+          readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(),
+          readOp.in_boundsAttr());
 
       SmallVector<int64_t, 4> elementOffsets =
           getVectorOffset(originalSize, *targetShape, i);
@@ -279,6 +281,10 @@ struct UnrollTransferWritePattern
         options(options) {}
   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (writeOp.getTransferRank() == 0)
+      return failure();
+
     if (writeOp.mask())
       return failure();
     auto targetShape = getTargetShape(options, writeOp);
@@ -305,8 +311,7 @@ struct UnrollTransferWritePattern
                                writeOp.permutation_map(), loc, rewriter);
       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
           loc, slicedVector, resultTensor ? resultTensor : writeOp.source(),
-          indices, writeOp.permutation_map(),
-          writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
+          indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
       // For the tensor case update the destination for the next transfer write.
       if (!slicedWrite->getResults().empty())
         resultTensor = slicedWrite->getResult(0);
@@ -2057,6 +2062,10 @@ static Value createInBoundsCond(OpBuilder &b,
 ///  rank-reducing subviews.
 static LogicalResult
 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
+  // TODO: support 0-d corner case.
+  if (xferOp.getTransferRank() == 0)
+    return failure();
+
   // TODO: expand support to these 2 cases.
   if (!xferOp.permutation_map().isMinorIdentity())
     return failure();
@@ -2682,6 +2691,10 @@ struct TransferReadExtractPattern
       : OpRewritePattern<vector::TransferReadOp>(context) {}
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (read.getTransferRank() == 0)
+      return failure();
+
     if (!read.getResult().hasOneUse())
       return failure();
     auto extract =
@@ -2711,8 +2724,8 @@ struct TransferReadExtractPattern
           {indices[indexPos], extract.ids()[idCount++]});
     }
     Value newRead = lb.create<vector::TransferReadOp>(
-        extract.getType(), read.source(), indices, read.permutation_map(),
-        read.padding(), read.in_boundsAttr());
+        extract.getType(), read.source(), indices, read.permutation_mapAttr(),
+        read.padding(), read.mask(), read.in_boundsAttr());
     Value dest = lb.create<arith::ConstantOp>(
         read.getType(), rewriter.getZeroAttr(read.getType()));
     newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids());
@@ -2727,6 +2740,10 @@ struct TransferWriteInsertPattern
       : OpRewritePattern<vector::TransferWriteOp>(context) {}
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (write.getTransferRank() == 0)
+      return failure();
+
     auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
     if (!insert)
       return failure();
@@ -2754,8 +2771,8 @@ struct TransferWriteInsertPattern
                                   {indices[indexPos], insert.ids()[idCount++]});
     }
     rewriter.create<vector::TransferWriteOp>(
-        loc, insert.vector(), write.source(), indices, write.permutation_map(),
-        write.in_boundsAttr());
+        loc, insert.vector(), write.source(), indices,
+        write.permutation_mapAttr(), write.in_boundsAttr());
     rewriter.eraseOp(write);
     return success();
   }
@@ -2780,15 +2797,19 @@ struct TransferReadToVectorLoadLowering
                                 PatternRewriter &rewriter) const override {
     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
       return failure();
+
     SmallVector<unsigned, 4> broadcastedDims;
     // Permutations are handled by VectorToSCF or
     // populateVectorTransferPermutationMapLoweringPatterns.
+    // We let the 0-d corner case pass-through as it is supported.
     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
             &broadcastedDims))
       return failure();
+
     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
     if (!memRefType)
       return failure();
+
     // Non-unit strides are handled by VectorToSCF.
     if (!vector::isLastMemrefDimUnitStride(memRefType))
       return failure();
@@ -2808,6 +2829,7 @@ struct TransferReadToVectorLoadLowering
     auto memrefElTy = memRefType.getElementType();
     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
       return failure();
+
     // Otherwise, element types of the memref and the vector must match.
     if (!memrefElTy.isa<VectorType>() &&
         memrefElTy != read.getVectorType().getElementType())
@@ -2845,7 +2867,14 @@ struct TransferReadToVectorLoadLowering
   llvm::Optional<unsigned> maxTransferRank;
 };
 
-/// Replace a scalar vector.load with a memref.load.
+/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
+// TODO: we shouldn't cross the vector/scalar domains just for this
+// but atm we lack the infra to avoid it. Possible solutions include:
+// - go directly to LLVM + bitcast
+// - introduce a bitcast op and likely a new pointer dialect
+// - let memref.load/store additionally support the 0-d vector case
+// There are still deeper data layout issues lingering even in this
+// trivial case (for architectures for which this matters).
 struct VectorLoadToMemrefLoadLowering
     : public OpRewritePattern<vector::LoadOp> {
   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
@@ -2857,13 +2886,13 @@ struct VectorLoadToMemrefLoadLowering
       return failure();
     auto memrefLoad = rewriter.create<memref::LoadOp>(
         loadOp.getLoc(), loadOp.base(), loadOp.indices());
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
+                                                     memrefLoad);
     return success();
   }
 };
 
-/// Replace a scalar vector.store with a memref.store.
+/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
 struct VectorStoreToMemrefStoreLowering
     : public OpRewritePattern<vector::StoreOp> {
   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
@@ -2873,9 +2902,17 @@ struct VectorStoreToMemrefStoreLowering
     auto vecType = storeOp.getVectorType();
     if (vecType.getNumElements() != 1)
       return failure();
-    SmallVector<int64_t> indices(vecType.getRank(), 0);
-    Value extracted = rewriter.create<vector::ExtractOp>(
-        storeOp.getLoc(), storeOp.valueToStore(), indices);
+    Value extracted;
+    if (vecType.getRank() == 0) {
+      // TODO: Unifiy once ExtractOp supports 0-d vectors.
+      extracted = rewriter.create<vector::ExtractElementOp>(
+          storeOp.getLoc(), storeOp.valueToStore());
+    } else {
+      SmallVector<int64_t> indices(vecType.getRank(), 0);
+      extracted = rewriter.create<vector::ExtractOp>(
+          storeOp.getLoc(), storeOp.valueToStore(), indices);
+    }
+
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
         storeOp, extracted, storeOp.base(), storeOp.indices());
     return success();
@@ -2901,25 +2938,32 @@ struct TransferWriteToVectorStoreLowering
                                 PatternRewriter &rewriter) const override {
     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
       return failure();
+
     // Permutations are handled by VectorToSCF or
     // populateVectorTransferPermutationMapLoweringPatterns.
-    if (!write.isZeroD() && !write.permutation_map().isMinorIdentity())
+    if ( // pass-through for the 0-d corner case.
+        !write.permutation_map().isMinorIdentity())
       return failure();
+
     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
     if (!memRefType)
       return failure();
+
     // Non-unit strides are handled by VectorToSCF.
     if (!vector::isLastMemrefDimUnitStride(memRefType))
       return failure();
+
     // `vector.store` supports vector types as memref's elements only when the
     // type of the vector value being written is the same as the element type.
     auto memrefElTy = memRefType.getElementType();
     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
       return failure();
+
     // Otherwise, element types of the memref and the vector must match.
     if (!memrefElTy.isa<VectorType>() &&
         memrefElTy != write.getVectorType().getElementType())
       return failure();
+
     // Out-of-bounds dims are handled by MaterializeTransferMask.
     if (write.hasOutOfBoundsDim())
       return failure();
@@ -3319,6 +3363,14 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (readOp.getTransferRank() == 0)
+      return failure();
+
+    // TODO: support mask.
+    if (readOp.mask())
+      return failure();
+
     auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
     if (!srcType || !srcType.hasStaticShape())
       return failure();
@@ -3375,7 +3427,7 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
     SmallVector<int64_t> offsets(srcType.getRank(), 0);
     SmallVector<int64_t> strides(srcType.getRank(), 1);
 
-    ArrayAttr inBounds =
+    ArrayAttr inBoundsAttr =
         readOp.in_bounds()
             ? rewriter.getArrayAttr(
                   readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
@@ -3387,8 +3439,10 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
     Value result = rewriter.create<vector::TransferReadOp>(
         loc, resultTargetVecType, rankedReducedView,
-        readOp.indices().drop_back(dimsToDrop), permMap, readOp.padding(),
-        inBounds);
+        readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+        readOp.padding(),
+        // TODO: support mask.
+        /*mask=*/Value(), inBoundsAttr);
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
                                                      result);
     return success();

diff  --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
index 625ffa9852396..3115a5d983c45 100644
--- a/mlir/lib/Interfaces/VectorInterfaces.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -20,7 +20,7 @@ VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
       shape.push_back(vecType.getDimSize(i));
     }
   }
-  return shape.empty() ? VectorType() : VectorType::get(shape, i1Type);
+  return VectorType::get(shape, i1Type);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 08b3ffbdb688e..7cddb46f094e1 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -2,25 +2,20 @@
 // RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
 
 // CHECK-LABEL: func @vector_transfer_ops_0d(
-//  CHECK-SAME:   %[[MEM:.*]]: memref<f32>) {
 func @vector_transfer_ops_0d(%M: memref<f32>) {
-    %f0 = arith.constant 0.0 : f32
-
-//  CHECK: %[[V0:.*]] = arith.constant dense<0{{.*}}> : vector<1xf32>
-//  CHECK: %[[R0:.*]] = scf.for %[[I:.*]] = {{.*}} iter_args(%[[V0_ITER:.*]] = %[[V0]]) -> (vector<1xf32>) {
-//  CHECK:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK:   %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[I]] : index] : vector<1xf32>
-//  CHECK:   scf.yield %[[R_ITER]] : vector<1xf32>
-    %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
-      memref<f32>, vector<1xf32>
-
-//  CHECK: scf.for %[[J:.*]] = %{{.*}}
-//  CHECK:   %[[SS:.*]] = vector.extractelement %[[R0]][%[[J]] : index] : vector<1xf32>
-//  CHECK:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
-    vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
-      vector<1xf32>, memref<f32>
-
-    return
+  %f0 = arith.constant 0.0 : f32
+
+  // 0-d transfers are left untouched by vector-to-scf.
+  // They are independently lowered to the proper memref.load/store.
+  //  CHECK: vector.transfer_read {{.*}}: memref<f32>, vector<f32>
+  %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->()>} :
+    memref<f32>, vector<f32>
+
+  //  CHECK: vector.transfer_write {{.*}}: vector<f32>, memref<f32>
+  vector.transfer_write %0, %M[] {permutation_map = affine_map<()->()>} :
+    vector<f32>, memref<f32>
+
+  return
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c055ef47a36d0..b7ef524475487 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -200,8 +200,8 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
 // CHECK-LABEL: func @test_vectorize_fill
 func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
   // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
-  //      CHECK:   %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
-  //      CHECK:   vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref<f32>
+  //      CHECK:   %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+  //      CHECK:   vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
   linalg.fill(%arg0, %A) : f32, memref<f32>
   return
 }
@@ -221,10 +221,10 @@ func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
 // CHECK-LABEL: func @test_vectorize_copy_scalar
 func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
   //  CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
-  //       CHECK:   %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<1xf32>
-  //       CHECK:   %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32>
-  //       CHECK:   %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
-  //       CHECK:   vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref<f32>
+  //       CHECK:   %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
+  //       CHECK:   %[[val:.*]] = vector.extractelement %[[V]][] : vector<f32>
+  //       CHECK:   %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+  //       CHECK:   vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
   linalg.copy(%A, %B) :  memref<f32>, memref<f32>
   return
 }
@@ -1005,7 +1005,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
 //  CHECK-LABEL: func @reduce_1d(
 //   CHECK-SAME:   %[[A:.*]]: tensor<32xf32>
 func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
-  //  CHECK-DAG: %[[F0_v1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+  //  CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
   //  CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
   //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %f0 = arith.constant 0.000000e+00 : f32
@@ -1013,17 +1013,18 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
   //      CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
   %0 = linalg.init_tensor [] : tensor<f32>
 
-  //      CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][]
-  // CHECK-SAME:   : vector<1xf32>, tensor<f32>
+  //      CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][]
+  // CHECK-SAME:   : vector<f32>, tensor<f32>
   %1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
   //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
   // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
+  //      CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
   //      CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[r]] [0]
   // CHECK-SAME:   : vector<32xf32> to f32
-  //      CHECK: %[[a:.*]] = arith.addf %[[red]], %[[F0]] : f32
-  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<1xf32>
+  //      CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
+  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>
   //      CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
-  // CHECK-SAME:   : vector<1xf32>, tensor<f32>
+  // CHECK-SAME:   : vector<f32>, tensor<f32>
   %2 = linalg.generic {
          indexing_maps = [affine_map<(d0) -> (d0)>,
                           affine_map<(d0) -> ()>],

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 593686a425a51..c550a0818809f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1427,15 +1427,3 @@ func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
   %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
 }
 
-// -----
-
-func @vector_transfer_ops_0d(%arg0: tensor<f32>)
-  -> tensor<f32> {
-    %f0 = arith.constant 0.0 : f32
-    // expected-error at +1 {{0-d transfer requires vector<1xt> shape and () -> (0) permutation_map}}
-    %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<(d0)->(d0)>} :
-      tensor<f32>, vector<1xf32>
-    %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
-      vector<1xf32>, tensor<f32>
-    return %1: tensor<f32>
-}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 11b986fc9b87c..576924e1addff 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -4,17 +4,33 @@
 func @vector_transfer_ops_0d(%arg0: tensor<f32>, %arg1: memref<f32>)
   -> tensor<f32> {
     %f0 = arith.constant 0.0 : f32
-    %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} :
-      tensor<f32>, vector<1xf32>
-    %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
-      vector<1xf32>, tensor<f32>
-    %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->(0)>} :
-      memref<f32>, vector<1xf32>
-    vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->(0)>} :
-      vector<1xf32>, memref<f32>
+    %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->()>} :
+      tensor<f32>, vector<f32>
+    %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->()>} :
+      vector<f32>, tensor<f32>
+    %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->()>} :
+      memref<f32>, vector<f32>
+    vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->()>} :
+      vector<f32>, memref<f32>
     return %1: tensor<f32>
 }
 
+// CHECK-LABEL: func @vector_transfer_ops_0d_from_higher_d(
+func @vector_transfer_ops_0d_from_higher_d(%arg0: tensor<?xf32>, %arg1: memref<?x?xf32>)
+  -> tensor<?xf32> {
+    %c0 = arith.constant 0 : index
+    %f0 = arith.constant 0.0 : f32
+    %0 = vector.transfer_read %arg0[%c0], %f0 {permutation_map = affine_map<(d0)->()>} :
+      tensor<?xf32>, vector<f32>
+    %1 = vector.transfer_write %0, %arg0[%c0] {permutation_map = affine_map<(d0)->()>} :
+      vector<f32>, tensor<?xf32>
+    %2 = vector.transfer_read %arg1[%c0, %c0], %f0 {permutation_map = affine_map<(d0, d1)->()>} :
+      memref<?x?xf32>, vector<f32>
+    vector.transfer_write %2, %arg1[%c0, %c0] {permutation_map = affine_map<(d0, d1)->()>} :
+      vector<f32>, memref<?x?xf32>
+    return %1: tensor<?xf32>
+}
+
 // CHECK-LABEL: func @vector_transfer_ops(
 func @vector_transfer_ops(%arg0: memref<?x?xf32>,
                           %arg1 : memref<?x?xvector<4x3xf32>>,

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index a5c0cb584b11b..562870c4d9fe6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -6,13 +6,13 @@
 func @vector_transfer_ops_0d_memref(%M: memref<f32>, %v: vector<1x1x1xf32>) {
     %f0 = arith.constant 0.0 : f32
 
-//  CHECK-NEXT:   %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
-    %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
-      memref<f32>, vector<1xf32>
+//  CHECK-NEXT:   %[[s:.*]] = memref.load %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[s]] : f32 to vector<f32>
+    %0 = vector.transfer_read %M[], %f0 : memref<f32>, vector<f32>
 
-//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
-    vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
-      vector<1xf32>, memref<f32>
+//  CHECK-NEXT:   %[[ss:.*]] = vector.extractelement %[[V]][] : vector<f32>
+//  CHECK-NEXT:   memref.store %[[ss]], %[[MEM]][] : memref<f32>
+    vector.transfer_write %0, %M[] : vector<f32>, memref<f32>
 
 //  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : vector<1x1x1xf32>
 //  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>


        


More information about the Mlir-commits mailing list