[Mlir-commits] [mlir] a45fb19 - [mlir][Affine] Introduce affine memory interfaces
Diego Caballero
llvmlistbot at llvm.org
Tue May 19 17:44:40 PDT 2020
Author: Diego Caballero
Date: 2020-05-19T17:32:50-07:00
New Revision: a45fb1942fc5d21dfcdc37b99ab98778d3b16b79
URL: https://github.com/llvm/llvm-project/commit/a45fb1942fc5d21dfcdc37b99ab98778d3b16b79
DIFF: https://github.com/llvm/llvm-project/commit/a45fb1942fc5d21dfcdc37b99ab98778d3b16b79.diff
LOG: [mlir][Affine] Introduce affine memory interfaces
This patch introduces interfaces for read and write ops with affine
restrictions. I used `read`/`write` intead of `load`/`store` for the
interfaces so that they can also be implemented by dma ops.
For now, they are only implemented by affine.load, affine.store,
affine.vector_load and affine.vector_store.
For testing purposes, this patch also migrates affine loop fusion and
required analysis to use the new interfaces. No other changes are made
beyond that.
Co-authored-by: Alex Zinenko <zinenko at google.com>
Reviewed By: bondhugula, ftynse
Differential Revision: https://reviews.llvm.org/D79829
Added:
mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h
mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Dialect/Affine/IR/CMakeLists.txt
mlir/lib/Transforms/LoopFusion.cpp
mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h
new file mode 100644
index 000000000000..f42fc256befa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h
@@ -0,0 +1,24 @@
+//===- AffineMemoryOpInterfaces.h -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces for affine memory ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_
+#define MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h.inc"
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
new file mode 100644
index 000000000000..8738000d8d5f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
@@ -0,0 +1,128 @@
+//===- AffineMemoryOpInterfaces.td -------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces for affine memory ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_AFFINEMEMORYOPINTERFACES
+#define MLIR_AFFINEMEMORYOPINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> {
+ let description = [{
+ Interface to query characteristics of read-like ops with affine
+ restrictions.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{ Returns the memref operand to read from. }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getMemRef",
+ /*args=*/(ins),
+ /*methodBody*/[{}],
+ /*defaultImplementation=*/ [{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return op.getOperand(op.getMemRefOperandIndex());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Returns the type of the memref operand. }],
+ /*retTy=*/"MemRefType",
+ /*methodName=*/"getMemRefType",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return op.getMemRef().getType().template cast<MemRefType>();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Returns affine map operands. }],
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getMapOperands",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return llvm::drop_begin(op.getOperands(), 1);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Returns the affine map used to index the memref for this
+ operation. }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getAffineMap",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return op.getAffineMapAttr().getValue();
+ }]
+ >,
+ ];
+}
+
+def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
+ let description = [{
+ Interface to query characteristics of write-like ops with affine
+ restrictions.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{ Returns the memref operand to write to. }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getMemRef",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return op.getOperand(op.getMemRefOperandIndex());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Returns the type of the memref operand. }],
+ /*retTy=*/"MemRefType",
+ /*methodName=*/"getMemRefType",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return op.getMemRef().getType().template cast<MemRefType>();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Returns affine map operands. }],
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getMapOperands",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return llvm::drop_begin(op.getOperands(), 2);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Returns the affine map used to index the memref for this
+ operation. }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getAffineMap",
+ /*args=*/(ins),
+ /*methodName=*/[{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ return op.getAffineMapAttr().getValue();
+ }]
+ >,
+ ];
+}
+
+#endif // MLIR_AFFINEMEMORYOPINTERFACES
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 4910fc92ea09..808fe892b4c1 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
#define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8286d8f315bd..3366ef47a8be 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -14,6 +14,7 @@
#define AFFINE_OPS
include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
+include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -371,7 +372,8 @@ def AffineIfOp : Affine_Op<"if",
}
class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
- Affine_Op<mnemonic, traits> {
+ Affine_Op<mnemonic, !listconcat(traits,
+ [DeclareOpInterfaceMethods<AffineReadOpInterface>])> {
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices);
@@ -380,18 +382,9 @@ class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 0; }
- /// Get memref operand.
- Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
- MemRefType getMemRefType() {
- return getMemRef().getType().cast<MemRefType>();
- }
-
- /// Get affine map operands.
- operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
/// Returns the affine map used to index the memref for this operation.
- AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
@@ -407,7 +400,7 @@ class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
}];
}
-def AffineLoadOp : AffineLoadOpBase<"load", []> {
+def AffineLoadOp : AffineLoadOpBase<"load"> {
let summary = "affine load operation";
let description = [{
The "affine.load" op reads an element from a memref, where the index
@@ -666,8 +659,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
}
class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
- Affine_Op<mnemonic, traits> {
-
+ Affine_Op<mnemonic, !listconcat(traits,
+ [DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
code extraClassDeclarationBase = [{
/// Get value to be stored by store operation.
Value getValueToStore() { return getOperand(0); }
@@ -675,19 +668,9 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 1; }
- /// Get memref operand.
- Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
- MemRefType getMemRefType() {
- return getMemRef().getType().cast<MemRefType>();
- }
-
- /// Get affine map operands.
- operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
-
/// Returns the affine map used to index the memref for this operation.
- AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
@@ -703,7 +686,7 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
}];
}
-def AffineStoreOp : AffineStoreOpBase<"store", []> {
+def AffineStoreOp : AffineStoreOpBase<"store"> {
let summary = "affine store operation";
let description = [{
The "affine.store" op writes an element to a memref, where the index
@@ -776,7 +759,7 @@ def AffineTerminatorOp :
let verifier = ?;
}
-def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> {
+def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
let summary = "affine vector load operation";
let description = [{
The "affine.vector_load" is the vector counterpart of
@@ -825,7 +808,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> {
}];
}
-def AffineVectorStoreOp : AffineStoreOpBase<"vector_store", []> {
+def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
let summary = "affine vector store operation";
let description = [{
The "affine.vector_store" is the vector counterpart of
diff --git a/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
index 1fd2a505d850..77806274f14c 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
@@ -1,2 +1,10 @@
add_mlir_dialect(AffineOps affine)
add_mlir_doc(AffineOps -gen-op-doc AffineOps Dialects/)
+
+set(LLVM_TARGET_DEFINITIONS AffineMemoryOpInterfaces.td)
+mlir_tablegen(AffineMemoryOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(AffineMemoryOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRAffineMemoryOpInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIRAffineMemoryOpInterfacesIncGen)
+
+add_dependencies(MLIRAffineOpsIncGen MLIRAffineMemoryOpInterfacesIncGen)
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 5a395937101f..8c4828805882 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -660,10 +660,12 @@ static void computeDirectionVector(
void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// Get affine map from AffineLoad/Store.
AffineMap map;
- if (auto loadOp = dyn_cast<AffineLoadOp>(opInst))
+ if (auto loadOp = dyn_cast<AffineReadOpInterface>(opInst)) {
map = loadOp.getAffineMap();
- else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst))
+ } else {
+ auto storeOp = cast<AffineWriteOpInterface>(opInst);
map = storeOp.getAffineMap();
+ }
SmallVector<Value, 8> operands(indices.begin(), indices.end());
fullyComposeAffineMapAndOperands(&map, &operands);
map = simplifyAffineMap(map);
@@ -771,9 +773,10 @@ DependenceResult mlir::checkMemrefAccessDependence(
if (srcAccess.memref != dstAccess.memref)
return DependenceResult::NoDependence;
- // Return 'NoDependence' if one of these accesses is not an AffineStoreOp.
- if (!allowRAR && !isa<AffineStoreOp>(srcAccess.opInst) &&
- !isa<AffineStoreOp>(dstAccess.opInst))
+ // Return 'NoDependence' if one of these accesses is not an
+ // AffineWriteOpInterface.
+ if (!allowRAR && !isa<AffineWriteOpInterface>(srcAccess.opInst) &&
+ !isa<AffineWriteOpInterface>(dstAccess.opInst))
return DependenceResult::NoDependence;
// Get composed access function for 'srcAccess'.
@@ -857,7 +860,8 @@ void mlir::getDependenceComponents(
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Operation *, 8> loadAndStoreOpInsts;
forOp.getOperation()->walk([&](Operation *opInst) {
- if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+ if (isa<AffineReadOpInterface>(opInst) ||
+ isa<AffineWriteOpInterface>(opInst))
loadAndStoreOpInsts.push_back(opInst);
});
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index e6d7127762d5..b43d0bedec56 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -196,8 +196,8 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
ComputationSliceState *sliceState,
bool addMemRefDimBounds) {
- assert((isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) &&
- "affine load/store op expected");
+ assert((isa<AffineReadOpInterface>(op) || isa<AffineWriteOpInterface>(op)) &&
+ "affine read/write op expected");
MemRefAccess access(op);
memref = access.memref;
@@ -404,9 +404,10 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
template <typename LoadOrStoreOp>
LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
bool emitError) {
- static_assert(
- llvm::is_one_of<LoadOrStoreOp, AffineLoadOp, AffineStoreOp>::value,
- "argument should be either a AffineLoadOp or a AffineStoreOp");
+ static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
+ AffineWriteOpInterface>::value,
+ "argument should be either a AffineReadOpInterface or a "
+ "AffineWriteOpInterface");
Operation *op = loadOrStoreOp.getOperation();
MemRefRegion region(op->getLoc());
@@ -456,10 +457,10 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
}
// Explicitly instantiate the template so that the compiler knows we need them!
-template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp,
- bool emitError);
-template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp,
- bool emitError);
+template LogicalResult
+mlir::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError);
+template LogicalResult
+mlir::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError);
// Returns in 'positions' the Block positions of 'op' in each ancestor
// Block from the Block containing operation, stopping at 'limitBlock'.
@@ -575,8 +576,8 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
return failure();
}
- bool readReadAccesses = isa<AffineLoadOp>(srcAccess.opInst) &&
- isa<AffineLoadOp>(dstAccess.opInst);
+ bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
+ isa<AffineReadOpInterface>(dstAccess.opInst);
FlatAffineConstraints dependenceConstraints;
// Check dependence between 'srcAccess' and 'dstAccess'.
DependenceResult result = checkMemrefAccessDependence(
@@ -768,7 +769,8 @@ void mlir::getComputationSliceState(
: std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
llvm::SmallDenseSet<Value, 8> sequentialLoops;
- if (isa<AffineLoadOp>(depSourceOp) && isa<AffineLoadOp>(depSinkOp)) {
+ if (isa<AffineReadOpInterface>(depSourceOp) &&
+ isa<AffineReadOpInterface>(depSinkOp)) {
// For read-read access pairs, clear any slice bounds on sequential loops.
// Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
@@ -865,7 +867,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
// Constructs MemRefAccess populating it with the memref, its indices and
// opinst from 'loadOrStoreOpInst'.
MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
- if (auto loadOp = dyn_cast<AffineLoadOp>(loadOrStoreOpInst)) {
+ if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
memref = loadOp.getMemRef();
opInst = loadOrStoreOpInst;
auto loadMemrefType = loadOp.getMemRefType();
@@ -874,8 +876,9 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
indices.push_back(index);
}
} else {
- assert(isa<AffineStoreOp>(loadOrStoreOpInst) && "load/store op expected");
- auto storeOp = dyn_cast<AffineStoreOp>(loadOrStoreOpInst);
+ assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
+ "Affine read/write op expected");
+ auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
opInst = loadOrStoreOpInst;
memref = storeOp.getMemRef();
auto storeMemrefType = storeOp.getMemRefType();
@@ -890,7 +893,9 @@ unsigned MemRefAccess::getRank() const {
return memref.getType().cast<MemRefType>().getRank();
}
-bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); }
+bool MemRefAccess::isStore() const {
+ return isa<AffineWriteOpInterface>(opInst);
+}
/// Returns the nesting depth of this statement, i.e., the number of loops
/// surrounding this statement.
@@ -947,7 +952,8 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
// Walk this 'affine.for' operation to gather all memory regions.
auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
- if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) {
+ if (!isa<AffineReadOpInterface>(opInst) &&
+ !isa<AffineWriteOpInterface>(opInst)) {
// Neither load nor a store op.
return WalkResult::advance();
}
@@ -1007,7 +1013,8 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Operation *, 8> loadAndStoreOpInsts;
auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
- if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+ if (isa<AffineReadOpInterface>(opInst) ||
+ isa<AffineWriteOpInterface>(opInst))
loadAndStoreOpInsts.push_back(opInst);
else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
!isa<AffineIfOp>(opInst) &&
diff --git a/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp b/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp
new file mode 100644
index 000000000000..6f5861efa956
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp
@@ -0,0 +1,18 @@
+//===- AffineMemoryOpInterfaces.cpp - Loop-like operations in MLIR --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Affine Memory Op Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the definitions of the affine memory op interfaces.
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp.inc"
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 12105b40571b..20bc86366668 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRAffineOps
+ AffineMemoryOpInterfaces.cpp
AffineOps.cpp
AffineValueMap.cpp
@@ -6,6 +7,7 @@ add_mlir_dialect_library(MLIRAffineOps
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
DEPENDS
+ MLIRAffineMemoryOpInterfacesIncGen
MLIRAffineOpsIncGen
LINK_LIBS PUBLIC
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 72dfc1d62faf..bb219fa07711 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -70,7 +70,7 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace,
// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
static bool isMemRefDereferencingOp(Operation &op) {
- if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+ if (isa<AffineReadOpInterface>(op) || isa<AffineWriteOpInterface>(op) ||
isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
return true;
return false;
@@ -92,9 +92,9 @@ struct LoopNestStateCollector {
forOps.push_back(cast<AffineForOp>(op));
else if (op->getNumRegions() != 0)
hasNonForRegion = true;
- else if (isa<AffineLoadOp>(op))
+ else if (isa<AffineReadOpInterface>(op))
loadOpInsts.push_back(op);
- else if (isa<AffineStoreOp>(op))
+ else if (isa<AffineWriteOpInterface>(op))
storeOpInsts.push_back(op);
});
}
@@ -125,7 +125,7 @@ struct MemRefDependenceGraph {
unsigned getLoadOpCount(Value memref) {
unsigned loadOpCount = 0;
for (auto *loadOpInst : loads) {
- if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+ if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
++loadOpCount;
}
return loadOpCount;
@@ -135,7 +135,7 @@ struct MemRefDependenceGraph {
unsigned getStoreOpCount(Value memref) {
unsigned storeOpCount = 0;
for (auto *storeOpInst : stores) {
- if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+ if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
++storeOpCount;
}
return storeOpCount;
@@ -145,7 +145,7 @@ struct MemRefDependenceGraph {
void getStoreOpsForMemref(Value memref,
SmallVectorImpl<Operation *> *storeOps) {
for (auto *storeOpInst : stores) {
- if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+ if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
storeOps->push_back(storeOpInst);
}
}
@@ -154,7 +154,7 @@ struct MemRefDependenceGraph {
void getLoadOpsForMemref(Value memref,
SmallVectorImpl<Operation *> *loadOps) {
for (auto *loadOpInst : loads) {
- if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+ if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
loadOps->push_back(loadOpInst);
}
}
@@ -164,10 +164,10 @@ struct MemRefDependenceGraph {
void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
llvm::SmallDenseSet<Value, 2> loadMemrefs;
for (auto *loadOpInst : loads) {
- loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
+ loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef());
}
for (auto *storeOpInst : stores) {
- auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
@@ -259,7 +259,7 @@ struct MemRefDependenceGraph {
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
- auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
auto *op = memref.getDefiningOp();
// Return true if 'memref' is a block argument.
if (!op)
@@ -272,13 +272,14 @@ struct MemRefDependenceGraph {
return false;
}
- // Returns the unique AffineStoreOp in `node` that meets all the following:
+ // Returns the unique AffineWriteOpInterface in `node` that meets all the
+ // following:
// *) store is the only one that writes to a function-local memref live out
// of `node`,
// *) store is not the source of a self-dependence on `node`.
- // Otherwise, returns a null AffineStoreOp.
- AffineStoreOp getUniqueOutgoingStore(Node *node) {
- AffineStoreOp uniqueStore;
+ // Otherwise, returns a null AffineWriteOpInterface.
+ AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
+ AffineWriteOpInterface uniqueStore;
// Return null if `node` doesn't have any outgoing edges.
auto outEdgeIt = outEdges.find(node->id);
@@ -287,7 +288,7 @@ struct MemRefDependenceGraph {
const auto &nodeOutEdges = outEdgeIt->second;
for (auto *op : node->stores) {
- auto storeOp = cast<AffineStoreOp>(op);
+ auto storeOp = cast<AffineWriteOpInterface>(op);
auto memref = storeOp.getMemRef();
// Skip this store if there are no dependences on its memref. This means
// that store either:
@@ -322,7 +323,8 @@ struct MemRefDependenceGraph {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
// Return false if there exist out edges from 'id' on 'memref'.
- if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
+ auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+ if (getOutEdgeCount(id, storeMemref) > 0)
return false;
}
return true;
@@ -651,28 +653,28 @@ bool MemRefDependenceGraph::init(FuncOp f) {
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
- auto memref = cast<AffineLoadOp>(opInst).getMemRef();
+ auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
- auto memref = cast<AffineStoreOp>(opInst).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
nodes.insert({node.id, node});
- } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+ } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
- auto memref = cast<AffineLoadOp>(op).getMemRef();
+ auto memref = cast<AffineReadOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+ } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
- auto memref = cast<AffineStoreOp>(op).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
@@ -733,7 +735,7 @@ static void moveLoadsAccessingMemrefTo(Value memref,
dstLoads->clear();
SmallVector<Operation *, 4> srcLoadsToKeep;
for (auto *load : *srcLoads) {
- if (cast<AffineLoadOp>(load).getMemRef() == memref)
+ if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
dstLoads->push_back(load);
else
srcLoadsToKeep.push_back(load);
@@ -854,7 +856,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// Builder to create constants at the top level.
OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
// Create new memref type based on slice bounds.
- auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
unsigned rank = oldMemRefType.getRank();
@@ -962,9 +964,10 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// Returns true if 'dstNode's read/write region to 'memref' is a super set of
// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
// TODO(andydavis) Generalize this to handle more live in/out cases.
-static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
- AffineStoreOp srcLiveOutStoreOp,
- MemRefDependenceGraph *mdg) {
+static bool
+canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+ AffineWriteOpInterface srcLiveOutStoreOp,
+ MemRefDependenceGraph *mdg) {
assert(srcLiveOutStoreOp && "Expected a valid store op");
auto *dstNode = mdg->getNode(dstId);
Value memref = srcLiveOutStoreOp.getMemRef();
@@ -1450,7 +1453,7 @@ struct GreedyFusion {
DenseSet<Value> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
- auto memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+ auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
if (visitedMemrefs.count(memref) > 0)
continue;
visitedMemrefs.insert(memref);
@@ -1488,7 +1491,7 @@ struct GreedyFusion {
// feasibility for loops with multiple stores.
unsigned maxLoopDepth = 0;
for (auto *op : srcNode->stores) {
- auto storeOp = cast<AffineStoreOp>(op);
+ auto storeOp = cast<AffineWriteOpInterface>(op);
if (storeOp.getMemRef() != memref) {
srcStoreOp = nullptr;
break;
@@ -1563,7 +1566,7 @@ struct GreedyFusion {
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Operation *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
- if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+ if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref)
dstStoreOpInsts.push_back(storeOpInst);
unsigned bestDstLoopDepth;
@@ -1601,7 +1604,8 @@ struct GreedyFusion {
// Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<Operation *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
- if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+ if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
+ memref)
storesForMemref.push_back(storeOpInst);
}
// TODO(andydavis) Use union of memref write regions to compute
@@ -1624,7 +1628,8 @@ struct GreedyFusion {
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
- auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+ auto loadMemRef =
+ cast<AffineReadOpInterface>(loadOpInst).getMemRef();
// NOTE: Change 'loads' to a hash set in case efficiency is an
// issue. We still use a vector since it's expected to be small.
if (visitedMemrefs.count(loadMemRef) == 0 &&
@@ -1785,7 +1790,8 @@ struct GreedyFusion {
// Check that all stores are to the same memref.
DenseSet<Value> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
- storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
+ storeMemrefs.insert(
+ cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
}
if (storeMemrefs.size() != 1)
return false;
@@ -1796,7 +1802,7 @@ struct GreedyFusion {
auto fn = dstNode->op->getParentOfType<FuncOp>();
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
for (auto *user : fn.getArgument(i).getUsers()) {
- if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+ if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);
diff --git a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp
index da140b3486f0..087ea4fdde94 100644
--- a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp
+++ b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp
@@ -37,8 +37,9 @@ struct TestMemRefBoundCheck
void TestMemRefBoundCheck::runOnFunction() {
getFunction().walk([](Operation *opInst) {
- TypeSwitch<Operation *>(opInst).Case<AffineLoadOp, AffineStoreOp>(
- [](auto op) { boundCheckLoadOrStoreOp(op); });
+ TypeSwitch<Operation *>(opInst)
+ .Case<AffineReadOpInterface, AffineWriteOpInterface>(
+ [](auto op) { boundCheckLoadOrStoreOp(op); });
// TODO(bondhugula): do this for DMA ops as well.
});
More information about the Mlir-commits
mailing list