[Mlir-commits] [mlir] 99c0458 - separate AffineMapAccessInterface from AffineRead/WriteOpInterface
Jeremy Bruestle
llvmlistbot at llvm.org
Tue Feb 16 13:06:44 PST 2021
Author: Adam Straw
Date: 2021-02-16T13:05:27-08:00
New Revision: 99c0458f2f53ec26d663e5b9ad750e54ecf51d4b
URL: https://github.com/llvm/llvm-project/commit/99c0458f2f53ec26d663e5b9ad750e54ecf51d4b
DIFF: https://github.com/llvm/llvm-project/commit/99c0458f2f53ec26d663e5b9ad750e54ecf51d4b.diff
LOG: separate AffineMapAccessInterface from AffineRead/WriteOpInterface
Separating the AffineMapAccessInterface from AffineRead/WriteOp interface so that dialects which extend Affine capabilities (e.g. PlaidML PXA = parallel extensions for Affine) can utilize relevant passes (e.g. MemRef normalization).
Reviewed By: bondhugula
Differential Revision: https://reviews.llvm.org/D96284
Added:
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/Utils/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
index 1f25073f07e3..29a379ca8961 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
@@ -68,19 +68,6 @@ def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> {
return op.getAffineMapAttr().getValue();
}]
>,
- InterfaceMethod<
- /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
- /*retTy=*/"NamedAttribute",
- /*methodName=*/"getAffineMapAttrForMemRef",
- /*args=*/(ins "Value":$memref),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
- assert(memref == getMemRef());
- return {Identifier::get(op.getMapAttrName(), op.getContext()),
- op.getAffineMapAttr()};
- }]
- >,
InterfaceMethod<
/*desc=*/"Returns the value read by this operation.",
/*retTy=*/"Value",
@@ -148,27 +135,40 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
}]
>,
InterfaceMethod<
- /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
- /*retTy=*/"NamedAttribute",
- /*methodName=*/"getAffineMapAttrForMemRef",
- /*args=*/(ins "Value":$memref),
+ /*desc=*/"Returns the value to store.",
+ /*retTy=*/"Value",
+ /*methodName=*/"getValueToStore",
+ /*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
- assert(memref == getMemRef());
- return {Identifier::get(op.getMapAttrName(), op.getContext()),
- op.getAffineMapAttr()};
+ return op.getOperand(op.getStoredValOperandIndex());
}]
>,
+ ];
+}
+
+def AffineMapAccessInterface : OpInterface<"AffineMapAccessInterface"> {
+ let description = [{
+ Interface to query the AffineMap used to dereference and access a given
+ memref. Implementers of this interface must operate on at least one
+ memref operand. The memref argument given to this interface much match
+ one of those memref operands.
+ }];
+
+ let methods = [
InterfaceMethod<
- /*desc=*/"Returns the value to store.",
- /*retTy=*/"Value",
- /*methodName=*/"getValueToStore",
- /*args=*/(ins),
+ /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
+ /*retTy=*/"NamedAttribute",
+ /*methodName=*/"getAffineMapAttrForMemRef",
+ /*args=*/(ins "Value":$memref),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
- return op.getOperand(op.getStoredValOperandIndex());
+ assert(memref == op.getMemRef() &&
+ "Expected memref argument to match memref operand");
+ return {Identifier::get(op.getMapAttrName(), op.getContext()),
+ op.getAffineMapAttr()};
}]
>,
];
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 9b30c9b160b7..29fc305b682a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -82,7 +82,8 @@ bool isTopLevelValue(Value value);
// TODO: Consider replacing src/dst memref indices with view memrefs.
class AffineDmaStartOp
: public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
- OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+ OpTrait::VariadicOperands, OpTrait::ZeroResult,
+ AffineMapAccessInterface::Trait> {
public:
using Op::Op;
@@ -191,6 +192,7 @@ class AffineDmaStartOp
getTagMap().getNumInputs());
}
+ /// Impelements the AffineMapAccessInterface.
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
if (memref == getSrcMemRef())
@@ -271,7 +273,8 @@ class AffineDmaStartOp
//
class AffineDmaWaitOp
: public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
- OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+ OpTrait::VariadicOperands, OpTrait::ZeroResult,
+ AffineMapAccessInterface::Trait> {
public:
using Op::Op;
@@ -303,6 +306,7 @@ class AffineDmaWaitOp
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
+ /// Impelements the AffineMapAccessInterface.
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getTagMemRef());
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index cc63603f6c90..95be63d21991 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -458,6 +458,7 @@ def AffineIfOp : Affine_Op<"if",
class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
Affine_Op<mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<AffineReadOpInterface>,
+ DeclareOpInterfaceMethods<AffineMapAccessInterface>,
MemRefsNormalizable])> {
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
@@ -698,7 +699,8 @@ def AffineParallelOp : Affine_Op<"parallel",
let hasFolder = 1;
}
-def AffinePrefetchOp : Affine_Op<"prefetch"> {
+def AffinePrefetchOp : Affine_Op<"prefetch",
+ [DeclareOpInterfaceMethods<AffineMapAccessInterface>]> {
let summary = "affine prefetch operation";
let description = [{
The "affine.prefetch" op prefetches data from a memref location described
@@ -752,9 +754,11 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
return (*this)->getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
+ /// Impelements the AffineMapAccessInterface.
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value mref) {
- assert(mref == memref());
+ assert(mref == memref() &&
+ "Expected mref argument to match memref operand");
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
@@ -777,6 +781,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
Affine_Op<mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<AffineWriteOpInterface>,
+ DeclareOpInterfaceMethods<AffineMapAccessInterface>,
MemRefsNormalizable])> {
code extraClassDeclarationBase = [{
/// Returns the operand index of the value to be stored.
@@ -792,13 +797,6 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
return (*this)->getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
- /// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value memref) {
- assert(memref == getMemRef());
- return {Identifier::get(getMapAttrName(), getContext()),
- getAffineMapAttr()};
- }
-
static StringRef getMapAttrName() { return "map"; }
}];
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
index 199747342987..f52a7a2c5bf8 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -61,11 +61,6 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist);
-static bool isMemRefDereferencingOp(Operation &op) {
- // TODO: Support DMA Ops.
- return isa<AffineReadOpInterface, AffineWriteOpInterface>(op);
-}
-
// Returns true if the individual op is loop invariant.
bool isOpLoopInvariant(Operation &op, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
@@ -89,7 +84,7 @@ bool isOpLoopInvariant(Operation &op, Value indVar,
// which are themselves not being hoisted.
definedOps.insert(&op);
- if (isMemRefDereferencingOp(op)) {
+ if (isa<AffineMapAccessInterface>(op)) {
Value memref = isa<AffineReadOpInterface>(op)
? cast<AffineReadOpInterface>(op).getMemRef()
: cast<AffineWriteOpInterface>(op).getMemRef();
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index c3ffb17b1be7..e13616e1d71c 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -68,12 +68,6 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace,
maximalFusion);
}
-// TODO: Replace when this is modeled through side-effects/op traits
-static bool isMemRefDereferencingOp(Operation &op) {
- return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
- AffineDmaWaitOp>(op);
-}
-
namespace {
// LoopNestStateCollector walks loop nests and collects load and store
@@ -264,7 +258,7 @@ struct MemRefDependenceGraph {
return true;
// Return true if any use of 'memref' escapes the function.
for (auto *user : memref.getUsers())
- if (!isMemRefDereferencingOp(*user))
+ if (!isa<AffineMapAccessInterface>(*user))
return true;
}
return false;
@@ -703,7 +697,7 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
// Check if 'memref' escapes through a non-affine op (e.g., std load/store,
// call op, etc.).
for (Operation *user : memref.getUsers())
- if (!isMemRefDereferencingOp(*user))
+ if (!isa<AffineMapAccessInterface>(*user))
escapingMemRefs.insert(memref);
}
}
@@ -979,7 +973,7 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
// Interrupt the walk if found.
auto walkResult = op->walk([&](Operation *user) {
// Skip affine ops.
- if (isMemRefDereferencingOp(*user))
+ if (isa<AffineMapAccessInterface>(*user))
return WalkResult::advance();
// Find a non-affine op that uses the memref.
if (llvm::is_contained(users, user))
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index f88b7b12d5f2..f99159b25054 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -25,22 +25,6 @@
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
-/// Return true if this operation dereferences one or more memref's.
-// Temporary utility: will be replaced when this is modeled through
-// side-effects/op traits. TODO
-static bool isMemRefDereferencingOp(Operation &op) {
- return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
- AffineDmaWaitOp>(op);
-}
-
-/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
-static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) {
- return TypeSwitch<Operation *, NamedAttribute>(op)
- .Case<AffineDmaStartOp, AffineReadOpInterface, AffinePrefetchOp,
- AffineWriteOpInterface, AffineDmaWaitOp>(
- [=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
-}
-
// Perform the replacement in `op`.
LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
Operation *op,
@@ -88,17 +72,20 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
OpBuilder builder(op);
// The following checks if op is dereferencing memref and performs the access
// index rewrites.
- if (!isMemRefDereferencingOp(*op)) {
- if (!allowNonDereferencingOps)
+ auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
+ if (!affMapAccInterface) {
+ if (!allowNonDereferencingOps) {
// Failure: memref used in a non-dereferencing context (potentially
// escapes); no replacement in these cases unless allowNonDereferencingOps
// is set.
return failure();
+ }
op->setOperand(memRefOperandPos, newMemRef);
return success();
}
// Perform index rewrites for the dereferencing op and then replace the op
- NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
+ NamedAttribute oldMapAttrPair =
+ affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value, 4> oldMapOperands(
@@ -272,7 +259,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
// Check if the memref was used in a non-dereferencing context. It is fine
// for the memref to be used in a non-dereferencing way outside of the
// region where this replacement is happening.
- if (!isMemRefDereferencingOp(*op)) {
+ if (!isa<AffineMapAccessInterface>(*op)) {
if (!allowNonDereferencingOps)
return failure();
// Currently we support the following non-dereferencing ops to be a
More information about the Mlir-commits
mailing list