[Mlir-commits] [mlir] 99c0458 - separate AffineMapAccessInterface from AffineRead/WriteOpInterface

Jeremy Bruestle llvmlistbot at llvm.org
Tue Feb 16 13:06:44 PST 2021


Author: Adam Straw
Date: 2021-02-16T13:05:27-08:00
New Revision: 99c0458f2f53ec26d663e5b9ad750e54ecf51d4b

URL: https://github.com/llvm/llvm-project/commit/99c0458f2f53ec26d663e5b9ad750e54ecf51d4b
DIFF: https://github.com/llvm/llvm-project/commit/99c0458f2f53ec26d663e5b9ad750e54ecf51d4b.diff

LOG: separate AffineMapAccessInterface from AffineRead/WriteOpInterface

Separating the AffineMapAccessInterface from AffineRead/WriteOp interface so that dialects which extend Affine capabilities (e.g. PlaidML PXA = parallel extensions for Affine) can utilize relevant passes (e.g. MemRef normalization).

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D96284

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/lib/Transforms/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
index 1f25073f07e3..29a379ca8961 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
@@ -68,19 +68,6 @@ def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> {
         return op.getAffineMapAttr().getValue();
       }]
     >,
-    InterfaceMethod<
-      /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
-      /*retTy=*/"NamedAttribute",
-      /*methodName=*/"getAffineMapAttrForMemRef",
-      /*args=*/(ins "Value":$memref),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
-        assert(memref == getMemRef());
-        return {Identifier::get(op.getMapAttrName(), op.getContext()),
-                op.getAffineMapAttr()};
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/"Returns the value read by this operation.",
       /*retTy=*/"Value",
@@ -148,27 +135,40 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
       }]
     >,
     InterfaceMethod<
-      /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
-      /*retTy=*/"NamedAttribute",
-      /*methodName=*/"getAffineMapAttrForMemRef",
-      /*args=*/(ins "Value":$memref),
+      /*desc=*/"Returns the value to store.",
+      /*retTy=*/"Value",
+      /*methodName=*/"getValueToStore",
+      /*args=*/(ins),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
         ConcreteOp op = cast<ConcreteOp>(this->getOperation());
-        assert(memref == getMemRef());
-        return {Identifier::get(op.getMapAttrName(), op.getContext()),
-                op.getAffineMapAttr()};
+        return op.getOperand(op.getStoredValOperandIndex());
       }]
     >,
+  ];
+}
+
+def AffineMapAccessInterface : OpInterface<"AffineMapAccessInterface"> {
+  let description = [{
+      Interface to query the AffineMap used to dereference and access a given
+      memref. Implementers of this interface must operate on at least one
+      memref operand.  The memref argument given to this interface much match
+      one of those memref operands.
+  }];
+
+  let methods = [
     InterfaceMethod<
-      /*desc=*/"Returns the value to store.",
-      /*retTy=*/"Value",
-      /*methodName=*/"getValueToStore",
-      /*args=*/(ins),
+      /*desc=*/"Returns the AffineMapAttr associated with 'memref'.",
+      /*retTy=*/"NamedAttribute",
+      /*methodName=*/"getAffineMapAttrForMemRef",
+      /*args=*/(ins "Value":$memref),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{
         ConcreteOp op = cast<ConcreteOp>(this->getOperation());
-        return op.getOperand(op.getStoredValOperandIndex());
+        assert(memref == op.getMemRef() &&
+               "Expected memref argument to match memref operand");
+        return {Identifier::get(op.getMapAttrName(), op.getContext()),
+                op.getAffineMapAttr()};
       }]
     >,
   ];

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 9b30c9b160b7..29fc305b682a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -82,7 +82,8 @@ bool isTopLevelValue(Value value);
 // TODO: Consider replacing src/dst memref indices with view memrefs.
 class AffineDmaStartOp
     : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
-                OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+                OpTrait::VariadicOperands, OpTrait::ZeroResult,
+                AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
 
@@ -191,6 +192,7 @@ class AffineDmaStartOp
                       getTagMap().getNumInputs());
   }
 
+  /// Impelements the AffineMapAccessInterface.
   /// Returns the AffineMapAttr associated with 'memref'.
   NamedAttribute getAffineMapAttrForMemRef(Value memref) {
     if (memref == getSrcMemRef())
@@ -271,7 +273,8 @@ class AffineDmaStartOp
 //
 class AffineDmaWaitOp
     : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
-                OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+                OpTrait::VariadicOperands, OpTrait::ZeroResult,
+                AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
 
@@ -303,6 +306,7 @@ class AffineDmaWaitOp
     return getTagMemRef().getType().cast<MemRefType>().getRank();
   }
 
+  /// Impelements the AffineMapAccessInterface.
   /// Returns the AffineMapAttr associated with 'memref'.
   NamedAttribute getAffineMapAttrForMemRef(Value memref) {
     assert(memref == getTagMemRef());

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index cc63603f6c90..95be63d21991 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -458,6 +458,7 @@ def AffineIfOp : Affine_Op<"if",
 class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
     Affine_Op<mnemonic, !listconcat(traits,
         [DeclareOpInterfaceMethods<AffineReadOpInterface>,
+        DeclareOpInterfaceMethods<AffineMapAccessInterface>,
         MemRefsNormalizable])> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$memref,
@@ -698,7 +699,8 @@ def AffineParallelOp : Affine_Op<"parallel",
   let hasFolder = 1;
 }
 
-def AffinePrefetchOp : Affine_Op<"prefetch"> {
+def AffinePrefetchOp : Affine_Op<"prefetch",
+  [DeclareOpInterfaceMethods<AffineMapAccessInterface>]> {
   let summary = "affine prefetch operation";
   let description = [{
     The "affine.prefetch" op prefetches data from a memref location described
@@ -752,9 +754,11 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
       return (*this)->getAttr(getMapAttrName()).cast<AffineMapAttr>();
     }
 
+    /// Impelements the AffineMapAccessInterface.
     /// Returns the AffineMapAttr associated with 'memref'.
     NamedAttribute getAffineMapAttrForMemRef(Value mref) {
-      assert(mref == memref());
+      assert(mref == memref() &&
+             "Expected mref argument to match memref operand");
       return {Identifier::get(getMapAttrName(), getContext()),
         getAffineMapAttr()};
     }
@@ -777,6 +781,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
 class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
     Affine_Op<mnemonic, !listconcat(traits,
     [DeclareOpInterfaceMethods<AffineWriteOpInterface>,
+    DeclareOpInterfaceMethods<AffineMapAccessInterface>,
     MemRefsNormalizable])> {
   code extraClassDeclarationBase = [{
     /// Returns the operand index of the value to be stored.
@@ -792,13 +797,6 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
       return (*this)->getAttr(getMapAttrName()).cast<AffineMapAttr>();
     }
 
-    /// Returns the AffineMapAttr associated with 'memref'.
-    NamedAttribute getAffineMapAttrForMemRef(Value memref) {
-      assert(memref == getMemRef());
-      return {Identifier::get(getMapAttrName(), getContext()),
-              getAffineMapAttr()};
-    }
-
     static StringRef getMapAttrName() { return "map"; }
   }];
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
index 199747342987..f52a7a2c5bf8 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -61,11 +61,6 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
                                  SmallPtrSetImpl<Operation *> &definedOps,
                                  SmallPtrSetImpl<Operation *> &opsToHoist);
 
-static bool isMemRefDereferencingOp(Operation &op) {
-  // TODO: Support DMA Ops.
-  return isa<AffineReadOpInterface, AffineWriteOpInterface>(op);
-}
-
 // Returns true if the individual op is loop invariant.
 bool isOpLoopInvariant(Operation &op, Value indVar,
                        SmallPtrSetImpl<Operation *> &definedOps,
@@ -89,7 +84,7 @@ bool isOpLoopInvariant(Operation &op, Value indVar,
     // which are themselves not being hoisted.
     definedOps.insert(&op);
 
-    if (isMemRefDereferencingOp(op)) {
+    if (isa<AffineMapAccessInterface>(op)) {
       Value memref = isa<AffineReadOpInterface>(op)
                          ? cast<AffineReadOpInterface>(op).getMemRef()
                          : cast<AffineWriteOpInterface>(op).getMemRef();

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index c3ffb17b1be7..e13616e1d71c 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -68,12 +68,6 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace,
                                       maximalFusion);
 }
 
-// TODO: Replace when this is modeled through side-effects/op traits
-static bool isMemRefDereferencingOp(Operation &op) {
-  return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
-             AffineDmaWaitOp>(op);
-}
-
 namespace {
 
 // LoopNestStateCollector walks loop nests and collects load and store
@@ -264,7 +258,7 @@ struct MemRefDependenceGraph {
         return true;
       // Return true if any use of 'memref' escapes the function.
       for (auto *user : memref.getUsers())
-        if (!isMemRefDereferencingOp(*user))
+        if (!isa<AffineMapAccessInterface>(*user))
           return true;
     }
     return false;
@@ -703,7 +697,7 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
     // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
     // call op, etc.).
     for (Operation *user : memref.getUsers())
-      if (!isMemRefDereferencingOp(*user))
+      if (!isa<AffineMapAccessInterface>(*user))
         escapingMemRefs.insert(memref);
   }
 }
@@ -979,7 +973,7 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
       // Interrupt the walk if found.
       auto walkResult = op->walk([&](Operation *user) {
         // Skip affine ops.
-        if (isMemRefDereferencingOp(*user))
+        if (isa<AffineMapAccessInterface>(*user))
           return WalkResult::advance();
         // Find a non-affine op that uses the memref.
         if (llvm::is_contained(users, user))

diff  --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index f88b7b12d5f2..f99159b25054 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -25,22 +25,6 @@
 #include "llvm/ADT/TypeSwitch.h"
 using namespace mlir;
 
-/// Return true if this operation dereferences one or more memref's.
-// Temporary utility: will be replaced when this is modeled through
-// side-effects/op traits. TODO
-static bool isMemRefDereferencingOp(Operation &op) {
-  return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
-             AffineDmaWaitOp>(op);
-}
-
-/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
-static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) {
-  return TypeSwitch<Operation *, NamedAttribute>(op)
-      .Case<AffineDmaStartOp, AffineReadOpInterface, AffinePrefetchOp,
-            AffineWriteOpInterface, AffineDmaWaitOp>(
-          [=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
-}
-
 // Perform the replacement in `op`.
 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
                                              Operation *op,
@@ -88,17 +72,20 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
   OpBuilder builder(op);
   // The following checks if op is dereferencing memref and performs the access
   // index rewrites.
-  if (!isMemRefDereferencingOp(*op)) {
-    if (!allowNonDereferencingOps)
+  auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
+  if (!affMapAccInterface) {
+    if (!allowNonDereferencingOps) {
       // Failure: memref used in a non-dereferencing context (potentially
       // escapes); no replacement in these cases unless allowNonDereferencingOps
       // is set.
       return failure();
+    }
     op->setOperand(memRefOperandPos, newMemRef);
     return success();
   }
   // Perform index rewrites for the dereferencing op and then replace the op
-  NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
+  NamedAttribute oldMapAttrPair =
+      affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
   AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
   unsigned oldMapNumInputs = oldMap.getNumInputs();
   SmallVector<Value, 4> oldMapOperands(
@@ -272,7 +259,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
     // Check if the memref was used in a non-dereferencing context. It is fine
     // for the memref to be used in a non-dereferencing way outside of the
     // region where this replacement is happening.
-    if (!isMemRefDereferencingOp(*op)) {
+    if (!isa<AffineMapAccessInterface>(*op)) {
       if (!allowNonDereferencingOps)
         return failure();
       // Currently we support the following non-dereferencing ops to be a


        


More information about the Mlir-commits mailing list