[Mlir-commits] [mlir] af371f9 - Reland [GreedPatternRewriter] Preprocess constants while building worklist when not processing top down

River Riddle llvmlistbot at llvm.org
Thu Apr 7 11:32:08 PDT 2022


Author: River Riddle
Date: 2022-04-07T11:31:42-07:00
New Revision: af371f9f98dab48b6374c182f3bc6a18e4faa5fb

URL: https://github.com/llvm/llvm-project/commit/af371f9f98dab48b6374c182f3bc6a18e4faa5fb
DIFF: https://github.com/llvm/llvm-project/commit/af371f9f98dab48b6374c182f3bc6a18e4faa5fb.diff

LOG: Reland [GreedPatternRewriter] Preprocess constants while building worklist when not processing top down

Reland Note: Adds a fix to properly mark a commutative operation as folded if we change the order
             of its operands. This was uncovered by the fact that we no longer re-process constants.

This avoids accidentally reversing the order of constants during successive
application, e.g. when running the canonicalizer. This helps reduce the number
of iterations, and also avoids unnecessary changes to input IR.

Fixes #51892

Differential Revision: https://reviews.llvm.org/D122692

Added: 
    

Modified: 
    flang/test/Lower/host-associated.f90
    mlir/include/mlir/Transforms/FoldUtils.h
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
    mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir
    mlir/test/Dialect/Linalg/detensorize_if.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/Dialect/SparseTensor/dense.mlir
    mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/Dialect/Tensor/split-padding.mlir
    mlir/test/Transforms/test-operation-folder.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index ea8c21dcfa6d0..a2c7ef10ed58a 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -569,12 +569,12 @@ end subroutine test_proc_dummy_other
 ! CHECK-SAME:                       %[[VAL_0:.*]]: !fir.ref<!fir.char<1,40>>,
 ! CHECK-SAME:                       %[[VAL_1:.*]]: index,
 ! CHECK-SAME:                       %[[VAL_2:.*]]: tuple<!fir.boxproc<() -> ()>, i64> {fir.char_proc}) -> !fir.boxchar<1> {
-! CHECK:         %[[VAL_3:.*]] = arith.constant 40 : index
-! CHECK:         %[[VAL_4:.*]] = arith.constant 12 : index
-! CHECK:         %[[VAL_5:.*]] = arith.constant false
-! CHECK:         %[[VAL_6:.*]] = arith.constant 1 : index
-! CHECK:         %[[VAL_7:.*]] = arith.constant 32 : i8
-! CHECK:         %[[VAL_8:.*]] = arith.constant 0 : index
+! CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 40 : index
+! CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 12 : index
+! CHECK-DAG:     %[[VAL_5:.*]] = arith.constant false
+! CHECK-DAG:     %[[VAL_6:.*]] = arith.constant 1 : index
+! CHECK-DAG:     %[[VAL_7:.*]] = arith.constant 32 : i8
+! CHECK-DAG:     %[[VAL_8:.*]] = arith.constant 0 : index
 ! CHECK:         %[[VAL_9:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.char<1,40>>) -> !fir.ref<!fir.char<1,?>>
 ! CHECK:         %[[VAL_10:.*]] = fir.address_of(@_QQcl.{{.*}}) : !fir.ref<!fir.char<1,12>>
 ! CHECK:         %[[VAL_11:.*]] = fir.extract_value %[[VAL_2]], [0 : index] : (tuple<!fir.boxproc<() -> ()>, i64>) -> !fir.boxproc<() -> ()>

diff  --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 49af0f45e1a22..e10955cfb324f 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -45,6 +45,16 @@ class OperationFolder {
             function_ref<void(Operation *)> preReplaceAction = nullptr,
             bool *inPlaceUpdate = nullptr);
 
+  /// Tries to fold a pre-existing constant operation. `constValue` represents
+  /// the value of the constant, and can be optionally passed if the value is
+  /// already known (e.g. if the constant was discovered by m_Constant). This is
+  /// purely an optimization opportunity for callers that already know the value
+  /// of the constant. Returns false if an existing constant for `op` already
+  /// exists in the folder, in which case `op` is replaced and erased.
+  /// Otherwise, returns true and `op` is inserted into the folder (and
+  /// hoisted if necessary).
+  bool insertKnownConstant(Operation *op, Attribute constValue = {});
+
   /// Notifies that the given constant `op` should be remove from this
   /// OperationFolder's internal bookkeeping.
   ///
@@ -114,12 +124,24 @@ class OperationFolder {
   using ConstantMap =
       DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
 
+  /// Returns true if the given operation is an already folded constant that is
+  /// owned by this folder.
+  bool isFolderOwnedConstant(Operation *op) const;
+
   /// Tries to perform folding on the given `op`. If successful, populates
   /// `results` with the results of the folding.
   LogicalResult tryToFold(
       OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
       function_ref<void(Operation *)> processGeneratedConstants = nullptr);
 
+  /// Try to process a set of fold results, generating constants as necessary.
+  /// Populates `results` on success, otherwise leaves it unchanged.
+  LogicalResult
+  processFoldResults(OpBuilder &builder, Operation *op,
+                     SmallVectorImpl<Value> &results,
+                     ArrayRef<OpFoldResult> foldResults,
+                     function_ref<void(Operation *)> processGeneratedConstants);
+
   /// Try to get or create a new constant entry. On success this returns the
   /// constant operation, nullptr otherwise.
   Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 36ebdbd4b5858..fcfdbe4afab5f 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -75,8 +75,14 @@ LogicalResult OperationFolder::tryToFold(
 
   // If this is a unique'd constant, return failure as we know that it has
   // already been folded.
-  if (referencedDialects.count(op))
+  if (isFolderOwnedConstant(op)) {
+    // Check to see if we should rehoist, i.e. if a non-constant operation was
+    // inserted before this one.
+    Block *opBlock = op->getBlock();
+    if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
+      op->moveBefore(&opBlock->front());
     return failure();
+  }
 
   // Try to fold the operation.
   SmallVector<Value, 8> results;
@@ -104,6 +110,59 @@ LogicalResult OperationFolder::tryToFold(
   return success();
 }
 
+bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
+  Block *opBlock = op->getBlock();
+
+  // If this is a constant we unique'd, we don't need to insert, but we can
+  // check to see if we should rehoist it.
+  if (isFolderOwnedConstant(op)) {
+    if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
+      op->moveBefore(&opBlock->front());
+    return true;
+  }
+
+  // Get the constant value of the op if necessary.
+  if (!constValue) {
+    matchPattern(op, m_Constant(&constValue));
+    assert(constValue && "expected `op` to be a constant");
+  } else {
+    // Ensure that the provided constant was actually correct.
+#ifndef NDEBUG
+    Attribute expectedValue;
+    matchPattern(op, m_Constant(&expectedValue));
+    assert(
+        expectedValue == constValue &&
+        "provided constant value was not the expected value of the constant");
+#endif
+  }
+
+  // Check for an existing constant operation for the attribute value.
+  Region *insertRegion = getInsertionRegion(interfaces, opBlock);
+  auto &uniquedConstants = foldScopes[insertRegion];
+  Operation *&folderConstOp = uniquedConstants[std::make_tuple(
+      op->getDialect(), constValue, *op->result_type_begin())];
+
+  // If there is an existing constant, replace `op`.
+  if (folderConstOp) {
+    op->replaceAllUsesWith(folderConstOp);
+    op->erase();
+    return false;
+  }
+
+  // Otherwise, we insert `op`. If `op` is in the insertion block and is either
+  // already at the front of the block, or the previous operation is already a
+  // constant we unique'd (i.e. one we inserted), then we don't need to do
+  // anything. Otherwise, we move the constant to the insertion block.
+  Block *insertBlock = &insertRegion->front();
+  if (opBlock != insertBlock || (&insertBlock->front() != op &&
+                                 !isFolderOwnedConstant(op->getPrevNode())))
+    op->moveBefore(&insertBlock->front());
+
+  folderConstOp = op;
+  referencedDialects[op].push_back(op->getDialect());
+  return true;
+}
+
 /// Notifies that the given constant `op` should be remove from this
 /// OperationFolder's internal bookkeeping.
 void OperationFolder::notifyRemoval(Operation *op) {
@@ -156,19 +215,30 @@ Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
   return constOp ? constOp->getResult(0) : Value();
 }
 
+bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
+  return referencedDialects.count(op);
+}
+
 /// Tries to perform folding on the given `op`. If successful, populates
 /// `results` with the results of the folding.
 LogicalResult OperationFolder::tryToFold(
     OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
     function_ref<void(Operation *)> processGeneratedConstants) {
   SmallVector<Attribute, 8> operandConstants;
-  SmallVector<OpFoldResult, 8> foldResults;
 
   // If this is a commutative operation, move constants to be trailing operands.
+  bool updatedOpOperands = false;
   if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
-    std::stable_partition(
-        op->getOpOperands().begin(), op->getOpOperands().end(),
-        [&](OpOperand &o) { return !matchPattern(o.get(), m_Constant()); });
+    auto isNonConstant = [&](OpOperand &o) {
+      return !matchPattern(o.get(), m_Constant());
+    };
+    auto *firstConstantIt =
+        llvm::find_if_not(op->getOpOperands(), isNonConstant);
+    auto *newConstantIt = std::stable_partition(
+        firstConstantIt, op->getOpOperands().end(), isNonConstant);
+
+    // Remember if we actually moved anything.
+    updatedOpOperands = firstConstantIt != newConstantIt;
   }
 
   // Check to see if any operands to the operation is constant and whether
@@ -177,10 +247,21 @@ LogicalResult OperationFolder::tryToFold(
   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
 
-  // Attempt to constant fold the operation.
-  if (failed(op->fold(operandConstants, foldResults)))
-    return failure();
+  // Attempt to constant fold the operation. If we failed, check to see if we at
+  // least updated the operands of the operation. We treat this as an in-place
+  // fold.
+  SmallVector<OpFoldResult, 8> foldResults;
+  if (failed(op->fold(operandConstants, foldResults)) ||
+      failed(processFoldResults(builder, op, results, foldResults,
+                                processGeneratedConstants)))
+    return success(updatedOpOperands);
+  return success();
+}
 
+LogicalResult OperationFolder::processFoldResults(
+    OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
+    ArrayRef<OpFoldResult> foldResults,
+    function_ref<void(Operation *)> processGeneratedConstants) {
   // Check to see if the operation was just updated in place.
   if (foldResults.empty())
     return success();
@@ -204,8 +285,10 @@ LogicalResult OperationFolder::tryToFold(
 
     // Check if the result was an SSA value.
     if (auto repl = foldResults[i].dyn_cast<Value>()) {
-      if (repl.getType() != op->getResult(i).getType())
+      if (repl.getType() != op->getResult(i).getType()) {
+        results.clear();
         return failure();
+      }
       results.emplace_back(repl);
       continue;
     }
@@ -233,6 +316,7 @@ LogicalResult OperationFolder::tryToFold(
       notifyRemoval(&op);
       op.erase();
     }
+    results.clear();
     return failure();
   }
 

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 81b57c420a726..8043f604f9b4a 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -140,8 +141,18 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
 
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
-      for (auto &region : regions)
-        region.walk([this](Operation *op) { addToWorklist(op); });
+      for (auto &region : regions) {
+        region.walk([this](Operation *op) {
+          // If we aren't processing top-down, check for existing constants when
+          // populating the worklist. This avoids accidentally reversing the
+          // constant order during processing.
+          Attribute constValue;
+          if (matchPattern(op, m_Constant(&constValue)))
+            if (!folder.insertKnownConstant(op, constValue))
+              return;
+          addToWorklist(op);
+        });
+      }
     } else {
       // Add all nested operations to the worklist in preorder.
       for (auto &region : regions)

diff  --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 471d3992cf5ae..ddf6128c63a5a 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -244,9 +244,9 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x
   // CHECK:     }
   // CHECK:     %[[cst:.*]] = memref.load %[[alloc]][] : memref<vector<3x15xf32>>
 
-  // FULL-UNROLL: %[[C7:.*]] = arith.constant 7.000000e+00 : f32
-  // FULL-UNROLL: %[[VEC0:.*]] = arith.constant dense<7.000000e+00> : vector<3x15xf32>
-  // FULL-UNROLL: %[[C0:.*]] = arith.constant 0 : index
+  // FULL-UNROLL-DAG: %[[C7:.*]] = arith.constant 7.000000e+00 : f32
+  // FULL-UNROLL-DAG: %[[VEC0:.*]] = arith.constant dense<7.000000e+00> : vector<3x15xf32>
+  // FULL-UNROLL-DAG: %[[C0:.*]] = arith.constant 0 : index
   // FULL-UNROLL: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?xf32>
   // FULL-UNROLL: cmpi sgt, %[[DIM]], %[[base]] : index
   // FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir
index 8272316e0998e..5b660d5300b55 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir
@@ -5,33 +5,33 @@
 // CHECK:       %[[MEMREF:.*]]: memref<?xf32>
 func @num_worker_threads(%arg0: memref<?xf32>) {
 
-  // CHECK: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32
-  // CHECK: %[[bracketLowerBound4:.*]] = arith.constant 4 : index
-  // CHECK: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32
-  // CHECK: %[[bracketLowerBound8:.*]] = arith.constant 8 : index
-  // CHECK: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32
-  // CHECK: %[[bracketLowerBound16:.*]] = arith.constant 16 : index
-  // CHECK: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32
-  // CHECK: %[[bracketLowerBound32:.*]] = arith.constant 32 : index
-  // CHECK: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32
-  // CHECK: %[[bracketLowerBound64:.*]] = arith.constant 64 : index
-  // CHECK: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32
-  // CHECK:   %[[workersIndex:.*]] = async.runtime.num_worker_threads : index
-  // CHECK:   %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index
-  // CHECK:   %[[scalingFactor4:.*]] = arith.select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32
-  // CHECK:   %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index
-  // CHECK:   %[[scalingFactor8:.*]] = arith.select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32
-  // CHECK:   %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index
-  // CHECK:   %[[scalingFactor16:.*]] = arith.select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32
-  // CHECK:   %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index
-  // CHECK:   %[[scalingFactor32:.*]] = arith.select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32
-  // CHECK:   %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index
-  // CHECK:   %[[scalingFactor64:.*]] = arith.select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32
-  // CHECK:   %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32
-  // CHECK:   %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32
-  // CHECK:   %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32
-  // CHECK:   %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32
-  // CHECK:   %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index
+  // CHECK-DAG: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32
+  // CHECK-DAG: %[[bracketLowerBound4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32
+  // CHECK-DAG: %[[bracketLowerBound8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32
+  // CHECK-DAG: %[[bracketLowerBound16:.*]] = arith.constant 16 : index
+  // CHECK-DAG: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32
+  // CHECK-DAG: %[[bracketLowerBound32:.*]] = arith.constant 32 : index
+  // CHECK-DAG: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32
+  // CHECK-DAG: %[[bracketLowerBound64:.*]] = arith.constant 64 : index
+  // CHECK-DAG: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32
+  // CHECK:     %[[workersIndex:.*]] = async.runtime.num_worker_threads : index
+  // CHECK:     %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index
+  // CHECK:     %[[scalingFactor4:.*]] = arith.select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32
+  // CHECK:     %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index
+  // CHECK:     %[[scalingFactor8:.*]] = arith.select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32
+  // CHECK:     %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index
+  // CHECK:     %[[scalingFactor16:.*]] = arith.select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32
+  // CHECK:     %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index
+  // CHECK:     %[[scalingFactor32:.*]] = arith.select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32
+  // CHECK:     %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index
+  // CHECK:     %[[scalingFactor64:.*]] = arith.select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32
+  // CHECK:     %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32
+  // CHECK:     %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32
+  // CHECK:     %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32
+  // CHECK:     %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32
+  // CHECK:     %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index
 
   %lb = arith.constant 0 : index
   %ub = arith.constant 100 : index

diff  --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 1fcd7f6c7a8a6..2cc282e422654 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,9 +42,9 @@ func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-NEXT:     arith.constant 0
-// CHECK-NEXT:     arith.constant 10
-// CHECK-NEXT:     cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
 // CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
 // CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
 // CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
@@ -106,9 +106,9 @@ func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-NEXT:     arith.constant 0
-// CHECK-NEXT:     arith.constant 10
-// CHECK-NEXT:     cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
 // CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
 // CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
 // CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
@@ -171,9 +171,9 @@ func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-NEXT:     arith.constant 0
-// CHECK-NEXT:     arith.constant 10
-// CHECK-NEXT:     cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
 // CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
 // CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
 // CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 522422526906f..4cbc3d52486c1 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -301,7 +301,7 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
   return
 }
 // CHECK-LABEL: func @aligned_promote_fill
-// CHECK:	  %[[cf:.*]] = arith.constant {{.*}} : f32
+// CHECK:	  %[[cf:.*]] = arith.constant 1.{{.*}} : f32
 // CHECK:         %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
 // CHECK:         %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8>
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>

diff  --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index 25c60724f5bc7..5e9b587b39b28 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -78,11 +78,11 @@ func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
 // CHECK-LABEL:   func @dense2(
 // CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
 // CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK:           %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
 // CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
index 9bbc16a1c0e10..ea654ce8cb72b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
@@ -24,9 +24,9 @@
 // CHECK-SAME:              %[[VAL_2:.*2]]: f32,
 // CHECK-SAME:              %[[VAL_3:.*3]]: f32,
 // CHECK-SAME:              %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
-// CHECK:           %[[VAL_5:.*]] = arith.constant 2.200000e+00 : f32
-// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2.200000e+00 : f32
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_8:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : f32
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index b29ec0201f09f..da27b9c80b6e5 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -183,9 +183,9 @@ func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
 // CHECK-LABEL:   func @tensor.generate(
 // CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
 // CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
-// CHECK:           %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
 // CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
 // CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
 // CHECK:             %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>

diff  --git a/mlir/test/Dialect/Tensor/split-padding.mlir b/mlir/test/Dialect/Tensor/split-padding.mlir
index 40d186c678c4d..730cd63ed14cf 100644
--- a/mlir/test/Dialect/Tensor/split-padding.mlir
+++ b/mlir/test/Dialect/Tensor/split-padding.mlir
@@ -27,8 +27,8 @@ func @pad_non_zero_sizes(%input: tensor<?x?x8xf32>, %low0: index, %high1: index)
   return %0 : tensor<?x?x8xf32>
 }
 
-// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
 // CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[LOW0]], %[[C0]] : index
 // CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[HIGH1]], %[[C0]] : index
 // CHECK: %[[AND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1

diff  --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 76b529dd35b61..23b80c4e95aa5 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-patterns %s | FileCheck %s
+// RUN: mlir-opt -test-patterns -test-patterns %s | FileCheck %s
 
 func @foo() -> i32 {
   %c42 = arith.constant 42 : i32
@@ -22,3 +22,14 @@ func @test_fold_before_previously_folded_op() -> (i32, i32) {
   %1 = "test.cast"() {test_fold_before_previously_folded_op} : () -> (i32)
   return %0, %1 : i32, i32
 }
+
+func @test_dont_reorder_constants() -> (i32, i32, i32) {
+  // Test that we don't reorder existing constants during folding if it isn't necessary.
+  // CHECK: %[[CST:.+]] = arith.constant 1
+  // CHECK-NEXT: %[[CST:.+]] = arith.constant 2
+  // CHECK-NEXT: %[[CST:.+]] = arith.constant 3
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 2 : i32
+  %2 = arith.constant 3 : i32
+  return %0, %1, %2 : i32, i32, i32
+}


        


More information about the Mlir-commits mailing list