[Mlir-commits] [mlir] [mlir][affine] remove divide zero check when simplifer affineMap (#64622) Draft (PR #68519)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 13 00:50:56 PDT 2023


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/68519

>From 20e3fbb17364f3fb5c822e59f69368ccf7ddcfa6 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sun, 8 Oct 2023 18:31:59 +0800
Subject: [PATCH] [mlir][affine] remove divide zero check when simplifer
 affineMap (#64622)

when affineApplyOp has poison semantics we should not fold the op, but also not crash
---
 mlir/include/mlir/IR/AffineExprVisitor.h | 172 +++++++++++++++++------
 mlir/lib/IR/AffineExpr.cpp               |  59 ++++----
 mlir/lib/IR/AffineMap.cpp                |  27 +++-
 3 files changed, 182 insertions(+), 76 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index f6216614c2238e1..7d38bbfb8a506ab 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -14,6 +14,7 @@
 #define MLIR_IR_AFFINEEXPRVISITOR_H
 
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
@@ -65,8 +66,80 @@ namespace mlir {
 /// just as efficient as having your own switch instruction over the instruction
 /// opcode.
 
+template <typename SubClass, typename RetTy>
+class AffineExprVisitorBase {
+public:
+  // Function to visit an AffineExpr.
+  RetTy visit(AffineExpr expr) {
+    static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
+                  "Must instantiate with a derived type of AffineExprVisitor");
+    switch (expr.getKind()) {
+    case AffineExprKind::Add: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+    }
+    case AffineExprKind::Mul: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+    }
+    case AffineExprKind::Mod: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+    }
+    case AffineExprKind::FloorDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+    }
+    case AffineExprKind::CeilDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+    }
+    case AffineExprKind::Constant:
+      return static_cast<SubClass *>(this)->visitConstantExpr(
+          expr.cast<AffineConstantExpr>());
+    case AffineExprKind::DimId:
+      return static_cast<SubClass *>(this)->visitDimExpr(
+          expr.cast<AffineDimExpr>());
+    case AffineExprKind::SymbolId:
+      return static_cast<SubClass *>(this)->visitSymbolExpr(
+          expr.cast<AffineSymbolExpr>());
+    }
+    llvm_unreachable("Unknown AffineExpr");
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Visitation functions... these functions provide default fallbacks in case
+  // the user does not specify what to do for a particular instruction type.
+  // The default behavior is to generalize the instruction type to its subtype
+  // and try visiting the subtype.  All of this should be inlined perfectly,
+  // because there are no virtual functions to get in the way.
+  //
+
+  // Default visit methods. Note that the default op-specific binary op visit
+  // methods call the general visitAffineBinaryOpExpr visit method.
+  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
+  RetTy visitAddExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitMulExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitModExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
+  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
+  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
+};
+
 template <typename SubClass, typename RetTy = void>
-class AffineExprVisitor {
+class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
   //===--------------------------------------------------------------------===//
   // Interface code - This is the public interface of the AffineExprVisitor
   // that you use to visit affine expressions...
@@ -113,29 +186,59 @@ class AffineExprVisitor {
     }
   }
 
-  // Function to visit an AffineExpr.
-  RetTy visit(AffineExpr expr) {
+private:
+  // Walk the operands - each operand is itself walked in post order.
+  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+    walkPostOrder(expr.getLHS());
+    walkPostOrder(expr.getRHS());
+  }
+};
+
+template <typename SubClass>
+class AffineExprVisitor<SubClass, LogicalResult>
+    : public AffineExprVisitorBase<SubClass, LogicalResult> {
+  //===--------------------------------------------------------------------===//
+  // Interface code - This is the public interface of the AffineExprVisitor
+  // that you use to visit affine expressions...
+public:
+  // Function to walk an AffineExpr (in post order).
+  LogicalResult walkPostOrder(AffineExpr expr) {
     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
                   "Must instantiate with a derived type of AffineExprVisitor");
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr))) {
+        return failure();
+      }
       return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr))) {
+        return failure();
+      }
       return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr))) {
+        return failure();
+      }
       return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr))) {
+        return failure();
+      }
       return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr))) {
+        return failure();
+      }
       return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
@@ -151,41 +254,16 @@ class AffineExprVisitor {
     llvm_unreachable("Unknown AffineExpr");
   }
 
-  //===--------------------------------------------------------------------===//
-  // Visitation functions... these functions provide default fallbacks in case
-  // the user does not specify what to do for a particular instruction type.
-  // The default behavior is to generalize the instruction type to its subtype
-  // and try visiting the subtype.  All of this should be inlined perfectly,
-  // because there are no virtual functions to get in the way.
-  //
-
-  // Default visit methods. Note that the default op-specific binary op visit
-  // methods call the general visitAffineBinaryOpExpr visit method.
-  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
-  RetTy visitAddExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitMulExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitModExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
-  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
-  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
-
 private:
   // Walk the operands - each operand is itself walked in post order.
-  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
-    walkPostOrder(expr.getLHS());
-    walkPostOrder(expr.getRHS());
+  LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+    if (failed(walkPostOrder(expr.getLHS()))) {
+      return failure();
+    }
+    if (failed(walkPostOrder(expr.getRHS()))) {
+      return failure();
+    }
+    return success();
   }
 };
 
@@ -246,7 +324,7 @@ class AffineExprVisitor {
 // expressions are mapped to the same local identifier (same column position in
 // 'localVarCst').
 class SimpleAffineExprFlattener
-    : public AffineExprVisitor<SimpleAffineExprFlattener> {
+    : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
 public:
   // Flattend expression layout: [dims, symbols, locals, constant]
   // Stack that holds the LHS and RHS operands while visiting a binary op expr.
@@ -275,13 +353,13 @@ class SimpleAffineExprFlattener
   virtual ~SimpleAffineExprFlattener() = default;
 
   // Visitor method overrides.
-  void visitMulExpr(AffineBinaryOpExpr expr);
-  void visitAddExpr(AffineBinaryOpExpr expr);
-  void visitDimExpr(AffineDimExpr expr);
-  void visitSymbolExpr(AffineSymbolExpr expr);
-  void visitConstantExpr(AffineConstantExpr expr);
-  void visitCeilDivExpr(AffineBinaryOpExpr expr);
-  void visitFloorDivExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitDimExpr(AffineDimExpr expr);
+  LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
+  LogicalResult visitConstantExpr(AffineConstantExpr expr);
+  LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
 
   //
   // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
@@ -289,7 +367,9 @@ class SimpleAffineExprFlattener
   // A mod expression "expr mod c" is thus flattened by introducing a new local
   // variable q (= expr floordiv c), such that expr mod c is replaced with
   // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
-  void visitModExpr(AffineBinaryOpExpr expr);
+  // TODO Modify the return value to LogicResult and handle cases where the
+  // division is zero
+  LogicalResult visitModExpr(AffineBinaryOpExpr expr);
 
 protected:
   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -328,7 +408,7 @@ class SimpleAffineExprFlattener
   //
   // A ceildiv is similarly flattened:
   // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
-  void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
+  LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
 
   int findLocalId(AffineExpr localExpr);
 
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 7eccbca4e6e7a1a..563fea3d958d953 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -511,7 +511,6 @@ unsigned AffineSymbolExpr::getPosition() const {
 
 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
-  ;
 }
 
 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
@@ -1135,7 +1134,7 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
 // introduce a local variable p (= expr * symbolic_expr), and the affine
 // expression expr * symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
   SmallVector<int64_t, 8> rhs = operandExprStack.back();
   operandExprStack.pop_back();
@@ -1151,7 +1150,7 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     addLocalVariableSemiAffine(a * b, lhs, lhs.size());
-    return;
+    return success();
   }
 
   // Get the RHS constant.
@@ -1159,9 +1158,10 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
     lhs[i] *= rhsConst;
   }
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
   const auto &rhs = operandExprStack.back();
   auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -1172,6 +1172,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
   }
   // Pop off the RHS.
   operandExprStack.pop_back();
+  return success();
 }
 
 //
@@ -1184,7 +1185,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
 // introduce a local variable m (= expr mod symbolic_expr), and the affine
 // expression expr mod symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
 
   SmallVector<int64_t, 8> rhs = operandExprStack.back();
@@ -1202,13 +1203,12 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
     addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
-    return;
+    return success();
   }
 
   int64_t rhsConst = rhs[getConstantIndex()];
-  // TODO: handle modulo by zero case when this issue is fixed
-  // at the other places in the IR.
-  assert(rhsConst > 0 && "RHS constant has to be positive");
+  if (rhsConst <= 0)
+    return failure();
 
   // Check if the LHS expression is a multiple of modulo factor.
   unsigned i, e;
@@ -1218,7 +1218,7 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
   // If yes, modulo expression here simplifies to zero.
   if (i == lhs.size()) {
     std::fill(lhs.begin(), lhs.end(), 0);
-    return;
+    return success();
   }
 
   // Add a local variable for the quotient, i.e., expr % c is replaced by
@@ -1250,33 +1250,41 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     // Reuse the existing local id.
     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
   }
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
-  visitDivExpr(expr, /*isCeil=*/true);
+LogicalResult
+SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
+  return visitDivExpr(expr, /*isCeil=*/true);
 }
-void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
-  visitDivExpr(expr, /*isCeil=*/false);
+LogicalResult
+SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
+  return visitDivExpr(expr, /*isCeil=*/false);
 }
 
-void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
   auto &eq = operandExprStack.back();
   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
   eq[getDimStartIndex() + expr.getPosition()] = 1;
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
   auto &eq = operandExprStack.back();
   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
   auto &eq = operandExprStack.back();
   eq[getConstantIndex()] = expr.getValue();
+  return success();
 }
 
 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
@@ -1307,8 +1315,8 @@ void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
 // `localExprs`.
-void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
-                                             bool isCeil) {
+LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
+                                                      bool isCeil) {
   assert(operandExprStack.size() >= 2);
 
   MLIRContext *context = expr.getContext();
@@ -1326,14 +1334,13 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
     addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
-    return;
+    return success();
   }
 
   // This is a pure affine expr; the RHS is a positive constant.
   int64_t rhsConst = rhs[getConstantIndex()];
-  // TODO: handle division by zero at the same time the issue is
-  // fixed at other places.
-  assert(rhsConst > 0 && "RHS constant has to be positive");
+  if (rhsConst <= 0)
+    return failure();
 
   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
   // common divisors of the numerator and denominator.
@@ -1349,7 +1356,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
   // If the divisor becomes 1, the updated LHS is the result. (The
   // divisor can't be negative since rhsConst is positive).
   if (divisor == 1)
-    return;
+    return success();
 
   // If the divisor cannot be simplified to one, we will have to retain
   // the ceil/floor expr (simplified up until here). Add an existential
@@ -1379,6 +1386,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
   else
     lhs[getLocalVarStartIndex() + loc] = 1;
+  return success();
 }
 
 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -1419,7 +1427,8 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
     expr = simplifySemiAffine(expr);
 
   SimpleAffineExprFlattener flattener(numDims, numSymbols);
-  flattener.walkPostOrder(expr);
+  if (failed(flattener.walkPostOrder(expr)))
+    return expr;
   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
   if (!expr.isPureAffine() &&
       expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9cdac964710ca86..deef146c233d074 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/STLExtras.h"
@@ -56,13 +57,28 @@ class AffineExprConstantFolder {
           expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
     case AffineExprKind::Mod:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
+          expr, [expr](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            if (rhs < 1) {
+              return std::nullopt;
+            }
+            return mod(lhs, rhs);
+          });
     case AffineExprKind::FloorDiv:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
+          expr, [expr](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            if (0 == rhs) {
+              return std::nullopt;
+            }
+            return floorDiv(lhs, rhs);
+          });
     case AffineExprKind::CeilDiv:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
+          expr, [expr](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            if (0 == rhs) {
+              return std::nullopt;
+            }
+            return ceilDiv(lhs, rhs);
+          });
     case AffineExprKind::Constant:
       return expr.cast<AffineConstantExpr>().getValue();
     case AffineExprKind::DimId:
@@ -81,8 +97,9 @@ class AffineExprConstantFolder {
   }
 
   // TODO: Change these to operate on APInts too.
-  std::optional<int64_t> constantFoldBinExpr(AffineExpr expr,
-                                             int64_t (*op)(int64_t, int64_t)) {
+  std::optional<int64_t> constantFoldBinExpr(
+      AffineExpr expr,
+      llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) {
     auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
     if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
       if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))



More information about the Mlir-commits mailing list