[Mlir-commits] [mlir] [mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (PR #145068)
Nicolas Vasilache
llvmlistbot at llvm.org
Sat Jun 21 01:12:29 PDT 2025
================
@@ -0,0 +1,153 @@
+//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a transform to simplify mix/max affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/ADT/IntEqClasses.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-min-max"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Simplifies an affine min/max operation by proving there's a lower or upper
+/// bound.
+template <typename AffineOp>
+static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
+ using Variable = ValueBoundsConstraintSet::Variable;
+ using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
+
+ AffineMap affineMap = affineOp.getMap();
+ ValueRange operands = affineOp.getOperands();
+ static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
+
+ LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+
+ // Create a `Variable` list with values corresponding to each of the results
+ // in the affine affineMap.
+ SmallVector<Variable> variables = llvm::map_to_vector(
+ llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
+ [&](unsigned i) {
+ return Variable(affineMap.getSliceMap(i, 1), operands);
+ });
+
+ // Get the comparison operation.
+ ComparisonOperator cmpOp =
+ isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
+
+ // Find disjoint sets bounded by a common value.
+ llvm::IntEqClasses boundedClasses(variables.size());
+ DenseMap<unsigned, Variable *> bounds;
+ for (auto &&[i, v] : llvm::enumerate(variables)) {
+ unsigned eqClass = boundedClasses.findLeader(i);
+
+ // If the class already has a bound continue.
+ if (bounds.contains(eqClass))
+ continue;
+
+ // Initialize the bound.
+ Variable *bound = &v;
+
+ LLVM_DEBUG({
+ DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
+ });
+
+ // Check against the other variables.
+ for (size_t j = i + 1; j < variables.size(); ++j) {
+ unsigned jEqClass = boundedClasses.findLeader(j);
+ // Get the bound of the equivalence class or itself.
+ Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
+
+ LLVM_DEBUG({
+ DBGS() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap() << "\n";
+ });
+
+ // Compare the variables.
+ FailureOr<bool> cmpResult =
+ ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
----------------
nicolasvasilache wrote:
Can you elaborate why you need a strict comparison here?
It is not clear to me: redundant constraints can be dropped with regular comparisons so if you had to do this I am wondering if there is something off.
In particular we need: `forall x f(x) <= g(x)` (and not e.g. `exists x such that f(x) < g(x)` but I don't see evidence that this is what is being computed).
https://github.com/llvm/llvm-project/pull/145068
More information about the Mlir-commits
mailing list