[llvm-branch-commits] [mlir] a706885 - [mlir][Affine] Introduce affine.vector_load and affine.vector_store
Diego Caballero via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu May 14 13:14:19 PDT 2020
Author: Diego Caballero
Date: 2020-05-14T13:09:39-07:00
New Revision: a706885bb2605d74d62476241cce4f5729db2da2
URL: https://github.com/llvm/llvm-project/commit/a706885bb2605d74d62476241cce4f5729db2da2
DIFF: https://github.com/llvm/llvm-project/commit/a706885bb2605d74d62476241cce4f5729db2da2.diff
LOG: [mlir][Affine] Introduce affine.vector_load and affine.vector_store
This patch adds `affine.vector_load` and `affine.vector_store` ops to
the Affine dialect and lowers them to `vector.transfer_read` and
`vector.transfer_write`, respectively, in the Vector dialect.
Reviewed By: bondhugula, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D79658
Added:
mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
Modified:
mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Affine/invalid.mlir
mlir/test/Dialect/Affine/load-store.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
index 967c74c5c23f..5d04f157b8ce 100644
--- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
+++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
@@ -44,6 +44,11 @@ Optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
+/// Collect a set of patterns to convert vector-related Affine ops to the Vector
+/// dialect.
+void populateAffineToVectorConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx);
+
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.
Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index bb02f8b18c5d..8286d8f315bd 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -370,7 +370,44 @@ def AffineIfOp : Affine_Op<"if",
let hasFolder = 1;
}
-def AffineLoadOp : Affine_Op<"load", []> {
+class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
+ Affine_Op<mnemonic, traits> {
+ let arguments = (ins Arg<AnyMemRef, "the reference to load from",
+ [MemRead]>:$memref,
+ Variadic<Index>:$indices);
+
+ code extraClassDeclarationBase = [{
+ /// Returns the operand index of the memref.
+ unsigned getMemRefOperandIndex() { return 0; }
+
+ /// Get memref operand.
+ Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
+ void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
+ MemRefType getMemRefType() {
+ return getMemRef().getType().cast<MemRefType>();
+ }
+
+ /// Get affine map operands.
+ operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
+
+ /// Returns the affine map used to index the memref for this operation.
+ AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
+ AffineMapAttr getAffineMapAttr() {
+ return getAttr(getMapAttrName()).cast<AffineMapAttr>();
+ }
+
+ /// Returns the AffineMapAttr associated with 'memref'.
+ NamedAttribute getAffineMapAttrForMemRef(Value memref) {
+ assert(memref == getMemRef());
+ return {Identifier::get(getMapAttrName(), getContext()),
+ getAffineMapAttr()};
+ }
+
+ static StringRef getMapAttrName() { return "map"; }
+ }];
+}
+
+def AffineLoadOp : AffineLoadOpBase<"load", []> {
let summary = "affine load operation";
let description = [{
The "affine.load" op reads an element from a memref, where the index
@@ -393,9 +430,6 @@ def AffineLoadOp : Affine_Op<"load", []> {
```
}];
- let arguments = (ins Arg<AnyMemRef, "the reference to load from",
- [MemRead]>:$memref,
- Variadic<Index>:$indices);
let results = (outs AnyType:$result);
let builders = [
@@ -410,35 +444,7 @@ def AffineLoadOp : Affine_Op<"load", []> {
"AffineMap map, ValueRange mapOperands">
];
- let extraClassDeclaration = [{
- /// Returns the operand index of the memref.
- unsigned getMemRefOperandIndex() { return 0; }
-
- /// Get memref operand.
- Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
- void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
- MemRefType getMemRefType() {
- return getMemRef().getType().cast<MemRefType>();
- }
-
- /// Get affine map operands.
- operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
-
- /// Returns the affine map used to index the memref for this operation.
- AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
- AffineMapAttr getAffineMapAttr() {
- return getAttr(getMapAttrName()).cast<AffineMapAttr>();
- }
-
- /// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value memref) {
- assert(memref == getMemRef());
- return {Identifier::get(getMapAttrName(), getContext()),
- getAffineMapAttr()};
- }
-
- static StringRef getMapAttrName() { return "map"; }
- }];
+ let extraClassDeclaration = extraClassDeclarationBase;
let hasCanonicalizer = 1;
let hasFolder = 1;
@@ -659,7 +665,45 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
let hasFolder = 1;
}
-def AffineStoreOp : Affine_Op<"store", []> {
+class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
+ Affine_Op<mnemonic, traits> {
+
+ code extraClassDeclarationBase = [{
+ /// Get value to be stored by store operation.
+ Value getValueToStore() { return getOperand(0); }
+
+ /// Returns the operand index of the memref.
+ unsigned getMemRefOperandIndex() { return 1; }
+
+ /// Get memref operand.
+ Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
+ void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
+
+ MemRefType getMemRefType() {
+ return getMemRef().getType().cast<MemRefType>();
+ }
+
+ /// Get affine map operands.
+ operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
+
+ /// Returns the affine map used to index the memref for this operation.
+ AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
+ AffineMapAttr getAffineMapAttr() {
+ return getAttr(getMapAttrName()).cast<AffineMapAttr>();
+ }
+
+ /// Returns the AffineMapAttr associated with 'memref'.
+ NamedAttribute getAffineMapAttrForMemRef(Value memref) {
+ assert(memref == getMemRef());
+ return {Identifier::get(getMapAttrName(), getContext()),
+ getAffineMapAttr()};
+ }
+
+ static StringRef getMapAttrName() { return "map"; }
+ }];
+}
+
+def AffineStoreOp : AffineStoreOpBase<"store", []> {
let summary = "affine store operation";
let description = [{
The "affine.store" op writes an element to a memref, where the index
@@ -686,7 +730,6 @@ def AffineStoreOp : Affine_Op<"store", []> {
[MemWrite]>:$memref,
Variadic<Index>:$indices);
-
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
@@ -696,39 +739,7 @@ def AffineStoreOp : Affine_Op<"store", []> {
"ValueRange mapOperands">
];
- let extraClassDeclaration = [{
- /// Get value to be stored by store operation.
- Value getValueToStore() { return getOperand(0); }
-
- /// Returns the operand index of the memref.
- unsigned getMemRefOperandIndex() { return 1; }
-
- /// Get memref operand.
- Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
- void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
-
- MemRefType getMemRefType() {
- return getMemRef().getType().cast<MemRefType>();
- }
-
- /// Get affine map operands.
- operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
-
- /// Returns the affine map used to index the memref for this operation.
- AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
- AffineMapAttr getAffineMapAttr() {
- return getAttr(getMapAttrName()).cast<AffineMapAttr>();
- }
-
- /// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value memref) {
- assert(memref == getMemRef());
- return {Identifier::get(getMapAttrName(), getContext()),
- getAffineMapAttr()};
- }
-
- static StringRef getMapAttrName() { return "map"; }
- }];
+ let extraClassDeclaration = extraClassDeclarationBase;
let hasCanonicalizer = 1;
let hasFolder = 1;
@@ -765,4 +776,107 @@ def AffineTerminatorOp :
let verifier = ?;
}
+def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> {
+ let summary = "affine vector load operation";
+ let description = [{
+ The "affine.vector_load" is the vector counterpart of
+ [affine.load](#affineload-operation). It reads a slice from a
+ [MemRef](../LangRef.md#memref-type), supplied as its first operand,
+ into a [vector](../LangRef.md#vector-type) of the same base elemental type.
+ The index for each memref dimension is an affine expression of loop induction
+ variables and symbols. These indices determine the start position of the read
+ within the memref. The shape of the return vector type determines the shape of
+ the slice read from the memref. This slice is contiguous along the respective
+ dimensions of the shape. Strided vector loads will be supported in the future.
+ An affine expression of loop IVs and symbols must be specified for each
+ dimension of the memref. The keyword 'symbol' can be used to indicate SSA
+ identifiers which are symbolic.
+
+ Example 1: 8-wide f32 vector load.
+
+ ```mlir
+ %1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32>
+ ```
+
+ Example 2: 4-wide f32 vector load. Uses 'symbol' keyword for symbols '%n' and '%m'.
+
+ ```mlir
+ %1 = affine.vector_load %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32>
+ ```
+
+ Example 3: 2-dim f32 vector load.
+
+ ```mlir
+ %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+ ```
+
+ TODOs:
+ * Add support for strided vector loads.
+ * Consider adding a permutation map to permute the slice that is read from memory
+ (see [vector.transfer_read](../Vector/#vectortransfer_read-vectortransferreadop)).
+ }];
+
+ let results = (outs AnyVector:$result);
+
+ let extraClassDeclaration = extraClassDeclarationBase # [{
+ VectorType getVectorType() {
+ return result().getType().cast<VectorType>();
+ }
+ }];
+}
+
+def AffineVectorStoreOp : AffineStoreOpBase<"vector_store", []> {
+ let summary = "affine vector store operation";
+ let description = [{
+ The "affine.vector_store" is the vector counterpart of
+ [affine.store](#affinestore-affinestoreop). It writes a
+ [vector](../LangRef.md#vector-type), supplied as its first operand,
+ into a slice within a [MemRef](../LangRef.md#memref-type) of the same base
+ elemental type, supplied as its second operand.
+ The index for each memref dimension is an affine expression of loop
+ induction variables and symbols. These indices determine the start position
+ of the write within the memref. The shape of th input vector determines the
+ shape of the slice written to the memref. This slice is contiguous along the
+ respective dimensions of the shape. Strided vector stores will be supported
+ in the future.
+ An affine expression of loop IVs and symbols must be specified for each
+ dimension of the memref. The keyword 'symbol' can be used to indicate SSA
+ identifiers which are symbolic.
+
+ Example 1: 8-wide f32 vector store.
+
+ ```mlir
+ affine.vector_store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32>
+ ```
+
+ Example 2: 4-wide f32 vector store. Uses 'symbol' keyword for symbols '%n' and '%m'.
+
+ ```mlir
+ affine.vector_store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32>
+ ```
+
+ Example 3: 2-dim f32 vector store.
+
+ ```mlir
+ affine.vector_store %v0, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+ ```
+
+ TODOs:
+ * Add support for strided vector stores.
+ * Consider adding a permutation map to permute the slice that is written to memory
+ (see [vector.transfer_write](../Vector/#vectortransfer_write-vectortransferwriteop)).
+ }];
+
+ let arguments = (ins AnyVector:$value,
+ Arg<AnyMemRef, "the reference to store to",
+ [MemWrite]>:$memref,
+ Variadic<Index>:$indices);
+
+ let extraClassDeclaration = extraClassDeclarationBase # [{
+ VectorType getVectorType() {
+ return value().getType().cast<VectorType>();
+ }
+ }];
+}
+
#endif // AFFINE_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index c45c61998446..4c71a168dae7 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -994,6 +994,13 @@ def Vector_TransferReadOp :
```
}];
+ let builders = [
+ // Builder that sets permutation map and padding to 'getMinorIdentityMap'
+ // and zero, respectively, by default.
+ OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
+ "Value memref, ValueRange indices">
+ ];
+
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
@@ -1058,6 +1065,13 @@ def Vector_TransferWriteOp :
```
}];
+ let builders = [
+ // Builder that sets permutation map and padding to 'getMinorIdentityMap'
+ // by default.
+ OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
+ "Value memref, ValueRange indices">
+ ];
+
let extraClassDeclaration = [{
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 862d76c32cc4..96b62969c8cd 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -27,6 +28,7 @@
#include "mlir/Transforms/Passes.h"
using namespace mlir;
+using namespace mlir::vector;
namespace {
/// Visit affine expressions recursively and build the sequence of operations
@@ -556,6 +558,51 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
}
};
+/// Apply the affine map from an 'affine.vector_load' operation to its operands,
+/// and feed the results to a newly created 'vector.transfer_read' operation
+/// (which replaces the original 'affine.vector_load').
+class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
+public:
+ using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AffineVectorLoadOp op,
+ PatternRewriter &rewriter) const override {
+ // Expand affine map from 'affineVectorLoadOp'.
+ SmallVector<Value, 8> indices(op.getMapOperands());
+ auto resultOperands =
+ expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+ if (!resultOperands)
+ return failure();
+
+ // Build vector.transfer_read memref[expandedMap.results].
+ rewriter.replaceOpWithNewOp<TransferReadOp>(
+ op, op.getVectorType(), op.getMemRef(), *resultOperands);
+ return success();
+ }
+};
+
+/// Apply the affine map from an 'affine.vector_store' operation to its
+/// operands, and feed the results to a newly created 'vector.transfer_write'
+/// operation (which replaces the original 'affine.vector_store').
+class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
+public:
+ using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AffineVectorStoreOp op,
+ PatternRewriter &rewriter) const override {
+ // Expand affine map from 'affineVectorStoreOp'.
+ SmallVector<Value, 8> indices(op.getMapOperands());
+ auto maybeExpandedMap =
+ expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+ if (!maybeExpandedMap)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<TransferWriteOp>(
+ op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
+ return success();
+ }
+};
+
} // end namespace
void mlir::populateAffineToStdConversionPatterns(
@@ -576,13 +623,24 @@ void mlir::populateAffineToStdConversionPatterns(
// clang-format on
}
+void mlir::populateAffineToVectorConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ // clang-format off
+ patterns.insert<
+ AffineVectorLoadLowering,
+ AffineVectorStoreLowering>(ctx);
+ // clang-format on
+}
+
namespace {
class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
+ populateAffineToVectorConversionPatterns(patterns, &getContext());
ConversionTarget target(getContext());
- target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
+ target
+ .addLegalDialect<scf::SCFDialect, StandardOpsDialect, VectorDialect>();
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 16f4a3c6068e..27f4450924b6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1912,32 +1912,47 @@ void print(OpAsmPrinter &p, AffineLoadOp op) {
p << " : " << op.getMemRefType();
}
-LogicalResult verify(AffineLoadOp op) {
- if (op.getType() != op.getMemRefType().getElementType())
- return op.emitOpError("result type must match element type of memref");
-
- auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
+/// Verify common indexing invariants of affine.load, affine.store,
+/// affine.vector_load and affine.vector_store.
+static LogicalResult
+verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
+ Operation::operand_range mapOperands,
+ MemRefType memrefType, unsigned numIndexOperands) {
if (mapAttr) {
- AffineMap map =
- op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()).getValue();
- if (map.getNumResults() != op.getMemRefType().getRank())
- return op.emitOpError("affine.load affine map num results must equal"
- " memref rank");
- if (map.getNumInputs() != op.getNumOperands() - 1)
- return op.emitOpError("expects as many subscripts as affine map inputs");
+ AffineMap map = mapAttr.getValue();
+ if (map.getNumResults() != memrefType.getRank())
+ return op->emitOpError("affine map num results must equal memref rank");
+ if (map.getNumInputs() != numIndexOperands)
+ return op->emitOpError("expects as many subscripts as affine map inputs");
} else {
- if (op.getMemRefType().getRank() != op.getNumOperands() - 1)
- return op.emitOpError(
+ if (memrefType.getRank() != numIndexOperands)
+ return op->emitOpError(
"expects the number of subscripts to be equal to memref rank");
}
Region *scope = getAffineScope(op);
- for (auto idx : op.getMapOperands()) {
+ for (auto idx : mapOperands) {
if (!idx.getType().isIndex())
- return op.emitOpError("index to load must have 'index' type");
+ return op->emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return op.emitOpError("index must be a dimension or symbol identifier");
+ return op->emitOpError("index must be a dimension or symbol identifier");
}
+
+ return success();
+}
+
+LogicalResult verify(AffineLoadOp op) {
+ auto memrefType = op.getMemRefType();
+ if (op.getType() != memrefType.getElementType())
+ return op.emitOpError("result type must match element type of memref");
+
+ if (failed(verifyMemoryOpIndexing(
+ op.getOperation(),
+ op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+ op.getMapOperands(), memrefType,
+ /*numIndexOperands=*/op.getNumOperands() - 1)))
+ return failure();
+
return success();
}
@@ -2014,31 +2029,18 @@ void print(OpAsmPrinter &p, AffineStoreOp op) {
LogicalResult verify(AffineStoreOp op) {
// First operand must have same type as memref element type.
- if (op.getValueToStore().getType() != op.getMemRefType().getElementType())
+ auto memrefType = op.getMemRefType();
+ if (op.getValueToStore().getType() != memrefType.getElementType())
return op.emitOpError(
"first operand must have same type memref element type");
- auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
- if (mapAttr) {
- AffineMap map = mapAttr.getValue();
- if (map.getNumResults() != op.getMemRefType().getRank())
- return op.emitOpError("affine.store affine map num results must equal"
- " memref rank");
- if (map.getNumInputs() != op.getNumOperands() - 2)
- return op.emitOpError("expects as many subscripts as affine map inputs");
- } else {
- if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
- return op.emitOpError(
- "expects the number of subscripts to be equal to memref rank");
- }
+ if (failed(verifyMemoryOpIndexing(
+ op.getOperation(),
+ op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+ op.getMapOperands(), memrefType,
+ /*numIndexOperands=*/op.getNumOperands() - 2)))
+ return failure();
- Region *scope = getAffineScope(op);
- for (auto idx : op.getMapOperands()) {
- if (!idx.getType().isIndex())
- return op.emitOpError("index to store must have 'index' type");
- if (!isValidAffineIndexOperand(idx, scope))
- return op.emitOpError("index must be a dimension or symbol identifier");
- }
return success();
}
@@ -2493,6 +2495,125 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
return success();
}
+//===----------------------------------------------------------------------===//
+// AffineVectorLoadOp
+//===----------------------------------------------------------------------===//
+
+ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ auto indexTy = builder.getIndexType();
+
+ MemRefType memrefType;
+ VectorType resultType;
+ OpAsmParser::OperandType memrefInfo;
+ AffineMapAttr mapAttr;
+ SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+ return failure(
+ parser.parseOperand(memrefInfo) ||
+ parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
+ AffineVectorLoadOp::getMapAttrName(),
+ result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(memrefType) || parser.parseComma() ||
+ parser.parseType(resultType) ||
+ parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
+ parser.resolveOperands(mapOperands, indexTy, result.operands) ||
+ parser.addTypeToList(resultType, result.types));
+}
+
+void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
+ p << "affine.vector_load " << op.getMemRef() << '[';
+ if (AffineMapAttr mapAttr =
+ op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
+ p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
+ p << ']';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
+ p << " : " << op.getMemRefType() << ", " << op.getType();
+}
+
+/// Verify common invariants of affine.vector_load and affine.vector_store.
+static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
+ VectorType vectorType) {
+ // Check that memref and vector element types match.
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return op->emitOpError(
+ "requires memref and vector types of the same elemental type");
+
+ return success();
+}
+
+static LogicalResult verify(AffineVectorLoadOp op) {
+ MemRefType memrefType = op.getMemRefType();
+ if (failed(verifyMemoryOpIndexing(
+ op.getOperation(),
+ op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+ op.getMapOperands(), memrefType,
+ /*numIndexOperands=*/op.getNumOperands() - 1)))
+ return failure();
+
+ if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
+ op.getVectorType())))
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineVectorStoreOp
+//===----------------------------------------------------------------------===//
+
+ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
+ OperationState &result) {
+ auto indexTy = parser.getBuilder().getIndexType();
+
+ MemRefType memrefType;
+ VectorType resultType;
+ OpAsmParser::OperandType storeValueInfo;
+ OpAsmParser::OperandType memrefInfo;
+ AffineMapAttr mapAttr;
+ SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+ return failure(
+ parser.parseOperand(storeValueInfo) || parser.parseComma() ||
+ parser.parseOperand(memrefInfo) ||
+ parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
+ AffineVectorStoreOp::getMapAttrName(),
+ result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(memrefType) || parser.parseComma() ||
+ parser.parseType(resultType) ||
+ parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
+ parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
+ parser.resolveOperands(mapOperands, indexTy, result.operands));
+}
+
+void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
+ p << "affine.vector_store " << op.getValueToStore();
+ p << ", " << op.getMemRef() << '[';
+ if (AffineMapAttr mapAttr =
+ op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
+ p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
+ p << ']';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
+ p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
+}
+
+static LogicalResult verify(AffineVectorStoreOp op) {
+ MemRefType memrefType = op.getMemRefType();
+ if (failed(verifyMemoryOpIndexing(
+ op.getOperation(),
+ op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
+ op.getMapOperands(), memrefType,
+ /*numIndexOperands=*/op.getNumOperands() - 2)))
+ return failure();
+
+ if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
+ op.getVectorType())))
+ return failure();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index ec4ef0f50361..8385b253e9fb 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1284,6 +1284,21 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
return success();
}
+/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and
+/// zero, respectively, by default.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vector, Value memref,
+ ValueRange indices) {
+ auto permMap = AffineMap::getMinorIdentityMap(
+ memref.getType().cast<MemRefType>().getRank(), vector.getRank(),
+ builder.getContext());
+ Type elemType = vector.cast<VectorType>().getElementType();
+ Value padding = builder.create<ConstantOp>(result.location, elemType,
+ builder.getZeroAttr(elemType));
+
+ build(builder, result, vector, memref, indices, permMap, padding);
+}
+
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
<< "], " << op.padding() << " ";
@@ -1361,6 +1376,17 @@ static LogicalResult verify(TransferReadOp op) {
// TransferWriteOp
//===----------------------------------------------------------------------===//
+/// Builder that sets permutation map and padding to 'getMinorIdentityMap' by
+/// default.
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value memref, ValueRange indices) {
+ auto vectorType = vector.getType().cast<VectorType>();
+ auto permMap = AffineMap::getMinorIdentityMap(
+ memref.getType().cast<MemRefType>().getRank(), vectorType.getRank(),
+ builder.getContext());
+ build(builder, result, vector, memref, indices, permMap);
+}
+
static LogicalResult verify(TransferWriteOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
new file mode 100644
index 000000000000..f9a78aa495a5
--- /dev/null
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
@@ -0,0 +1,117 @@
+// RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_load
+func @affine_vector_load(%arg0 : index) {
+ %0 = alloc() : memref<100xf32>
+ affine.for %i0 = 0 to 16 {
+ %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
+ }
+// CHECK: %[[buf:.*]] = alloc
+// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index
+// CHECK-NEXT: %[[pad:.*]] = constant 0.0
+// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100xf32>, vector<8xf32>
+ return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_store
+func @affine_vector_store(%arg0 : index) {
+ %0 = alloc() : memref<100xf32>
+ %1 = constant dense<11.0> : vector<4xf32>
+ affine.for %i0 = 0 to 16 {
+ affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
+}
+// CHECK: %[[buf:.*]] = alloc
+// CHECK: %[[val:.*]] = constant dense
+// CHECK: %[[c_1:.*]] = constant -1 : index
+// CHECK-NEXT: %[[a:.*]] = muli %arg0, %[[c_1]] : index
+// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index
+// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index
+// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] {permutation_map = #[[perm_map]]} : vector<4xf32>, memref<100xf32>
+ return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_load
+func @affine_vector_load(%arg0 : index) {
+ %0 = alloc() : memref<100xf32>
+ affine.for %i0 = 0 to 16 {
+ %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
+ }
+// CHECK: %[[buf:.*]] = alloc
+// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index
+// CHECK-NEXT: %[[pad:.*]] = constant 0.0
+// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100xf32>, vector<8xf32>
+ return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @affine_vector_store
+func @affine_vector_store(%arg0 : index) {
+ %0 = alloc() : memref<100xf32>
+ %1 = constant dense<11.0> : vector<4xf32>
+ affine.for %i0 = 0 to 16 {
+ affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
+}
+// CHECK: %[[buf:.*]] = alloc
+// CHECK: %[[val:.*]] = constant dense
+// CHECK: %[[c_1:.*]] = constant -1 : index
+// CHECK-NEXT: %[[a:.*]] = muli %arg0, %[[c_1]] : index
+// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index
+// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
+// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index
+// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] {permutation_map = #[[perm_map]]} : vector<4xf32>, memref<100xf32>
+ return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @vector_load_2d
+func @vector_load_2d() {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 16 step 2{
+ affine.for %i1 = 0 to 16 step 8 {
+ %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK: %[[buf:.*]] = alloc
+// CHECK: scf.for %[[i0:.*]] =
+// CHECK: scf.for %[[i1:.*]] =
+// CHECK-NEXT: %[[pad:.*]] = constant 0.0
+// CHECK-NEXT: vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100x100xf32>, vector<2x8xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: #[[perm_map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @vector_store_2d
+func @vector_store_2d() {
+ %0 = alloc() : memref<100x100xf32>
+ %1 = constant dense<11.0> : vector<2x8xf32>
+ affine.for %i0 = 0 to 16 step 2{
+ affine.for %i1 = 0 to 16 step 8 {
+ affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK: %[[buf:.*]] = alloc
+// CHECK: %[[val:.*]] = constant dense
+// CHECK: scf.for %[[i0:.*]] =
+// CHECK: scf.for %[[i1:.*]] =
+// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[i0]], %[[i1]]] {permutation_map = #[[perm_map]]} : vector<2x8xf32>, memref<100x100xf32>
+ }
+ }
+ return
+}
+
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index c0855987ac32..102dd394f93c 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -262,3 +262,49 @@ func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
}
return
}
+
+// -----
+
+func @vector_load_invalid_vector_type() {
+ %0 = alloc() : memref<100xf32>
+ affine.for %i0 = 0 to 16 step 8 {
+ // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+ %1 = affine.vector_load %0[%i0] : memref<100xf32>, vector<8xf64>
+ }
+ return
+}
+
+// -----
+
+func @vector_store_invalid_vector_type() {
+ %0 = alloc() : memref<100xf32>
+ %1 = constant dense<7.0> : vector<8xf64>
+ affine.for %i0 = 0 to 16 step 8 {
+ // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+ affine.vector_store %1, %0[%i0] : memref<100xf32>, vector<8xf64>
+ }
+ return
+}
+
+// -----
+
+func @vector_load_vector_memref() {
+ %0 = alloc() : memref<100xvector<8xf32>>
+ affine.for %i0 = 0 to 4 {
+ // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+ %1 = affine.vector_load %0[%i0] : memref<100xvector<8xf32>>, vector<8xf32>
+ }
+ return
+}
+
+// -----
+
+func @vector_store_vector_memref() {
+ %0 = alloc() : memref<100xvector<8xf32>>
+ %1 = constant dense<7.0> : vector<8xf32>
+ affine.for %i0 = 0 to 4 {
+ // expected-error at +1 {{requires memref and vector types of the same elemental type}}
+ affine.vector_store %1, %0[%i0] : memref<100xvector<8xf32>>, vector<8xf32>
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Affine/load-store.mlir b/mlir/test/Dialect/Affine/load-store.mlir
index 54e753d17fef..06a93d41fd64 100644
--- a/mlir/test/Dialect/Affine/load-store.mlir
+++ b/mlir/test/Dialect/Affine/load-store.mlir
@@ -214,3 +214,65 @@ func @test_prefetch(%arg0 : index, %arg1 : index) {
}
return
}
+
+// -----
+
+// CHECK: [[MAP_ID:#map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// Test with just loop IVs.
+func @vector_load_vector_store_iv() {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 16 {
+ affine.for %i1 = 0 to 16 step 8 {
+ %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<8xf32>
+ affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<8xf32>
+// CHECK: %[[buf:.*]] = alloc
+// CHECK-NEXT: affine.for %[[i0:.*]] = 0
+// CHECK-NEXT: affine.for %[[i1:.*]] = 0
+// CHECK-NEXT: %[[val:.*]] = affine.vector_load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<8xf32>
+// CHECK-NEXT: affine.vector_store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<8xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 + 3, d1 + 7)>
+
+// Test with loop IVs and constants.
+func @vector_load_vector_store_iv_constant() {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 16 step 4 {
+ %1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<4xf32>
+ affine.vector_store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<4xf32>
+// CHECK: %[[buf:.*]] = alloc
+// CHECK-NEXT: affine.for %[[i0:.*]] = 0
+// CHECK-NEXT: affine.for %[[i1:.*]] = 0
+// CHECK-NEXT: %[[val:.*]] = affine.vector_load %{{.*}}[%{{.*}} + 3, %{{.*}} + 7] : memref<100x100xf32>, vector<4xf32>
+// CHECK-NEXT: affine.vector_store %[[val]], %[[buf]][%[[i0]] + 3, %[[i1]] + 7] : memref<100x100xf32>, vector<4xf32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+func @vector_load_vector_store_2d() {
+ %0 = alloc() : memref<100x100xf32>
+ affine.for %i0 = 0 to 16 step 2{
+ affine.for %i1 = 0 to 16 step 8 {
+ %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+ affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK: %[[buf:.*]] = alloc
+// CHECK-NEXT: affine.for %[[i0:.*]] = 0
+// CHECK-NEXT: affine.for %[[i1:.*]] = 0
+// CHECK-NEXT: %[[val:.*]] = affine.vector_load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
+// CHECK-NEXT: affine.vector_store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
+ }
+ }
+ return
+}
More information about the llvm-branch-commits
mailing list