[Mlir-commits] [mlir] d47dd11 - [mlir] Add support for querying the ModRef behavior from the AliasAnalysis class

River Riddle llvmlistbot at llvm.org
Thu May 27 13:57:50 PDT 2021


Author: River Riddle
Date: 2021-05-27T13:57:29-07:00
New Revision: d47dd11071322ad7be6ec7e35a89d0d8f26534b9

URL: https://github.com/llvm/llvm-project/commit/d47dd11071322ad7be6ec7e35a89d0d8f26534b9
DIFF: https://github.com/llvm/llvm-project/commit/d47dd11071322ad7be6ec7e35a89d0d8f26534b9.diff

LOG: [mlir] Add support for querying the ModRef behavior from the AliasAnalysis class

This allows for checking if a given operation may modify/reference/or both a given value. Right now this API is limited to Value based memory locations, but we should expand this to include attribute based values at some point. This is left for future work because the rest of the AliasAnalysis API also has this restriction.

Differential Revision: https://reviews.llvm.org/D101673

Added: 
    mlir/test/Analysis/test-alias-analysis-modref.mlir

Modified: 
    mlir/include/mlir/Analysis/AliasAnalysis.h
    mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h
    mlir/lib/Analysis/AliasAnalysis.cpp
    mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
    mlir/test/lib/Analysis/TestAliasAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AliasAnalysis.h b/mlir/include/mlir/Analysis/AliasAnalysis.h
index f3fce42097a9..925af24ddd69 100644
--- a/mlir/include/mlir/Analysis/AliasAnalysis.h
+++ b/mlir/include/mlir/Analysis/AliasAnalysis.h
@@ -67,14 +67,106 @@ class AliasResult {
   /// Returns if this result is a partial alias.
   bool isPartial() const { return kind == PartialAlias; }
 
-  /// Return the internal kind of this alias result.
-  Kind getKind() const { return kind; }
+  /// Print this alias result to the provided output stream.
+  void print(raw_ostream &os) const;
 
 private:
   /// The internal kind of the result.
   Kind kind;
 };
 
+inline raw_ostream &operator<<(raw_ostream &os, const AliasResult &result) {
+  result.print(os);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// ModRefResult
+//===----------------------------------------------------------------------===//
+
+/// The possible results of whether a memory access modifies or references
+/// a memory location. The possible results are: no access at all, a
+/// modification, a reference, or both a modification and a reference.
+class LLVM_NODISCARD ModRefResult {
+  /// Note: This is a simplified version of the ModRefResult in
+  /// `llvm/Analysis/AliasAnalysis.h`, and namely removes the `Must` concept. If
+  /// this becomes useful/necessary we should add it here.
+  enum class Kind {
+    /// The access neither references nor modifies the value stored in memory.
+    NoModRef = 0,
+    /// The access may reference the value stored in memory.
+    Ref = 1,
+    /// The access may modify the value stored in memory.
+    Mod = 2,
+    /// The access may reference and may modify the value stored in memory.
+    ModRef = Ref | Mod,
+  };
+
+public:
+  bool operator==(const ModRefResult &rhs) const { return kind == rhs.kind; }
+  bool operator!=(const ModRefResult &rhs) const { return !(*this == rhs); }
+
+  /// Return a new result that indicates that the memory access neither
+  /// references nor modifies the value stored in memory.
+  static ModRefResult getNoModRef() { return Kind::NoModRef; }
+
+  /// Return a new result that indicates that the memory access may reference
+  /// the value stored in memory.
+  static ModRefResult getRef() { return Kind::Ref; }
+
+  /// Return a new result that indicates that the memory access may modify the
+  /// value stored in memory.
+  static ModRefResult getMod() { return Kind::Mod; }
+
+  /// Return a new result that indicates that the memory access may reference
+  /// and may modify the value stored in memory.
+  static ModRefResult getModAndRef() { return Kind::ModRef; }
+
+  /// Returns if this result does not modify or reference memory.
+  LLVM_NODISCARD bool isNoModRef() const { return kind == Kind::NoModRef; }
+
+  /// Returns if this result modifies memory.
+  LLVM_NODISCARD bool isMod() const {
+    return static_cast<int>(kind) & static_cast<int>(Kind::Mod);
+  }
+
+  /// Returns if this result references memory.
+  LLVM_NODISCARD bool isRef() const {
+    return static_cast<int>(kind) & static_cast<int>(Kind::Ref);
+  }
+
+  /// Returns if this result modifies *or* references memory.
+  LLVM_NODISCARD bool isModOrRef() const { return kind != Kind::NoModRef; }
+
+  /// Returns if this result modifies *and* references memory.
+  LLVM_NODISCARD bool isModAndRef() const { return kind == Kind::ModRef; }
+
+  /// Merge this ModRef result with `other` and return the result.
+  ModRefResult merge(const ModRefResult &other) {
+    return ModRefResult(static_cast<Kind>(static_cast<int>(kind) |
+                                          static_cast<int>(other.kind)));
+  }
+  /// Intersect this ModRef result with `other` and return the result.
+  ModRefResult intersect(const ModRefResult &other) {
+    return ModRefResult(static_cast<Kind>(static_cast<int>(kind) &
+                                          static_cast<int>(other.kind)));
+  }
+
+  /// Print this ModRef result to the provided output stream.
+  void print(raw_ostream &os) const;
+
+private:
+  ModRefResult(Kind kind) : kind(kind) {}
+
+  /// The internal kind of the result.
+  Kind kind;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const ModRefResult &result) {
+  result.print(os);
+  return os;
+}
+
 //===----------------------------------------------------------------------===//
 // AliasAnalysisTraits
 //===----------------------------------------------------------------------===//
@@ -92,6 +184,9 @@ struct AliasAnalysisTraits {
 
     /// Given two values, return their aliasing behavior.
     virtual AliasResult alias(Value lhs, Value rhs) = 0;
+
+    /// Return the modify-reference behavior of `op` on `location`.
+    virtual ModRefResult getModRef(Operation *op, Value location) = 0;
   };
 
   /// This class represents the `Model` of an alias analysis implementation
@@ -108,6 +203,11 @@ struct AliasAnalysisTraits {
       return impl.alias(lhs, rhs);
     }
 
+    /// Return the modify-reference behavior of `op` on `location`.
+    ModRefResult getModRef(Operation *op, Value location) final {
+      return impl.getModRef(op, location);
+    }
+
   private:
     ImplT impl;
   };
@@ -147,7 +247,12 @@ class AliasAnalysis {
   ///   * AnalysisT(AnalysisT &&)
   ///   * AliasResult alias(Value lhs, Value rhs)
   ///     - This method returns an `AliasResult` that corresponds to the
-  ///       aliasing behavior between `lhs` and `rhs`.
+  ///       aliasing behavior between `lhs` and `rhs`. The conservative "I don't
+  ///       know" result of this method should be MayAlias.
+  ///   * ModRefResult getModRef(Operation *op, Value location)
+  ///     - This method returns a `ModRefResult` that corresponds to the
+  ///       modify-reference behavior of `op` on the given `location`. The
+  ///       conservative "I don't know" result of this method should be ModRef.
   template <typename AnalysisT>
   void addAnalysisImplementation(AnalysisT &&analysis) {
     aliasImpls.push_back(
@@ -161,6 +266,13 @@ class AliasAnalysis {
   /// Given two values, return their aliasing behavior.
   AliasResult alias(Value lhs, Value rhs);
 
+  //===--------------------------------------------------------------------===//
+  // ModRef Queries
+  //===--------------------------------------------------------------------===//
+
+  /// Return the modify-reference behavior of `op` on `location`.
+  ModRefResult getModRef(Operation *op, Value location);
+
 private:
   /// A set of internal alias analysis implementations.
   SmallVector<std::unique_ptr<Concept>, 4> aliasImpls;

diff  --git a/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h b/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h
index 45edd2088cdd..afed185e6c29 100644
--- a/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h
+++ b/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h
@@ -25,6 +25,9 @@ class LocalAliasAnalysis {
 public:
   /// Given two values, return their aliasing behavior.
   AliasResult alias(Value lhs, Value rhs);
+
+  /// Return the modify-reference behavior of `op` on `location`.
+  ModRefResult getModRef(Operation *op, Value location);
 };
 } // end namespace mlir
 

diff  --git a/mlir/lib/Analysis/AliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis.cpp
index 946825e156af..2f2b78249528 100644
--- a/mlir/lib/Analysis/AliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis.cpp
@@ -27,6 +27,44 @@ AliasResult AliasResult::merge(AliasResult other) const {
   return MayAlias;
 }
 
+void AliasResult::print(raw_ostream &os) const {
+  switch (kind) {
+  case Kind::NoAlias:
+    os << "NoAlias";
+    break;
+  case Kind::MayAlias:
+    os << "MayAlias";
+    break;
+  case Kind::PartialAlias:
+    os << "PartialAlias";
+    break;
+  case Kind::MustAlias:
+    os << "MustAlias";
+    break;
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// ModRefResult
+//===----------------------------------------------------------------------===//
+
+void ModRefResult::print(raw_ostream &os) const {
+  switch (kind) {
+  case Kind::NoModRef:
+    os << "NoModRef";
+    break;
+  case Kind::Ref:
+    os << "Ref";
+    break;
+  case Kind::Mod:
+    os << "Mod";
+    break;
+  case Kind::ModRef:
+    os << "ModRef";
+    break;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // AliasAnalysis
 //===----------------------------------------------------------------------===//
@@ -35,7 +73,6 @@ AliasAnalysis::AliasAnalysis(Operation *op) {
   addAnalysisImplementation(LocalAliasAnalysis());
 }
 
-/// Given the two values, return their aliasing behavior.
 AliasResult AliasAnalysis::alias(Value lhs, Value rhs) {
   // Check each of the alias analysis implemenations for an alias result.
   for (const std::unique_ptr<Concept> &aliasImpl : aliasImpls) {
@@ -45,3 +82,16 @@ AliasResult AliasAnalysis::alias(Value lhs, Value rhs) {
   }
   return AliasResult::MayAlias;
 }
+
+ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
+  // Compute the mod-ref behavior by refining a top `ModRef` result with each of
+  // the alias analysis implementations. We early exit at the point where we
+  // refine down to a `NoModRef`.
+  ModRefResult result = ModRefResult::getModAndRef();
+  for (const std::unique_ptr<Concept> &aliasImpl : aliasImpls) {
+    result = result.intersect(aliasImpl->getModRef(op, location));
+    if (result.isNoModRef())
+      return result;
+  }
+  return result;
+}

diff  --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 17a9ded86917..062443a39619 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -195,7 +195,7 @@ static void collectUnderlyingAddressValues(Value value,
 }
 
 //===----------------------------------------------------------------------===//
-// LocalAliasAnalysis
+// LocalAliasAnalysis: alias
 //===----------------------------------------------------------------------===//
 
 /// Given a value, try to get an allocation effect attached to it. If
@@ -336,3 +336,56 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
   // We should always have a valid result here.
   return *result;
 }
+
+//===----------------------------------------------------------------------===//
+// LocalAliasAnalysis: getModRef
+//===----------------------------------------------------------------------===//
+
+ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
+  // Check to see if this operation relies on nested side effects.
+  if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
+    // TODO: To check recursive operations we need to check all of the nested
+    // operations, which can result in a quadratic number of queries. We should
+    // introduce some caching of some kind to help alleviate this, especially as
+    // this caching could be used in other areas of the codebase (e.g. when
+    // checking `wouldOpBeTriviallyDead`).
+    return ModRefResult::getModAndRef();
+  }
+
+  // Otherwise, check to see if this operation has a memory effect interface.
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface)
+    return ModRefResult::getModAndRef();
+
+  // Build a ModRefResult by merging the behavior of the effects of this
+  // operation.
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  interface.getEffects(effects);
+
+  ModRefResult result = ModRefResult::getNoModRef();
+  for (const MemoryEffects::EffectInstance &effect : effects) {
+    if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
+      continue;
+
+    // Check for an alias between the effect and our memory location.
+    // TODO: Add support for checking an alias with a symbol reference.
+    AliasResult aliasResult = AliasResult::MayAlias;
+    if (Value effectValue = effect.getValue())
+      aliasResult = alias(effectValue, location);
+
+    // If we don't alias, ignore this effect.
+    if (aliasResult.isNo())
+      continue;
+
+    // Merge in the corresponding mod or ref for this effect.
+    if (isa<MemoryEffects::Read>(effect.getEffect())) {
+      result = result.merge(ModRefResult::getRef());
+    } else {
+      assert(isa<MemoryEffects::Write>(effect.getEffect()));
+      result = result.merge(ModRefResult::getMod());
+    }
+    if (result.isModAndRef())
+      break;
+  }
+  return result;
+}

diff  --git a/mlir/test/Analysis/test-alias-analysis-modref.mlir b/mlir/test/Analysis/test-alias-analysis-modref.mlir
new file mode 100644
index 000000000000..46ac7fbdf5d9
--- /dev/null
+++ b/mlir/test/Analysis/test-alias-analysis-modref.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -pass-pipeline='func(test-alias-analysis-modref)' -split-input-file -allow-unregistered-dialect 2>&1 | FileCheck %s
+
+// CHECK-LABEL: Testing : "no_side_effects"
+// CHECK: alloc -> func.region0#0: NoModRef
+// CHECK: dealloc -> func.region0#0: NoModRef
+// CHECK: return -> func.region0#0: NoModRef
+func @no_side_effects(%arg: memref<2xf32>) attributes {test.ptr = "func"} {
+  %1 = memref.alloc() {test.ptr = "alloc"} : memref<8x64xf32>
+  memref.dealloc %1 {test.ptr = "dealloc"} : memref<8x64xf32>
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "simple"
+// CHECK-DAG: store -> alloc#0: Mod
+// CHECK-DAG: load -> alloc#0: Ref
+
+// CHECK-DAG: store -> func.region0#0: NoModRef
+// CHECK-DAG: load -> func.region0#0: NoModRef
+func @simple(%arg: memref<i32>, %value: i32) attributes {test.ptr = "func"} {
+  %1 = memref.alloca() {test.ptr = "alloc"} : memref<i32>
+  memref.store %value, %1[] {test.ptr = "store"} : memref<i32>
+  %2 = memref.load %1[] {test.ptr = "load"} : memref<i32>
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "mayalias"
+// CHECK-DAG: store -> func.region0#0: Mod
+// CHECK-DAG: load -> func.region0#0: Ref
+
+// CHECK-DAG: store -> func.region0#1: Mod
+// CHECK-DAG: load -> func.region0#1: Ref
+func @mayalias(%arg0: memref<i32>, %arg1: memref<i32>, %value: i32) attributes {test.ptr = "func"} {
+  memref.store %value, %arg1[] {test.ptr = "store"} : memref<i32>
+  %1 = memref.load %arg1[] {test.ptr = "load"} : memref<i32>
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "recursive"
+// CHECK-DAG: if -> func.region0#0: ModRef
+// CHECK-DAG: if -> func.region0#1: ModRef
+
+// TODO: This is provably NoModRef, but requires handling recursive side
+// effects.
+// CHECK-DAG: if -> alloc#0: ModRef
+func @recursive(%arg0: memref<i32>, %arg1: memref<i32>, %cond: i1, %value: i32) attributes {test.ptr = "func"} {
+  %0 = memref.alloca() {test.ptr = "alloc"} : memref<i32>
+  scf.if %cond {
+    memref.store %value, %arg0[] : memref<i32>
+    %1 = memref.load %arg0[] : memref<i32>
+  } {test.ptr = "if"}
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "unknown"
+// CHECK-DAG: unknown -> func.region0#0: ModRef
+func @unknown(%arg0: memref<i32>) attributes {test.ptr = "func"} {
+  "foo.op"() {test.ptr = "unknown"} : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp
index d17a1c1b360a..c54e5d8ba582 100644
--- a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp
+++ b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp
@@ -16,15 +16,38 @@
 
 using namespace mlir;
 
+/// Print a value that is used as an operand of an alias query.
+static void printAliasOperand(Operation *op) {
+  llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr").getValue();
+}
+static void printAliasOperand(Value value) {
+  if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
+    Region *region = arg.getParentRegion();
+    unsigned parentBlockNumber =
+        std::distance(region->begin(), arg.getOwner()->getIterator());
+    llvm::errs() << region->getParentOp()
+                        ->getAttrOfType<StringAttr>("test.ptr")
+                        .getValue()
+                 << ".region" << region->getRegionNumber();
+    if (parentBlockNumber != 0)
+      llvm::errs() << ".block" << parentBlockNumber;
+    llvm::errs() << "#" << arg.getArgNumber();
+    return;
+  }
+  OpResult result = value.cast<OpResult>();
+  printAliasOperand(result.getOwner());
+  llvm::errs() << "#" << result.getResultNumber();
+}
+
+//===----------------------------------------------------------------------===//
+// Testing AliasResult
+//===----------------------------------------------------------------------===//
+
 namespace {
 struct TestAliasAnalysisPass
     : public PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
   void runOnOperation() override {
-    llvm::errs() << "Testing : ";
-    if (Attribute testName = getOperation()->getAttr("test.name"))
-      llvm::errs() << testName << "\n";
-    else
-      llvm::errs() << getOperation()->getAttr("sym_name") << "\n";
+    llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
 
     // Collect all of the values to check for aliasing behavior.
     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
@@ -49,52 +72,64 @@ struct TestAliasAnalysisPass
     printAliasOperand(lhs);
     llvm::errs() << " <-> ";
     printAliasOperand(rhs);
-    llvm::errs() << ": ";
+    llvm::errs() << ": " << result << "\n";
+  }
+};
+} // end anonymous namespace
 
-    switch (result.getKind()) {
-    case AliasResult::NoAlias:
-      llvm::errs() << "NoAlias";
-      break;
-    case AliasResult::MayAlias:
-      llvm::errs() << "MayAlias";
-      break;
-    case AliasResult::PartialAlias:
-      llvm::errs() << "PartialAlias";
-      break;
-    case AliasResult::MustAlias:
-      llvm::errs() << "MustAlias";
-      break;
+//===----------------------------------------------------------------------===//
+// Testing ModRefResult
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestAliasAnalysisModRefPass
+    : public PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
+  void runOnOperation() override {
+    llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
+
+    // Collect all of the values to check for aliasing behavior.
+    AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+    SmallVector<Value, 32> valsToCheck;
+    getOperation()->walk([&](Operation *op) {
+      if (!op->getAttr("test.ptr"))
+        return;
+      valsToCheck.append(op->result_begin(), op->result_end());
+      for (Region &region : op->getRegions())
+        for (Block &block : region)
+          valsToCheck.append(block.args_begin(), block.args_end());
+    });
+
+    // Check for aliasing behavior between each of the values.
+    for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it) {
+      getOperation()->walk([&](Operation *op) {
+        if (!op->getAttr("test.ptr"))
+          return;
+        printModRefResult(aliasAnalysis.getModRef(op, *it), op, *it);
+      });
     }
-    llvm::errs() << "\n";
   }
-  /// Print a value that is used as an operand of an alias query.
-  void printAliasOperand(Value value) {
-    if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
-      Region *region = arg.getParentRegion();
-      unsigned parentBlockNumber =
-          std::distance(region->begin(), arg.getOwner()->getIterator());
-      llvm::errs() << region->getParentOp()
-                          ->getAttrOfType<StringAttr>("test.ptr")
-                          .getValue()
-                   << ".region" << region->getRegionNumber();
-      if (parentBlockNumber != 0)
-        llvm::errs() << ".block" << parentBlockNumber;
-      llvm::errs() << "#" << arg.getArgNumber();
-      return;
-    }
-    OpResult result = value.cast<OpResult>();
-    llvm::errs()
-        << result.getOwner()->getAttrOfType<StringAttr>("test.ptr").getValue()
-        << "#" << result.getResultNumber();
+
+  /// Print the result of an alias query.
+  void printModRefResult(ModRefResult result, Operation *op, Value location) {
+    printAliasOperand(op);
+    llvm::errs() << " -> ";
+    printAliasOperand(location);
+    llvm::errs() << ": " << result << "\n";
   }
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// Pass Registration
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 namespace test {
 void registerTestAliasAnalysisPass() {
-  PassRegistration<TestAliasAnalysisPass> pass("test-alias-analysis",
-                                               "Test alias analysis results.");
+  PassRegistration<TestAliasAnalysisPass> aliasPass(
+      "test-alias-analysis", "Test alias analysis results.");
+  PassRegistration<TestAliasAnalysisModRefPass> modRefPass(
+      "test-alias-analysis-modref", "Test alias analysis ModRef results.");
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list