[Mlir-commits] [mlir] 0dc9087 - [mlir][Interfaces] ValueBoundsOpInterface: Compute constant bounds

Matthias Springer llvmlistbot at llvm.org
Thu Apr 6 19:39:22 PDT 2023


Author: Matthias Springer
Date: 2023-04-07T11:35:02+09:00
New Revision: 0dc9087ac752659d29fdea6b9fabdd1b7987c996

URL: https://github.com/llvm/llvm-project/commit/0dc9087ac752659d29fdea6b9fabdd1b7987c996
DIFF: https://github.com/llvm/llvm-project/commit/0dc9087ac752659d29fdea6b9fabdd1b7987c996.diff

LOG: [mlir][Interfaces] ValueBoundsOpInterface: Compute constant bounds

Add a helper function that computes a constant (`int64_t`) bound. The `stopCondition` is optional: If none is provided, the traversal continues until a constant bound could be computed.

Differential Revision: https://reviews.llvm.org/D146296

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
    mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
    mlir/test/Dialect/Affine/value-bounds-reification.mlir
    mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
    mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 97d27a04df893..a4a7c98ae3e01 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -16,6 +16,8 @@
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
 
+#include <queue>
+
 namespace mlir {
 
 using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
@@ -100,6 +102,24 @@ class ValueBoundsConstraintSet {
                                     std::optional<int64_t> dim,
                                     StopConditionFn stopCondition);
 
+  /// Compute a constant bound for the given index-typed value or shape
+  /// dimension size.
+  ///
+  /// `dim` must be `nullopt` if and only if `value` is index-typed. This
+  /// function traverses the backward slice of the given value in a
+  /// worklist-driven manner until `stopCondition` evaluates to "true". The
+  /// constraint set is populated according to `ValueBoundsOpInterface` for each
+  /// visited value. (No constraints are added for values for which the stop
+  /// condition evaluates to "true".)
+  ///
+  /// The stop condition is optional: If none is specified, the backward slice
+  /// is traversed in a breadth-first manner until a constant bound could be
+  /// computed.
+  static FailureOr<int64_t>
+  computeConstantBound(presburger::BoundType type, Value value,
+                       std::optional<int64_t> dim = std::nullopt,
+                       StopConditionFn stopCondition = nullptr);
+
   /// Add a bound for the given index-typed value or shaped value. This function
   /// returns a builder that adds the bound.
   BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
@@ -162,7 +182,7 @@ class ValueBoundsConstraintSet {
   DenseMap<ValueDim, int64_t> valueDimToPosition;
 
   /// Worklist of values/shape dimensions that have not been processed yet.
-  SetVector<int64_t> worklist;
+  std::queue<int64_t> worklist;
 
   /// Constraint system of equalities and inequalities.
   FlatLinearConstraints cstr;

diff  --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 8db5e6865646b..a2885e22d01bb 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -121,7 +121,7 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
     valueDimToPosition[positionToValueDim[i]] = i;
 
-  worklist.insert(pos);
+  worklist.push(pos);
   return pos;
 }
 
@@ -148,7 +148,8 @@ static Operation *getOwnerOfValue(Value value) {
 
 void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
   while (!worklist.empty()) {
-    int64_t pos = worklist.pop_back_val();
+    int64_t pos = worklist.front();
+    worklist.pop();
     ValueDim valueDim = positionToValueDim[pos];
     Value value = valueDim.first;
     int64_t dim = valueDim.second;
@@ -337,6 +338,33 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   return success();
 }
 
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+    presburger::BoundType type, Value value, std::optional<int64_t> dim,
+    StopConditionFn stopCondition) {
+#ifndef NDEBUG
+  assertValidValueDim(value, dim);
+#endif // NDEBUG
+
+  // Process the backward slice of `value` (i.e., reverse use-def chain) until
+  // `stopCondition` is met.
+  ValueBoundsConstraintSet cstr(value, dim);
+  int64_t pos = cstr.getPos(value, dim);
+  if (stopCondition) {
+    cstr.processWorklist(stopCondition);
+  } else {
+    // No stop condition specified: Keep adding constraints until a bound could
+    // be computed.
+    cstr.processWorklist(/*stopCondition=*/[&](Value v) {
+      return cstr.cstr.getConstantBound64(type, pos).has_value();
+    });
+  }
+
+  // Compute constant bound for `valueDim`.
+  if (auto bound = cstr.cstr.getConstantBound64(type, pos))
+    return type == BoundType::UB ? *bound + 1 : *bound;
+  return failure();
+}
+
 ValueBoundsConstraintSet::BoundBuilder &
 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
   assert(!this->dim.has_value() && "dim was already set");

diff  --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index e5ee497e8e205..5b4d1f2f42c2f 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -26,6 +26,8 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
 // CHECK-LABEL: func @reify_slice_bound(
 //       CHECK:   %[[c5:.*]] = arith.constant 5 : index
 //       CHECK:   "test.some_use"(%[[c5]])
+//       CHECK:   %[[c5:.*]] = arith.constant 5 : index
+//       CHECK:   "test.some_use"(%[[c5]])
 func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
   %c0 = arith.constant 0 : index
   %c4 = arith.constant 4 : index
@@ -33,8 +35,12 @@ func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f
     %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
     %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
     %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
+
     %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
     "test.some_use"(%bound) : (index) -> ()
+
+    %bound_const = "test.reify_constant_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+    "test.some_use"(%bound_const) : (index) -> ()
   }
   return
 }
@@ -77,6 +83,11 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
     %lb1_ub = "test.reify_bound"(%lb1) {type = "UB"} : (index) -> (index)
     "test.some_use"(%lb1_ub) : (index) -> ()
 
+    // CHECK: %[[c129:.*]] = arith.constant 129 : index
+    // CHECK: "test.some_use"(%[[c129]])
+    %lb1_ub_const = "test.reify_constant_bound"(%lb1) {type = "UB"} : (index) -> (index)
+    "test.some_use"(%lb1_ub_const) : (index) -> ()
+
     scf.for %iv1 = %lb1 to %ub1 step %c32 {
       // CHECK: %[[c32:.*]] = arith.constant 32 : index
       // CHECK: "test.some_use"(%[[c32]])
@@ -94,6 +105,11 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
         // CHECK: "test.some_use"(%[[c32]])
         %matmul_ub = "test.reify_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
         "test.some_use"(%matmul_ub) : (index) -> ()
+
+        // CHECK: %[[c32:.*]] = arith.constant 32 : index
+        // CHECK: "test.some_use"(%[[c32]])
+        %matmul_ub_const = "test.reify_constant_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+        "test.some_use"(%matmul_ub_const) : (index) -> ()
       }
     }
   }

diff  --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 576759e4f21ca..614c6014fec98 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -80,6 +80,27 @@ func.func @extract_slice_static(%t: tensor<?xf32>) -> index {
 
 // -----
 
+func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> index {
+  %0 = tensor.extract_slice %t[2][%sz][1] : tensor<?xf32> to tensor<?xf32>
+  // expected-error @below{{could not reify bound}}
+  %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_slice_static_constant(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
+//       CHECK:   %[[c5:.*]] = arith.constant 5 : index
+//       CHECK:   return %[[c5]]
+func.func @extract_slice_static_constant(%t: tensor<?xf32>) -> index {
+  %0 = tensor.extract_slice %t[2][5][1] : tensor<?xf32> to tensor<5xf32>
+  %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index)
+  return %1 : index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_slice_rank_reduce(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[sz:.*]]: index
 //       CHECK:   return %[[sz]]

diff  --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index e2a06a51f4636..7f66db3b39993 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -63,7 +63,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
   IRRewriter rewriter(funcOp.getContext());
   WalkResult result = funcOp.walk([&](Operation *op) {
     // Look for test.reify_bound ops.
-    if (op->getName().getStringRef() == "test.reify_bound") {
+    if (op->getName().getStringRef() == "test.reify_bound" ||
+        op->getName().getStringRef() == "test.reify_constant_bound") {
       if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
           !op->getResultTypes()[0].isIndex()) {
         op->emitOpError("invalid op");
@@ -94,22 +95,37 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
                      : std::make_optional<int64_t>(
                            op->getAttrOfType<IntegerAttr>("dim").getInt());
 
-      // Reify value bound.
-      rewriter.setInsertionPointAfter(op);
-      FailureOr<OpFoldResult> reified;
-      if (!reifyToFuncArgs) {
-        // Reify in terms of the op's operands.
-        reified =
-            reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim);
-      } else {
+      // Check if a constant was requested.
+      bool constant =
+          op->getName().getStringRef() == "test.reify_constant_bound";
+
+      // Prepare stop condition. By default, reify in terms of the op's
+      // operands. No stop condition is used when a constant was requested.
+      std::function<bool(Value)> stopCondition = [&](Value v) {
+        // Reify in terms of SSA values that are 
diff erent from `value`.
+        return v != value;
+      };
+      if (reifyToFuncArgs) {
         // Reify in terms of function block arguments.
-        auto stopCondition = [](Value v) {
+        stopCondition = stopCondition = [](Value v) {
           auto bbArg = v.dyn_cast<BlockArgument>();
           if (!bbArg)
             return false;
           return isa<FunctionOpInterface>(
               bbArg.getParentBlock()->getParentOp());
         };
+      }
+
+      // Reify value bound
+      rewriter.setInsertionPointAfter(op);
+      FailureOr<OpFoldResult> reified = failure();
+      if (constant) {
+        auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
+            *boundType, value, dim, /*stopCondition=*/nullptr);
+        if (succeeded(reifiedConst))
+          reified =
+              FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
+      } else {
         reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value,
                                   dim, stopCondition);
       }


        


More information about the Mlir-commits mailing list