[Mlir-commits] [mlir] bc5565f - [mlir][Affine] Introduce affine.vector_load and affine.vector_store

Diego Caballero llvmlistbot at llvm.org
Thu May 14 13:25:32 PDT 2020


Author: Diego Caballero
Date: 2020-05-14T13:17:58-07:00
New Revision: bc5565f9ea7aa7d3815a3554a0c937c7b48c7dcd

URL: https://github.com/llvm/llvm-project/commit/bc5565f9ea7aa7d3815a3554a0c937c7b48c7dcd
DIFF: https://github.com/llvm/llvm-project/commit/bc5565f9ea7aa7d3815a3554a0c937c7b48c7dcd.diff

LOG: [mlir][Affine] Introduce affine.vector_load and affine.vector_store

This patch adds `affine.vector_load` and `affine.vector_store` ops to
the Affine dialect and lowers them to `vector.transfer_read` and
`vector.transfer_write`, respectively, in the Vector dialect.

Reviewed By: bondhugula, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D79658

Added: 
    mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir

Modified: 
    mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Affine/invalid.mlir
    mlir/test/Dialect/Affine/load-store.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
index 967c74c5c23f..5d04f157b8ce 100644
--- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
+++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
@@ -44,6 +44,11 @@ Optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
 void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
                                            MLIRContext *ctx);
 
+/// Collect a set of patterns to convert vector-related Affine ops to the Vector
+/// dialect.
+void populateAffineToVectorConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
 /// Emit code that computes the lower bound of the given affine loop using
 /// standard arithmetic operations.
 Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index bb02f8b18c5d..8286d8f315bd 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -370,7 +370,44 @@ def AffineIfOp : Affine_Op<"if",
   let hasFolder = 1;
 }
 
-def AffineLoadOp : Affine_Op<"load", []> {
+class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
+    Affine_Op<mnemonic, traits> {
+  let arguments = (ins Arg<AnyMemRef, "the reference to load from",
+      [MemRead]>:$memref,
+      Variadic<Index>:$indices);
+
+  code extraClassDeclarationBase = [{
+    /// Returns the operand index of the memref.
+    unsigned getMemRefOperandIndex() { return 0; }
+
+    /// Get memref operand.
+    Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
+    void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
+    MemRefType getMemRefType() {
+      return getMemRef().getType().cast<MemRefType>();
+    }
+
+    /// Get affine map operands.
+    operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
+
+    /// Returns the affine map used to index the memref for this operation.
+    AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
+    AffineMapAttr getAffineMapAttr() {
+      return getAttr(getMapAttrName()).cast<AffineMapAttr>();
+    }
+
+    /// Returns the AffineMapAttr associated with 'memref'.
+    NamedAttribute getAffineMapAttrForMemRef(Value memref) {
+      assert(memref == getMemRef());
+      return {Identifier::get(getMapAttrName(), getContext()),
+              getAffineMapAttr()};
+    }
+
+    static StringRef getMapAttrName() { return "map"; }
+  }];
+}
+
+def AffineLoadOp : AffineLoadOpBase<"load", []> {
   let summary = "affine load operation";
   let description = [{
     The "affine.load" op reads an element from a memref, where the index
@@ -393,9 +430,6 @@ def AffineLoadOp : Affine_Op<"load", []> {
     ```
   }];
 
-  let arguments = (ins Arg<AnyMemRef, "the reference to load from",
-      [MemRead]>:$memref,
-      Variadic<Index>:$indices);
   let results = (outs AnyType:$result);
 
   let builders = [
@@ -410,35 +444,7 @@ def AffineLoadOp : Affine_Op<"load", []> {
                       "AffineMap map, ValueRange mapOperands">
   ];
 
-  let extraClassDeclaration = [{
-    /// Returns the operand index of the memref.
-    unsigned getMemRefOperandIndex() { return 0; }
-
-    /// Get memref operand.
-    Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
-    void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
-    MemRefType getMemRefType() {
-      return getMemRef().getType().cast<MemRefType>();
-    }
-
-    /// Get affine map operands.
-    operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
-
-    /// Returns the affine map used to index the memref for this operation.
-    AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
-    AffineMapAttr getAffineMapAttr() {
-      return getAttr(getMapAttrName()).cast<AffineMapAttr>();
-    }
-
-    /// Returns the AffineMapAttr associated with 'memref'.
-    NamedAttribute getAffineMapAttrForMemRef(Value memref) {
-      assert(memref == getMemRef());
-      return {Identifier::get(getMapAttrName(), getContext()),
-              getAffineMapAttr()};
-    }
-
-    static StringRef getMapAttrName() { return "map"; }
-  }];
+  let extraClassDeclaration = extraClassDeclarationBase;
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
@@ -659,7 +665,45 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
   let hasFolder = 1;
 }
 
-def AffineStoreOp : Affine_Op<"store", []> {
+class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
+    Affine_Op<mnemonic, traits> {
+
+  code extraClassDeclarationBase = [{
+    /// Get value to be stored by store operation.
+    Value getValueToStore() { return getOperand(0); }
+
+    /// Returns the operand index of the memref.
+    unsigned getMemRefOperandIndex() { return 1; }
+
+    /// Get memref operand.
+    Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
+    void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
+
+    MemRefType getMemRefType() {
+      return getMemRef().getType().cast<MemRefType>();
+    }
+
+    /// Get affine map operands.
+    operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
+
+    /// Returns the affine map used to index the memref for this operation.
+    AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
+    AffineMapAttr getAffineMapAttr() {
+      return getAttr(getMapAttrName()).cast<AffineMapAttr>();
+    }
+
+    /// Returns the AffineMapAttr associated with 'memref'.
+    NamedAttribute getAffineMapAttrForMemRef(Value memref) {
+      assert(memref == getMemRef());
+      return {Identifier::get(getMapAttrName(), getContext()),
+              getAffineMapAttr()};
+    }
+
+    static StringRef getMapAttrName() { return "map"; }
+  }];
+}
+
+def AffineStoreOp : AffineStoreOpBase<"store", []> {
   let summary = "affine store operation";
   let description = [{
     The "affine.store" op writes an element to a memref, where the index
@@ -686,7 +730,6 @@ def AffineStoreOp : Affine_Op<"store", []> {
       [MemWrite]>:$memref,
       Variadic<Index>:$indices);
 
-
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<"OpBuilder &builder, OperationState &result, "
@@ -696,39 +739,7 @@ def AffineStoreOp : Affine_Op<"store", []> {
                       "ValueRange mapOperands">
   ];
 
-  let extraClassDeclaration = [{
-    /// Get value to be stored by store operation.
-    Value getValueToStore() { return getOperand(0); }
-
-    /// Returns the operand index of the memref.
-    unsigned getMemRefOperandIndex() { return 1; }
-
-    /// Get memref operand.
-    Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
-    void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
-
-    MemRefType getMemRefType() {
-      return getMemRef().getType().cast<MemRefType>();
-    }
-
-    /// Get affine map operands.
-    operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
-
-    /// Returns the affine map used to index the memref for this operation.
-    AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
-    AffineMapAttr getAffineMapAttr() {
-      return getAttr(getMapAttrName()).cast<AffineMapAttr>();
-    }
-
-    /// Returns the AffineMapAttr associated with 'memref'.
-    NamedAttribute getAffineMapAttrForMemRef(Value memref) {
-      assert(memref == getMemRef());
-      return {Identifier::get(getMapAttrName(), getContext()),
-              getAffineMapAttr()};
-    }
-
-    static StringRef getMapAttrName() { return "map"; }
-  }];
+  let extraClassDeclaration = extraClassDeclarationBase;
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
@@ -765,4 +776,107 @@ def AffineTerminatorOp :
   let verifier = ?;
 }
 
+def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> {
+  let summary = "affine vector load operation";
+  let description = [{
+    The "affine.vector_load" is the vector counterpart of
+    [affine.load](#affineload-operation). It reads a slice from a
+    [MemRef](../LangRef.md#memref-type), supplied as its first operand,
+    into a [vector](../LangRef.md#vector-type) of the same base elemental type.
+    The index for each memref dimension is an affine expression of loop induction
+    variables and symbols. These indices determine the start position of the read
+    within the memref. The shape of the return vector type determines the shape of
+    the slice read from the memref. This slice is contiguous along the respective
+    dimensions of the shape. Strided vector loads will be supported in the future.
+    An affine expression of loop IVs and symbols must be specified for each
+    dimension of the memref. The keyword 'symbol' can be used to indicate SSA
+    identifiers which are symbolic.
+
+    Example 1: 8-wide f32 vector load.
+
+    ```mlir
+    %1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32>
+    ```
+
+    Example 2: 4-wide f32 vector load. Uses 'symbol' keyword for symbols '%n' and '%m'.
+
+    ```mlir
+    %1 = affine.vector_load %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32>
+    ```
+
+    Example 3: 2-dim f32 vector load.
+
+    ```mlir
+    %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+    ```
+
+    TODOs:
+    * Add support for strided vector loads.
+    * Consider adding a permutation map to permute the slice that is read from memory
+    (see [vector.transfer_read](../Vector/#vectortransfer_read-vectortransferreadop)).
+  }];
+
+  let results = (outs AnyVector:$result);
+
+  let extraClassDeclaration = extraClassDeclarationBase # [{
+    VectorType getVectorType() {
+      return result().getType().cast<VectorType>();
+    }
+  }];
+}
+
+def AffineVectorStoreOp : AffineStoreOpBase<"vector_store", []> {
+  let summary = "affine vector store operation";
+  let description = [{
+    The "affine.vector_store" is the vector counterpart of
+    [affine.store](#affinestore-affinestoreop). It writes a
+    [vector](../LangRef.md#vector-type), supplied as its first operand,
+    into a slice within a [MemRef](../LangRef.md#memref-type) of the same base
+    elemental type, supplied as its second operand.
+    The index for each memref dimension is an affine expression of loop
+    induction variables and symbols. These indices determine the start position
+    of the write within the memref. The shape of th input vector determines the
+    shape of the slice written to the memref. This slice is contiguous along the
+    respective dimensions of the shape. Strided vector stores will be supported
+    in the future.
+    An affine expression of loop IVs and symbols must be specified for each
+    dimension of the memref. The keyword 'symbol' can be used to indicate SSA
+    identifiers which are symbolic.
+
+    Example 1: 8-wide f32 vector store.
+
+    ```mlir
+    affine.vector_store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32>
+    ```
+
+    Example 2: 4-wide f32 vector store. Uses 'symbol' keyword for symbols '%n' and '%m'.
+
+    ```mlir
+    affine.vector_store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32>
+    ```
+
+    Example 3: 2-dim f32 vector store.
+
+    ```mlir
+    affine.vector_store %v0, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+    ```
+
+    TODOs:
+    * Add support for strided vector stores.
+    * Consider adding a permutation map to permute the slice that is written to memory
+    (see [vector.transfer_write](../Vector/#vectortransfer_write-vectortransferwriteop)).
+  }];
+
+  let arguments = (ins AnyVector:$value,
+      Arg<AnyMemRef, "the reference to store to",
+      [MemWrite]>:$memref,
+      Variadic<Index>:$indices);
+
+  let extraClassDeclaration = extraClassDeclarationBase # [{
+    VectorType getVectorType() {
+      return value().getType().cast<VectorType>();
+    }
+  }];
+}
+
 #endif // AFFINE_OPS

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index c45c61998446..4c71a168dae7 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -994,6 +994,13 @@ def Vector_TransferReadOp :
     ```
   }];
 
+  let builders = [
+    // Builder that sets permutation map and padding to 'getMinorIdentityMap'
+    // and zero, respectively, by default.
+    OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
+              "Value memref, ValueRange indices">
+  ];
+
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return memref().getType().cast<MemRefType>();
@@ -1058,6 +1065,13 @@ def Vector_TransferWriteOp :
     ```
   }];
 
+  let builders = [
+    // Builder that sets permutation map and padding to 'getMinorIdentityMap'
+    // by default.
+    OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
+              "Value memref, ValueRange indices">
+  ];
+
   let extraClassDeclaration = [{
     VectorType getVectorType() {
       return vector().getType().cast<VectorType>();

diff  --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 862d76c32cc4..96b62969c8cd 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
@@ -27,6 +28,7 @@
 #include "mlir/Transforms/Passes.h"
 
 using namespace mlir;
+using namespace mlir::vector;
 
 namespace {
 /// Visit affine expressions recursively and build the sequence of operations
@@ -556,6 +558,51 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
   }
 };
 
+/// Apply the affine map from an 'affine.vector_load' operation to its operands,
+/// and feed the results to a newly created 'vector.transfer_read' operation
+/// (which replaces the original 'affine.vector_load').
+class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
+public:
+  using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineVectorLoadOp op,
+                                PatternRewriter &rewriter) const override {
+    // Expand affine map from 'affineVectorLoadOp'.
+    SmallVector<Value, 8> indices(op.getMapOperands());
+    auto resultOperands =
+        expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+    if (!resultOperands)
+      return failure();
+
+    // Build vector.transfer_read memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<TransferReadOp>(
+        op, op.getVectorType(), op.getMemRef(), *resultOperands);
+    return success();
+  }
+};
+
+/// Apply the affine map from an 'affine.vector_store' operation to its
+/// operands, and feed the results to a newly created 'vector.transfer_write'
+/// operation (which replaces the original 'affine.vector_store').
+class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
+public:
+  using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineVectorStoreOp op,
+                                PatternRewriter &rewriter) const override {
+    // Expand affine map from 'affineVectorStoreOp'.
+    SmallVector<Value, 8> indices(op.getMapOperands());
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+    if (!maybeExpandedMap)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<TransferWriteOp>(
+        op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
+    return success();
+  }
+};
+
 } // end namespace
 
 void mlir::populateAffineToStdConversionPatterns(
@@ -576,13 +623,24 @@ void mlir::populateAffineToStdConversionPatterns(
   // clang-format on
 }
 
+void mlir::populateAffineToVectorConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  // clang-format off
+  patterns.insert<
+      AffineVectorLoadLowering,
+      AffineVectorStoreLowering>(ctx);
+  // clang-format on
+}
+
 namespace {
 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     populateAffineToStdConversionPatterns(patterns, &getContext());
+    populateAffineToVectorConversionPatterns(patterns, &getContext());
     ConversionTarget target(getContext());
-    target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
+    target
+        .addLegalDialect<scf::SCFDialect, StandardOpsDialect, VectorDialect>();
     if (failed(applyPartialConversion(getFunction(), target, patterns)))
       signalPassFailure();
   }

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 16f4a3c6068e..27f4450924b6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1912,32 +1912,47 @@ void print(OpAsmPrinter &p, AffineLoadOp op) {
   p << " : " << op.getMemRefType();
 }
 
-LogicalResult verify(AffineLoadOp op) {
-  if (op.getType() != op.getMemRefType().getElementType())
-    return op.emitOpError("result type must match element type of memref");
-
-  auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
+/// Verify common indexing invariants of affine.load, affine.store,
+/// affine.vector_load and affine.vector_store.
+static LogicalResult
+verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
+                       Operation::operand_range mapOperands,
+                       MemRefType memrefType, unsigned numIndexOperands) {
   if (mapAttr) {
-    AffineMap map =
-        op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()).getValue();
-    if (map.getNumResults() != op.getMemRefType().getRank())
-      return op.emitOpError("affine.load affine map num results must equal"
-                            " memref rank");
-    if (map.getNumInputs() != op.getNumOperands() - 1)
-      return op.emitOpError("expects as many subscripts as affine map inputs");
+    AffineMap map = mapAttr.getValue();
+    if (map.getNumResults() != memrefType.getRank())
+      return op->emitOpError("affine map num results must equal memref rank");
+    if (map.getNumInputs() != numIndexOperands)
+      return op->emitOpError("expects as many subscripts as affine map inputs");
   } else {
-    if (op.getMemRefType().getRank() != op.getNumOperands() - 1)
-      return op.emitOpError(
+    if (memrefType.getRank() != numIndexOperands)
+      return op->emitOpError(
           "expects the number of subscripts to be equal to memref rank");
   }
 
   Region *scope = getAffineScope(op);
-  for (auto idx : op.getMapOperands()) {
+  for (auto idx : mapOperands) {
     if (!idx.getType().isIndex())
-      return op.emitOpError("index to load must have 'index' type");
+      return op->emitOpError("index to load must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return op.emitOpError("index must be a dimension or symbol identifier");
+      return op->emitOpError("index must be a dimension or symbol identifier");
   }
+
+  return success();
+}
+
+LogicalResult verify(AffineLoadOp op) {
+  auto memrefType = op.getMemRefType();
+  if (op.getType() != memrefType.getElementType())
+    return op.emitOpError("result type must match element type of memref");
+
+  if (failed(verifyMemoryOpIndexing(
+          op.getOperation(),
+          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+          op.getMapOperands(), memrefType,
+          /*numIndexOperands=*/op.getNumOperands() - 1)))
+    return failure();
+
   return success();
 }
 
@@ -2014,31 +2029,18 @@ void print(OpAsmPrinter &p, AffineStoreOp op) {
 
 LogicalResult verify(AffineStoreOp op) {
   // First operand must have same type as memref element type.
-  if (op.getValueToStore().getType() != op.getMemRefType().getElementType())
+  auto memrefType = op.getMemRefType();
+  if (op.getValueToStore().getType() != memrefType.getElementType())
     return op.emitOpError(
         "first operand must have same type memref element type");
 
-  auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
-  if (mapAttr) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != op.getMemRefType().getRank())
-      return op.emitOpError("affine.store affine map num results must equal"
-                            " memref rank");
-    if (map.getNumInputs() != op.getNumOperands() - 2)
-      return op.emitOpError("expects as many subscripts as affine map inputs");
-  } else {
-    if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
-      return op.emitOpError(
-          "expects the number of subscripts to be equal to memref rank");
-  }
+  if (failed(verifyMemoryOpIndexing(
+          op.getOperation(),
+          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+          op.getMapOperands(), memrefType,
+          /*numIndexOperands=*/op.getNumOperands() - 2)))
+    return failure();
 
-  Region *scope = getAffineScope(op);
-  for (auto idx : op.getMapOperands()) {
-    if (!idx.getType().isIndex())
-      return op.emitOpError("index to store must have 'index' type");
-    if (!isValidAffineIndexOperand(idx, scope))
-      return op.emitOpError("index must be a dimension or symbol identifier");
-  }
   return success();
 }
 
@@ -2493,6 +2495,125 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// AffineVectorLoadOp
+//===----------------------------------------------------------------------===//
+
+ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
+                                    OperationState &result) {
+  auto &builder = parser.getBuilder();
+  auto indexTy = builder.getIndexType();
+
+  MemRefType memrefType;
+  VectorType resultType;
+  OpAsmParser::OperandType memrefInfo;
+  AffineMapAttr mapAttr;
+  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+  return failure(
+      parser.parseOperand(memrefInfo) ||
+      parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
+                                    AffineVectorLoadOp::getMapAttrName(),
+                                    result.attributes) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(memrefType) || parser.parseComma() ||
+      parser.parseType(resultType) ||
+      parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
+      parser.resolveOperands(mapOperands, indexTy, result.operands) ||
+      parser.addTypeToList(resultType, result.types));
+}
+
+void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
+  p << "affine.vector_load " << op.getMemRef() << '[';
+  if (AffineMapAttr mapAttr =
+          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
+    p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
+  p << ']';
+  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
+  p << " : " << op.getMemRefType() << ", " << op.getType();
+}
+
+/// Verify common invariants of affine.vector_load and affine.vector_store.
+static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
+                                          VectorType vectorType) {
+  // Check that memref and vector element types match.
+  if (memrefType.getElementType() != vectorType.getElementType())
+    return op->emitOpError(
+        "requires memref and vector types of the same elemental type");
+
+  return success();
+}
+
+static LogicalResult verify(AffineVectorLoadOp op) {
+  MemRefType memrefType = op.getMemRefType();
+  if (failed(verifyMemoryOpIndexing(
+          op.getOperation(),
+          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+          op.getMapOperands(), memrefType,
+          /*numIndexOperands=*/op.getNumOperands() - 1)))
+    return failure();
+
+  if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
+                                  op.getVectorType())))
+    return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineVectorStoreOp
+//===----------------------------------------------------------------------===//
+
+ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
+                                     OperationState &result) {
+  auto indexTy = parser.getBuilder().getIndexType();
+
+  MemRefType memrefType;
+  VectorType resultType;
+  OpAsmParser::OperandType storeValueInfo;
+  OpAsmParser::OperandType memrefInfo;
+  AffineMapAttr mapAttr;
+  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+  return failure(
+      parser.parseOperand(storeValueInfo) || parser.parseComma() ||
+      parser.parseOperand(memrefInfo) ||
+      parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
+                                    AffineVectorStoreOp::getMapAttrName(),
+                                    result.attributes) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(memrefType) || parser.parseComma() ||
+      parser.parseType(resultType) ||
+      parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
+      parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
+      parser.resolveOperands(mapOperands, indexTy, result.operands));
+}
+
+void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
+  p << "affine.vector_store " << op.getValueToStore();
+  p << ", " << op.getMemRef() << '[';
+  if (AffineMapAttr mapAttr =
+          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
+    p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
+  p << ']';
+  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
+  p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
+}
+
+static LogicalResult verify(AffineVectorStoreOp op) {
+  MemRefType memrefType = op.getMemRefType();
+  if (failed(verifyMemoryOpIndexing(
+          op.getOperation(),
+          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+          op.getMapOperands(), memrefType,
+          /*numIndexOperands=*/op.getNumOperands() - 2)))
+    return failure();
+
+  if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
+                                  op.getVectorType())))
+    return failure();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index ec4ef0f50361..8385b253e9fb 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1284,6 +1284,21 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
   return success();
 }
 
+/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and
+/// zero, respectively, by default.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vector, Value memref,
+                           ValueRange indices) {
+  auto permMap = AffineMap::getMinorIdentityMap(
+      memref.getType().cast<MemRefType>().getRank(), vector.getRank(),
+      builder.getContext());
+  Type elemType = vector.cast<VectorType>().getElementType();
+  Value padding = builder.create<ConstantOp>(result.location, elemType,
+                                             builder.getZeroAttr(elemType));
+
+  build(builder, result, vector, memref, indices, permMap, padding);
+}
+
 static void print(OpAsmPrinter &p, TransferReadOp op) {
   p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
     << "], " << op.padding() << " ";
@@ -1361,6 +1376,17 @@ static LogicalResult verify(TransferReadOp op) {
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
 
+/// Builder that sets permutation map and padding to 'getMinorIdentityMap' by
+/// default.
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+                            Value vector, Value memref, ValueRange indices) {
+  auto vectorType = vector.getType().cast<VectorType>();
+  auto permMap = AffineMap::getMinorIdentityMap(
+      memref.getType().cast<MemRefType>().getRank(), vectorType.getRank(),
+      builder.getContext());
+  build(builder, result, vector, memref, indices, permMap);
+}
+
 static LogicalResult verify(TransferWriteOp op) {
   // Consistency of elemental types in memref and vector.
   MemRefType memrefType = op.getMemRefType();

diff  --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
new file mode 100644
index 000000000000..f9a78aa495a5
--- /dev/null
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
@@ -0,0 +1,117 @@
+// RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_load
+func @affine_vector_load(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  affine.for %i0 = 0 to 16 {
+    %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
+  }
+// CHECK:       %[[buf:.*]] = alloc
+// CHECK:       %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT:  %[[b:.*]] = addi %[[a]], %[[c7]] : index
+// CHECK-NEXT:  %[[pad:.*]] = constant 0.0
+// CHECK-NEXT:  vector.transfer_read %[[buf]][%[[b]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100xf32>, vector<8xf32>
+  return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_store
+func @affine_vector_store(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  %1 = constant dense<11.0> : vector<4xf32>
+  affine.for %i0 = 0 to 16 {
+    affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
+}
+// CHECK:       %[[buf:.*]] = alloc
+// CHECK:       %[[val:.*]] = constant dense
+// CHECK:       %[[c_1:.*]] = constant -1 : index
+// CHECK-NEXT:  %[[a:.*]] = muli %arg0, %[[c_1]] : index
+// CHECK-NEXT:  %[[b:.*]] = addi %{{.*}}, %[[a]] : index
+// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT:  %[[c:.*]] = addi %[[b]], %[[c7]] : index
+// CHECK-NEXT:  vector.transfer_write  %[[val]], %[[buf]][%[[c]]] {permutation_map = #[[perm_map]]} : vector<4xf32>, memref<100xf32>
+  return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_load
+func @affine_vector_load(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  affine.for %i0 = 0 to 16 {
+    %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
+  }
+// CHECK:       %[[buf:.*]] = alloc
+// CHECK:       %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT:  %[[b:.*]] = addi %[[a]], %[[c7]] : index
+// CHECK-NEXT:  %[[pad:.*]] = constant 0.0
+// CHECK-NEXT:  vector.transfer_read %[[buf]][%[[b]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100xf32>, vector<8xf32>
+  return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_store
+func @affine_vector_store(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  %1 = constant dense<11.0> : vector<4xf32>
+  affine.for %i0 = 0 to 16 {
+    affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
+}
+// CHECK:       %[[buf:.*]] = alloc
+// CHECK:       %[[val:.*]] = constant dense
+// CHECK:       %[[c_1:.*]] = constant -1 : index
+// CHECK-NEXT:  %[[a:.*]] = muli %arg0, %[[c_1]] : index
+// CHECK-NEXT:  %[[b:.*]] = addi %{{.*}}, %[[a]] : index
+// CHECK-NEXT:  %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT:  %[[c:.*]] = addi %[[b]], %[[c7]] : index
+// CHECK-NEXT:  vector.transfer_write  %[[val]], %[[buf]][%[[c]]] {permutation_map = #[[perm_map]]} : vector<4xf32>, memref<100xf32>
+  return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @vector_load_2d
+func @vector_load_2d() {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 16 step 2{
+    affine.for %i1 = 0 to 16 step 8 {
+      %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK:      %[[buf:.*]] = alloc
+// CHECK:      scf.for %[[i0:.*]] =
+// CHECK:        scf.for %[[i1:.*]] =
+// CHECK-NEXT:     %[[pad:.*]] = constant 0.0
+// CHECK-NEXT:     vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100x100xf32>, vector<2x8xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @vector_store_2d
+func @vector_store_2d() {
+  %0 = alloc() : memref<100x100xf32>
+  %1 = constant dense<11.0> : vector<2x8xf32>
+  affine.for %i0 = 0 to 16 step 2{
+    affine.for %i1 = 0 to 16 step 8 {
+      affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK:      %[[buf:.*]] = alloc
+// CHECK:      %[[val:.*]] = constant dense
+// CHECK:      scf.for %[[i0:.*]] =
+// CHECK:        scf.for %[[i1:.*]] =
+// CHECK-NEXT:     vector.transfer_write  %[[val]], %[[buf]][%[[i0]], %[[i1]]] {permutation_map = #[[perm_map]]} : vector<2x8xf32>, memref<100x100xf32>
+    }
+  }
+  return
+}
+

diff  --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index c0855987ac32..102dd394f93c 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -262,3 +262,49 @@ func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
   }
   return
 }
+
+// -----
+
+func @vector_load_invalid_vector_type() {
+  %0 = alloc() : memref<100xf32>
+  affine.for %i0 = 0 to 16 step 8 {
+    // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+    %1 = affine.vector_load %0[%i0] : memref<100xf32>, vector<8xf64>
+  }
+  return
+}
+
+// -----
+
+func @vector_store_invalid_vector_type() {
+  %0 = alloc() : memref<100xf32>
+  %1 = constant dense<7.0> : vector<8xf64>
+  affine.for %i0 = 0 to 16 step 8 {
+    // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+    affine.vector_store %1, %0[%i0] : memref<100xf32>, vector<8xf64>
+  }
+  return
+}
+
+// -----
+
+func @vector_load_vector_memref() {
+  %0 = alloc() : memref<100xvector<8xf32>>
+  affine.for %i0 = 0 to 4 {
+    // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+    %1 = affine.vector_load %0[%i0] : memref<100xvector<8xf32>>, vector<8xf32>
+  }
+  return
+}
+
+// -----
+
+func @vector_store_vector_memref() {
+  %0 = alloc() : memref<100xvector<8xf32>>
+  %1 = constant dense<7.0> : vector<8xf32>
+  affine.for %i0 = 0 to 4 {
+    // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+    affine.vector_store %1, %0[%i0] : memref<100xvector<8xf32>>, vector<8xf32>
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/Affine/load-store.mlir b/mlir/test/Dialect/Affine/load-store.mlir
index 54e753d17fef..06a93d41fd64 100644
--- a/mlir/test/Dialect/Affine/load-store.mlir
+++ b/mlir/test/Dialect/Affine/load-store.mlir
@@ -214,3 +214,65 @@ func @test_prefetch(%arg0 : index, %arg1 : index) {
   }
   return
 }
+
+// -----
+
+// CHECK: [[MAP_ID:#map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// Test with just loop IVs.
+func @vector_load_vector_store_iv() {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 16 {
+    affine.for %i1 = 0 to 16 step 8 {
+      %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<8xf32>
+      affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<8xf32>
+// CHECK:      %[[buf:.*]] = alloc
+// CHECK-NEXT: affine.for %[[i0:.*]] = 0
+// CHECK-NEXT:   affine.for %[[i1:.*]] = 0
+// CHECK-NEXT:     %[[val:.*]] = affine.vector_load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<8xf32>
+// CHECK-NEXT:     affine.vector_store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<8xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 + 3, d1 + 7)>
+
+// Test with loop IVs and constants.
+func @vector_load_vector_store_iv_constant() {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 16 step 4 {
+      %1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<4xf32>
+      affine.vector_store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<4xf32>
+// CHECK:      %[[buf:.*]] = alloc
+// CHECK-NEXT: affine.for %[[i0:.*]] = 0
+// CHECK-NEXT:   affine.for %[[i1:.*]] = 0
+// CHECK-NEXT:     %[[val:.*]] = affine.vector_load %{{.*}}[%{{.*}} + 3, %{{.*}} + 7] : memref<100x100xf32>, vector<4xf32>
+// CHECK-NEXT:     affine.vector_store %[[val]], %[[buf]][%[[i0]] + 3, %[[i1]] + 7] : memref<100x100xf32>, vector<4xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+func @vector_load_vector_store_2d() {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 16 step 2{
+    affine.for %i1 = 0 to 16 step 8 {
+      %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+      affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK:      %[[buf:.*]] = alloc
+// CHECK-NEXT: affine.for %[[i0:.*]] = 0
+// CHECK-NEXT:   affine.for %[[i1:.*]] = 0
+// CHECK-NEXT:     %[[val:.*]] = affine.vector_load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK-NEXT:     affine.vector_store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list