[llvm-branch-commits] [mlir] a2abbc2 - test composition
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jun 29 10:22:14 PDT 2022
Author: Mogball
Date: 2022-06-29T09:58:35-07:00
New Revision: a2abbc2ec1b00a45e446a9a19ac65868ac9ea8d1
URL: https://github.com/llvm/llvm-project/commit/a2abbc2ec1b00a45e446a9a19ac65868ac9ea8d1
DIFF: https://github.com/llvm/llvm-project/commit/a2abbc2ec1b00a45e446a9a19ac65868ac9ea8d1.diff
LOG: test composition
Added:
Modified:
mlir/test/lib/Analysis/TestDataFlowFramework.cpp
Removed:
################################################################################
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 87b81b533dd6..6f4a1ceb8065 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
@@ -182,8 +183,80 @@ void TestFooAnalysisPass::runOnOperation() {
});
}
+namespace {
+struct AugmentSCP : public DataFlowAnalysis {
+ using DataFlowAnalysis::DataFlowAnalysis;
+
+ LogicalResult initialize(Operation *top) override {
+ top->walk([&](Operation *op) {
+ if (op->getName().getStringRef() == "test.scp_region")
+ (void)visit(op);
+ });
+ return success();
+ }
+
+ LogicalResult visit(ProgramPoint point) override {
+ auto *op = point.get<Operation *>();
+ assert(op->getName().getStringRef() == "test.scp_region");
+
+ auto *rhs = getOrCreateFor<ConstantValueState>(op, op->getOperand(0));
+ if (rhs->isUninitialized()) return success();
+
+ for (Region ®ion : op->getRegions()) {
+ for (Value value : region.getArguments()) {
+ assert(staticallyProvides(TypeID::get<ConstantValueState>(), value));
+ update<ConstantValueState>(
+ value, [rhs](ConstantValueState *lhs) { return lhs->join(*rhs); });
+ }
+ }
+ return success();
+ }
+
+ bool staticallyProvides(TypeID stateID, ProgramPoint point) const override {
+ if (stateID != TypeID::get<ConstantValueState>())
+ return false;
+
+ auto value = point.dyn_cast<Value>();
+ if (!value || !value.isa<BlockArgument>() ||
+ value.getParentBlock() != &value.getParentRegion()->front())
+ return false;
+
+ return value.getParentRegion()->getParentOp()->getName().getStringRef() ==
+ "test.scp_region";
+ }
+};
+
+struct AugmentSCPPass : public PassWrapper<AugmentSCPPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AugmentSCPPass)
+
+ StringRef getArgument() const override { return "test-augment-scp"; }
+
+ void runOnOperation() override {
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
+ solver.load<AugmentSCP>();
+ if (failed(solver.initializeAndRun(getOperation())))
+ return signalPassFailure();
+
+ getOperation()->walk([&](Operation *op) {
+ for (auto &result : llvm::enumerate(op->getResults())) {
+ auto *cv = solver.lookup<ConstantValueState>(result.value());
+ if (!cv || cv->isUninitialized() || !cv->getValue().getConstantValue())
+ continue;
+ llvm::errs() << "op " << op->getName() << " result #" << result.index()
+ << " -> " << cv->getValue().getConstantValue() << "\n";
+ }
+ });
+ }
+};
+} // end anonymous namespace
+
namespace mlir {
namespace test {
-void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
-} // namespace test
-} // namespace mlir
+void registerTestFooAnalysisPass() {
+ PassRegistration<TestFooAnalysisPass>();
+ PassRegistration<AugmentSCPPass>();
+}
+} // end namespace test
+} // end namespace mlir
More information about the llvm-branch-commits
mailing list