[Mlir-commits] [mlir] a45fb19 - [mlir][Affine] Introduce affine memory interfaces

Diego Caballero llvmlistbot at llvm.org
Tue May 19 17:44:40 PDT 2020


Author: Diego Caballero
Date: 2020-05-19T17:32:50-07:00
New Revision: a45fb1942fc5d21dfcdc37b99ab98778d3b16b79

URL: https://github.com/llvm/llvm-project/commit/a45fb1942fc5d21dfcdc37b99ab98778d3b16b79
DIFF: https://github.com/llvm/llvm-project/commit/a45fb1942fc5d21dfcdc37b99ab98778d3b16b79.diff

LOG: [mlir][Affine] Introduce affine memory interfaces

This patch introduces interfaces for read and write ops with affine
restrictions. I used `read`/`write` intead of `load`/`store` for the
interfaces so that they can also be implemented by dma ops.
For now, they are only implemented by affine.load, affine.store,
affine.vector_load and affine.vector_store.

For testing purposes, this patch also migrates affine loop fusion and
required analysis to use the new interfaces. No other changes are made
beyond that.

Co-authored-by: Alex Zinenko <zinenko at google.com>

Reviewed By: bondhugula, ftynse

Differential Revision: https://reviews.llvm.org/D79829

Added: 
    mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h
    mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
    mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/IR/CMakeLists.txt
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h
new file mode 100644
index 000000000000..f42fc256befa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h
@@ -0,0 +1,24 @@
+//===- AffineMemoryOpInterfaces.h -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces for affine memory ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_
+#define MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h.inc"
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_AFFINEMEMORYOPINTERFACES_H_

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
new file mode 100644
index 000000000000..8738000d8d5f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
@@ -0,0 +1,128 @@
+//===- AffineMemoryOpInterfaces.td -------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces for affine memory ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_AFFINEMEMORYOPINTERFACES
+#define MLIR_AFFINEMEMORYOPINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> {
+  let description = [{
+      Interface to query characteristics of read-like ops with affine
+      restrictions.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{ Returns the memref operand to read from. }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getMemRef",
+      /*args=*/(ins),
+      /*methodBody*/[{}],
+      /*defaultImplementation=*/ [{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getOperand(op.getMemRefOperandIndex());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns the type of the memref operand. }],
+      /*retTy=*/"MemRefType",
+      /*methodName=*/"getMemRefType",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getMemRef().getType().template cast<MemRefType>();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns affine map operands. }],
+      /*retTy=*/"Operation::operand_range",
+      /*methodName=*/"getMapOperands",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return llvm::drop_begin(op.getOperands(), 1);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns the affine map used to index the memref for this
+                  operation. }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getAffineMap",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getAffineMapAttr().getValue();
+      }]
+    >,
+  ];
+}
+
+def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
+  let description = [{
+      Interface to query characteristics of write-like ops with affine
+      restrictions.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{ Returns the memref operand to write to. }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getMemRef",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getOperand(op.getMemRefOperandIndex());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns the type of the memref operand. }],
+      /*retTy=*/"MemRefType",
+      /*methodName=*/"getMemRefType",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getMemRef().getType().template cast<MemRefType>();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns affine map operands. }],
+      /*retTy=*/"Operation::operand_range",
+      /*methodName=*/"getMapOperands",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return llvm::drop_begin(op.getOperands(), 2);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns the affine map used to index the memref for this
+                  operation. }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getAffineMap",
+      /*args=*/(ins),
+      /*methodName=*/[{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getAffineMapAttr().getValue();
+      }]
+    >,
+  ];
+}
+
+#endif // MLIR_AFFINEMEMORYOPINTERFACES

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 4910fc92ea09..808fe892b4c1 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
 #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
 
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8286d8f315bd..3366ef47a8be 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -14,6 +14,7 @@
 #define AFFINE_OPS
 
 include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
+include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
@@ -371,7 +372,8 @@ def AffineIfOp : Affine_Op<"if",
 }
 
 class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
-    Affine_Op<mnemonic, traits> {
+    Affine_Op<mnemonic, !listconcat(traits,
+        [DeclareOpInterfaceMethods<AffineReadOpInterface>])> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$memref,
       Variadic<Index>:$indices);
@@ -380,18 +382,9 @@ class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
     /// Returns the operand index of the memref.
     unsigned getMemRefOperandIndex() { return 0; }
 
-    /// Get memref operand.
-    Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
     void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
-    MemRefType getMemRefType() {
-      return getMemRef().getType().cast<MemRefType>();
-    }
-
-    /// Get affine map operands.
-    operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
 
     /// Returns the affine map used to index the memref for this operation.
-    AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
     AffineMapAttr getAffineMapAttr() {
       return getAttr(getMapAttrName()).cast<AffineMapAttr>();
     }
@@ -407,7 +400,7 @@ class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
   }];
 }
 
-def AffineLoadOp : AffineLoadOpBase<"load", []> {
+def AffineLoadOp : AffineLoadOpBase<"load"> {
   let summary = "affine load operation";
   let description = [{
     The "affine.load" op reads an element from a memref, where the index
@@ -666,8 +659,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
 }
 
 class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
-    Affine_Op<mnemonic, traits> {
-
+    Affine_Op<mnemonic, !listconcat(traits,
+        [DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
   code extraClassDeclarationBase = [{
     /// Get value to be stored by store operation.
     Value getValueToStore() { return getOperand(0); }
@@ -675,19 +668,9 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
     /// Returns the operand index of the memref.
     unsigned getMemRefOperandIndex() { return 1; }
 
-    /// Get memref operand.
-    Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
     void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
 
-    MemRefType getMemRefType() {
-      return getMemRef().getType().cast<MemRefType>();
-    }
-
-    /// Get affine map operands.
-    operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
-
     /// Returns the affine map used to index the memref for this operation.
-    AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
     AffineMapAttr getAffineMapAttr() {
       return getAttr(getMapAttrName()).cast<AffineMapAttr>();
     }
@@ -703,7 +686,7 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
   }];
 }
 
-def AffineStoreOp : AffineStoreOpBase<"store", []> {
+def AffineStoreOp : AffineStoreOpBase<"store"> {
   let summary = "affine store operation";
   let description = [{
     The "affine.store" op writes an element to a memref, where the index
@@ -776,7 +759,7 @@ def AffineTerminatorOp :
   let verifier = ?;
 }
 
-def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> {
+def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
   let summary = "affine vector load operation";
   let description = [{
     The "affine.vector_load" is the vector counterpart of
@@ -825,7 +808,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> {
   }];
 }
 
-def AffineVectorStoreOp : AffineStoreOpBase<"vector_store", []> {
+def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
   let summary = "affine vector store operation";
   let description = [{
     The "affine.vector_store" is the vector counterpart of

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
index 1fd2a505d850..77806274f14c 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Affine/IR/CMakeLists.txt
@@ -1,2 +1,10 @@
 add_mlir_dialect(AffineOps affine)
 add_mlir_doc(AffineOps -gen-op-doc AffineOps Dialects/)
+
+set(LLVM_TARGET_DEFINITIONS AffineMemoryOpInterfaces.td)
+mlir_tablegen(AffineMemoryOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(AffineMemoryOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRAffineMemoryOpInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIRAffineMemoryOpInterfacesIncGen)
+
+add_dependencies(MLIRAffineOpsIncGen MLIRAffineMemoryOpInterfacesIncGen)

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 5a395937101f..8c4828805882 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -660,10 +660,12 @@ static void computeDirectionVector(
 void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
   // Get affine map from AffineLoad/Store.
   AffineMap map;
-  if (auto loadOp = dyn_cast<AffineLoadOp>(opInst))
+  if (auto loadOp = dyn_cast<AffineReadOpInterface>(opInst)) {
     map = loadOp.getAffineMap();
-  else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst))
+  } else {
+    auto storeOp = cast<AffineWriteOpInterface>(opInst);
     map = storeOp.getAffineMap();
+  }
   SmallVector<Value, 8> operands(indices.begin(), indices.end());
   fullyComposeAffineMapAndOperands(&map, &operands);
   map = simplifyAffineMap(map);
@@ -771,9 +773,10 @@ DependenceResult mlir::checkMemrefAccessDependence(
   if (srcAccess.memref != dstAccess.memref)
     return DependenceResult::NoDependence;
 
-  // Return 'NoDependence' if one of these accesses is not an AffineStoreOp.
-  if (!allowRAR && !isa<AffineStoreOp>(srcAccess.opInst) &&
-      !isa<AffineStoreOp>(dstAccess.opInst))
+  // Return 'NoDependence' if one of these accesses is not an
+  // AffineWriteOpInterface.
+  if (!allowRAR && !isa<AffineWriteOpInterface>(srcAccess.opInst) &&
+      !isa<AffineWriteOpInterface>(dstAccess.opInst))
     return DependenceResult::NoDependence;
 
   // Get composed access function for 'srcAccess'.
@@ -857,7 +860,8 @@ void mlir::getDependenceComponents(
   // Collect all load and store ops in loop nest rooted at 'forOp'.
   SmallVector<Operation *, 8> loadAndStoreOpInsts;
   forOp.getOperation()->walk([&](Operation *opInst) {
-    if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+    if (isa<AffineReadOpInterface>(opInst) ||
+        isa<AffineWriteOpInterface>(opInst))
       loadAndStoreOpInsts.push_back(opInst);
   });
 

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index e6d7127762d5..b43d0bedec56 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -196,8 +196,8 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
                                     ComputationSliceState *sliceState,
                                     bool addMemRefDimBounds) {
-  assert((isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) &&
-         "affine load/store op expected");
+  assert((isa<AffineReadOpInterface>(op) || isa<AffineWriteOpInterface>(op)) &&
+         "affine read/write op expected");
 
   MemRefAccess access(op);
   memref = access.memref;
@@ -404,9 +404,10 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
 template <typename LoadOrStoreOp>
 LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
                                             bool emitError) {
-  static_assert(
-      llvm::is_one_of<LoadOrStoreOp, AffineLoadOp, AffineStoreOp>::value,
-      "argument should be either a AffineLoadOp or a AffineStoreOp");
+  static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
+                                AffineWriteOpInterface>::value,
+                "argument should be either a AffineReadOpInterface or a "
+                "AffineWriteOpInterface");
 
   Operation *op = loadOrStoreOp.getOperation();
   MemRefRegion region(op->getLoc());
@@ -456,10 +457,10 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
 }
 
 // Explicitly instantiate the template so that the compiler knows we need them!
-template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp,
-                                                     bool emitError);
-template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp,
-                                                     bool emitError);
+template LogicalResult
+mlir::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError);
+template LogicalResult
+mlir::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError);
 
 // Returns in 'positions' the Block positions of 'op' in each ancestor
 // Block from the Block containing operation, stopping at 'limitBlock'.
@@ -575,8 +576,8 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
         return failure();
       }
 
-      bool readReadAccesses = isa<AffineLoadOp>(srcAccess.opInst) &&
-                              isa<AffineLoadOp>(dstAccess.opInst);
+      bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
+                              isa<AffineReadOpInterface>(dstAccess.opInst);
       FlatAffineConstraints dependenceConstraints;
       // Check dependence between 'srcAccess' and 'dstAccess'.
       DependenceResult result = checkMemrefAccessDependence(
@@ -768,7 +769,8 @@ void mlir::getComputationSliceState(
                       : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
 
   llvm::SmallDenseSet<Value, 8> sequentialLoops;
-  if (isa<AffineLoadOp>(depSourceOp) && isa<AffineLoadOp>(depSinkOp)) {
+  if (isa<AffineReadOpInterface>(depSourceOp) &&
+      isa<AffineReadOpInterface>(depSinkOp)) {
     // For read-read access pairs, clear any slice bounds on sequential loops.
     // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
     getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
@@ -865,7 +867,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
 // Constructs  MemRefAccess populating it with the memref, its indices and
 // opinst from 'loadOrStoreOpInst'.
 MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
-  if (auto loadOp = dyn_cast<AffineLoadOp>(loadOrStoreOpInst)) {
+  if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
     memref = loadOp.getMemRef();
     opInst = loadOrStoreOpInst;
     auto loadMemrefType = loadOp.getMemRefType();
@@ -874,8 +876,9 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
       indices.push_back(index);
     }
   } else {
-    assert(isa<AffineStoreOp>(loadOrStoreOpInst) && "load/store op expected");
-    auto storeOp = dyn_cast<AffineStoreOp>(loadOrStoreOpInst);
+    assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
+           "Affine read/write op expected");
+    auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
     opInst = loadOrStoreOpInst;
     memref = storeOp.getMemRef();
     auto storeMemrefType = storeOp.getMemRefType();
@@ -890,7 +893,9 @@ unsigned MemRefAccess::getRank() const {
   return memref.getType().cast<MemRefType>().getRank();
 }
 
-bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); }
+bool MemRefAccess::isStore() const {
+  return isa<AffineWriteOpInterface>(opInst);
+}
 
 /// Returns the nesting depth of this statement, i.e., the number of loops
 /// surrounding this statement.
@@ -947,7 +952,8 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
 
   // Walk this 'affine.for' operation to gather all memory regions.
   auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
-    if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) {
+    if (!isa<AffineReadOpInterface>(opInst) &&
+        !isa<AffineWriteOpInterface>(opInst)) {
       // Neither load nor a store op.
       return WalkResult::advance();
     }
@@ -1007,7 +1013,8 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
   // Collect all load and store ops in loop nest rooted at 'forOp'.
   SmallVector<Operation *, 8> loadAndStoreOpInsts;
   auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
-    if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+    if (isa<AffineReadOpInterface>(opInst) ||
+        isa<AffineWriteOpInterface>(opInst))
       loadAndStoreOpInsts.push_back(opInst);
     else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
              !isa<AffineIfOp>(opInst) &&

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp b/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp
new file mode 100644
index 000000000000..6f5861efa956
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp
@@ -0,0 +1,18 @@
+//===- AffineMemoryOpInterfaces.cpp - Loop-like operations in MLIR --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Affine Memory Op Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the definitions of the affine memory op interfaces.
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp.inc"

diff  --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 12105b40571b..20bc86366668 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRAffineOps
+  AffineMemoryOpInterfaces.cpp
   AffineOps.cpp
   AffineValueMap.cpp
 
@@ -6,6 +7,7 @@ add_mlir_dialect_library(MLIRAffineOps
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
 
   DEPENDS
+  MLIRAffineMemoryOpInterfacesIncGen
   MLIRAffineOpsIncGen
 
   LINK_LIBS PUBLIC

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 72dfc1d62faf..bb219fa07711 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -70,7 +70,7 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace,
 
 // TODO(b/117228571) Replace when this is modeled through side-effects/op traits
 static bool isMemRefDereferencingOp(Operation &op) {
-  if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+  if (isa<AffineReadOpInterface>(op) || isa<AffineWriteOpInterface>(op) ||
       isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
     return true;
   return false;
@@ -92,9 +92,9 @@ struct LoopNestStateCollector {
         forOps.push_back(cast<AffineForOp>(op));
       else if (op->getNumRegions() != 0)
         hasNonForRegion = true;
-      else if (isa<AffineLoadOp>(op))
+      else if (isa<AffineReadOpInterface>(op))
         loadOpInsts.push_back(op);
-      else if (isa<AffineStoreOp>(op))
+      else if (isa<AffineWriteOpInterface>(op))
         storeOpInsts.push_back(op);
     });
   }
@@ -125,7 +125,7 @@ struct MemRefDependenceGraph {
     unsigned getLoadOpCount(Value memref) {
       unsigned loadOpCount = 0;
       for (auto *loadOpInst : loads) {
-        if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+        if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
           ++loadOpCount;
       }
       return loadOpCount;
@@ -135,7 +135,7 @@ struct MemRefDependenceGraph {
     unsigned getStoreOpCount(Value memref) {
       unsigned storeOpCount = 0;
       for (auto *storeOpInst : stores) {
-        if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+        if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
           ++storeOpCount;
       }
       return storeOpCount;
@@ -145,7 +145,7 @@ struct MemRefDependenceGraph {
     void getStoreOpsForMemref(Value memref,
                               SmallVectorImpl<Operation *> *storeOps) {
       for (auto *storeOpInst : stores) {
-        if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+        if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
           storeOps->push_back(storeOpInst);
       }
     }
@@ -154,7 +154,7 @@ struct MemRefDependenceGraph {
     void getLoadOpsForMemref(Value memref,
                              SmallVectorImpl<Operation *> *loadOps) {
       for (auto *loadOpInst : loads) {
-        if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+        if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
           loadOps->push_back(loadOpInst);
       }
     }
@@ -164,10 +164,10 @@ struct MemRefDependenceGraph {
     void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
       llvm::SmallDenseSet<Value, 2> loadMemrefs;
       for (auto *loadOpInst : loads) {
-        loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
+        loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef());
       }
       for (auto *storeOpInst : stores) {
-        auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+        auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
         if (loadMemrefs.count(memref) > 0)
           loadAndStoreMemrefSet->insert(memref);
       }
@@ -259,7 +259,7 @@ struct MemRefDependenceGraph {
   bool writesToLiveInOrEscapingMemrefs(unsigned id) {
     Node *node = getNode(id);
     for (auto *storeOpInst : node->stores) {
-      auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+      auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
       auto *op = memref.getDefiningOp();
       // Return true if 'memref' is a block argument.
       if (!op)
@@ -272,13 +272,14 @@ struct MemRefDependenceGraph {
     return false;
   }
 
-  // Returns the unique AffineStoreOp in `node` that meets all the following:
+  // Returns the unique AffineWriteOpInterface in `node` that meets all the
+  // following:
   //   *) store is the only one that writes to a function-local memref live out
   //      of `node`,
   //   *) store is not the source of a self-dependence on `node`.
-  // Otherwise, returns a null AffineStoreOp.
-  AffineStoreOp getUniqueOutgoingStore(Node *node) {
-    AffineStoreOp uniqueStore;
+  // Otherwise, returns a null AffineWriteOpInterface.
+  AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
+    AffineWriteOpInterface uniqueStore;
 
     // Return null if `node` doesn't have any outgoing edges.
     auto outEdgeIt = outEdges.find(node->id);
@@ -287,7 +288,7 @@ struct MemRefDependenceGraph {
 
     const auto &nodeOutEdges = outEdgeIt->second;
     for (auto *op : node->stores) {
-      auto storeOp = cast<AffineStoreOp>(op);
+      auto storeOp = cast<AffineWriteOpInterface>(op);
       auto memref = storeOp.getMemRef();
       // Skip this store if there are no dependences on its memref. This means
       // that store either:
@@ -322,7 +323,8 @@ struct MemRefDependenceGraph {
     Node *node = getNode(id);
     for (auto *storeOpInst : node->stores) {
       // Return false if there exist out edges from 'id' on 'memref'.
-      if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
+      auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+      if (getOutEdgeCount(id, storeMemref) > 0)
         return false;
     }
     return true;
@@ -651,28 +653,28 @@ bool MemRefDependenceGraph::init(FuncOp f) {
       Node node(nextNodeId++, &op);
       for (auto *opInst : collector.loadOpInsts) {
         node.loads.push_back(opInst);
-        auto memref = cast<AffineLoadOp>(opInst).getMemRef();
+        auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
         memrefAccesses[memref].insert(node.id);
       }
       for (auto *opInst : collector.storeOpInsts) {
         node.stores.push_back(opInst);
-        auto memref = cast<AffineStoreOp>(opInst).getMemRef();
+        auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
         memrefAccesses[memref].insert(node.id);
       }
       forToNodeMap[&op] = node.id;
       nodes.insert({node.id, node});
-    } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+    } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
       // Create graph node for top-level load op.
       Node node(nextNodeId++, &op);
       node.loads.push_back(&op);
-      auto memref = cast<AffineLoadOp>(op).getMemRef();
+      auto memref = cast<AffineReadOpInterface>(op).getMemRef();
       memrefAccesses[memref].insert(node.id);
       nodes.insert({node.id, node});
-    } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+    } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
       // Create graph node for top-level store op.
       Node node(nextNodeId++, &op);
       node.stores.push_back(&op);
-      auto memref = cast<AffineStoreOp>(op).getMemRef();
+      auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
       memrefAccesses[memref].insert(node.id);
       nodes.insert({node.id, node});
     } else if (op.getNumRegions() != 0) {
@@ -733,7 +735,7 @@ static void moveLoadsAccessingMemrefTo(Value memref,
   dstLoads->clear();
   SmallVector<Operation *, 4> srcLoadsToKeep;
   for (auto *load : *srcLoads) {
-    if (cast<AffineLoadOp>(load).getMemRef() == memref)
+    if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
       dstLoads->push_back(load);
     else
       srcLoadsToKeep.push_back(load);
@@ -854,7 +856,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   // Builder to create constants at the top level.
   OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
   // Create new memref type based on slice bounds.
-  auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
   unsigned rank = oldMemRefType.getRank();
 
@@ -962,9 +964,10 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
 // Returns true if 'dstNode's read/write region to 'memref' is a super set of
 // 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
 // TODO(andydavis) Generalize this to handle more live in/out cases.
-static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
-                                           AffineStoreOp srcLiveOutStoreOp,
-                                           MemRefDependenceGraph *mdg) {
+static bool
+canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+                               AffineWriteOpInterface srcLiveOutStoreOp,
+                               MemRefDependenceGraph *mdg) {
   assert(srcLiveOutStoreOp && "Expected a valid store op");
   auto *dstNode = mdg->getNode(dstId);
   Value memref = srcLiveOutStoreOp.getMemRef();
@@ -1450,7 +1453,7 @@ struct GreedyFusion {
       DenseSet<Value> visitedMemrefs;
       while (!loads.empty()) {
         // Get memref of load on top of the stack.
-        auto memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+        auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
         if (visitedMemrefs.count(memref) > 0)
           continue;
         visitedMemrefs.insert(memref);
@@ -1488,7 +1491,7 @@ struct GreedyFusion {
             // feasibility for loops with multiple stores.
             unsigned maxLoopDepth = 0;
             for (auto *op : srcNode->stores) {
-              auto storeOp = cast<AffineStoreOp>(op);
+              auto storeOp = cast<AffineWriteOpInterface>(op);
               if (storeOp.getMemRef() != memref) {
                 srcStoreOp = nullptr;
                 break;
@@ -1563,7 +1566,7 @@ struct GreedyFusion {
           // Gather 'dstNode' store ops to 'memref'.
           SmallVector<Operation *, 2> dstStoreOpInsts;
           for (auto *storeOpInst : dstNode->stores)
-            if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+            if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref)
               dstStoreOpInsts.push_back(storeOpInst);
 
           unsigned bestDstLoopDepth;
@@ -1601,7 +1604,8 @@ struct GreedyFusion {
               // Create private memref for 'memref' in 'dstAffineForOp'.
               SmallVector<Operation *, 4> storesForMemref;
               for (auto *storeOpInst : sliceCollector.storeOpInsts) {
-                if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+                if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
+                    memref)
                   storesForMemref.push_back(storeOpInst);
               }
               // TODO(andydavis) Use union of memref write regions to compute
@@ -1624,7 +1628,8 @@ struct GreedyFusion {
             // Add new load ops to current Node load op list 'loads' to
             // continue fusing based on new operands.
             for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
-              auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+              auto loadMemRef =
+                  cast<AffineReadOpInterface>(loadOpInst).getMemRef();
               // NOTE: Change 'loads' to a hash set in case efficiency is an
               // issue. We still use a vector since it's expected to be small.
               if (visitedMemrefs.count(loadMemRef) == 0 &&
@@ -1785,7 +1790,8 @@ struct GreedyFusion {
       // Check that all stores are to the same memref.
       DenseSet<Value> storeMemrefs;
       for (auto *storeOpInst : sibNode->stores) {
-        storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
+        storeMemrefs.insert(
+            cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
       }
       if (storeMemrefs.size() != 1)
         return false;
@@ -1796,7 +1802,7 @@ struct GreedyFusion {
     auto fn = dstNode->op->getParentOfType<FuncOp>();
     for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
       for (auto *user : fn.getArgument(i).getUsers()) {
-        if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+        if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
           // Gather loops surrounding 'use'.
           SmallVector<AffineForOp, 4> loops;
           getLoopIVs(*user, &loops);

diff  --git a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp
index da140b3486f0..087ea4fdde94 100644
--- a/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp
+++ b/mlir/test/lib/Transforms/TestMemRefBoundCheck.cpp
@@ -37,8 +37,9 @@ struct TestMemRefBoundCheck
 
 void TestMemRefBoundCheck::runOnFunction() {
   getFunction().walk([](Operation *opInst) {
-    TypeSwitch<Operation *>(opInst).Case<AffineLoadOp, AffineStoreOp>(
-        [](auto op) { boundCheckLoadOrStoreOp(op); });
+    TypeSwitch<Operation *>(opInst)
+        .Case<AffineReadOpInterface, AffineWriteOpInterface>(
+            [](auto op) { boundCheckLoadOrStoreOp(op); });
 
     // TODO(bondhugula): do this for DMA ops as well.
   });


        


More information about the Mlir-commits mailing list