[Mlir-commits] [mlir] 041bc48 - [mlir][Interfaces] ValueBoundsOpInterface: Support LB and UB bounds

Matthias Springer llvmlistbot at llvm.org
Thu Apr 6 18:53:20 PDT 2023


Author: Matthias Springer
Date: 2023-04-07T10:48:19+09:00
New Revision: 041bc485bf2122b238eb1a336d3a38168feb8eaa

URL: https://github.com/llvm/llvm-project/commit/041bc485bf2122b238eb1a336d3a38168feb8eaa
DIFF: https://github.com/llvm/llvm-project/commit/041bc485bf2122b238eb1a336d3a38168feb8eaa.diff

LOG: [mlir][Interfaces] ValueBoundsOpInterface: Support LB and UB bounds

This change also adds support for `affine.min` and `affine.max` ops.

Differential Revision: https://reviews.llvm.org/D145787

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
    mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
    mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
    mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 2abd329b51620..e957fb1cd0dbd 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -51,14 +51,15 @@ FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
 
 /// Reify a bound for the given index-typed value or shape dimension size in
 /// terms of the owning op's operands. `dim` must be `nullopt` if and only if
-/// `value` is index-typed.
+/// `value` is index-typed. LB and EQ bounds are closed, UB bounds are open.
 FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
                                         presburger::BoundType type, Value value,
                                         std::optional<int64_t> dim);
 
 /// Reify a bound for the given index-typed value or shape dimension size in
 /// terms of SSA values for which `stopCondition` is met. `dim` must be
-/// `nullopt` if and only if `value` is index-typed.
+/// `nullopt` if and only if `value` is index-typed. LB and EQ bounds are
+/// closed, UB bounds are open.
 ///
 /// Example:
 /// %0 = arith.addi %a, %b : index

diff  --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index ed8b9a74e2e77..0036023a8a015 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -38,6 +38,48 @@ struct AffineApplyOpInterface
   }
 };
 
+struct AffineMinOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<AffineMinOpInterface,
+                                                   AffineMinOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto minOp = cast<AffineMinOp>(op);
+    assert(value == minOp.getResult() && "invalid value");
+
+    // Align affine map results with dims/symbols in the constraint set.
+    for (AffineExpr expr : minOp.getAffineMap().getResults()) {
+      SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
+          minOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
+          minOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      AffineExpr bound =
+          expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
+      cstr.bound(value) <= bound;
+    }
+  };
+};
+
+struct AffineMaxOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<AffineMaxOpInterface,
+                                                   AffineMaxOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto maxOp = cast<AffineMaxOp>(op);
+    assert(value == maxOp.getResult() && "invalid value");
+
+    // Align affine map results with dims/symbols in the constraint set.
+    for (AffineExpr expr : maxOp.getAffineMap().getResults()) {
+      SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
+          maxOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
+          maxOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+      AffineExpr bound =
+          expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
+      cstr.bound(value) >= bound;
+    }
+  };
+};
+
 } // namespace
 } // namespace mlir
 
@@ -45,5 +87,7 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) {
     AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
+    AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
+    AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
   });
 }

diff  --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 73757b611520a..8db5e6865646b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -214,9 +214,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   assertValidValueDim(value, dim);
 #endif // NDEBUG
 
-  // Only EQ bounds are supported at the moment.
-  assert(type == BoundType::EQ && "unsupported bound type");
-
   Builder b(value.getContext());
   mapOperands.clear();
 
@@ -249,16 +246,39 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   SmallVector<AffineMap> lb(1), ub(1);
   cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
                            /*getClosedUB=*/true);
+
   // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
   // case, no lower/upper bound can be computed at the moment.
-  if (lb.empty() || !lb[0] || ub.empty() || !ub[0] ||
-      lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1)
+  // EQ, UB bounds: upper bound is needed.
+  if ((type != BoundType::LB) &&
+      (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
     return failure();
+  // EQ, LB bounds: lower bound is needed.
+  if ((type != BoundType::UB) &&
+      (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
+    return failure();
+
+  // TODO: Generate an affine map with multiple results.
+  if (type != BoundType::LB)
+    assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
+           "multiple bounds not supported");
+  if (type != BoundType::UB)
+    assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
+           "multiple bounds not supported");
 
-  // Look for same lower and upper bound: EQ bound.
-  if (ub[0] != lb[0])
+  // EQ bound: lower and upper bound must match.
+  if (type == BoundType::EQ && ub[0] != lb[0])
     return failure();
 
+  AffineMap bound;
+  if (type == BoundType::EQ || type == BoundType::LB) {
+    bound = lb[0];
+  } else {
+    // Computed UB is a closed bound. Turn into an open bound.
+    bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(),
+                           ub[0].getResult(0) + 1);
+  }
+
   // Gather all SSA values that are used in the computed bound.
   assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
          "inconsistent mapping state");
@@ -273,10 +293,10 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     bool used = false;
     bool isDim = i < cstr.cstr.getNumDimVars();
     if (isDim) {
-      if (lb[0].isFunctionOfDim(i))
+      if (bound.isFunctionOfDim(i))
         used = true;
     } else {
-      if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
+      if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
         used = true;
     }
 
@@ -312,7 +332,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     mapOperands.push_back(std::make_pair(value, dim));
   }
 
-  resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols,
+  resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols,
                                           numDims, numSymbols);
   return success();
 }

diff  --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 436ec4a8326cb..338c48c5b210b 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -12,3 +12,49 @@ func.func @affine_apply(%a: index, %b: index) -> index {
   %1 = "test.reify_bound"(%0) : (index) -> (index)
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @affine_max_lb(
+//  CHECK-SAME:     %[[a:.*]]: index
+//       CHECK:   %[[c2:.*]] = arith.constant 2 : index
+//       CHECK:   return %[[c2]]
+func.func @affine_max_lb(%a: index) -> (index) {
+  // Note: There are two LBs: s0 and 2. FlatAffineValueConstraints always
+  // returns the constant one at the moment.
+  %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a]
+  %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+func.func @affine_max_ub(%a: index) -> (index) {
+  %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a]
+  // expected-error @below{{could not reify bound}}
+  %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @affine_min_ub(
+//  CHECK-SAME:     %[[a:.*]]: index
+//       CHECK:   %[[c3:.*]] = arith.constant 3 : index
+//       CHECK:   return %[[c3]]
+func.func @affine_min_ub(%a: index) -> (index) {
+  // Note: There are two UBs: s0 + 1 and 3. FlatAffineValueConstraints always
+  // returns the constant one at the moment.
+  %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a]
+  %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+func.func @affine_min_lb(%a: index) -> (index) {
+  %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a]
+  // expected-error @below{{could not reify bound}}
+  %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
+  return %2 : index
+}

diff  --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 582831565aa73..e2a06a51f4636 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -16,6 +16,7 @@
 #define PASS_NAME "test-affine-reify-value-bounds"
 
 using namespace mlir;
+using mlir::presburger::BoundType;
 
 namespace {
 
@@ -45,6 +46,16 @@ struct TestReifyValueBounds
 
 } // namespace
 
+FailureOr<BoundType> parseBoundType(std::string type) {
+  if (type == "EQ")
+    return BoundType::EQ;
+  if (type == "LB")
+    return BoundType::LB;
+  if (type == "UB")
+    return BoundType::UB;
+  return failure();
+}
+
 /// Look for "test.reify_bound" ops in the input and replace their results with
 /// the reified values.
 static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
@@ -67,6 +78,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
         return WalkResult::skip();
       }
 
+      // Get bound type.
+      std::string boundTypeStr = "EQ";
+      if (auto boundTypeAttr = op->getAttrOfType<StringAttr>("type"))
+        boundTypeStr = boundTypeAttr.str();
+      auto boundType = parseBoundType(boundTypeStr);
+      if (failed(boundType)) {
+        op->emitOpError("invalid op");
+        return WalkResult::interrupt();
+      }
+
+      // Get shape dimension (if any).
       auto dim = value.getType().isIndex()
                      ? std::nullopt
                      : std::make_optional<int64_t>(
@@ -77,8 +99,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
       FailureOr<OpFoldResult> reified;
       if (!reifyToFuncArgs) {
         // Reify in terms of the op's operands.
-        reified = reifyValueBound(rewriter, op->getLoc(),
-                                  presburger::BoundType::EQ, value, dim);
+        reified =
+            reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim);
       } else {
         // Reify in terms of function block arguments.
         auto stopCondition = [](Value v) {
@@ -88,9 +110,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
           return isa<FunctionOpInterface>(
               bbArg.getParentBlock()->getParentOp());
         };
-        reified =
-            reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ,
-                            value, dim, stopCondition);
+        reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value,
+                                  dim, stopCondition);
       }
       if (failed(reified)) {
         op->emitOpError("could not reify bound");


        


More information about the Mlir-commits mailing list