[Mlir-commits] [mlir] [mlir][affine] remove divide zero check when simplifer affineMap (#64622) Draft (PR #68519)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 13 00:50:56 PDT 2023
https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/68519
>From 20e3fbb17364f3fb5c822e59f69368ccf7ddcfa6 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sun, 8 Oct 2023 18:31:59 +0800
Subject: [PATCH] [mlir][affine] remove divide zero check when simplifer
affineMap (#64622)
when affineApplyOp has poison semantics we should not fold the op, but also not crash
---
mlir/include/mlir/IR/AffineExprVisitor.h | 172 +++++++++++++++++------
mlir/lib/IR/AffineExpr.cpp | 59 ++++----
mlir/lib/IR/AffineMap.cpp | 27 +++-
3 files changed, 182 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index f6216614c2238e1..7d38bbfb8a506ab 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -14,6 +14,7 @@
#define MLIR_IR_AFFINEEXPRVISITOR_H
#include "mlir/IR/AffineExpr.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
@@ -65,8 +66,80 @@ namespace mlir {
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
+template <typename SubClass, typename RetTy>
+class AffineExprVisitorBase {
+public:
+ // Function to visit an AffineExpr.
+ RetTy visit(AffineExpr expr) {
+ static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
+ "Must instantiate with a derived type of AffineExprVisitor");
+ switch (expr.getKind()) {
+ case AffineExprKind::Add: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+ }
+ case AffineExprKind::Mul: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+ }
+ case AffineExprKind::Mod: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+ }
+ case AffineExprKind::FloorDiv: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+ }
+ case AffineExprKind::CeilDiv: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+ }
+ case AffineExprKind::Constant:
+ return static_cast<SubClass *>(this)->visitConstantExpr(
+ expr.cast<AffineConstantExpr>());
+ case AffineExprKind::DimId:
+ return static_cast<SubClass *>(this)->visitDimExpr(
+ expr.cast<AffineDimExpr>());
+ case AffineExprKind::SymbolId:
+ return static_cast<SubClass *>(this)->visitSymbolExpr(
+ expr.cast<AffineSymbolExpr>());
+ }
+ llvm_unreachable("Unknown AffineExpr");
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Visitation functions... these functions provide default fallbacks in case
+ // the user does not specify what to do for a particular instruction type.
+ // The default behavior is to generalize the instruction type to its subtype
+ // and try visiting the subtype. All of this should be inlined perfectly,
+ // because there are no virtual functions to get in the way.
+ //
+
+ // Default visit methods. Note that the default op-specific binary op visit
+ // methods call the general visitAffineBinaryOpExpr visit method.
+ RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
+ RetTy visitAddExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitMulExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitModExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
+ RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
+ RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
+};
+
template <typename SubClass, typename RetTy = void>
-class AffineExprVisitor {
+class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
@@ -113,29 +186,59 @@ class AffineExprVisitor {
}
}
- // Function to visit an AffineExpr.
- RetTy visit(AffineExpr expr) {
+private:
+ // Walk the operands - each operand is itself walked in post order.
+ RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+ walkPostOrder(expr.getLHS());
+ walkPostOrder(expr.getRHS());
+ }
+};
+
+template <typename SubClass>
+class AffineExprVisitor<SubClass, LogicalResult>
+ : public AffineExprVisitorBase<SubClass, LogicalResult> {
+ //===--------------------------------------------------------------------===//
+ // Interface code - This is the public interface of the AffineExprVisitor
+ // that you use to visit affine expressions...
+public:
+ // Function to walk an AffineExpr (in post order).
+ LogicalResult walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr))) {
+ return failure();
+ }
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr))) {
+ return failure();
+ }
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr))) {
+ return failure();
+ }
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr))) {
+ return failure();
+ }
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr))) {
+ return failure();
+ }
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
@@ -151,41 +254,16 @@ class AffineExprVisitor {
llvm_unreachable("Unknown AffineExpr");
}
- //===--------------------------------------------------------------------===//
- // Visitation functions... these functions provide default fallbacks in case
- // the user does not specify what to do for a particular instruction type.
- // The default behavior is to generalize the instruction type to its subtype
- // and try visiting the subtype. All of this should be inlined perfectly,
- // because there are no virtual functions to get in the way.
- //
-
- // Default visit methods. Note that the default op-specific binary op visit
- // methods call the general visitAffineBinaryOpExpr visit method.
- RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
- RetTy visitAddExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitMulExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitModExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
- RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
- RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
-
private:
// Walk the operands - each operand is itself walked in post order.
- RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
- walkPostOrder(expr.getLHS());
- walkPostOrder(expr.getRHS());
+ LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+ if (failed(walkPostOrder(expr.getLHS()))) {
+ return failure();
+ }
+ if (failed(walkPostOrder(expr.getRHS()))) {
+ return failure();
+ }
+ return success();
}
};
@@ -246,7 +324,7 @@ class AffineExprVisitor {
// expressions are mapped to the same local identifier (same column position in
// 'localVarCst').
class SimpleAffineExprFlattener
- : public AffineExprVisitor<SimpleAffineExprFlattener> {
+ : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
public:
// Flattend expression layout: [dims, symbols, locals, constant]
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
@@ -275,13 +353,13 @@ class SimpleAffineExprFlattener
virtual ~SimpleAffineExprFlattener() = default;
// Visitor method overrides.
- void visitMulExpr(AffineBinaryOpExpr expr);
- void visitAddExpr(AffineBinaryOpExpr expr);
- void visitDimExpr(AffineDimExpr expr);
- void visitSymbolExpr(AffineSymbolExpr expr);
- void visitConstantExpr(AffineConstantExpr expr);
- void visitCeilDivExpr(AffineBinaryOpExpr expr);
- void visitFloorDivExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitDimExpr(AffineDimExpr expr);
+ LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
+ LogicalResult visitConstantExpr(AffineConstantExpr expr);
+ LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
//
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
@@ -289,7 +367,9 @@ class SimpleAffineExprFlattener
// A mod expression "expr mod c" is thus flattened by introducing a new local
// variable q (= expr floordiv c), such that expr mod c is replaced with
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
- void visitModExpr(AffineBinaryOpExpr expr);
+ // TODO Modify the return value to LogicResult and handle cases where the
+ // division is zero
+ LogicalResult visitModExpr(AffineBinaryOpExpr expr);
protected:
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -328,7 +408,7 @@ class SimpleAffineExprFlattener
//
// A ceildiv is similarly flattened:
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
- void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
+ LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
int findLocalId(AffineExpr localExpr);
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 7eccbca4e6e7a1a..563fea3d958d953 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -511,7 +511,6 @@ unsigned AffineSymbolExpr::getPosition() const {
AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
- ;
}
AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
@@ -1135,7 +1134,7 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
// introduce a local variable p (= expr * symbolic_expr), and the affine
// expression expr * symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
SmallVector<int64_t, 8> rhs = operandExprStack.back();
operandExprStack.pop_back();
@@ -1151,7 +1150,7 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
addLocalVariableSemiAffine(a * b, lhs, lhs.size());
- return;
+ return success();
}
// Get the RHS constant.
@@ -1159,9 +1158,10 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
for (unsigned i = 0, e = lhs.size(); i < e; i++) {
lhs[i] *= rhsConst;
}
+ return success();
}
-void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
const auto &rhs = operandExprStack.back();
auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -1172,6 +1172,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
}
// Pop off the RHS.
operandExprStack.pop_back();
+ return success();
}
//
@@ -1184,7 +1185,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
// introduce a local variable m (= expr mod symbolic_expr), and the affine
// expression expr mod symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
SmallVector<int64_t, 8> rhs = operandExprStack.back();
@@ -1202,13 +1203,12 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
- return;
+ return success();
}
int64_t rhsConst = rhs[getConstantIndex()];
- // TODO: handle modulo by zero case when this issue is fixed
- // at the other places in the IR.
- assert(rhsConst > 0 && "RHS constant has to be positive");
+ if (rhsConst <= 0)
+ return failure();
// Check if the LHS expression is a multiple of modulo factor.
unsigned i, e;
@@ -1218,7 +1218,7 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
// If yes, modulo expression here simplifies to zero.
if (i == lhs.size()) {
std::fill(lhs.begin(), lhs.end(), 0);
- return;
+ return success();
}
// Add a local variable for the quotient, i.e., expr % c is replaced by
@@ -1250,33 +1250,41 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
// Reuse the existing local id.
lhs[getLocalVarStartIndex() + loc] = -rhsConst;
}
+ return success();
}
-void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
- visitDivExpr(expr, /*isCeil=*/true);
+LogicalResult
+SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr, /*isCeil=*/true);
}
-void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
- visitDivExpr(expr, /*isCeil=*/false);
+LogicalResult
+SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr, /*isCeil=*/false);
}
-void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
assert(expr.getPosition() < numDims && "Inconsistent number of dims");
eq[getDimStartIndex() + expr.getPosition()] = 1;
+ return success();
}
-void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
eq[getSymbolStartIndex() + expr.getPosition()] = 1;
+ return success();
}
-void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getConstantIndex()] = expr.getValue();
+ return success();
}
void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
@@ -1307,8 +1315,8 @@ void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
// `localExprs`.
-void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
- bool isCeil) {
+LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
+ bool isCeil) {
assert(operandExprStack.size() >= 2);
MLIRContext *context = expr.getContext();
@@ -1326,14 +1334,13 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
- return;
+ return success();
}
// This is a pure affine expr; the RHS is a positive constant.
int64_t rhsConst = rhs[getConstantIndex()];
- // TODO: handle division by zero at the same time the issue is
- // fixed at other places.
- assert(rhsConst > 0 && "RHS constant has to be positive");
+ if (rhsConst <= 0)
+ return failure();
// Simplify the floordiv, ceildiv if possible by canceling out the greatest
// common divisors of the numerator and denominator.
@@ -1349,7 +1356,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
// If the divisor becomes 1, the updated LHS is the result. (The
// divisor can't be negative since rhsConst is positive).
if (divisor == 1)
- return;
+ return success();
// If the divisor cannot be simplified to one, we will have to retain
// the ceil/floor expr (simplified up until here). Add an existential
@@ -1379,6 +1386,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
else
lhs[getLocalVarStartIndex() + loc] = 1;
+ return success();
}
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -1419,7 +1427,8 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
expr = simplifySemiAffine(expr);
SimpleAffineExprFlattener flattener(numDims, numSymbols);
- flattener.walkPostOrder(expr);
+ if (failed(flattener.walkPostOrder(expr)))
+ return expr;
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
if (!expr.isPureAffine() &&
expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9cdac964710ca86..deef146c233d074 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/STLExtras.h"
@@ -56,13 +57,28 @@ class AffineExprConstantFolder {
expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
case AffineExprKind::Mod:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
+ expr, [expr](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ if (rhs < 1) {
+ return std::nullopt;
+ }
+ return mod(lhs, rhs);
+ });
case AffineExprKind::FloorDiv:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
+ expr, [expr](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ if (0 == rhs) {
+ return std::nullopt;
+ }
+ return floorDiv(lhs, rhs);
+ });
case AffineExprKind::CeilDiv:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
+ expr, [expr](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ if (0 == rhs) {
+ return std::nullopt;
+ }
+ return ceilDiv(lhs, rhs);
+ });
case AffineExprKind::Constant:
return expr.cast<AffineConstantExpr>().getValue();
case AffineExprKind::DimId:
@@ -81,8 +97,9 @@ class AffineExprConstantFolder {
}
// TODO: Change these to operate on APInts too.
- std::optional<int64_t> constantFoldBinExpr(AffineExpr expr,
- int64_t (*op)(int64_t, int64_t)) {
+ std::optional<int64_t> constantFoldBinExpr(
+ AffineExpr expr,
+ llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
More information about the Mlir-commits
mailing list