[Mlir-commits] [mlir] ee66e43 - [mlir][Vector] Introduce 'vector.load' and 'vector.store' ops

Diego Caballero llvmlistbot at llvm.org
Fri Feb 12 10:53:28 PST 2021


Author: Diego Caballero
Date: 2021-02-12T20:48:37+02:00
New Revision: ee66e43a96e138cc0ed5c37897576d05fa897c27

URL: https://github.com/llvm/llvm-project/commit/ee66e43a96e138cc0ed5c37897576d05fa897c27
DIFF: https://github.com/llvm/llvm-project/commit/ee66e43a96e138cc0ed5c37897576d05fa897c27.diff

LOG: [mlir][Vector] Introduce 'vector.load' and 'vector.store' ops

This patch adds the 'vector.load' and 'vector.store' ops to the Vector
dialect [1]. These operations model *contiguous* vector loads and stores
from/to memory. Their semantics are similar to the 'affine.vector_load' and
'affine.vector_store' counterparts but without the affine constraints. The
most relevant feature is that these new vector operations may perform a vector
load/store on memrefs with a non-vector element type, unlike 'std.load' and
'std.store' ops. This opens the representation to model more generic vector
load/store scenarios: unaligned vector loads/stores, perform scalar and vector
memory access on the same memref, decouple memory allocation constraints from
memory accesses, etc [1]. These operations will also facilitate the progressive
lowering of both Affine vector loads/stores and Vector transfer reads/writes
for those that read/write contiguous slices from/to memory.

In particular, this patch adds the 'vector.load' and 'vector.store' ops to the
Vector dialect, implements their lowering to the LLVM dialect, and changes the
lowering of 'affine.vector_load' and 'affine.vector_store' ops to the new vector
ops. The lowering of Vector transfer reads/writes will be implemented in the
future, probably as an independent pass. The API of 'vector.maskedload' and
'vector.maskedstore' has also been changed slightly to align it with the
transfer read/write ops and the vector new ops. This will improve reusability
among all these operations. For example, the lowering of 'vector.load',
'vector.store', 'vector.maskedload' and 'vector.maskedstore' to the LLVM dialect
is implemented with a single template conversion pattern.

[1] https://llvm.discourse.group/t/memref-type-and-data-layout/

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96185

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 1aeb92a2faf8..557446bddb00 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1320,6 +1320,156 @@ def Vector_TransferWriteOp :
   let hasFolder = 1;
 }
 
+def Vector_LoadOp : Vector_Op<"load"> {
+  let summary = "reads an n-D slice of memory into an n-D vector";
+  let description = [{
+    The 'vector.load' operation reads an n-D slice of memory into an n-D
+    vector. It takes a 'base' memref, an index for each memref dimension and a
+    result vector type as arguments. It returns a value of the result vector
+    type. The 'base' memref and indices determine the start memory address from
+    which to read. Each index provides an offset for each memref dimension
+    based on the element type of the memref. The shape of the result vector
+    type determines the shape of the slice read from the start memory address.
+    The elements along each dimension of the slice are strided by the memref
+    strides. Only memref with default strides are allowed. These constraints
+    guarantee that elements read along the first dimension of the slice are
+    contiguous in memory.
+
+    The memref element type can be a scalar or a vector type. If the memref
+    element type is a scalar, it should match the element type of the result
+    vector. If the memref element type is vector, it should match the result
+    vector type.
+
+    Example 1: 1-D vector load on a scalar memref.
+    ```mlir
+    %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
+    ```
+
+    Example 2: 1-D vector load on a vector memref.
+    ```mlir
+    %result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+    ```
+
+    Example 3:  2-D vector load on a scalar memref.
+    ```mlir
+    %result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+    ```
+
+    Example 4:  2-D vector load on a vector memref.
+    ```mlir
+    %result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+    ```
+
+    Representation-wise, the 'vector.load' operation permits out-of-bounds
+    reads. Support and implementation of out-of-bounds vector loads is
+    target-specific. No assumptions should be made on the value of elements
+    loaded out of bounds. Not all targets may support out-of-bounds vector
+    loads.
+
+    Example 5:  Potential out-of-bound vector load.
+    ```mlir
+    %result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
+    ```
+
+    Example 6:  Explicit out-of-bound vector load.
+    ```mlir
+    %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
+    ```
+  }];
+
+  let arguments = (ins Arg<AnyMemRef, "the reference to load from",
+      [MemRead]>:$base,
+      Variadic<Index>:$indices);
+  let results = (outs AnyVector:$result);
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+
+    VectorType getVectorType() {
+      return result().getType().cast<VectorType>();
+    }
+  }];
+
+  let assemblyFormat =
+      "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+}
+
+def Vector_StoreOp : Vector_Op<"store"> {
+  let summary = "writes an n-D vector to an n-D slice of memory";
+  let description = [{
+    The 'vector.store' operation writes an n-D vector to an n-D slice of memory.
+    It takes the vector value to be stored, a 'base' memref and an index for
+    each memref dimension. The 'base' memref and indices determine the start
+    memory address from which to write. Each index provides an offset for each
+    memref dimension based on the element type of the memref. The shape of the
+    vector value to store determines the shape of the slice written from the
+    start memory address. The elements along each dimension of the slice are
+    strided by the memref strides. Only memref with default strides are allowed.
+    These constraints guarantee that elements written along the first dimension
+    of the slice are contiguous in memory.
+
+    The memref element type can be a scalar or a vector type. If the memref
+    element type is a scalar, it should match the element type of the value
+    to store. If the memref element type is vector, it should match the type
+    of the value to store.
+
+    Example 1: 1-D vector store on a scalar memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+    ```
+
+    Example 2: 1-D vector store on a vector memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+    ```
+
+    Example 3:  2-D vector store on a scalar memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+    ```
+
+    Example 4:  2-D vector store on a vector memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+    ```
+
+    Representation-wise, the 'vector.store' operation permits out-of-bounds
+    writes. Support and implementation of out-of-bounds vector stores are
+    target-specific. No assumptions should be made on the memory written out of
+    bounds. Not all targets may support out-of-bounds vector stores.
+
+    Example 5:  Potential out-of-bounds vector store.
+    ```mlir
+    vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
+    ```
+
+    Example 6:  Explicit out-of-bounds vector store.
+    ```mlir
+    vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
+    ```
+  }];
+
+  let arguments = (ins AnyVector:$valueToStore,
+      Arg<AnyMemRef, "the reference to store to",
+      [MemWrite]>:$base,
+      Variadic<Index>:$indices);
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+
+    VectorType getVectorType() {
+      return valueToStore().getType().cast<VectorType>();
+    }
+  }];
+
+  let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
+                       "`:` type($base) `,` type($valueToStore)";
+}
+
 def Vector_MaskedLoadOp :
   Vector_Op<"maskedload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
@@ -1363,7 +1513,7 @@ def Vector_MaskedLoadOp :
     VectorType getPassThruVectorType() {
       return pass_thru().getType().cast<VectorType>();
     }
-    VectorType getResultVectorType() {
+    VectorType getVectorType() {
       return result().getType().cast<VectorType>();
     }
   }];
@@ -1377,7 +1527,7 @@ def Vector_MaskedStoreOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$value)> {
+               VectorOfRank<[1]>:$valueToStore)> {
 
   let summary = "stores elements from a vector into memory as defined by a mask vector";
 
@@ -1411,12 +1561,13 @@ def Vector_MaskedStoreOp :
     VectorType getMaskVectorType() {
       return mask().getType().cast<VectorType>();
     }
-    VectorType getValueVectorType() {
-      return value().getType().cast<VectorType>();
+    VectorType getVectorType() {
+      return valueToStore().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
-    "type($base) `,` type($mask) `,` type($value)";
+  let assemblyFormat =
+      "$base `[` $indices `]` `,` $mask `,` $valueToStore "
+      "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 5a2fe919382c..ef04f688ba39 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -578,8 +578,9 @@ class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
     if (!resultOperands)
       return failure();
 
-    // Build std.load memref[expandedMap.results].
-    rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
+    // Build vector.load memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, op.getMemRef(),
+                                              *resultOperands);
     return success();
   }
 };
@@ -625,8 +626,8 @@ class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
       return failure();
 
     // Build std.store valueToStore, memref[expandedMap.results].
-    rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
-                                         op.getMemRef(), *maybeExpandedMap);
+    rewriter.replaceOpWithNewOp<mlir::StoreOp>(
+        op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
     return success();
   }
 };
@@ -695,8 +696,8 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
 };
 
 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
-/// and feed the results to a newly created 'vector.transfer_read' operation
-/// (which replaces the original 'affine.vector_load').
+/// and feed the results to a newly created 'vector.load' operation (which
+/// replaces the original 'affine.vector_load').
 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
 public:
   using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
@@ -710,16 +711,16 @@ class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
     if (!resultOperands)
       return failure();
 
-    // Build vector.transfer_read memref[expandedMap.results].
-    rewriter.replaceOpWithNewOp<TransferReadOp>(
+    // Build vector.load memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<vector::LoadOp>(
         op, op.getVectorType(), op.getMemRef(), *resultOperands);
     return success();
   }
 };
 
 /// Apply the affine map from an 'affine.vector_store' operation to its
-/// operands, and feed the results to a newly created 'vector.transfer_write'
-/// operation (which replaces the original 'affine.vector_store').
+/// operands, and feed the results to a newly created 'vector.store' operation
+/// (which replaces the original 'affine.vector_store').
 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
 public:
   using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
@@ -733,7 +734,7 @@ class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
     if (!maybeExpandedMap)
       return failure();
 
-    rewriter.replaceOpWithNewOp<TransferWriteOp>(
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
     return success();
   }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 54cdd9cfde60..3393bb702a78 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -357,64 +357,72 @@ class VectorFlatTransposeOpConversion
   }
 };
 
-/// Conversion pattern for a vector.maskedload.
-class VectorMaskedLoadOpConversion
-    : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
-public:
-  using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto loc = load->getLoc();
-    auto adaptor = vector::MaskedLoadOpAdaptor(operands);
-    MemRefType memRefType = load.getMemRefType();
+/// Overloaded utility that replaces a vector.load, vector.store,
+/// vector.maskedload and vector.maskedstore with their respective LLVM
+/// couterparts.
+static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
+                                 vector::LoadOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
+}
 
-    // Resolve alignment.
-    unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
-      return failure();
+static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
+                                 vector::MaskedLoadOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
+      loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
+}
 
-    // Resolve address.
-    auto vtype = typeConverter->convertType(load.getResultVectorType());
-    Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
-                                               adaptor.indices(), rewriter);
-    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
+static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
+                                 vector::StoreOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
+                                             ptr, align);
+}
 
-    rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
-        load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
-        rewriter.getI32IntegerAttr(align));
-    return success();
-  }
-};
+static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
+                                 vector::MaskedStoreOpAdaptor adaptor,
+                                 VectorType vectorTy, Value ptr, unsigned align,
+                                 ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
+      storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
+}
 
-/// Conversion pattern for a vector.maskedstore.
-class VectorMaskedStoreOpConversion
-    : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
+/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
+/// vector.maskedstore.
+template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
+class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 public:
-  using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
+  using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
+  matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = store->getLoc();
-    auto adaptor = vector::MaskedStoreOpAdaptor(operands);
-    MemRefType memRefType = store.getMemRefType();
+    // Only 1-D vectors can be lowered to LLVM.
+    VectorType vectorTy = loadOrStoreOp.getVectorType();
+    if (vectorTy.getRank() > 1)
+      return failure();
+
+    auto loc = loadOrStoreOp->getLoc();
+    auto adaptor = LoadOrStoreOpAdaptor(operands);
+    MemRefType memRefTy = loadOrStoreOp.getMemRefType();
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
+    if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
       return failure();
 
     // Resolve address.
-    auto vtype = typeConverter->convertType(store.getValueVectorType());
-    Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+    auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
+                     .template cast<VectorType>();
+    Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
                                                adaptor.indices(), rewriter);
-    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
+    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
 
-    rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
-        store, adaptor.value(), ptr, adaptor.mask(),
-        rewriter.getI32IntegerAttr(align));
+    replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
     return success();
   }
 };
@@ -1511,8 +1519,14 @@ void mlir::populateVectorToLLVMConversionPatterns(
               VectorInsertOpConversion,
               VectorPrintOpConversion,
               VectorTypeCastOpConversion,
-              VectorMaskedLoadOpConversion,
-              VectorMaskedStoreOpConversion,
+              VectorLoadStoreConversion<vector::LoadOp,
+                                        vector::LoadOpAdaptor>,
+              VectorLoadStoreConversion<vector::MaskedLoadOp,
+                                        vector::MaskedLoadOpAdaptor>,
+              VectorLoadStoreConversion<vector::StoreOp,
+                                        vector::StoreOpAdaptor>,
+              VectorLoadStoreConversion<vector::MaskedStoreOp,
+                                        vector::MaskedStoreOpAdaptor>,
               VectorGatherOpConversion,
               VectorScatterOpConversion,
               VectorExpandLoadOpConversion,

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 99b978895c7e..a56b49a315d8 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2373,6 +2373,67 @@ void TransferWriteOp::getEffects(
                          SideEffects::DefaultResource::get());
 }
 
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+                                                 MemRefType memRefTy) {
+  auto affineMaps = memRefTy.getAffineMaps();
+  if (!affineMaps.empty())
+    return op->emitOpError("base memref should have a default identity layout");
+  return success();
+}
+
+static LogicalResult verify(vector::LoadOp op) {
+  VectorType resVecTy = op.getVectorType();
+  MemRefType memRefTy = op.getMemRefType();
+
+  if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
+    return failure();
+
+  // Checks for vector memrefs.
+  Type memElemTy = memRefTy.getElementType();
+  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
+    if (memVecTy != resVecTy)
+      return op.emitOpError("base memref and result vector types should match");
+    memElemTy = memVecTy.getElementType();
+  }
+
+  if (resVecTy.getElementType() != memElemTy)
+    return op.emitOpError("base and result element types should match");
+  if (llvm::size(op.indices()) != memRefTy.getRank())
+    return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(vector::StoreOp op) {
+  VectorType valueVecTy = op.getVectorType();
+  MemRefType memRefTy = op.getMemRefType();
+
+  if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
+    return failure();
+
+  // Checks for vector memrefs.
+  Type memElemTy = memRefTy.getElementType();
+  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
+    if (memVecTy != valueVecTy)
+      return op.emitOpError(
+          "base memref and valueToStore vector types should match");
+    memElemTy = memVecTy.getElementType();
+  }
+
+  if (valueVecTy.getElementType() != memElemTy)
+    return op.emitOpError("base and valueToStore element type should match");
+  if (llvm::size(op.indices()) != memRefTy.getRank())
+    return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedLoadOp
 //===----------------------------------------------------------------------===//
@@ -2380,7 +2441,7 @@ void TransferWriteOp::getEffects(
 static LogicalResult verify(MaskedLoadOp op) {
   VectorType maskVType = op.getMaskVectorType();
   VectorType passVType = op.getPassThruVectorType();
-  VectorType resVType = op.getResultVectorType();
+  VectorType resVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (resVType.getElementType() != memType.getElementType())
@@ -2427,15 +2488,15 @@ void MaskedLoadOp::getCanonicalizationPatterns(
 
 static LogicalResult verify(MaskedStoreOp op) {
   VectorType maskVType = op.getMaskVectorType();
-  VectorType valueVType = op.getValueVectorType();
+  VectorType valueVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (valueVType.getElementType() != memType.getElementType())
-    return op.emitOpError("base and value element type should match");
+    return op.emitOpError("base and valueToStore element type should match");
   if (llvm::size(op.indices()) != memType.getRank())
     return op.emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
-    return op.emitOpError("expected value dim to match mask dim");
+    return op.emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
 
@@ -2448,7 +2509,7 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
     switch (get1DMaskFormat(store.mask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-          store, store.value(), store.base(), store.indices(), false);
+          store, store.valueToStore(), store.base(), store.indices(), false);
       return success();
     case MaskFormat::AllFalse:
       rewriter.eraseOp(store);

diff  --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
index 7fba0996d8f5..3df9bb33ece2 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
@@ -1,41 +1,5 @@
 // RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s
 
-// CHECK-LABEL: func @affine_vector_load
-func @affine_vector_load(%arg0 : index) {
-  %0 = alloc() : memref<100xf32>
-  affine.for %i0 = 0 to 16 {
-    %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
-  }
-// CHECK:       %[[buf:.*]] = alloc
-// CHECK:       %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
-// CHECK-NEXT:  %[[b:.*]] = addi %[[a]], %[[c7]] : index
-// CHECK-NEXT:  %[[pad:.*]] = constant 0.0
-// CHECK-NEXT:  vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @affine_vector_store
-func @affine_vector_store(%arg0 : index) {
-  %0 = alloc() : memref<100xf32>
-  %1 = constant dense<11.0> : vector<4xf32>
-  affine.for %i0 = 0 to 16 {
-    affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
-}
-// CHECK:       %[[buf:.*]] = alloc
-// CHECK:       %[[val:.*]] = constant dense
-// CHECK:       %[[c_1:.*]] = constant -1 : index
-// CHECK-NEXT:  %[[a:.*]] = muli %arg0, %[[c_1]] : index
-// CHECK-NEXT:  %[[b:.*]] = addi %{{.*}}, %[[a]] : index
-// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
-// CHECK-NEXT:  %[[c:.*]] = addi %[[b]], %[[c7]] : index
-// CHECK-NEXT:  vector.transfer_write  %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
-  return
-}
-
-// -----
 
 // CHECK-LABEL: func @affine_vector_load
 func @affine_vector_load(%arg0 : index) {
@@ -47,8 +11,7 @@ func @affine_vector_load(%arg0 : index) {
 // CHECK:       %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
 // CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
 // CHECK-NEXT:  %[[b:.*]] = addi %[[a]], %[[c7]] : index
-// CHECK-NEXT:  %[[pad:.*]] = constant 0.0
-// CHECK-NEXT:  vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
+// CHECK-NEXT:  vector.load %[[buf]][%[[b]]] : memref<100xf32>, vector<8xf32>
   return
 }
 
@@ -68,7 +31,7 @@ func @affine_vector_store(%arg0 : index) {
 // CHECK-NEXT:  %[[b:.*]] = addi %{{.*}}, %[[a]] : index
 // CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
 // CHECK-NEXT:  %[[c:.*]] = addi %[[b]], %[[c7]] : index
-// CHECK-NEXT:  vector.transfer_write  %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
+// CHECK-NEXT:  vector.store %[[val]], %[[buf]][%[[c]]] : memref<100xf32>, vector<4xf32>
   return
 }
 
@@ -83,8 +46,7 @@ func @vector_load_2d() {
 // CHECK:      %[[buf:.*]] = alloc
 // CHECK:      scf.for %[[i0:.*]] =
 // CHECK:        scf.for %[[i1:.*]] =
-// CHECK-NEXT:     %[[pad:.*]] = constant 0.0
-// CHECK-NEXT:     vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK-NEXT:     vector.load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
     }
   }
   return
@@ -103,9 +65,8 @@ func @vector_store_2d() {
 // CHECK:      %[[val:.*]] = constant dense
 // CHECK:      scf.for %[[i0:.*]] =
 // CHECK:        scf.for %[[i1:.*]] =
-// CHECK-NEXT:     vector.transfer_write  %[[val]], %[[buf]][%[[i0]], %[[i1]]] : vector<2x8xf32>, memref<100x100xf32>
+// CHECK-NEXT:     vector.store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
     }
   }
   return
 }
-

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index facc91cf03d0..3d1294318a6e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -23,6 +23,7 @@ func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
 
 // -----
 
+
 func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
   %0 = vector.broadcast %arg0 : f32 to vector<2xf32>
   return %0 : vector<2xf32>
@@ -1242,6 +1243,33 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 // -----
 
+func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @vector_load_op
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
+// CHECK: llvm.load %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<8xf32>>
+
+func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  %val = constant dense<11.0> : vector<4xf32>
+  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_op
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
+// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
+
 func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 099dad7eada4..ab58fdc37ccf 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1198,6 +1198,38 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
 
 // -----
 
+func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
+                               %i : index, %j : index, %value : vector<8xf32>) {
+  // expected-error at +1 {{'vector.store' op base memref should have a default identity layout}}
+  vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
+                                         vector<8xf32>
+}
+
+// -----
+
+func @vector_memref_mismatch(%memref : memref<200x100xvector<4xf32>>, %i : index,
+                             %j : index, %value : vector<8xf32>) {
+  // expected-error at +1 {{'vector.store' op base memref and valueToStore vector types should match}}
+  vector.store %value, %memref[%i, %j] : memref<200x100xvector<4xf32>>, vector<8xf32>
+}
+
+// -----
+
+func @store_base_type_mismatch(%base : memref<?xf64>, %value : vector<16xf32>) {
+  %c0 = constant 0 : index
+  // expected-error at +1 {{'vector.store' op base and valueToStore element type should match}}
+  vector.store %value, %base[%c0] : memref<?xf64>, vector<16xf32>
+}
+
+// -----
+
+func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16xf32>) {
+  // expected-error at +1 {{'vector.store' op requires 1 indices}}
+  vector.store %value, %base[] : memref<?xf32>, vector<16xf32>
+}
+
+// -----
+
 func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
   %c0 = constant 0 : index
   // expected-error at +1 {{'vector.maskedload' op base and result element type should match}}
@@ -1231,7 +1263,7 @@ func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pa
 
 func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op base and value element type should match}}
+  // expected-error at +1 {{'vector.maskedstore' op base and valueToStore element type should match}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1239,7 +1271,7 @@ func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
 
 func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op expected value dim to match mask dim}}
+  // expected-error at +1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 7284cab523a7..11197f1e0bee 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -450,6 +450,56 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
   return %0 : vector<16xi32>
 }
 
+// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
+func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_1d_vector_memref
+func @vector_load_and_store_1d_vector_memref(%memref : memref<200x100xvector<8xf32>>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_out_of_bounds
+func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) {
+  %c0 = constant 0 : index
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
+  %0 = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
+  vector.store %0, %memref[%c0] : memref<7xf32>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_2d_scalar_memref
+func @vector_load_and_store_2d_scalar_memref(%memref : memref<200x100xf32>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_2d_vector_memref
+func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvector<4x8xf32>>,
+                                             %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
+  return
+}
+
 // CHECK-LABEL: @masked_load_and_store
 func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
   %c0 = constant 0 : index


        


More information about the Mlir-commits mailing list