[flang-commits] [flang] 7814b55 - [GreedyPatternRewriter] Avoid reversing constant order

River Riddle via flang-commits flang-commits at lists.llvm.org
Wed May 18 00:56:14 PDT 2022


Author: rkayaith
Date: 2022-05-18T00:55:59-07:00
New Revision: 7814b559bd5e1dbb3c016b393068698bc5781cc5

URL: https://github.com/llvm/llvm-project/commit/7814b559bd5e1dbb3c016b393068698bc5781cc5
DIFF: https://github.com/llvm/llvm-project/commit/7814b559bd5e1dbb3c016b393068698bc5781cc5.diff

LOG: [GreedyPatternRewriter] Avoid reversing constant order

The previous fix from af371f9f98da only applied when using a bottom-up
traversal. The change here applies the constant preprocessing logic to the
top-down case as well. This resolves the issue with the canonicalizer pass still
reordering constants, since it uses a top-down traversal by default.

Fixes #51892

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D125623

Added: 
    

Modified: 
    flang/test/Lower/Intrinsics/achar.f90
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir
    mlir/test/Dialect/SCF/canonicalize.mlir
    mlir/test/Transforms/test-operation-folder.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/flang/test/Lower/Intrinsics/achar.f90 b/flang/test/Lower/Intrinsics/achar.f90
index 9cf394893b9ae..924de45b2e8ac 100644
--- a/flang/test/Lower/Intrinsics/achar.f90
+++ b/flang/test/Lower/Intrinsics/achar.f90
@@ -3,8 +3,8 @@
 
 ! CHECK-LABEL: test1
 ! CHECK-SAME: (%[[XREF:.*]]: !fir.ref<i32> {{.*}}, %[[CBOX:.*]]: !fir.boxchar<1> {{.*}})
-! CHECK: %[[C1:.*]] = arith.constant 1 : index
-! CHECK: %[[FALSE:.*]] = arith.constant false
+! CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+! CHECK-DAG: %[[FALSE:.*]] = arith.constant false
 ! CHECK: %[[TEMP:.*]] = fir.alloca !fir.char<1> {adapt.valuebyref}
 ! CHECK: %[[C:.*]]:2 = fir.unboxchar %[[CBOX]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
 ! CHECK: %[[X:.*]] = fir.load %[[XREF]] : !fir.ref<i32>

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 0b80cd66459bd..b2e9659451018 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -133,6 +133,16 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
   };
 #endif
 
+  auto insertKnownConstant = [&](Operation *op) {
+    // Check for existing constants when populating the worklist. This avoids
+    // accidentally reversing the constant order during processing.
+    Attribute constValue;
+    if (matchPattern(op, m_Constant(&constValue)))
+      if (!folder.insertKnownConstant(op, constValue))
+        return true;
+    return false;
+  };
+
   bool changed = false;
   unsigned iteration = 0;
   do {
@@ -142,22 +152,18 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       for (auto &region : regions) {
-        region.walk([this](Operation *op) {
-          // If we aren't processing top-down, check for existing constants when
-          // populating the worklist. This avoids accidentally reversing the
-          // constant order during processing.
-          Attribute constValue;
-          if (matchPattern(op, m_Constant(&constValue)))
-            if (!folder.insertKnownConstant(op, constValue))
-              return;
-          addToWorklist(op);
+        region.walk([&](Operation *op) {
+          if (!insertKnownConstant(op))
+            addToWorklist(op);
         });
       }
     } else {
       // Add all nested operations to the worklist in preorder.
       for (auto &region : regions)
-        region.walk<WalkOrder::PreOrder>(
-            [this](Operation *op) { worklist.push_back(op); });
+        region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+          if (!insertKnownConstant(op))
+            worklist.push_back(op);
+        });
 
       // Reverse the list so our pop-back loop processes them in-order.
       std::reverse(worklist.begin(), worklist.end());

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index c560222b18eca..11e20458bf60e 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -733,8 +733,8 @@ func.func @bitcastOfBitcast(%arg : i16) -> i16 {
 // -----
 
 // CHECK-LABEL: test_maxsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
 // CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]]
 // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
 func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -749,8 +749,8 @@ func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
 }
 
 // CHECK-LABEL: test_maxsi2
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
 // CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]]
 // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
 func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -767,8 +767,8 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
 // -----
 
 // CHECK-LABEL: test_maxui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
 // CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]]
 // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
 func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -783,8 +783,8 @@ func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
 }
 
 // CHECK-LABEL: test_maxui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
 // CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]]
 // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
 func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -801,8 +801,8 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
 // -----
 
 // CHECK-LABEL: test_minsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
 // CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]]
 // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
 func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -817,8 +817,8 @@ func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
 }
 
 // CHECK-LABEL: test_minsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
 // CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]]
 // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
 func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -835,8 +835,8 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
 // -----
 
 // CHECK-LABEL: test_minui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
 // CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]]
 // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
 func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -851,8 +851,8 @@ func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
 }
 
 // CHECK-LABEL: test_minui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
 // CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]]
 // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
 func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 986819e7b5691..cca2f439bd70f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1036,9 +1036,9 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
   }
   return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
 }
-// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
-// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
 // CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
+// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
 // CHECK:    %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
 // CHECK:       arith.cmpi slt, %[[ARG0]], %{{.*}}
 // CHECK:       tensor.extract %{{.*}}[]
@@ -1069,9 +1069,9 @@ func.func @while_loop_invariant_argument_
diff erent_order() -> (tensor<i32>, tens
   }
   return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
 }
-// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
-// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
 // CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
+// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
 // CHECK:    %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
 // CHECK:       arith.cmpi slt, %[[ZERO]], %[[CST42]]
 // CHECK:       tensor.extract %{{.*}}[]

diff  --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 488231a226ca3..670ec232a3922 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -test-patterns -test-patterns %s | FileCheck %s
+// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
+// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
 
 func.func @foo() -> i32 {
   %c42 = arith.constant 42 : i32

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 09c7a12d96dfd..264b118c8956d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -151,6 +151,9 @@ struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
 
+  TestPatternDriver() = default;
+  TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
+
   StringRef getArgument() const final { return "test-patterns"; }
   StringRef getDescription() const final { return "Run test dialect patterns"; }
   void runOnOperation() override {
@@ -162,8 +165,16 @@ struct TestPatternDriver
                  FolderInsertBeforePreviouslyFoldedConstantPattern,
                  FolderCommutativeOp2WithConstant>(&getContext());
 
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    GreedyRewriteConfig config;
+    config.useTopDownTraversal = this->useTopDownTraversal;
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                       config);
   }
+
+  Option<bool> useTopDownTraversal{
+      *this, "top-down",
+      llvm::cl::desc("Seed the worklist in general top-down order"),
+      llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
 };
 } // namespace
 


        


More information about the flang-commits mailing list