[llvm-branch-commits] [mlir] a2abbc2 - test composition

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jun 29 10:22:14 PDT 2022


Author: Mogball
Date: 2022-06-29T09:58:35-07:00
New Revision: a2abbc2ec1b00a45e446a9a19ac65868ac9ea8d1

URL: https://github.com/llvm/llvm-project/commit/a2abbc2ec1b00a45e446a9a19ac65868ac9ea8d1
DIFF: https://github.com/llvm/llvm-project/commit/a2abbc2ec1b00a45e446a9a19ac65868ac9ea8d1.diff

LOG: test composition

Added: 
    

Modified: 
    mlir/test/lib/Analysis/TestDataFlowFramework.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 87b81b533dd6..6f4a1ceb8065 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Pass/Pass.h"
 
@@ -182,8 +183,80 @@ void TestFooAnalysisPass::runOnOperation() {
   });
 }
 
+namespace {
+struct AugmentSCP : public DataFlowAnalysis {
+  using DataFlowAnalysis::DataFlowAnalysis;
+
+  LogicalResult initialize(Operation *top) override {
+    top->walk([&](Operation *op) {
+      if (op->getName().getStringRef() == "test.scp_region")
+        (void)visit(op);
+    });
+    return success();
+  }
+
+  LogicalResult visit(ProgramPoint point) override {
+    auto *op = point.get<Operation *>();
+    assert(op->getName().getStringRef() == "test.scp_region");
+
+    auto *rhs = getOrCreateFor<ConstantValueState>(op, op->getOperand(0));
+    if (rhs->isUninitialized()) return success();
+
+    for (Region &region : op->getRegions()) {
+      for (Value value : region.getArguments()) {
+        assert(staticallyProvides(TypeID::get<ConstantValueState>(), value));
+        update<ConstantValueState>(
+            value, [rhs](ConstantValueState *lhs) { return lhs->join(*rhs); });
+      }
+    }
+    return success();
+  }
+
+  bool staticallyProvides(TypeID stateID, ProgramPoint point) const override {
+    if (stateID != TypeID::get<ConstantValueState>())
+      return false;
+
+    auto value = point.dyn_cast<Value>();
+    if (!value || !value.isa<BlockArgument>() ||
+        value.getParentBlock() != &value.getParentRegion()->front())
+      return false;
+
+    return value.getParentRegion()->getParentOp()->getName().getStringRef() ==
+           "test.scp_region";
+  }
+};
+
+struct AugmentSCPPass : public PassWrapper<AugmentSCPPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AugmentSCPPass)
+
+  StringRef getArgument() const override { return "test-augment-scp"; }
+
+  void runOnOperation() override {
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<AugmentSCP>();
+    if (failed(solver.initializeAndRun(getOperation())))
+      return signalPassFailure();
+
+    getOperation()->walk([&](Operation *op) {
+      for (auto &result : llvm::enumerate(op->getResults())) {
+        auto *cv = solver.lookup<ConstantValueState>(result.value());
+        if (!cv || cv->isUninitialized() || !cv->getValue().getConstantValue())
+          continue;
+        llvm::errs() << "op " << op->getName() << " result #" << result.index()
+                     << " -> " << cv->getValue().getConstantValue() << "\n";
+      }
+    });
+  }
+};
+} // end anonymous namespace
+
 namespace mlir {
 namespace test {
-void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
-} // namespace test
-} // namespace mlir
+void registerTestFooAnalysisPass() {
+  PassRegistration<TestFooAnalysisPass>();
+  PassRegistration<AugmentSCPPass>();
+}
+} // end namespace test
+} // end namespace mlir


        


More information about the llvm-branch-commits mailing list