[Mlir-commits] [mlir] d1a9e9a - [mlir][vector] Remove vector.transfer_read/write to LLVM lowering

Matthias Springer llvmlistbot at llvm.org
Fri Jul 16 22:07:44 PDT 2021


Author: Matthias Springer
Date: 2021-07-17T14:07:27+09:00
New Revision: d1a9e9a7cbad4044ccc8e08d0217c23aca417714

URL: https://github.com/llvm/llvm-project/commit/d1a9e9a7cbad4044ccc8e08d0217c23aca417714
DIFF: https://github.com/llvm/llvm-project/commit/d1a9e9a7cbad4044ccc8e08d0217c23aca417714.diff

LOG: [mlir][vector] Remove vector.transfer_read/write to LLVM lowering

This simplifies the vector to LLVM lowering. Previously, both vector.load/store and vector.transfer_read/write lowered directly to LLVM. With this commit, there is a single path to LLVM vector load/store instructions and vector.transfer_read/write ops must first be lowered to vector.load/store ops.

* Remove vector.transfer_read/write to LLVM lowering.
* Allow non-unit memref strides on all but the most minor dimension for vector.load/store ops.
* Add maxTransferRank option to populateVectorTransferLoweringPatterns.
* vector.transfer_reads with changing element type can no longer be lowered to LLVM. (This functionality is needed only for SPIRV.)

Differential Revision: https://reviews.llvm.org/D106118

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index b7d2d0c0eaec6..cf53e8fcff97c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -62,9 +62,12 @@ void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
 /// Collect a set of transfer read/write lowering patterns.
 ///
 /// These patterns lower transfer ops to simpler ops like `vector.load`,
-/// `vector.store` and `vector.broadcast`. Includes all patterns of
-/// populateVectorTransferPermutationMapLoweringPatterns.
-void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
+/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
+/// of a most `maxTransferRank` are lowered. This is useful when combined with
+/// VectorToSCF, which reduces the rank of vector transfer ops.
+void populateVectorTransferLoweringPatterns(
+    RewritePatternSet &patterns,
+    llvm::Optional<unsigned> maxTransferRank = llvm::None);
 
 /// Collect a set of transfer read/write lowering patterns that simplify the
 /// permutation map (e.g., converting it to a minor identity map) by inserting
@@ -185,6 +188,10 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
 Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
                            Value vector);
 
+/// Return true if the last dimension of the MemRefType has unit stride. Also
+/// return true for memrefs with no strides.
+bool isLastMemrefDimUnitStride(MemRefType type);
+
 namespace impl {
 /// Build the default minor identity map suitable for a vector transfer. This
 /// also handles the case memref<... x vector<...>> -> vector<...> in which the

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 9fc20647c81d5..911a9c60c1451 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1409,9 +1409,9 @@ def Vector_LoadOp : Vector_Op<"load"> {
     based on the element type of the memref. The shape of the result vector
     type determines the shape of the slice read from the start memory address.
     The elements along each dimension of the slice are strided by the memref
-    strides. Only memref with default strides are allowed. These constraints
-    guarantee that elements read along the first dimension of the slice are
-    contiguous in memory.
+    strides. Only unit strides are allowed along the most minor memref
+    dimension. These constraints guarantee that elements read along the first
+    dimension of the slice are contiguous in memory.
 
     The memref element type can be a scalar or a vector type. If the memref
     element type is a scalar, it should match the element type of the result
@@ -1470,6 +1470,8 @@ def Vector_LoadOp : Vector_Op<"load"> {
     }
   }];
 
+  let hasFolder = 1;
+
   let assemblyFormat =
       "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
 }
@@ -1484,9 +1486,9 @@ def Vector_StoreOp : Vector_Op<"store"> {
     memref dimension based on the element type of the memref. The shape of the
     vector value to store determines the shape of the slice written from the
     start memory address. The elements along each dimension of the slice are
-    strided by the memref strides. Only memref with default strides are allowed.
-    These constraints guarantee that elements written along the first dimension
-    of the slice are contiguous in memory.
+    strided by the memref strides. Only unit strides are allowed along the most
+    minor memref dimension. These constraints guarantee that elements written
+    along the first dimension of the slice are contiguous in memory.
 
     The memref element type can be a scalar or a vector type. If the memref
     element type is a scalar, it should match the element type of the value
@@ -1544,6 +1546,8 @@ def Vector_StoreOp : Vector_Op<"store"> {
     }
   }];
 
+  let hasFolder = 1;
+
   let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
                        "`:` type($base) `,` type($valueToStore)";
 }
@@ -1601,6 +1605,7 @@ def Vector_MaskedLoadOp :
   let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
     "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def Vector_MaskedStoreOp :
@@ -1653,6 +1658,7 @@ def Vector_MaskedStoreOp :
       "$base `[` $indices `]` `,` $mask `,` $valueToStore "
       "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def Vector_GatherOp :

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b1256cb4f6133..53ce5ca3d452e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -130,18 +130,6 @@ static unsigned getAssumedAlignment(Value value) {
   }
   return align;
 }
-// Helper that returns data layout alignment of a memref associated with a
-// transfer op, including additional information from assume_alignment calls
-// on the source of the transfer
-LogicalResult getTransferOpAlignment(LLVMTypeConverter &typeConverter,
-                                     VectorTransferOpInterface xfer,
-                                     unsigned &align) {
-  if (failed(getMemRefAlignment(
-          typeConverter, xfer.getShapedType().cast<MemRefType>(), align)))
-    return failure();
-  align = std::max(align, getAssumedAlignment(xfer.source()));
-  return success();
-}
 
 // Helper that returns data layout alignment of a memref associated with a
 // load, store, scatter, or gather op, including additional information from
@@ -181,79 +169,6 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
 }
 
-static LogicalResult
-replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
-                                 LLVMTypeConverter &typeConverter, Location loc,
-                                 TransferReadOp xferOp,
-                                 ArrayRef<Value> operands, Value dataPtr) {
-  unsigned align;
-  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
-    return failure();
-  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
-  return success();
-}
-
-static LogicalResult
-replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
-                            LLVMTypeConverter &typeConverter, Location loc,
-                            TransferReadOp xferOp, ArrayRef<Value> operands,
-                            Value dataPtr, Value mask) {
-  Type vecTy = typeConverter.convertType(xferOp.getVectorType());
-  if (!vecTy)
-    return failure();
-
-  auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
-  Value fill = rewriter.create<SplatOp>(loc, vecTy, adaptor.padding());
-
-  unsigned align;
-  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
-    return failure();
-  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
-      xferOp, vecTy, dataPtr, mask, ValueRange{fill},
-      rewriter.getI32IntegerAttr(align));
-  return success();
-}
-
-static LogicalResult
-replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
-                                 LLVMTypeConverter &typeConverter, Location loc,
-                                 TransferWriteOp xferOp,
-                                 ArrayRef<Value> operands, Value dataPtr) {
-  unsigned align;
-  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
-    return failure();
-  auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
-  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
-                                             align);
-  return success();
-}
-
-static LogicalResult
-replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
-                            LLVMTypeConverter &typeConverter, Location loc,
-                            TransferWriteOp xferOp, ArrayRef<Value> operands,
-                            Value dataPtr, Value mask) {
-  unsigned align;
-  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
-    return failure();
-
-  auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
-  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
-      xferOp, adaptor.vector(), dataPtr, mask,
-      rewriter.getI32IntegerAttr(align));
-  return success();
-}
-
-static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
-                                                  ArrayRef<Value> operands) {
-  return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
-}
-
-static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
-                                                   ArrayRef<Value> operands) {
-  return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
-}
-
 namespace {
 
 /// Conversion pattern for a vector.bitcast.
@@ -1026,15 +941,6 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
   }
 };
 
-/// Return true if the last dimension of the MemRefType has unit stride. Also
-/// return true for memrefs with no strides.
-static bool isLastMemrefDimUnitStride(MemRefType type) {
-  int64_t offset;
-  SmallVector<int64_t> strides;
-  auto successStrides = getStridesAndOffset(type, strides, offset);
-  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
-}
-
 /// Returns the strides if the memory underlying `memRefType` has a contiguous
 /// static layout.
 static llvm::Optional<SmallVector<int64_t, 4>>
@@ -1145,83 +1051,6 @@ class VectorTypeCastOpConversion
   }
 };
 
-/// Conversion pattern that converts a 1-D vector transfer read/write op into a
-/// a masked or unmasked read/write.
-template <typename ConcreteOp>
-class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
-public:
-  using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto adaptor = getTransferOpAdapter(xferOp, operands);
-
-    if (xferOp.getVectorType().getRank() > 1 || xferOp.indices().empty())
-      return failure();
-    if (xferOp.permutation_map() !=
-        AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
-                                       xferOp.getVectorType().getRank(),
-                                       xferOp->getContext()))
-      return failure();
-    auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
-    if (!memRefType)
-      return failure();
-    // Last dimension must be contiguous. (Otherwise: Use VectorToSCF.)
-    if (!isLastMemrefDimUnitStride(memRefType))
-      return failure();
-    // Out-of-bounds dims are handled by MaterializeTransferMask.
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-
-    auto toLLVMTy = [&](Type t) {
-      return this->getTypeConverter()->convertType(t);
-    };
-
-    Location loc = xferOp->getLoc();
-
-    if (auto memrefVectorElementType =
-            memRefType.getElementType().template dyn_cast<VectorType>()) {
-      // Memref has vector element type.
-      if (memrefVectorElementType.getElementType() !=
-          xferOp.getVectorType().getElementType())
-        return failure();
-#ifndef NDEBUG
-      // Check that memref vector type is a suffix of 'vectorType.
-      unsigned memrefVecEltRank = memrefVectorElementType.getRank();
-      unsigned resultVecRank = xferOp.getVectorType().getRank();
-      assert(memrefVecEltRank <= resultVecRank);
-      // TODO: Move this to isSuffix in Vector/Utils.h.
-      unsigned rankOffset = resultVecRank - memrefVecEltRank;
-      auto memrefVecEltShape = memrefVectorElementType.getShape();
-      auto resultVecShape = xferOp.getVectorType().getShape();
-      for (unsigned i = 0; i < memrefVecEltRank; ++i)
-        assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
-               "memref vector element shape should match suffix of vector "
-               "result shape.");
-#endif // ifndef NDEBUG
-    }
-
-    // Get the source/dst address as an LLVM vector pointer.
-    VectorType vtp = xferOp.getVectorType();
-    Value dataPtr = this->getStridedElementPtr(
-        loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
-    Value vectorDataPtr =
-        castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
-
-    // Rewrite as an unmasked masked read / write.
-    if (!xferOp.mask())
-      return replaceTransferOpWithLoadOrStore(rewriter,
-                                              *this->getTypeConverter(), loc,
-                                              xferOp, operands, vectorDataPtr);
-
-    // Rewrite as a masked read / write.
-    return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
-                                       xferOp, operands, vectorDataPtr,
-                                       xferOp.mask());
-  }
-};
-
 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 public:
   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
@@ -1450,9 +1279,10 @@ void mlir::populateVectorToLLVMConversionPatterns(
            VectorLoadStoreConversion<vector::MaskedStoreOp,
                                      vector::MaskedStoreOpAdaptor>,
            VectorGatherOpConversion, VectorScatterOpConversion,
-           VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-           VectorTransferConversion<TransferReadOp>,
-           VectorTransferConversion<TransferWriteOp>>(converter);
+           VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
+          converter);
+  // Transfer ops with rank > 1 are handled by VectorToSCF.
+  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 34359ef863e70..1a708dc4da6cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -64,6 +64,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateVectorToVectorCanonicalizationPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns);
+    // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+    populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
@@ -71,6 +73,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext());
   RewritePatternSet patterns(&getContext());
   populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
+  populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(converter, patterns,
                                          reassociateFPReductions);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 93a8b475a94ca..cd4d525d6a905 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -89,7 +89,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
         .add<ContractionOpToOuterProductOpLowering,
              ContractionOpToMatmulOpLowering, ContractionOpLowering>(
             vectorTransformsOptions, context);
-    vector::populateVectorTransferLoweringPatterns(
+    vector::populateVectorTransferPermutationMapLoweringPatterns(
         vectorContractLoweringPatterns);
     (void)applyPatternsAndFoldGreedily(
         func, std::move(vectorContractLoweringPatterns));

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9fbc6c3711d8b..045fbab987b69 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -102,6 +102,15 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
   return false;
 }
 
+/// Return true if the last dimension of the MemRefType has unit stride. Also
+/// return true for memrefs with no strides.
+bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  auto successStrides = getStridesAndOffset(type, strides, offset);
+  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
+}
+
 //===----------------------------------------------------------------------===//
 // CombiningKindAttr
 //===----------------------------------------------------------------------===//
@@ -2953,9 +2962,8 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
                                                  MemRefType memRefTy) {
-  auto affineMaps = memRefTy.getAffineMaps();
-  if (!affineMaps.empty())
-    return op->emitOpError("base memref should have a default identity layout");
+  if (!isLastMemrefDimUnitStride(memRefTy))
+    return op->emitOpError("most minor memref dim must have unit stride");
   return success();
 }
 
@@ -2981,6 +2989,12 @@ static LogicalResult verify(vector::LoadOp op) {
   return success();
 }
 
+OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -3008,6 +3022,11 @@ static LogicalResult verify(vector::StoreOp op) {
   return success();
 }
 
+LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
+                            SmallVectorImpl<OpFoldResult> &results) {
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedLoadOp
 //===----------------------------------------------------------------------===//
@@ -3056,6 +3075,12 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<MaskedLoadFolder>(context);
 }
 
+OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedStoreOp
 //===----------------------------------------------------------------------===//
@@ -3101,6 +3126,11 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<MaskedStoreFolder>(context);
 }
 
+LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
+                                  SmallVectorImpl<OpFoldResult> &results) {
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // GatherOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 4bd1ee15dece7..2a99eb6e7063b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2464,26 +2464,34 @@ struct TransferWriteInsertPattern
 /// Progressive lowering of transfer_read. This pattern supports lowering of
 /// `vector.transfer_read` to a combination of `vector.load` and
 /// `vector.broadcast` if all of the following hold:
-/// - The op reads from a memref with the default layout.
+/// - Stride of most minor memref dimension must be 1.
 /// - Out-of-bounds masking is not required.
 /// - If the memref's element type is a vector type then it coincides with the
 ///   result type.
 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// - The op has no mask.
 struct TransferReadToVectorLoadLowering
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+  TransferReadToVectorLoadLowering(MLIRContext *context,
+                                   llvm::Optional<unsigned> maxRank)
+      : OpRewritePattern<vector::TransferReadOp>(context),
+        maxTransferRank(maxRank) {}
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
+      return failure();
     SmallVector<unsigned, 4> broadcastedDims;
-    // TODO: Support permutations.
+    // Permutations are handled by VectorToSCF or
+    // populateVectorTransferPermutationMapLoweringPatterns.
     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
             &broadcastedDims))
       return failure();
     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
     if (!memRefType)
       return failure();
+    // Non-unit strides are handled by VectorToSCF.
+    if (!vector::isLastMemrefDimUnitStride(memRefType))
+      return failure();
 
     // If there is broadcasting involved then we first load the unbroadcasted
     // vector, and then broadcast it with `vector.broadcast`.
@@ -2497,32 +2505,44 @@ struct TransferReadToVectorLoadLowering
 
     // `vector.load` supports vector types as memref's elements only when the
     // resulting vector type is the same as the element type.
-    if (memRefType.getElementType().isa<VectorType>() &&
-        memRefType.getElementType() != unbroadcastedVectorType)
+    auto memrefElTy = memRefType.getElementType();
+    if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
       return failure();
-    // Only the default layout is supported by `vector.load`.
-    // TODO: Support non-default layouts.
-    if (!memRefType.getAffineMaps().empty())
+    // Otherwise, element types of the memref and the vector must match.
+    if (!memrefElTy.isa<VectorType>() &&
+        memrefElTy != read.getVectorType().getElementType())
       return failure();
-    // TODO: When out-of-bounds masking is required, we can create a
-    //       MaskedLoadOp.
+
+    // Out-of-bounds dims are handled by MaterializeTransferMask.
     if (read.hasOutOfBoundsDim())
       return failure();
-    if (read.mask())
-      return failure();
 
-    auto loadOp = rewriter.create<vector::LoadOp>(
-        read.getLoc(), unbroadcastedVectorType, read.source(), read.indices());
+    // Create vector load op.
+    Operation *loadOp;
+    if (read.mask()) {
+      Value fill = rewriter.create<SplatOp>(
+          read.getLoc(), unbroadcastedVectorType, read.padding());
+      loadOp = rewriter.create<vector::MaskedLoadOp>(
+          read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
+          read.mask(), fill);
+    } else {
+      loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
+                                               unbroadcastedVectorType,
+                                               read.source(), read.indices());
+    }
+
     // Insert a broadcasting op if required.
     if (!broadcastedDims.empty()) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          read, read.getVectorType(), loadOp.result());
+          read, read.getVectorType(), loadOp->getResult(0));
     } else {
-      rewriter.replaceOp(read, loadOp.result());
+      rewriter.replaceOp(read, loadOp->getResult(0));
     }
 
     return success();
   }
+
+  llvm::Optional<unsigned> maxTransferRank;
 };
 
 /// Replace a scalar vector.load with a memref.load.
@@ -2545,44 +2565,56 @@ struct VectorLoadToMemrefLoadLowering
 
 /// Progressive lowering of transfer_write. This pattern supports lowering of
 /// `vector.transfer_write` to `vector.store` if all of the following hold:
-/// - The op writes to a memref with the default layout.
+/// - Stride of most minor memref dimension must be 1.
 /// - Out-of-bounds masking is not required.
 /// - If the memref's element type is a vector type then it coincides with the
 ///   type of the written value.
 /// - The permutation map is the minor identity map (neither permutation nor
 ///   broadcasting is allowed).
-/// - The op has no mask.
 struct TransferWriteToVectorStoreLowering
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+  TransferWriteToVectorStoreLowering(MLIRContext *context,
+                                     llvm::Optional<unsigned> maxRank)
+      : OpRewritePattern<vector::TransferWriteOp>(context),
+        maxTransferRank(maxRank) {}
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
-    // TODO: Support non-minor-identity maps
+    if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
+      return failure();
+    // Permutations are handled by VectorToSCF or
+    // populateVectorTransferPermutationMapLoweringPatterns.
     if (!write.permutation_map().isMinorIdentity())
       return failure();
     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
     if (!memRefType)
       return failure();
+    // Non-unit strides are handled by VectorToSCF.
+    if (!vector::isLastMemrefDimUnitStride(memRefType))
+      return failure();
     // `vector.store` supports vector types as memref's elements only when the
     // type of the vector value being written is the same as the element type.
-    if (memRefType.getElementType().isa<VectorType>() &&
-        memRefType.getElementType() != write.getVectorType())
+    auto memrefElTy = memRefType.getElementType();
+    if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
       return failure();
-    // Only the default layout is supported by `vector.store`.
-    // TODO: Support non-default layouts.
-    if (!memRefType.getAffineMaps().empty())
+    // Otherwise, element types of the memref and the vector must match.
+    if (!memrefElTy.isa<VectorType>() &&
+        memrefElTy != write.getVectorType().getElementType())
       return failure();
-    // TODO: When out-of-bounds masking is required, we can create a
-    //       MaskedStoreOp.
+    // Out-of-bounds dims are handled by MaterializeTransferMask.
     if (write.hasOutOfBoundsDim())
       return failure();
-    if (write.mask())
-      return failure();
-    rewriter.replaceOpWithNewOp<vector::StoreOp>(
-        write, write.vector(), write.source(), write.indices());
+    if (write.mask()) {
+      rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+          write, write.source(), write.indices(), write.mask(), write.vector());
+    } else {
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          write, write.vector(), write.source(), write.indices());
+    }
     return success();
   }
+
+  llvm::Optional<unsigned> maxTransferRank;
 };
 
 /// Transpose a vector transfer op's `in_bounds` attribute according to given
@@ -2624,6 +2656,8 @@ struct TransferReadPermutationLowering
                                 PatternRewriter &rewriter) const override {
     SmallVector<unsigned> permutation;
     AffineMap map = op.permutation_map();
+    if (map.getNumResults() == 0)
+      return failure();
     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
       return failure();
     AffineMap permutationMap =
@@ -3680,11 +3714,11 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
 }
 
 void mlir::vector::populateVectorTransferLoweringPatterns(
-    RewritePatternSet &patterns) {
-  patterns
-      .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
-           VectorLoadToMemrefLoadLowering>(patterns.getContext());
-  populateVectorTransferPermutationMapLoweringPatterns(patterns);
+    RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
+  patterns.add<TransferReadToVectorLoadLowering,
+               TransferWriteToVectorStoreLowering>(patterns.getContext(),
+                                                   maxTransferRank);
+  patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
 }
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 329b79a195082..afb007c9e6b3e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1212,18 +1212,19 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //       CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
 //       CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
 //
-// 4. Bitcast to vector form.
+// 4. Create pass-through vector.
+//       CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
+//
+// 5. Bitcast to vector form.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
 //  CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
 //       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
 //  CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
 //
-// 5. Rewrite as a masked read.
-//       CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
+// 6. Rewrite as a masked read.
 //       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
 //  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
 //  CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
-
 //
 // 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
 //       CHECK: %[[linearIndex_b:.*]] = constant dense
@@ -1264,8 +1265,9 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
 }
 // CHECK-LABEL: func @transfer_read_index_1d
 //  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
-//       CHECK: %[[C7:.*]] = constant 7
-//       CHECK: %{{.*}} = unrealized_conversion_cast %[[C7]] : index to i64
+//       CHECK: %[[C7:.*]] = constant 7 : index
+//       CHECK: %[[SPLAT:.*]] = splat %[[C7]] : vector<17xindex>
+//       CHECK: %{{.*}} = unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
 
 //       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
 //  CHECK-SAME: (!llvm.ptr<vector<17xi64>>, vector<17xi1>, vector<17xi64>) -> vector<17xi64>
@@ -1384,26 +1386,6 @@ func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32>
 
 // -----
 
-func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
-  %c0 = constant 0: i32
-  %v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} :
-    memref<?xi32>, vector<12xi8>
-  return %v: vector<12xi8>
-}
-// CHECK-LABEL: func @transfer_read_1d_cast
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<12xi8>
-//
-// 1. Bitcast to vector form.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
-//       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
-//  CHECK-SAME: !llvm.ptr<i32> to !llvm.ptr<vector<12xi8>>
-//
-// 2. Rewrite as a load.
-//       CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vector<12xi8>>
-
-// -----
-
 func @genbool_1d() -> vector<8xi1> {
   %0 = vector.constant_mask [4] : vector<8xi1>
   return %0 : vector<8xi1>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e8ad86f5fffb0..7fb4ecb5b0d3a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1094,11 +1094,12 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
 
 // -----
 
-func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
+func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>,
                                %i : index, %j : index, %value : vector<8xf32>) {
-  // expected-error at +1 {{'vector.store' op base memref should have a default identity layout}}
-  vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
+  // expected-error at +1 {{'vector.store' op most minor memref dim must have unit stride}}
+  vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>,
                                          vector<8xf32>
+  return
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 931c3ba91774a..910100d61af4f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -114,14 +114,11 @@ func @transfer_not_inbounds(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
 
 // -----
 
-// TODO: transfer_read/write cannot be lowered to vector.load/store because the
-// memref has a non-default layout.
 // CHECK-LABEL:   func @transfer_nondefault_layout(
 // CHECK-SAME:                                          %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
 // CHECK-SAME:                                          %[[IDX:.*]]: index) -> vector<4xf32> {
-// CHECK-NEXT:      %[[CF0:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT:      %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {in_bounds = [true]} : memref<8x8xf32, #{{.*}}>, vector<4xf32>
-// CHECK-NEXT:      vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32, #{{.*}}>
+// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32>
+// CHECK-NEXT:      vector.store %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>,  vector<4xf32>
 // CHECK-NEXT:      return %[[RES]] : vector<4xf32>
 // CHECK-NEXT:    }
 

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index fbc80116b9554..11b56a583cc83 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -436,6 +436,7 @@ struct TestVectorTransferLoweringPatterns
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
     populateVectorTransferLoweringPatterns(patterns);
+    populateVectorTransferPermutationMapLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list