[Mlir-commits] [mlir] [mlir][affine] remove divide zero check when simplifer affineMap (#64622) (PR #68519)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 2 07:24:22 PDT 2023
https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/68519
>From dad2421a98d78f40efd23efc8b9eb9ee7a97cb05 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sun, 8 Oct 2023 18:31:59 +0800
Subject: [PATCH 1/4] [mlir][affine] remove divide zero check when simplifer
affineMap (#64622)
when affineApplyOp has poison semantics we should not fold the op, but also not crash
---
mlir/include/mlir/IR/AffineExprVisitor.h | 163 +++++++++++++-----
mlir/include/mlir/IR/AffineMap.h | 9 +-
.../Analysis/FlatLinearValueConstraints.cpp | 3 +-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 101 ++++++++++-
mlir/lib/Dialect/Affine/IR/CMakeLists.txt | 3 +-
mlir/lib/IR/AffineExpr.cpp | 59 ++++---
mlir/lib/IR/AffineMap.cpp | 54 ++++--
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 17 +-
mlir/test/Dialect/Affine/constant-fold.mlir | 19 ++
9 files changed, 334 insertions(+), 94 deletions(-)
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index f6216614c2238e1..7a3616f024d0a9e 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -14,6 +14,7 @@
#define MLIR_IR_AFFINEEXPRVISITOR_H
#include "mlir/IR/AffineExpr.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
@@ -65,8 +66,80 @@ namespace mlir {
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
+template <typename SubClass, typename RetTy>
+class AffineExprVisitorBase {
+public:
+ // Function to visit an AffineExpr.
+ RetTy visit(AffineExpr expr) {
+ static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
+ "Must instantiate with a derived type of AffineExprVisitor");
+ switch (expr.getKind()) {
+ case AffineExprKind::Add: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+ }
+ case AffineExprKind::Mul: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+ }
+ case AffineExprKind::Mod: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+ }
+ case AffineExprKind::FloorDiv: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+ }
+ case AffineExprKind::CeilDiv: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+ }
+ case AffineExprKind::Constant:
+ return static_cast<SubClass *>(this)->visitConstantExpr(
+ expr.cast<AffineConstantExpr>());
+ case AffineExprKind::DimId:
+ return static_cast<SubClass *>(this)->visitDimExpr(
+ expr.cast<AffineDimExpr>());
+ case AffineExprKind::SymbolId:
+ return static_cast<SubClass *>(this)->visitSymbolExpr(
+ expr.cast<AffineSymbolExpr>());
+ }
+ llvm_unreachable("Unknown AffineExpr");
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Visitation functions... these functions provide default fallbacks in case
+ // the user does not specify what to do for a particular instruction type.
+ // The default behavior is to generalize the instruction type to its subtype
+ // and try visiting the subtype. All of this should be inlined perfectly,
+ // because there are no virtual functions to get in the way.
+ //
+
+ // Default visit methods. Note that the default op-specific binary op visit
+ // methods call the general visitAffineBinaryOpExpr visit method.
+ RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
+ RetTy visitAddExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitMulExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitModExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ }
+ RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
+ RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
+ RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
+};
+
template <typename SubClass, typename RetTy = void>
-class AffineExprVisitor {
+class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
@@ -113,29 +186,54 @@ class AffineExprVisitor {
}
}
- // Function to visit an AffineExpr.
- RetTy visit(AffineExpr expr) {
+private:
+ // Walk the operands - each operand is itself walked in post order.
+ RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+ walkPostOrder(expr.getLHS());
+ walkPostOrder(expr.getRHS());
+ }
+};
+
+template <typename SubClass>
+class AffineExprVisitor<SubClass, LogicalResult>
+ : public AffineExprVisitorBase<SubClass, LogicalResult> {
+ //===--------------------------------------------------------------------===//
+ // Interface code - This is the public interface of the AffineExprVisitor
+ // that you use to visit affine expressions...
+public:
+ // Function to walk an AffineExpr (in post order).
+ LogicalResult walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
@@ -151,41 +249,14 @@ class AffineExprVisitor {
llvm_unreachable("Unknown AffineExpr");
}
- //===--------------------------------------------------------------------===//
- // Visitation functions... these functions provide default fallbacks in case
- // the user does not specify what to do for a particular instruction type.
- // The default behavior is to generalize the instruction type to its subtype
- // and try visiting the subtype. All of this should be inlined perfectly,
- // because there are no virtual functions to get in the way.
- //
-
- // Default visit methods. Note that the default op-specific binary op visit
- // methods call the general visitAffineBinaryOpExpr visit method.
- RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
- RetTy visitAddExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitMulExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitModExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
- return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
- }
- RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
- RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
- RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
-
private:
// Walk the operands - each operand is itself walked in post order.
- RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
- walkPostOrder(expr.getLHS());
- walkPostOrder(expr.getRHS());
+ LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+ if (failed(walkPostOrder(expr.getLHS())))
+ return failure();
+ if (failed(walkPostOrder(expr.getRHS())))
+ return failure();
+ return success();
}
};
@@ -246,7 +317,7 @@ class AffineExprVisitor {
// expressions are mapped to the same local identifier (same column position in
// 'localVarCst').
class SimpleAffineExprFlattener
- : public AffineExprVisitor<SimpleAffineExprFlattener> {
+ : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
public:
// Flattend expression layout: [dims, symbols, locals, constant]
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
@@ -275,13 +346,13 @@ class SimpleAffineExprFlattener
virtual ~SimpleAffineExprFlattener() = default;
// Visitor method overrides.
- void visitMulExpr(AffineBinaryOpExpr expr);
- void visitAddExpr(AffineBinaryOpExpr expr);
- void visitDimExpr(AffineDimExpr expr);
- void visitSymbolExpr(AffineSymbolExpr expr);
- void visitConstantExpr(AffineConstantExpr expr);
- void visitCeilDivExpr(AffineBinaryOpExpr expr);
- void visitFloorDivExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitDimExpr(AffineDimExpr expr);
+ LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
+ LogicalResult visitConstantExpr(AffineConstantExpr expr);
+ LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
//
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
@@ -289,7 +360,7 @@ class SimpleAffineExprFlattener
// A mod expression "expr mod c" is thus flattened by introducing a new local
// variable q (= expr floordiv c), such that expr mod c is replaced with
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
- void visitModExpr(AffineBinaryOpExpr expr);
+ LogicalResult visitModExpr(AffineBinaryOpExpr expr);
protected:
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -328,7 +399,7 @@ class SimpleAffineExprFlattener
//
// A ceildiv is similarly flattened:
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
- void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
+ LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
int findLocalId(AffineExpr localExpr);
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..78a0ef57e15c6d8 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -298,7 +298,8 @@ class AffineMap {
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible.
LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
- SmallVectorImpl<Attribute> &results) const;
+ SmallVectorImpl<Attribute> &results,
+ bool *hasPoison = nullptr) const;
/// Propagates the constant operands into this affine map. Operands are
/// allowed to be null, at which point they are treated as non-constant. This
@@ -306,9 +307,9 @@ class AffineMap {
/// which may be equal to the old map if no folding happened. If `results` is
/// provided and if all expressions in the map were folded to constants,
/// `results` will contain the values of these constants.
- AffineMap
- partialConstantFold(ArrayRef<Attribute> operandConstants,
- SmallVectorImpl<int64_t> *results = nullptr) const;
+ AffineMap partialConstantFold(ArrayRef<Attribute> operandConstants,
+ SmallVectorImpl<int64_t> *results = nullptr,
+ bool *hasPoison = nullptr) const;
/// Returns the AffineMap resulting from composing `this` with `map`.
/// The resulting AffineMap has as many AffineDimExpr as `map` and as many
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 382d05f3b2d4851..b8ff4d82be697b1 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -86,7 +86,8 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
if (!expr.isPureAffine())
return failure();
- flattener.walkPostOrder(expr);
+ auto flattenResult = flattener.walkPostOrder(expr);
+ assert(succeeded(flattenResult) && "affine expr containts poison expr");
}
assert(flattener.operandExprStack.size() == exprs.size());
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ba4285bd52394f3..6970dc2180b49f2 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
@@ -226,6 +227,8 @@ void AffineDialect::initialize() {
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
+ if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+ return builder.create<ub::PoisonOp>(loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
@@ -580,7 +583,12 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
// Otherwise, default to folding the map.
SmallVector<Attribute, 1> result;
- if (failed(map.constantFold(adaptor.getMapOperands(), result)))
+ bool hasPoison = false;
+ auto foldResult =
+ map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
+ if (hasPoison)
+ return ub::PoisonAttr::get(getContext());
+ if (failed(foldResult))
return {};
return result[0];
}
@@ -700,6 +708,94 @@ static std::optional<int64_t> getUpperBound(Value iv) {
return forOp.getConstantUpperBound() - 1;
}
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
+/// bound can't be computed. This method only handles simple sum of product
+/// expressions (w.r.t constant coefficients) so as to not depend on anything
+/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
+/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
+/// semi-affine ones will lead std::nullopt being returned.
+static std::optional<int64_t>
+getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+ ArrayRef<std::optional<int64_t>> constLowerBounds,
+ ArrayRef<std::optional<int64_t>> constUpperBounds,
+ bool isUpper) {
+ // Handle divs and mods.
+ if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+ // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+ // can compute an upper bound.
+ if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!rhsConst || rhsConst.getValue() < 1)
+ return std::nullopt;
+ auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (!bound)
+ return std::nullopt;
+ return mlir::floorDiv(*bound, rhsConst.getValue());
+ }
+ if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (rhsConst && rhsConst.getValue() >= 1) {
+ auto bound =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (!bound)
+ return std::nullopt;
+ return mlir::ceilDiv(*bound, rhsConst.getValue());
+ }
+ return std::nullopt;
+ }
+ if (binOpExpr.getKind() == AffineExprKind::Mod) {
+ // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+ // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+ // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (rhsConst && rhsConst.getValue() >= 1) {
+ int64_t rhsConstVal = rhsConst.getValue();
+ auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds,
+ /*isUpper=*/false);
+ auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (ub && lb &&
+ floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+ return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+ return isUpper ? rhsConstVal - 1 : 0;
+ }
+ }
+ }
+ // Flatten the expression.
+ SimpleAffineExprFlattener flattener(numDims, numSymbols);
+ auto flattenResult = flattener.walkPostOrder(expr);
+ assert(succeeded(flattenResult) && "affine expr containts poison expr");
+ ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+ // TODO: Handle local variables. We can get hold of flattener.localExprs and
+ // get bound on the local expr recursively.
+ if (flattener.numLocals > 0)
+ return std::nullopt;
+ int64_t bound = 0;
+ // Substitute the constant lower or upper bound for the dimensional or
+ // symbolic input depending on `isUpper` to determine the bound.
+ for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+ if (flattenedExpr[i] > 0) {
+ auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+ if (!constBound)
+ return std::nullopt;
+ bound += *constBound * flattenedExpr[i];
+ } else if (flattenedExpr[i] < 0) {
+ auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+ if (!constBound)
+ return std::nullopt;
+ bound += *constBound * flattenedExpr[i];
+ }
+ }
+ // Constant term.
+ bound += flattenedExpr.back();
+ return bound;
+}
+
/// Determine a constant upper bound for `expr` if one exists while exploiting
/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
/// is guaranteed to be less than or equal to it.
@@ -3379,7 +3475,8 @@ static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
return failure();
SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
- flattener.walkPostOrder(resultExpr);
+ auto flattenResult = flattener.walkPostOrder(resultExpr);
+ assert(succeeded(flattenResult) && "affine expr containts poison expr");
// Fail if the flattened expression has local variables.
if (flattener.operandExprStack.back().size() !=
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 89ea3128b0e743f..10df928da8233f0 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -16,8 +16,9 @@ add_mlir_dialect_library(MLIRAffineDialect
MLIRDialectUtils
MLIRIR
MLIRLoopLikeInterface
- MLIRMemRefDialect
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
MLIRValueBoundsOpInterface
+ MLIRMemRefDialect
+ MLIRUBDialect
)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 4b7ec89a842bd65..5aff41b48b724d1 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -511,7 +511,6 @@ unsigned AffineSymbolExpr::getPosition() const {
AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
- ;
}
AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
@@ -1135,7 +1134,7 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
// introduce a local variable p (= expr * symbolic_expr), and the affine
// expression expr * symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
SmallVector<int64_t, 8> rhs = operandExprStack.back();
operandExprStack.pop_back();
@@ -1151,7 +1150,7 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
addLocalVariableSemiAffine(a * b, lhs, lhs.size());
- return;
+ return success();
}
// Get the RHS constant.
@@ -1159,9 +1158,10 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
for (unsigned i = 0, e = lhs.size(); i < e; i++) {
lhs[i] *= rhsConst;
}
+ return success();
}
-void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
const auto &rhs = operandExprStack.back();
auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -1172,6 +1172,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
}
// Pop off the RHS.
operandExprStack.pop_back();
+ return success();
}
//
@@ -1184,7 +1185,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
// introduce a local variable m (= expr mod symbolic_expr), and the affine
// expression expr mod symbolic_expr is added to `localExprs`.
-void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
SmallVector<int64_t, 8> rhs = operandExprStack.back();
@@ -1202,13 +1203,12 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
- return;
+ return success();
}
int64_t rhsConst = rhs[getConstantIndex()];
- // TODO: handle modulo by zero case when this issue is fixed
- // at the other places in the IR.
- assert(rhsConst > 0 && "RHS constant has to be positive");
+ if (rhsConst <= 0)
+ return failure();
// Check if the LHS expression is a multiple of modulo factor.
unsigned i, e;
@@ -1218,7 +1218,7 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
// If yes, modulo expression here simplifies to zero.
if (i == lhs.size()) {
std::fill(lhs.begin(), lhs.end(), 0);
- return;
+ return success();
}
// Add a local variable for the quotient, i.e., expr % c is replaced by
@@ -1250,33 +1250,41 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
// Reuse the existing local id.
lhs[getLocalVarStartIndex() + loc] = -rhsConst;
}
+ return success();
}
-void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
- visitDivExpr(expr, /*isCeil=*/true);
+LogicalResult
+SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr, /*isCeil=*/true);
}
-void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
- visitDivExpr(expr, /*isCeil=*/false);
+LogicalResult
+SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr, /*isCeil=*/false);
}
-void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
+LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
assert(expr.getPosition() < numDims && "Inconsistent number of dims");
eq[getDimStartIndex() + expr.getPosition()] = 1;
+ return success();
}
-void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
eq[getSymbolStartIndex() + expr.getPosition()] = 1;
+ return success();
}
-void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
+LogicalResult
+SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getConstantIndex()] = expr.getValue();
+ return success();
}
void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
@@ -1307,8 +1315,8 @@ void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
// `localExprs`.
-void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
- bool isCeil) {
+LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
+ bool isCeil) {
assert(operandExprStack.size() >= 2);
MLIRContext *context = expr.getContext();
@@ -1326,14 +1334,13 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
- return;
+ return success();
}
// This is a pure affine expr; the RHS is a positive constant.
int64_t rhsConst = rhs[getConstantIndex()];
- // TODO: handle division by zero at the same time the issue is
- // fixed at other places.
- assert(rhsConst > 0 && "RHS constant has to be positive");
+ if (rhsConst <= 0)
+ return failure();
// Simplify the floordiv, ceildiv if possible by canceling out the greatest
// common divisors of the numerator and denominator.
@@ -1349,7 +1356,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
// If the divisor becomes 1, the updated LHS is the result. (The
// divisor can't be negative since rhsConst is positive).
if (divisor == 1)
- return;
+ return success();
// If the divisor cannot be simplified to one, we will have to retain
// the ceil/floor expr (simplified up until here). Add an existential
@@ -1379,6 +1386,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
else
lhs[getLocalVarStartIndex() + loc] = 1;
+ return success();
}
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
@@ -1419,7 +1427,8 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
expr = simplifySemiAffine(expr);
SimpleAffineExprFlattener flattener(numDims, numSymbols);
- flattener.walkPostOrder(expr);
+ if (failed(flattener.walkPostOrder(expr)))
+ return expr;
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
if (!expr.isPureAffine() &&
expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3bd1181b6c7bbd8..36ebf5effc58872 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/AffineMap.h"
#include "AffineMapDetail.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -46,6 +47,8 @@ class AffineExprConstantFolder {
return nullptr;
}
+ bool hasPoison() const { return hasPoison_; }
+
private:
std::optional<int64_t> constantFoldImpl(AffineExpr expr) {
switch (expr.getKind()) {
@@ -57,13 +60,34 @@ class AffineExprConstantFolder {
expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
case AffineExprKind::Mod:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
+ expr,
+ [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ if (rhs < 1) {
+ hasPoison_ = true;
+ return std::nullopt;
+ }
+ return mod(lhs, rhs);
+ });
case AffineExprKind::FloorDiv:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
+ expr,
+ [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ if (0 == rhs) {
+ hasPoison_ = true;
+ return std::nullopt;
+ }
+ return floorDiv(lhs, rhs);
+ });
case AffineExprKind::CeilDiv:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
+ expr,
+ [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ if (0 == rhs) {
+ hasPoison_ = true;
+ return std::nullopt;
+ }
+ return ceilDiv(lhs, rhs);
+ });
case AffineExprKind::Constant:
return expr.cast<AffineConstantExpr>().getValue();
case AffineExprKind::DimId:
@@ -82,8 +106,9 @@ class AffineExprConstantFolder {
}
// TODO: Change these to operate on APInts too.
- std::optional<int64_t> constantFoldBinExpr(AffineExpr expr,
- int64_t (*op)(int64_t, int64_t)) {
+ std::optional<int64_t> constantFoldBinExpr(
+ AffineExpr expr,
+ llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
@@ -95,6 +120,7 @@ class AffineExprConstantFolder {
unsigned numDims;
// The constant valued operands used to evaluate this AffineExpr.
ArrayRef<Attribute> operandConsts;
+ bool hasPoison_{false};
};
} // namespace
@@ -375,12 +401,12 @@ std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const {
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
-LogicalResult
-AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
- SmallVectorImpl<Attribute> &results) const {
+LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
+ SmallVectorImpl<Attribute> &results,
+ bool *hasPoison) const {
// Attempt partial folding.
SmallVector<int64_t, 2> integers;
- partialConstantFold(operandConstants, &integers);
+ partialConstantFold(operandConstants, &integers, hasPoison);
// If all expressions folded to a constant, populate results with attributes
// containing those constants.
@@ -394,9 +420,9 @@ AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
return success();
}
-AffineMap
-AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
- SmallVectorImpl<int64_t> *results) const {
+AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
+ SmallVectorImpl<int64_t> *results,
+ bool *hasPoison) const {
assert(getNumInputs() == operandConstants.size());
// Fold each of the result expressions.
@@ -406,6 +432,10 @@ AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
for (auto expr : getResults()) {
auto folded = exprFolder.constantFold(expr);
+ if (exprFolder.hasPoison() && hasPoison) {
+ *hasPoison = true;
+ return {};
+ }
// If did not fold to a constant, keep the original expression, and clear
// the integer results vector.
if (folded) {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 644113058bdc1cc..7a865333e78788f 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -268,7 +268,8 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
llvm::raw_string_ostream ostream(buffer);
if (useBytecode) {
if (failed(writeBytecodeToFile(op, ostream))) {
- op->emitOpError() << "failed to write bytecode, cannot verify round-trip.\n";
+ op->emitOpError()
+ << "failed to write bytecode, cannot verify round-trip.\n";
return failure();
}
} else {
@@ -281,7 +282,8 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
roundtripModule =
parseSourceString<Operation *>(ostream.str(), parseConfig);
if (!roundtripModule) {
- op->emitOpError() << "failed to parse bytecode back, cannot verify round-trip.\n";
+ op->emitOpError()
+ << "failed to parse bytecode back, cannot verify round-trip.\n";
return failure();
}
}
@@ -300,7 +302,8 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
}
if (reference != roundtrip) {
// TODO implement a diff.
- return op->emitOpError() << "roundTrip testing roundtripped module differs from reference:\n<<<<<<Reference\n"
+ return op->emitOpError() << "roundTrip testing roundtripped module differs "
+ "from reference:\n<<<<<<Reference\n"
<< reference << "\n=====\n"
<< roundtrip << "\n>>>>>roundtripped\n";
}
@@ -409,6 +412,14 @@ static LogicalResult processBuffer(raw_ostream &os,
// Create a context just for the current buffer. Disable threading on creation
// since we'll inject the thread-pool separately.
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
+
+ for (const auto allocFunc :
+ llvm::map_range(registry.getDialectNames(), [®istry](auto name) {
+ return registry.getDialectAllocator(name);
+ })) {
+ allocFunc(&context);
+ }
+
if (threadPool)
context.setThreadPool(*threadPool);
diff --git a/mlir/test/Dialect/Affine/constant-fold.mlir b/mlir/test/Dialect/Affine/constant-fold.mlir
index cdce39855acdff4..f99056ab39589e0 100644
--- a/mlir/test/Dialect/Affine/constant-fold.mlir
+++ b/mlir/test/Dialect/Affine/constant-fold.mlir
@@ -60,3 +60,22 @@ func.func @affine_min(%variable: index) -> (index, index) {
// CHECK: return %[[r]], %[[C44]]
return %0, %1 : index, index
}
+
+// -----
+
+func.func @affine_apply_poison_division_zero() {
+ %c0 = arith.constant 0 : index
+ %0 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%c0)[%c0]
+ %1 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%c0)[%c0]
+ %2 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)>(%c0)[%c0]
+ %alloc = memref.alloc(%0, %1, %2) : memref<?x?x?xi1>
+ %3 = affine.load %alloc[%0, %1, %2] : memref<?x?x?xi1>
+ affine.store %3, %alloc[%0, %1, %2] : memref<?x?x?xi1>
+ return
+}
+
+// CHECK-NOT: affine.apply
+// CHECK: %[[poison:.*]] = ub.poison : index
+// CHECK-NEXT: %[[alloc:.*]] = memref.alloc(%[[poison]], %[[poison]], %[[poison]])
+// CHECK-NEXT: %[[load:.*]] = affine.load %[[alloc]][%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1>
+// CHECK-NEXT: affine.store %[[load]], %alloc[%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1>
>From bfee534d2cf0fbe9aa384d39f160c76938ea2aad Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Mon, 30 Oct 2023 23:48:35 +0800
Subject: [PATCH 2/4] refine
---
mlir/include/mlir/IR/AffineExpr.h | 16 ++++++++++
mlir/include/mlir/IR/AffineExprVisitor.h | 20 ++++++------
.../Analysis/FlatLinearValueConstraints.cpp | 5 +--
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 31 ++++++++++++-------
mlir/lib/IR/AffineExpr.cpp | 6 +++-
mlir/lib/IR/AffineMap.cpp | 4 +--
mlir/test/Dialect/Affine/constant-fold.mlir | 11 ++++---
7 files changed, 60 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 69e02c94ef2708d..051a6df2a4bfd5e 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -210,6 +210,10 @@ class AffineBinaryOpExpr : public AffineExpr {
/* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
AffineExpr getLHS() const;
AffineExpr getRHS() const;
+
+ constexpr static bool classof(const AffineExpr *expr) {
+ return expr->isa<AffineBinaryOpExpr>();
+ }
};
/// A dimensional identifier appearing in an affine expression.
@@ -218,6 +222,10 @@ class AffineDimExpr : public AffineExpr {
using ImplType = detail::AffineDimExprStorage;
/* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
unsigned getPosition() const;
+
+ constexpr static bool classof(const AffineExpr *expr) {
+ return expr->isa<AffineDimExpr>();
+ }
};
/// A symbolic identifier appearing in an affine expression.
@@ -226,6 +234,10 @@ class AffineSymbolExpr : public AffineExpr {
using ImplType = detail::AffineDimExprStorage;
/* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
unsigned getPosition() const;
+
+ constexpr static bool classof(const AffineExpr *expr) {
+ return expr->isa<AffineSymbolExpr>();
+ }
};
/// An integer constant appearing in affine expression.
@@ -234,6 +246,10 @@ class AffineConstantExpr : public AffineExpr {
using ImplType = detail::AffineConstantExprStorage;
/* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
int64_t getValue() const;
+
+ constexpr static bool classof(const AffineExpr *expr) {
+ return expr->isa<AffineConstantExpr>();
+ }
};
/// Make AffineExpr hashable.
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 7a3616f024d0a9e..6eed020ba3ee3a0 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -73,36 +73,34 @@ class AffineExprVisitorBase {
RetTy visit(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
+ auto self = static_cast<SubClass *>(this);
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+ return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+ return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+ return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+ return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+ return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
- return static_cast<SubClass *>(this)->visitConstantExpr(
- expr.cast<AffineConstantExpr>());
+ return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
case AffineExprKind::DimId:
- return static_cast<SubClass *>(this)->visitDimExpr(
- expr.cast<AffineDimExpr>());
+ return self->visitDimExpr(expr.cast<AffineDimExpr>());
case AffineExprKind::SymbolId:
- return static_cast<SubClass *>(this)->visitSymbolExpr(
- expr.cast<AffineSymbolExpr>());
+ return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
}
llvm_unreachable("Unknown AffineExpr");
}
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index b8ff4d82be697b1..6c3d174f0f109be 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -85,9 +85,10 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
for (auto expr : exprs) {
if (!expr.isPureAffine())
return failure();
-
+ // has poison expression
auto flattenResult = flattener.walkPostOrder(expr);
- assert(succeeded(flattenResult) && "affine expr containts poison expr");
+ if (failed(flattenResult))
+ return failure();
}
assert(flattener.operandExprStack.size() == exprs.size());
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 6970dc2180b49f2..bb9906d0eea7921 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -729,16 +729,17 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
if (!rhsConst || rhsConst.getValue() < 1)
return std::nullopt;
- auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
- constLowerBounds, constUpperBounds, isUpper);
+ std::optional<int64_t> bound =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
if (!bound)
return std::nullopt;
return mlir::floorDiv(*bound, rhsConst.getValue());
}
if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
- auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
if (rhsConst && rhsConst.getValue() >= 1) {
- auto bound =
+ std::optional<int64_t> bound =
getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
constLowerBounds, constUpperBounds, isUpper);
if (!bound)
@@ -751,14 +752,16 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
// lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
// bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
// (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
- auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
if (rhsConst && rhsConst.getValue() >= 1) {
int64_t rhsConstVal = rhsConst.getValue();
- auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
- constLowerBounds, constUpperBounds,
- /*isUpper=*/false);
- auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
- constLowerBounds, constUpperBounds, isUpper);
+ std::optional<int64_t> lb =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds,
+ /*isUpper=*/false);
+ std::optional<int64_t> ub =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
if (ub && lb &&
floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
@@ -766,10 +769,13 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
}
}
}
+
// Flatten the expression.
SimpleAffineExprFlattener flattener(numDims, numSymbols);
auto flattenResult = flattener.walkPostOrder(expr);
- assert(succeeded(flattenResult) && "affine expr containts poison expr");
+ // has poison expression
+ if (failed(flattenResult))
+ return std::nullopt;
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
// TODO: Handle local variables. We can get hold of flattener.localExprs and
// get bound on the local expr recursively.
@@ -3476,7 +3482,8 @@ static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
auto flattenResult = flattener.walkPostOrder(resultExpr);
- assert(succeeded(flattenResult) && "affine expr containts poison expr");
+ if (failed(flattenResult))
+ return failure();
// Fail if the flattened expression has local variables.
if (flattener.operandExprStack.back().size() !=
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 5aff41b48b724d1..863ce590921320b 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1427,6 +1427,7 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
expr = simplifySemiAffine(expr);
SimpleAffineExprFlattener flattener(numDims, numSymbols);
+ // has poison expression
if (failed(flattener.walkPostOrder(expr)))
return expr;
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
@@ -1501,7 +1502,10 @@ std::optional<int64_t> mlir::getBoundForAffineExpr(
}
// Flatten the expression.
SimpleAffineExprFlattener flattener(numDims, numSymbols);
- flattener.walkPostOrder(expr);
+ auto simpleResult = flattener.walkPostOrder(expr);
+ // has poison expression
+ if (failed(simpleResult))
+ return std::nullopt;
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
// TODO: Handle local variables. We can get hold of flattener.localExprs and
// get bound on the local expr recursively.
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 36ebf5effc58872..0c3b6294186f265 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -72,7 +72,7 @@ class AffineExprConstantFolder {
return constantFoldBinExpr(
expr,
[expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
- if (0 == rhs) {
+ if (rhs == 0) {
hasPoison_ = true;
return std::nullopt;
}
@@ -82,7 +82,7 @@ class AffineExprConstantFolder {
return constantFoldBinExpr(
expr,
[expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
- if (0 == rhs) {
+ if (rhs == 0) {
hasPoison_ = true;
return std::nullopt;
}
diff --git a/mlir/test/Dialect/Affine/constant-fold.mlir b/mlir/test/Dialect/Affine/constant-fold.mlir
index f99056ab39589e0..289b3d9c49ea6cf 100644
--- a/mlir/test/Dialect/Affine/constant-fold.mlir
+++ b/mlir/test/Dialect/Affine/constant-fold.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-constant-fold -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-constant-fold -test-constant-fold -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @affine_apply
func.func @affine_apply(%variable : index) -> (index, index, index) {
@@ -64,10 +64,11 @@ func.func @affine_min(%variable: index) -> (index, index) {
// -----
func.func @affine_apply_poison_division_zero() {
- %c0 = arith.constant 0 : index
- %0 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%c0)[%c0]
- %1 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%c0)[%c0]
- %2 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)>(%c0)[%c0]
+ %c16 = arith.constant 16 : index
+ %zero = arith.subi %c16, %c16 : index
+ %0 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%c16)[%zero]
+ %1 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%c16)[%zero]
+ %2 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)>(%c16)[%zero]
%alloc = memref.alloc(%0, %1, %2) : memref<?x?x?xi1>
%3 = affine.load %alloc[%0, %1, %2] : memref<?x?x?xi1>
affine.store %3, %alloc[%0, %1, %2] : memref<?x?x?xi1>
>From 3164e522319fdb1fa8446378984f63d182661faa Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Wed, 1 Nov 2023 14:08:10 +0800
Subject: [PATCH 3/4] refine relace Affine*Expr.dyn_cast to
llvm::dyn_cast<Affine*Expr>
---
mlir/include/mlir/IR/AffineExpr.h | 40 +++++++++++++++++++-----
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 18 +++++------
2 files changed, 41 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 051a6df2a4bfd5e..7f6f7f508d19cfb 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -194,6 +194,8 @@ class AffineExpr {
reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
+ ImplType *getImpl() const { return expr; }
+
protected:
ImplType *expr{nullptr};
};
@@ -211,8 +213,8 @@ class AffineBinaryOpExpr : public AffineExpr {
AffineExpr getLHS() const;
AffineExpr getRHS() const;
- constexpr static bool classof(const AffineExpr *expr) {
- return expr->isa<AffineBinaryOpExpr>();
+ constexpr static bool classof(const AffineExpr expr) {
+ return expr.isa<AffineBinaryOpExpr>();
}
};
@@ -223,8 +225,8 @@ class AffineDimExpr : public AffineExpr {
/* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
unsigned getPosition() const;
- constexpr static bool classof(const AffineExpr *expr) {
- return expr->isa<AffineDimExpr>();
+ constexpr static bool classof(const AffineExpr expr) {
+ return expr.isa<AffineDimExpr>();
}
};
@@ -235,8 +237,8 @@ class AffineSymbolExpr : public AffineExpr {
/* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
unsigned getPosition() const;
- constexpr static bool classof(const AffineExpr *expr) {
- return expr->isa<AffineSymbolExpr>();
+ constexpr static bool classof(const AffineExpr expr) {
+ return expr.isa<AffineSymbolExpr>();
}
};
@@ -247,8 +249,8 @@ class AffineConstantExpr : public AffineExpr {
/* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
int64_t getValue() const;
- constexpr static bool classof(const AffineExpr *expr) {
- return expr->isa<AffineConstantExpr>();
+ constexpr static bool classof(const AffineExpr expr) {
+ return expr.isa<AffineConstantExpr>();
}
};
@@ -406,6 +408,28 @@ struct DenseMapInfo<mlir::AffineExpr> {
}
};
+/// Add support for llvm style casts. We provide a cast between To and From if
+/// From is mlir::AffineExpr or derives from it.
+template <typename To, typename From>
+struct CastInfo<To, From,
+ std::enable_if_t<std::is_same_v<mlir::AffineExpr,
+ std::remove_const_t<From>> ||
+ std::is_base_of_v<mlir::AffineExpr, From>>>
+ : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static inline bool isPossible(mlir::AffineExpr expr) {
+ /// Return a constant true instead of a dynamic true when casting to self or
+ /// up the hierarchy.
+ if constexpr (std::is_base_of_v<To, From>) {
+ return true;
+ } else {
+ return To::classof(expr);
+ }
+ }
+ static inline To doCast(mlir::AffineExpr expr) { return To(expr.getImpl()); }
+};
+
} // namespace llvm
#endif // MLIR_IR_AFFINEEXPR_H
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bb9906d0eea7921..c796a8b3dacf3d4 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -635,7 +635,7 @@ static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
/// being an affine dim expression or a constant.
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
int64_t k) {
- if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
int64_t constVal = constExpr.getValue();
return constVal >= 0 && constVal < k;
}
@@ -726,7 +726,7 @@ getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
// If the LHS of a floor or ceil is bounded and the RHS is a constant, we
// can compute an upper bound.
if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
- auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
if (!rhsConst || rhsConst.getValue() < 1)
return std::nullopt;
std::optional<int64_t> bound =
@@ -817,7 +817,7 @@ static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
constUpperBounds.push_back(getUpperBound(operand));
}
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
return constExpr.getValue();
return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
@@ -841,7 +841,7 @@ static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
}
std::optional<int64_t> lowerBound;
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
lowerBound = constExpr.getValue();
} else {
lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
@@ -877,7 +877,7 @@ static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
// The `lhs` and `rhs` may be different post construction of simplified expr.
lhs = binExpr.getLHS();
rhs = binExpr.getRHS();
- auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (!rhsConst)
return;
@@ -981,7 +981,7 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
lowerBounds.reserve(map.getNumResults());
upperBounds.reserve(map.getNumResults());
for (AffineExpr e : map.getResults()) {
- if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
lowerBounds.push_back(constExpr.getValue());
upperBounds.push_back(constExpr.getValue());
} else {
@@ -2168,7 +2168,7 @@ static void printBound(AffineMapAttr boundMap,
// Print constant bound.
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
p << constExpr.getValue();
return;
}
@@ -3864,7 +3864,7 @@ std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
out.reserve(rangesValueMap.getNumResults());
for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
auto expr = rangesValueMap.getResult(i);
- auto cst = expr.dyn_cast<AffineConstantExpr>();
+ auto cst = dyn_cast<AffineConstantExpr>(expr);
if (!cst)
return std::nullopt;
out.push_back(cst.getValue());
@@ -4292,7 +4292,7 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
SmallVector<int64_t, 4> steps;
auto stepsMap = stepsMapAttr.getValue();
for (const auto &result : stepsMap.getResults()) {
- auto constExpr = result.dyn_cast<AffineConstantExpr>();
+ auto constExpr = dyn_cast<AffineConstantExpr>(result);
if (!constExpr)
return parser.emitError(parser.getNameLoc(),
"steps must be constant integers");
>From 725c63de1502dd533f8782fb11db50a56ff50845 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Thu, 2 Nov 2023 22:14:33 +0800
Subject: [PATCH 4/4] refine
---
mlir/include/mlir/IR/AffineExprVisitor.h | 40 +++++++++++-------------
1 file changed, 18 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 6eed020ba3ee3a0..1e776187a691daf 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -146,41 +146,39 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
RetTy walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
+ auto self = static_cast<SubClass *>(this);
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
- return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+ return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
- return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+ return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
- return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+ return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
- return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+ return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
- return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+ return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
- return static_cast<SubClass *>(this)->visitConstantExpr(
- expr.cast<AffineConstantExpr>());
+ return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
case AffineExprKind::DimId:
- return static_cast<SubClass *>(this)->visitDimExpr(
- expr.cast<AffineDimExpr>());
+ return self->visitDimExpr(expr.cast<AffineDimExpr>());
case AffineExprKind::SymbolId:
- return static_cast<SubClass *>(this)->visitSymbolExpr(
- expr.cast<AffineSymbolExpr>());
+ return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
}
}
@@ -203,46 +201,44 @@ class AffineExprVisitor<SubClass, LogicalResult>
LogicalResult walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
+ auto self = static_cast<SubClass *>(this);
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
- return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+ return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
- return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+ return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
- return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+ return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
- return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+ return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
if (failed(walkOperandsPostOrder(binOpExpr)))
return failure();
- return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+ return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
- return static_cast<SubClass *>(this)->visitConstantExpr(
- expr.cast<AffineConstantExpr>());
+ return self->visitConstantExpr(expr.cast<AffineConstantExpr>());
case AffineExprKind::DimId:
- return static_cast<SubClass *>(this)->visitDimExpr(
- expr.cast<AffineDimExpr>());
+ return self->visitDimExpr(expr.cast<AffineDimExpr>());
case AffineExprKind::SymbolId:
- return static_cast<SubClass *>(this)->visitSymbolExpr(
- expr.cast<AffineSymbolExpr>());
+ return self->visitSymbolExpr(expr.cast<AffineSymbolExpr>());
}
llvm_unreachable("Unknown AffineExpr");
}
More information about the Mlir-commits
mailing list