[Mlir-commits] [mlir] [mlir][affine] remove divide zero check when simplifer affineMap (#64622) (PR #68519)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 18 08:50:19 PST 2023


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/68519

>From 7e74df5e5b1c866e910d72b97db96cb97de937e3 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 3 Nov 2023 23:54:06 +0800
Subject: [PATCH 1/6] [mlir][affine] remove divide zero check when simplifer
 affineMap (#64622)

---
 mlir/include/mlir/IR/AffineExprVisitor.h      | 212 ++++++++++++------
 mlir/include/mlir/IR/AffineMap.h              |   9 +-
 .../Analysis/FlatLinearValueConstraints.cpp   |   6 +-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 108 ++++++++-
 mlir/lib/Dialect/Affine/IR/CMakeLists.txt     |   3 +-
 mlir/lib/IR/AffineExpr.cpp                    |  64 +++---
 mlir/lib/IR/AffineMap.cpp                     |  46 +++-
 mlir/test/Dialect/Affine/constant-fold.mlir   |  21 ++
 8 files changed, 352 insertions(+), 117 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 382db22dce463e5..d603cd90b0cf949 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -14,6 +14,7 @@
 #define MLIR_IR_AFFINEEXPRVISITOR_H
 
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
@@ -65,8 +66,78 @@ namespace mlir {
 /// just as efficient as having your own switch instruction over the instruction
 /// opcode.
 
+template <typename SubClass, typename RetTy>
+class AffineExprVisitorBase {
+public:
+  // Function to visit an AffineExpr.
+  RetTy visit(AffineExpr expr) {
+    static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
+                  "Must instantiate with a derived type of AffineExprVisitor");
+    auto self = static_cast<SubClass *>(this);
+    switch (expr.getKind()) {
+    case AffineExprKind::Add: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return self->visitAddExpr(binOpExpr);
+    }
+    case AffineExprKind::Mul: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return self->visitMulExpr(binOpExpr);
+    }
+    case AffineExprKind::Mod: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return self->visitModExpr(binOpExpr);
+    }
+    case AffineExprKind::FloorDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return self->visitFloorDivExpr(binOpExpr);
+    }
+    case AffineExprKind::CeilDiv: {
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      return self->visitCeilDivExpr(binOpExpr);
+    }
+    case AffineExprKind::Constant:
+      return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
+    case AffineExprKind::DimId:
+      return self->visitDimExpr(expr.cast<AffineDimExpr>());
+    case AffineExprKind::SymbolId:
+      return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
+    }
+    llvm_unreachable("Unknown AffineExpr");
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Visitation functions... these functions provide default fallbacks in case
+  // the user does not specify what to do for a particular instruction type.
+  // The default behavior is to generalize the instruction type to its subtype
+  // and try visiting the subtype.  All of this should be inlined perfectly,
+  // because there are no virtual functions to get in the way.
+  //
+
+  // Default visit methods. Note that the default op-specific binary op visit
+  // methods call the general visitAffineBinaryOpExpr visit method.
+  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
+  RetTy visitAddExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitMulExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitModExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  }
+  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
+  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
+  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
+};
+
 template <typename SubClass, typename RetTy = void>
-class AffineExprVisitor {
+class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
   //===--------------------------------------------------------------------===//
   // Interface code - This is the public interface of the AffineExprVisitor
   // that you use to visit affine expressions...
@@ -75,117 +146,112 @@ class AffineExprVisitor {
   RetTy walkPostOrder(AffineExpr expr) {
     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
                   "Must instantiate with a derived type of AffineExprVisitor");
+    auto self = static_cast<SubClass *>(this);
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       walkOperandsPostOrder(binOpExpr);
-      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+      return self->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       walkOperandsPostOrder(binOpExpr);
-      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+      return self->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       walkOperandsPostOrder(binOpExpr);
-      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+      return self->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       walkOperandsPostOrder(binOpExpr);
-      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+      return self->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       walkOperandsPostOrder(binOpExpr);
-      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+      return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
-      return static_cast<SubClass *>(this)->visitConstantExpr(
-          cast<AffineConstantExpr>(expr));
+      return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
     case AffineExprKind::DimId:
-      return static_cast<SubClass *>(this)->visitDimExpr(
-          cast<AffineDimExpr>(expr));
+      return self->visitDimExpr(expr.cast<AffineDimExpr>());
     case AffineExprKind::SymbolId:
-      return static_cast<SubClass *>(this)->visitSymbolExpr(
-          cast<AffineSymbolExpr>(expr));
+      return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
     }
+    llvm_unreachable("Unknown AffineExpr");
   }
 
-  // Function to visit an AffineExpr.
-  RetTy visit(AffineExpr expr) {
+private:
+  // Walk the operands - each operand is itself walked in post order.
+  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+    walkPostOrder(expr.getLHS());
+    walkPostOrder(expr.getRHS());
+  }
+};
+
+template <typename SubClass>
+class AffineExprVisitor<SubClass, LogicalResult>
+    : public AffineExprVisitorBase<SubClass, LogicalResult> {
+  //===--------------------------------------------------------------------===//
+  // Interface code - This is the public interface of the AffineExprVisitor
+  // that you use to visit affine expressions...
+public:
+  // Function to walk an AffineExpr (in post order).
+  LogicalResult walkPostOrder(AffineExpr expr) {
     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
                   "Must instantiate with a derived type of AffineExprVisitor");
+    auto self = static_cast<SubClass *>(this);
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
-      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr)))
+        return failure();
+      return self->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
-      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr)))
+        return failure();
+      return self->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
-      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr)))
+        return failure();
+      return self->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
-      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr)))
+        return failure();
+      return self->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
-      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      if (failed(walkOperandsPostOrder(binOpExpr)))
+        return failure();
+      return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
-      return static_cast<SubClass *>(this)->visitConstantExpr(
-          cast<AffineConstantExpr>(expr));
+      return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
     case AffineExprKind::DimId:
-      return static_cast<SubClass *>(this)->visitDimExpr(
-          cast<AffineDimExpr>(expr));
+      return self->visitDimExpr(expr.cast<AffineDimExpr>());
     case AffineExprKind::SymbolId:
-      return static_cast<SubClass *>(this)->visitSymbolExpr(
-          cast<AffineSymbolExpr>(expr));
+      return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
     }
     llvm_unreachable("Unknown AffineExpr");
   }
 
-  //===--------------------------------------------------------------------===//
-  // Visitation functions... these functions provide default fallbacks in case
-  // the user does not specify what to do for a particular instruction type.
-  // The default behavior is to generalize the instruction type to its subtype
-  // and try visiting the subtype.  All of this should be inlined perfectly,
-  // because there are no virtual functions to get in the way.
-  //
-
-  // Default visit methods. Note that the default op-specific binary op visit
-  // methods call the general visitAffineBinaryOpExpr visit method.
-  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
-  RetTy visitAddExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitMulExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitModExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
-    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
-  }
-  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
-  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
-  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
-
 private:
   // Walk the operands - each operand is itself walked in post order.
-  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
-    walkPostOrder(expr.getLHS());
-    walkPostOrder(expr.getRHS());
+  LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+    if (failed(walkPostOrder(expr.getLHS())))
+      return failure();
+    if (failed(walkPostOrder(expr.getRHS())))
+      return failure();
+    return success();
   }
 };
 
@@ -246,7 +312,7 @@ class AffineExprVisitor {
 // expressions are mapped to the same local identifier (same column position in
 // 'localVarCst').
 class SimpleAffineExprFlattener
-    : public AffineExprVisitor<SimpleAffineExprFlattener> {
+    : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
 public:
   // Flattend expression layout: [dims, symbols, locals, constant]
   // Stack that holds the LHS and RHS operands while visiting a binary op expr.
@@ -275,13 +341,13 @@ class SimpleAffineExprFlattener
   virtual ~SimpleAffineExprFlattener() = default;
 
   // Visitor method overrides.
-  void visitMulExpr(AffineBinaryOpExpr expr);
-  void visitAddExpr(AffineBinaryOpExpr expr);
-  void visitDimExpr(AffineDimExpr expr);
-  void visitSymbolExpr(AffineSymbolExpr expr);
-  void visitConstantExpr(AffineConstantExpr expr);
-  void visitCeilDivExpr(AffineBinaryOpExpr expr);
-  void visitFloorDivExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitDimExpr(AffineDimExpr expr);
+  LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
+  LogicalResult visitConstantExpr(AffineConstantExpr expr);
+  LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
 
   //
   // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
@@ -289,7 +355,7 @@ class SimpleAffineExprFlattener
   // A mod expression "expr mod c" is thus flattened by introducing a new local
   // variable q (= expr floordiv c), such that expr mod c is replaced with
   // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
-  void visitModExpr(AffineBinaryOpExpr expr);
+  LogicalResult visitModExpr(AffineBinaryOpExpr expr);
 
 protected:
   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -328,7 +394,7 @@ class SimpleAffineExprFlattener
   //
   // A ceildiv is similarly flattened:
   // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
-  void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
+  LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
 
   int findLocalId(AffineExpr localExpr);
 
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 713aef767edf669..0e4a8d363946432 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -310,7 +310,8 @@ class AffineMap {
   /// Folds the results of the application of an affine map on the provided
   /// operands to a constant if possible.
   LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
-                             SmallVectorImpl<Attribute> &results) const;
+                             SmallVectorImpl<Attribute> &results,
+                             bool *hasPoison = nullptr) const;
 
   /// Propagates the constant operands into this affine map. Operands are
   /// allowed to be null, at which point they are treated as non-constant. This
@@ -318,9 +319,9 @@ class AffineMap {
   /// which may be equal to the old map if no folding happened. If `results` is
   /// provided and if all expressions in the map were folded to constants,
   /// `results` will contain the values of these constants.
-  AffineMap
-  partialConstantFold(ArrayRef<Attribute> operandConstants,
-                      SmallVectorImpl<int64_t> *results = nullptr) const;
+  AffineMap partialConstantFold(ArrayRef<Attribute> operandConstants,
+                                SmallVectorImpl<int64_t> *results = nullptr,
+                                bool *hasPoison = nullptr) const;
 
   /// Returns the AffineMap resulting from composing `this` with `map`.
   /// The resulting AffineMap has as many AffineDimExpr as `map` and as many
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index ea123ea56025b2f..72e8cebb5c312ad 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -85,8 +85,10 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
   for (auto expr : exprs) {
     if (!expr.isPureAffine())
       return failure();
-
-    flattener.walkPostOrder(expr);
+    // has poison expression
+    auto flattenResult = flattener.walkPostOrder(expr);
+    if (failed(flattenResult))
+      return failure();
   }
 
   assert(flattener.operandExprStack.size() == exprs.size());
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 05496e70716a2a1..79547a800135fca 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
@@ -226,6 +227,8 @@ void AffineDialect::initialize() {
 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
+  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+    return builder.create<ub::PoisonOp>(loc, type, poison);
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
@@ -580,7 +583,12 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
 
   // Otherwise, default to folding the map.
   SmallVector<Attribute, 1> result;
-  if (failed(map.constantFold(adaptor.getMapOperands(), result)))
+  bool hasPoison = false;
+  auto foldResult =
+      map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
+  if (hasPoison)
+    return ub::PoisonAttr::get(getContext());
+  if (failed(foldResult))
     return {};
   return result[0];
 }
@@ -700,6 +708,100 @@ static std::optional<int64_t> getUpperBound(Value iv) {
   return forOp.getConstantUpperBound() - 1;
 }
 
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
+/// bound can't be computed. This method only handles simple sum of product
+/// expressions (w.r.t constant coefficients) so as to not depend on anything
+/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
+/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
+/// semi-affine ones will lead std::nullopt being returned.
+static std::optional<int64_t>
+getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                ArrayRef<std::optional<int64_t>> constLowerBounds,
+                ArrayRef<std::optional<int64_t>> constUpperBounds,
+                bool isUpper) {
+  // Handle divs and mods.
+  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+    // can compute an upper bound.
+    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!rhsConst || rhsConst.getValue() < 1)
+        return std::nullopt;
+      std::optional<int64_t> bound =
+          getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                          constLowerBounds, constUpperBounds, isUpper);
+      if (!bound)
+        return std::nullopt;
+      return mlir::floorDiv(*bound, rhsConst.getValue());
+    }
+    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        std::optional<int64_t> bound =
+            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                            constLowerBounds, constUpperBounds, isUpper);
+        if (!bound)
+          return std::nullopt;
+        return mlir::ceilDiv(*bound, rhsConst.getValue());
+      }
+      return std::nullopt;
+    }
+    if (binOpExpr.getKind() == AffineExprKind::Mod) {
+      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        int64_t rhsConstVal = rhsConst.getValue();
+        std::optional<int64_t> lb =
+            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                            constLowerBounds, constUpperBounds,
+                            /*isUpper=*/false);
+        std::optional<int64_t> ub =
+            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                            constLowerBounds, constUpperBounds, isUpper);
+        if (ub && lb &&
+            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+        return isUpper ? rhsConstVal - 1 : 0;
+      }
+    }
+  }
+
+  // Flatten the expression.
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  auto flattenResult = flattener.walkPostOrder(expr);
+  // has poison expression
+  if (failed(flattenResult))
+    return std::nullopt;
+  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+  // TODO: Handle local variables. We can get hold of flattener.localExprs and
+  // get bound on the local expr recursively.
+  if (flattener.numLocals > 0)
+    return std::nullopt;
+  int64_t bound = 0;
+  // Substitute the constant lower or upper bound for the dimensional or
+  // symbolic input depending on `isUpper` to determine the bound.
+  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+    if (flattenedExpr[i] > 0) {
+      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    } else if (flattenedExpr[i] < 0) {
+      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    }
+  }
+  // Constant term.
+  bound += flattenedExpr.back();
+  return bound;
+}
+
 /// Determine a constant upper bound for `expr` if one exists while exploiting
 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
 /// is guaranteed to be less than or equal to it.
@@ -3379,7 +3481,9 @@ static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
       return failure();
 
     SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
-    flattener.walkPostOrder(resultExpr);
+    auto flattenResult = flattener.walkPostOrder(resultExpr);
+    if (failed(flattenResult))
+      return failure();
 
     // Fail if the flattened expression has local variables.
     if (flattener.operandExprStack.back().size() !=
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 89ea3128b0e743f..10df928da8233f0 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -16,8 +16,9 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRDialectUtils
   MLIRIR
   MLIRLoopLikeInterface
-  MLIRMemRefDialect
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
   MLIRValueBoundsOpInterface
+  MLIRMemRefDialect
+  MLIRUBDialect
   )
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cdceaac11069815..038ceea286a363f 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1216,7 +1216,7 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
 // introduce a local variable p (= expr * symbolic_expr), and the affine
 // expression expr * symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
   SmallVector<int64_t, 8> rhs = operandExprStack.back();
   operandExprStack.pop_back();
@@ -1232,7 +1232,7 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     addLocalVariableSemiAffine(a * b, lhs, lhs.size());
-    return;
+    return success();
   }
 
   // Get the RHS constant.
@@ -1240,9 +1240,10 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
     lhs[i] *= rhsConst;
   }
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
   const auto &rhs = operandExprStack.back();
   auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -1253,6 +1254,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
   }
   // Pop off the RHS.
   operandExprStack.pop_back();
+  return success();
 }
 
 //
@@ -1265,7 +1267,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
 // introduce a local variable m (= expr mod symbolic_expr), and the affine
 // expression expr mod symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
   assert(operandExprStack.size() >= 2);
 
   SmallVector<int64_t, 8> rhs = operandExprStack.back();
@@ -1283,13 +1285,12 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
     addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
-    return;
+    return success();
   }
 
   int64_t rhsConst = rhs[getConstantIndex()];
-  // TODO: handle modulo by zero case when this issue is fixed
-  // at the other places in the IR.
-  assert(rhsConst > 0 && "RHS constant has to be positive");
+  if (rhsConst <= 0)
+    return failure();
 
   // Check if the LHS expression is a multiple of modulo factor.
   unsigned i, e;
@@ -1299,7 +1300,7 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
   // If yes, modulo expression here simplifies to zero.
   if (i == lhs.size()) {
     std::fill(lhs.begin(), lhs.end(), 0);
-    return;
+    return success();
   }
 
   // Add a local variable for the quotient, i.e., expr % c is replaced by
@@ -1331,33 +1332,41 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     // Reuse the existing local id.
     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
   }
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
-  visitDivExpr(expr, /*isCeil=*/true);
+LogicalResult
+SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
+  return visitDivExpr(expr, /*isCeil=*/true);
 }
-void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
-  visitDivExpr(expr, /*isCeil=*/false);
+LogicalResult
+SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
+  return visitDivExpr(expr, /*isCeil=*/false);
 }
 
-void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
   auto &eq = operandExprStack.back();
   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
   eq[getDimStartIndex() + expr.getPosition()] = 1;
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
   auto &eq = operandExprStack.back();
   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
+  return success();
 }
 
-void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
   auto &eq = operandExprStack.back();
   eq[getConstantIndex()] = expr.getValue();
+  return success();
 }
 
 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
@@ -1388,8 +1397,8 @@ void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
 // `localExprs`.
-void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
-                                             bool isCeil) {
+LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
+                                                      bool isCeil) {
   assert(operandExprStack.size() >= 2);
 
   MLIRContext *context = expr.getContext();
@@ -1407,14 +1416,13 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
     addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
-    return;
+    return success();
   }
 
   // This is a pure affine expr; the RHS is a positive constant.
   int64_t rhsConst = rhs[getConstantIndex()];
-  // TODO: handle division by zero at the same time the issue is
-  // fixed at other places.
-  assert(rhsConst > 0 && "RHS constant has to be positive");
+  if (rhsConst <= 0)
+    return failure();
 
   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
   // common divisors of the numerator and denominator.
@@ -1430,7 +1438,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
   // If the divisor becomes 1, the updated LHS is the result. (The
   // divisor can't be negative since rhsConst is positive).
   if (divisor == 1)
-    return;
+    return success();
 
   // If the divisor cannot be simplified to one, we will have to retain
   // the ceil/floor expr (simplified up until here). Add an existential
@@ -1460,6 +1468,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
   else
     lhs[getLocalVarStartIndex() + loc] = 1;
+  return success();
 }
 
 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -1500,7 +1509,9 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
     expr = simplifySemiAffine(expr, numDims, numSymbols);
 
   SimpleAffineExprFlattener flattener(numDims, numSymbols);
-  flattener.walkPostOrder(expr);
+  // has poison expression
+  if (failed(flattener.walkPostOrder(expr)))
+    return expr;
   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
   if (!expr.isPureAffine() &&
       expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
@@ -1573,7 +1584,10 @@ std::optional<int64_t> mlir::getBoundForAffineExpr(
   }
   // Flatten the expression.
   SimpleAffineExprFlattener flattener(numDims, numSymbols);
-  flattener.walkPostOrder(expr);
+  auto simpleResult = flattener.walkPostOrder(expr);
+  // has poison expression
+  if (failed(simpleResult))
+    return std::nullopt;
   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
   // TODO: Handle local variables. We can get hold of flattener.localExprs and
   // get bound on the local expr recursively.
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 93a8d048e0a61d5..e0293812277a276 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "AffineMapDetail.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -59,13 +60,34 @@ class AffineExprConstantFolder {
           expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
     case AffineExprKind::Mod:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
+          expr,
+          [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            if (rhs < 1) {
+              hasPoison_ = true;
+              return std::nullopt;
+            }
+            return mod(lhs, rhs);
+          });
     case AffineExprKind::FloorDiv:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
+          expr,
+          [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            if (rhs == 0) {
+              hasPoison_ = true;
+              return std::nullopt;
+            }
+            return floorDiv(lhs, rhs);
+          });
     case AffineExprKind::CeilDiv:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
+          expr,
+          [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            if (rhs == 0) {
+              hasPoison_ = true;
+              return std::nullopt;
+            }
+            return ceilDiv(lhs, rhs);
+          });
     case AffineExprKind::Constant:
       return cast<AffineConstantExpr>(expr).getValue();
     case AffineExprKind::DimId:
@@ -387,12 +409,12 @@ std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const {
 /// Folds the results of the application of an affine map on the provided
 /// operands to a constant if possible. Returns false if the folding happens,
 /// true otherwise.
-LogicalResult
-AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
-                        SmallVectorImpl<Attribute> &results) const {
+LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
+                                      SmallVectorImpl<Attribute> &results,
+                                      bool *hasPoison) const {
   // Attempt partial folding.
   SmallVector<int64_t, 2> integers;
-  partialConstantFold(operandConstants, &integers);
+  partialConstantFold(operandConstants, &integers, hasPoison);
 
   // If all expressions folded to a constant, populate results with attributes
   // containing those constants.
@@ -406,9 +428,9 @@ AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
   return success();
 }
 
-AffineMap
-AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
-                               SmallVectorImpl<int64_t> *results) const {
+AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
+                                         SmallVectorImpl<int64_t> *results,
+                                         bool *hasPoison) const {
   assert(getNumInputs() == operandConstants.size());
 
   // Fold each of the result expressions.
@@ -418,6 +440,10 @@ AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
 
   for (auto expr : getResults()) {
     auto folded = exprFolder.constantFold(expr);
+    if (exprFolder.hasPoison() && hasPoison) {
+      *hasPoison = true;
+      return {};
+    }
     // If did not fold to a constant, keep the original expression, and clear
     // the integer results vector.
     if (folded) {
diff --git a/mlir/test/Dialect/Affine/constant-fold.mlir b/mlir/test/Dialect/Affine/constant-fold.mlir
index cdce39855acdff4..5236b44ddfed967 100644
--- a/mlir/test/Dialect/Affine/constant-fold.mlir
+++ b/mlir/test/Dialect/Affine/constant-fold.mlir
@@ -60,3 +60,24 @@ func.func @affine_min(%variable: index) -> (index, index) {
   // CHECK: return %[[r]], %[[C44]]
   return %0, %1 : index, index
 }
+
+// -----
+
+func.func @affine_apply_poison_division_zero() {
+  // This is just for mlir::context to load ub dailect
+  %ub = ub.poison : index
+  %c16 = arith.constant 16 : index
+  %0 = affine.apply affine_map<(d0)[s0] -> (d0 mod (s0 - s0))>(%c16)[%c16]
+  %1 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv (s0 - s0))>(%c16)[%c16]
+  %2 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv (s0 - s0))>(%c16)[%c16]
+  %alloc = memref.alloc(%0, %1, %2) : memref<?x?x?xi1>
+  %3 = affine.load %alloc[%0, %1, %2] : memref<?x?x?xi1>
+  affine.store %3, %alloc[%0, %1, %2] : memref<?x?x?xi1>
+  return
+}
+
+// CHECK-NOT: affine.apply
+// CHECK: %[[poison:.*]] = ub.poison : index
+// CHECK-NEXT: %[[alloc:.*]] = memref.alloc(%[[poison]], %[[poison]], %[[poison]])
+// CHECK-NEXT: %[[load:.*]] = affine.load %[[alloc]][%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1>
+// CHECK-NEXT: affine.store %[[load]], %alloc[%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1>

>From 3ee25c8f7ca867979c2a7fce63763114075bb4f1 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 4 Nov 2023 12:58:56 +0800
Subject: [PATCH 2/6] refine

---
 mlir/lib/Analysis/FlatLinearValueConstraints.cpp | 4 +++-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp         | 6 ++++--
 mlir/lib/Dialect/Affine/IR/CMakeLists.txt        | 2 +-
 3 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 72e8cebb5c312ad..db32921b041a471 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -67,7 +67,9 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
 } // namespace
 
 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
-// flattened (i.e., semi-affine expressions not handled yet).
+// flattened.For example two specific cases:
+// 1. semi-affine expressions not handled yet.
+// 2. has poison expression (i.e., division by zero).
 static LogicalResult
 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
                         unsigned numSymbols,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 79547a800135fca..c134407d2c50091 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -786,12 +786,14 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
   // symbolic input depending on `isUpper` to determine the bound.
   for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
     if (flattenedExpr[i] > 0) {
-      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+      const std::optional<int64_t> &constBound =
+          isUpper ? constUpperBounds[i] : constLowerBounds[i];
       if (!constBound)
         return std::nullopt;
       bound += *constBound * flattenedExpr[i];
     } else if (flattenedExpr[i] < 0) {
-      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+      const std::optional<int64_t> &constBound =
+          isUpper ? constLowerBounds[i] : constUpperBounds[i];
       if (!constBound)
         return std::nullopt;
       bound += *constBound * flattenedExpr[i];
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 10df928da8233f0..f1d98ab30f92355 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -16,9 +16,9 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRDialectUtils
   MLIRIR
   MLIRLoopLikeInterface
+  MLIRMemRefDialect
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
   MLIRValueBoundsOpInterface
-  MLIRMemRefDialect
   MLIRUBDialect
   )

>From eeef95d7b8ce98b7d0d25a5b1514dffe72c6b444 Mon Sep 17 00:00:00 2001
From: Javier Setoain <jsetoain at users.noreply.github.com>
Date: Mon, 6 Nov 2023 15:30:12 +0000
Subject: [PATCH 3/6] Typo

---
 mlir/lib/Analysis/FlatLinearValueConstraints.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index db32921b041a471..69846a356e0cc42 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -67,7 +67,7 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
 } // namespace
 
 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
-// flattened.For example two specific cases:
+// flattened. For example two specific cases:
 // 1. semi-affine expressions not handled yet.
 // 2. has poison expression (i.e., division by zero).
 static LogicalResult

>From 55a54f99f200e5c4ae2efd1a61a055a0e50b6c59 Mon Sep 17 00:00:00 2001
From: Javier Setoain <jsetoain at users.noreply.github.com>
Date: Mon, 6 Nov 2023 15:33:41 +0000
Subject: [PATCH 4/6] Preserve alphabetic order in dependency list

---
 mlir/lib/Dialect/Affine/IR/CMakeLists.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index f1d98ab30f92355..9e3c1161fd92a09 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -19,6 +19,6 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRMemRefDialect
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
-  MLIRValueBoundsOpInterface
   MLIRUBDialect
+  MLIRValueBoundsOpInterface
   )

>From 7b9149c2b5747befd6db5c3fe626a4d167ed0cce Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Thu, 16 Nov 2023 17:23:44 +0800
Subject: [PATCH 5/6] rebase

---
 mlir/include/mlir/IR/AffineExprVisitor.h | 38 ++++++++++++------------
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp |  8 ++---
 2 files changed, 23 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index d603cd90b0cf949..2860e73c8f42839 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -76,31 +76,31 @@ class AffineExprVisitorBase {
     auto self = static_cast<SubClass *>(this);
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       return self->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       return self->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       return self->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       return self->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
-      return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
+      return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
     case AffineExprKind::DimId:
-      return self->visitDimExpr(expr.cast<AffineDimExpr>());
+      return self->visitDimExpr(cast<AffineDimExpr>(expr));
     case AffineExprKind::SymbolId:
-      return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
+      return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
     }
     llvm_unreachable("Unknown AffineExpr");
   }
@@ -174,11 +174,11 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
       return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
-      return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
+      return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
     case AffineExprKind::DimId:
-      return self->visitDimExpr(expr.cast<AffineDimExpr>());
+      return self->visitDimExpr(cast<AffineDimExpr>(expr));
     case AffineExprKind::SymbolId:
-      return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
+      return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
     }
     llvm_unreachable("Unknown AffineExpr");
   }
@@ -205,41 +205,41 @@ class AffineExprVisitor<SubClass, LogicalResult>
     auto self = static_cast<SubClass *>(this);
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       if (failed(walkOperandsPostOrder(binOpExpr)))
         return failure();
       return self->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       if (failed(walkOperandsPostOrder(binOpExpr)))
         return failure();
       return self->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       if (failed(walkOperandsPostOrder(binOpExpr)))
         return failure();
       return self->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       if (failed(walkOperandsPostOrder(binOpExpr)))
         return failure();
       return self->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
-      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+      auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
       if (failed(walkOperandsPostOrder(binOpExpr)))
         return failure();
       return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
-      return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
+      return self->visitConstantExpr(cast<AffineConstantExpr>(expr));
     case AffineExprKind::DimId:
-      return self->visitDimExpr(expr.cast<AffineDimExpr>());
+      return self->visitDimExpr(cast<AffineDimExpr>(expr));
     case AffineExprKind::SymbolId:
-      return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
+      return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr));
     }
     llvm_unreachable("Unknown AffineExpr");
   }
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c134407d2c50091..6c9ff98d210d073 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -722,11 +722,11 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
                 ArrayRef<std::optional<int64_t>> constUpperBounds,
                 bool isUpper) {
   // Handle divs and mods.
-  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+  if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
     // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
     // can compute an upper bound.
     if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
       if (!rhsConst || rhsConst.getValue() < 1)
         return std::nullopt;
       std::optional<int64_t> bound =
@@ -737,7 +737,7 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
       return mlir::floorDiv(*bound, rhsConst.getValue());
     }
     if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
       if (rhsConst && rhsConst.getValue() >= 1) {
         std::optional<int64_t> bound =
             getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
@@ -752,7 +752,7 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
       // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
       // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
       // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
       if (rhsConst && rhsConst.getValue() >= 1) {
         int64_t rhsConstVal = rhsConst.getValue();
         std::optional<int64_t> lb =

>From 0309e94917b9e170bdc617838c4067e931247ccc Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sun, 19 Nov 2023 00:45:25 +0800
Subject: [PATCH 6/6] rebase

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 96 ------------------------
 1 file changed, 96 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 6c9ff98d210d073..d22a7539fb75018 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -708,102 +708,6 @@ static std::optional<int64_t> getUpperBound(Value iv) {
   return forOp.getConstantUpperBound() - 1;
 }
 
-/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
-/// the constant lower and upper bounds for its inputs provided in
-/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
-/// bound can't be computed. This method only handles simple sum of product
-/// expressions (w.r.t constant coefficients) so as to not depend on anything
-/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
-/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
-/// semi-affine ones will lead std::nullopt being returned.
-static std::optional<int64_t>
-getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
-                ArrayRef<std::optional<int64_t>> constLowerBounds,
-                ArrayRef<std::optional<int64_t>> constUpperBounds,
-                bool isUpper) {
-  // Handle divs and mods.
-  if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
-    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
-    // can compute an upper bound.
-    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
-      auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
-      if (!rhsConst || rhsConst.getValue() < 1)
-        return std::nullopt;
-      std::optional<int64_t> bound =
-          getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                          constLowerBounds, constUpperBounds, isUpper);
-      if (!bound)
-        return std::nullopt;
-      return mlir::floorDiv(*bound, rhsConst.getValue());
-    }
-    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
-      auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
-      if (rhsConst && rhsConst.getValue() >= 1) {
-        std::optional<int64_t> bound =
-            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                            constLowerBounds, constUpperBounds, isUpper);
-        if (!bound)
-          return std::nullopt;
-        return mlir::ceilDiv(*bound, rhsConst.getValue());
-      }
-      return std::nullopt;
-    }
-    if (binOpExpr.getKind() == AffineExprKind::Mod) {
-      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
-      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
-      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
-      auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
-      if (rhsConst && rhsConst.getValue() >= 1) {
-        int64_t rhsConstVal = rhsConst.getValue();
-        std::optional<int64_t> lb =
-            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                            constLowerBounds, constUpperBounds,
-                            /*isUpper=*/false);
-        std::optional<int64_t> ub =
-            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                            constLowerBounds, constUpperBounds, isUpper);
-        if (ub && lb &&
-            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
-          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
-        return isUpper ? rhsConstVal - 1 : 0;
-      }
-    }
-  }
-
-  // Flatten the expression.
-  SimpleAffineExprFlattener flattener(numDims, numSymbols);
-  auto flattenResult = flattener.walkPostOrder(expr);
-  // has poison expression
-  if (failed(flattenResult))
-    return std::nullopt;
-  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
-  // TODO: Handle local variables. We can get hold of flattener.localExprs and
-  // get bound on the local expr recursively.
-  if (flattener.numLocals > 0)
-    return std::nullopt;
-  int64_t bound = 0;
-  // Substitute the constant lower or upper bound for the dimensional or
-  // symbolic input depending on `isUpper` to determine the bound.
-  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
-    if (flattenedExpr[i] > 0) {
-      const std::optional<int64_t> &constBound =
-          isUpper ? constUpperBounds[i] : constLowerBounds[i];
-      if (!constBound)
-        return std::nullopt;
-      bound += *constBound * flattenedExpr[i];
-    } else if (flattenedExpr[i] < 0) {
-      const std::optional<int64_t> &constBound =
-          isUpper ? constLowerBounds[i] : constUpperBounds[i];
-      if (!constBound)
-        return std::nullopt;
-      bound += *constBound * flattenedExpr[i];
-    }
-  }
-  // Constant term.
-  bound += flattenedExpr.back();
-  return bound;
-}
-
 /// Determine a constant upper bound for `expr` if one exists while exploiting
 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
 /// is guaranteed to be less than or equal to it.



More information about the Mlir-commits mailing list