[Mlir-commits] [mlir] 3552163 - [mlir] Remove SDBM
Alex Zinenko
llvmlistbot at llvm.org
Tue Jun 29 05:46:34 PDT 2021
Author: Alex Zinenko
Date: 2021-06-29T14:46:26+02:00
New Revision: 355216380b9c11e5d7a16ac20619cf16b1c0151c
URL: https://github.com/llvm/llvm-project/commit/355216380b9c11e5d7a16ac20619cf16b1c0151c
DIFF: https://github.com/llvm/llvm-project/commit/355216380b9c11e5d7a16ac20619cf16b1c0151c.diff
LOG: [mlir] Remove SDBM
This data structure and algorithm collection is no longer in use.
Reviewed By: bondhugula
Differential Revision: https://reviews.llvm.org/D105102
Added:
Modified:
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/CMakeLists.txt
mlir/test/CMakeLists.txt
mlir/test/lit.cfg.py
mlir/test/mlir-opt/commandline.mlir
mlir/unittests/CMakeLists.txt
Removed:
mlir/include/mlir/Dialect/SDBM/SDBM.h
mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
mlir/include/mlir/Dialect/SDBM/SDBMExpr.h
mlir/lib/Dialect/SDBM/CMakeLists.txt
mlir/lib/Dialect/SDBM/SDBM.cpp
mlir/lib/Dialect/SDBM/SDBMDialect.cpp
mlir/lib/Dialect/SDBM/SDBMExpr.cpp
mlir/lib/Dialect/SDBM/SDBMExprDetail.h
mlir/test/SDBM/CMakeLists.txt
mlir/test/SDBM/lit.local.cfg
mlir/test/SDBM/sdbm-api-test.cpp
mlir/unittests/SDBM/CMakeLists.txt
mlir/unittests/SDBM/SDBMTest.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/SDBM/SDBM.h b/mlir/include/mlir/Dialect/SDBM/SDBM.h
deleted file mode 100644
index f4e6be874fdc..000000000000
--- a/mlir/include/mlir/Dialect/SDBM/SDBM.h
+++ /dev/null
@@ -1,197 +0,0 @@
-//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// A striped
diff erence-bound matrix (SDBM) is a set in Z^N (or R^N) defined
-// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SDBM_SDBM_H
-#define MLIR_DIALECT_SDBM_SDBM_H
-
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/DenseMap.h"
-
-namespace mlir {
-
-class MLIRContext;
-class SDBMDialect;
-class SDBMExpr;
-class SDBMTermExpr;
-
-/// A utility class for SDBM to represent an integer with potentially infinite
-/// positive value. This uses the largest value of int64_t to represent infinity
-/// and redefines the arithmetic operators so that the infinity "saturates":
-/// inf + x = inf,
-/// inf - x = inf.
-/// If a sum of two finite values reaches the largest value of int64_t, the
-/// behavior of IntInfty is undefined (in practice, it asserts), similarly to
-/// regular signed integer overflow.
-class IntInfty {
-public:
- constexpr static int64_t infty = std::numeric_limits<int64_t>::max();
-
- /*implicit*/ IntInfty(int64_t v) : value(v) {}
-
- IntInfty &operator=(int64_t v) {
- value = v;
- return *this;
- }
-
- static IntInfty infinity() { return IntInfty(infty); }
-
- int64_t getValue() const { return value; }
- explicit operator int64_t() const { return value; }
-
- bool isFinite() { return value != infty; }
-
-private:
- int64_t value;
-};
-
-inline IntInfty operator+(IntInfty lhs, IntInfty rhs) {
- if (!lhs.isFinite() || !rhs.isFinite())
- return IntInfty::infty;
-
- // Check for overflows, treating the sum of two values adding up to INT_MAX as
- // overflow. Convert values to unsigned to get an extra bit and avoid the
- // undefined behavior of signed integer overflows.
- assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 ||
- static_cast<uint64_t>(lhs.getValue()) +
- static_cast<uint64_t>(rhs.getValue()) <
- static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) &&
- "IntInfty overflow");
- // Check for underflows by converting values to unsigned to avoid undefined
- // behavior of signed integers perform the addition (bitwise result is same
- // because numbers are required to be two's complement in C++) and check if
- // the sign bit remains negative.
- assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 ||
- ((static_cast<uint64_t>(lhs.getValue()) +
- static_cast<uint64_t>(rhs.getValue())) >>
- 63) == 1) &&
- "IntInfty underflow");
-
- return lhs.getValue() + rhs.getValue();
-}
-
-inline bool operator<(IntInfty lhs, IntInfty rhs) {
- return lhs.getValue() < rhs.getValue();
-}
-
-inline bool operator<=(IntInfty lhs, IntInfty rhs) {
- return lhs.getValue() <= rhs.getValue();
-}
-
-inline bool operator==(IntInfty lhs, IntInfty rhs) {
- return lhs.getValue() == rhs.getValue();
-}
-
-inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); }
-
-/// Striped
diff erence-bound matrix is a representation of an integer set bound
-/// by a system of SDBMExprs interpreted as inequalities "expr <= 0".
-class SDBM {
-public:
- /// Obtain an SDBM from a list of SDBM expressions treated as inequalities and
- /// equalities with zero.
- static SDBM get(ArrayRef<SDBMExpr> inequalities,
- ArrayRef<SDBMExpr> equalities);
-
- void getSDBMExpressions(SDBMDialect *dialect,
- SmallVectorImpl<SDBMExpr> &inequalities,
- SmallVectorImpl<SDBMExpr> &equalities);
-
- void print(raw_ostream &os);
- void dump();
-
- IntInfty operator()(int i, int j) { return at(i, j); }
-
-private:
- /// Get the given element of the
diff erence bounds matrix. First index
- /// corresponds to the negative term of the
diff erence, second index
- /// corresponds to the positive term of the
diff erence.
- IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; }
-
- /// Populate `inequalities` and `equalities` based on the values at(row,col)
- /// and at(col,row) of the DBM. Depending on the values being finite and
- /// being subsumed by stripe expressions, this may or may not add elements to
- /// the lists of equalities and inequalities.
- void convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
- SDBMTermExpr colExpr,
- SmallVectorImpl<SDBMExpr> &inequalities,
- SmallVectorImpl<SDBMExpr> &equalities);
-
- /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only
- /// adds new inequalities if the inequality is not trivially true.
- void convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr,
- SmallVectorImpl<SDBMExpr> &inequalities);
-
- /// Get the total number of elements in the matrix.
- unsigned getNumVariables() const {
- return 1 + numDims + numSymbols + numTemporaries;
- }
-
- /// Get the position in the matrix that corresponds to the given dimension.
- unsigned getDimPosition(unsigned position) const { return 1 + position; }
-
- /// Get the position in the matrix that corresponds to the given symbol.
- unsigned getSymbolPosition(unsigned position) const {
- return 1 + numDims + position;
- }
-
- /// Get the position in the matrix that corresponds to the given temporary.
- unsigned getTemporaryPosition(unsigned position) const {
- return 1 + numDims + numSymbols + position;
- }
-
- /// Number of dimensions in the system,
- unsigned numDims;
- /// Number of symbols in the system.
- unsigned numSymbols;
- /// Number of temporary variables in the system.
- unsigned numTemporaries;
-
- /// Difference bounds matrix, stored as a linearized row-major vector.
- /// Each value in this matrix corresponds to an inequality
- ///
- /// v at col - v at row <= at(row, col)
- ///
- /// where v at col and v at row are the variables that correspond to the linearized
- /// position in the matrix. The positions correspond to
- ///
- /// - constant 0 (producing constraints v at col <= X and -v at row <= Y);
- /// - SDBM expression dimensions (d0, d1, ...);
- /// - SDBM expression symbols (s0, s1, ...);
- /// - temporary variables (t0, t1, ...).
- ///
- /// Temporary variables are introduced to represent expressions that are not
- /// trivially a
diff erence between two variables. For example, if one side of
- /// a
diff erence expression is itself a stripe expression, it will be replaced
- /// with a temporary variable assigned equal to this expression.
- ///
- /// Infinite entries in the matrix correspond correspond to an absence of a
- /// constraint:
- ///
- /// v at col - v at row <= infinity
- ///
- /// is trivially true. Negated values at symmetric positions in the matrix
- /// allow one to couple two inequalities into a single equality.
- std::vector<IntInfty> matrix;
-
- /// The mapping between the indices of variables in the DBM and the stripe
- /// expressions they are equal to. These expressions are stored as they
- /// appeared when constructing an SDBM from a SDBMExprs, in particular no
- /// temporaries can appear in these expressions. This removes the need to
- /// iteratively substitute definitions of the temporaries in the reverse
- /// conversion.
- DenseMap<unsigned, SDBMExpr> stripeToPoint;
-};
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_SDBM_SDBM_H
diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
deleted file mode 100644
index 85cfe91d2c9b..000000000000
--- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
+++ /dev/null
@@ -1,37 +0,0 @@
-//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H
-#define MLIR_DIALECT_SDBM_SDBMDIALECT_H
-
-#include "mlir/IR/Dialect.h"
-#include "mlir/Support/StorageUniquer.h"
-
-namespace mlir {
-class MLIRContext;
-
-class SDBMDialect : public Dialect {
-public:
- SDBMDialect(MLIRContext *context);
-
- /// Since there are no other virtual methods in this derived class, override
- /// the destructor so that key methods get defined in the corresponding
- /// module.
- ~SDBMDialect() override;
-
- static StringRef getDialectNamespace() { return "sdbm"; }
-
- /// Get the uniquer for SDBM expressions. This should not be used directly.
- StorageUniquer &getUniquer() { return uniquer; }
-
-private:
- StorageUniquer uniquer;
-};
-} // namespace mlir
-
-#endif // MLIR_DIALECT_SDBM_SDBMDIALECT_H
diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h
deleted file mode 100644
index 7b51b892384e..000000000000
--- a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h
+++ /dev/null
@@ -1,576 +0,0 @@
-//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// A striped
diff erence-bound matrix (SDBM) expression is a constant expression,
-// an identifier, a binary expression with constant RHS and +, stripe operators
-// or a
diff erence expression between two identifiers.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SDBM_SDBMEXPR_H
-#define MLIR_DIALECT_SDBM_SDBMEXPR_H
-
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/DenseMapInfo.h"
-
-namespace mlir {
-
-class AffineExpr;
-class MLIRContext;
-
-enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
-
-namespace detail {
-struct SDBMExprStorage;
-struct SDBMBinaryExprStorage;
-struct SDBMDiffExprStorage;
-struct SDBMTermExprStorage;
-struct SDBMConstantExprStorage;
-struct SDBMNegExprStorage;
-} // namespace detail
-
-class SDBMConstantExpr;
-class SDBMDialect;
-class SDBMDimExpr;
-class SDBMSymbolExpr;
-class SDBMTermExpr;
-
-/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
-/// expression for the SDBM framework. SDBM expressions are a subset of affine
-/// expressions supporting low-complexity algorithms for the operations used in
-/// loop transformations. In particular, are supported:
-/// - constant expressions;
-/// - single variables (dimensions and symbols) with +1 or -1 coefficient;
-/// - stripe expressions: "x # C", where "x" is a single variable or another
-/// stripe expression, "#" is the stripe operator, and "C" is a constant
-/// expression; "#" is defined as x - x mod C.
-/// - sum expressions between single variable/stripe expressions and constant
-/// expressions;
-/// -
diff erence expressions between single variable/stripe expressions.
-/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
-/// and operating on SDBM expressions. For example, it requires the LHS of a
-/// sum expression to be a single variable or a stripe expression. These
-/// restrictions are intended to force the caller to perform the necessary
-/// simplifications to stay within the SDBM domain, because SDBM expressions do
-/// not combine in more cases than they do. This choice may be reconsidered in
-/// the future.
-///
-/// SDBM expressions are grouped into the following structure
-/// - expression
-/// - varying
-/// - direct
-/// - sum <- (term, constant)
-/// - term
-/// - symbol
-/// - dimension
-/// - stripe <- (direct, constant)
-/// - negation <- (direct)
-/// -
diff erence <- (direct, term)
-/// - constant
-/// The notation <- (...) denotes the types of subexpressions a compound
-/// expression can combine. The tree of subexpressions essentially imposes the
-/// following canonicalization rules:
-/// - constants are always folded;
-/// - constants can only appear on the RHS of an expression;
-/// - double negation must be elided;
-/// - an additive constant term is only allowed in a sum expression, and
-/// should be sunk into the nearest such expression in the tree;
-/// - zero constant expression can only appear at the top level.
-///
-/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
-/// an MLIRContext, and should be used by-value. They are uniqued in the
-/// MLIRContext and immortal.
-class SDBMExpr {
-public:
- using ImplType = detail::SDBMExprStorage;
- SDBMExpr() : impl(nullptr) {}
- /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
-
- /// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
- /// which makes them trivially assignable and trivially copyable.
- SDBMExpr(const SDBMExpr &) = default;
- SDBMExpr &operator=(const SDBMExpr &) = default;
-
- /// SDBM expressions can be compared straight-forwardly.
- bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
- bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
-
- /// SDBM expressions are convertible to `bool`: null expressions are converted
- /// to false, non-null expressions are converted to true.
- explicit operator bool() const { return impl != nullptr; }
- bool operator!() const { return !static_cast<bool>(*this); }
-
- /// Negate the given SDBM expression.
- SDBMExpr operator-();
-
- /// Prints the SDBM expression.
- void print(raw_ostream &os) const;
- void dump() const;
-
- /// LLVM-style casts.
- template <typename U> bool isa() const { return U::isClassFor(*this); }
- template <typename U> U dyn_cast() const {
- if (!isa<U>())
- return {};
- return U(const_cast<SDBMExpr *>(this)->impl);
- }
- template <typename U> U cast() const {
- assert(isa<U>() && "cast to incorrect subtype");
- return U(const_cast<SDBMExpr *>(this)->impl);
- }
-
- /// Support for LLVM hashing.
- ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
-
- /// Returns the kind of the SDBM expression.
- SDBMExprKind getKind() const;
-
- /// Returns the MLIR context in which this expression lives.
- MLIRContext *getContext() const;
-
- /// Returns the SDBM dialect instance.
- SDBMDialect *getDialect() const;
-
- /// Convert the SDBM expression into an Affine expression. This always
- /// succeeds because SDBM are a subset of affine.
- AffineExpr getAsAffineExpr() const;
-
- /// Try constructing an SDBM expression from the given affine expression.
- /// This may fail if the affine expression is not representable as SDBM, in
- /// which case llvm::None is returned. The conversion procedure recognizes
- /// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B)
- /// patterns for the stripe expression.
- static Optional<SDBMExpr> tryConvertAffineExpr(AffineExpr affine);
-
-protected:
- ImplType *impl;
-};
-
-/// SDBM constant expression, wraps a 64-bit integer.
-class SDBMConstantExpr : public SDBMExpr {
-public:
- using ImplType = detail::SDBMConstantExprStorage;
-
- using SDBMExpr::SDBMExpr;
-
- /// Obtain or create a constant expression unique'ed in the given dialect
- /// (which belongs to a context).
- static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value);
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::Constant;
- }
-
- int64_t getValue() const;
-};
-
-/// SDBM varying expression can be one of:
-/// - input variable expression;
-/// - stripe expression;
-/// - negation (product with -1) of either of the above.
-/// - sum of a varying and a constant expression
-/// -
diff erence between varying expressions
-class SDBMVaryingExpr : public SDBMExpr {
-public:
- using ImplType = detail::SDBMExprStorage;
- using SDBMExpr::SDBMExpr;
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::DimId ||
- expr.getKind() == SDBMExprKind::SymbolId ||
- expr.getKind() == SDBMExprKind::Neg ||
- expr.getKind() == SDBMExprKind::Stripe ||
- expr.getKind() == SDBMExprKind::Add ||
- expr.getKind() == SDBMExprKind::Diff;
- }
-};
-
-/// SDBM direct expression includes exactly one variable (symbol or dimension),
-/// which is not negated in the expression. It can be one of:
-/// - term expression;
-/// - sum expression.
-class SDBMDirectExpr : public SDBMVaryingExpr {
-public:
- using SDBMVaryingExpr::SDBMVaryingExpr;
-
- /// If this is a sum expression, return its variable part, otherwise return
- /// self.
- SDBMTermExpr getTerm();
-
- /// If this is a sum expression, return its constant part, otherwise return 0.
- int64_t getConstant();
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::DimId ||
- expr.getKind() == SDBMExprKind::SymbolId ||
- expr.getKind() == SDBMExprKind::Stripe ||
- expr.getKind() == SDBMExprKind::Add;
- }
-};
-
-/// SDBM term expression can be one of:
-/// - single variable expression;
-/// - stripe expression.
-/// Stripe expressions are treated as terms since, in the SDBM domain, they are
-/// attached to temporary variables and can appear anywhere a variable can.
-class SDBMTermExpr : public SDBMDirectExpr {
-public:
- using SDBMDirectExpr::SDBMDirectExpr;
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::DimId ||
- expr.getKind() == SDBMExprKind::SymbolId ||
- expr.getKind() == SDBMExprKind::Stripe;
- }
-};
-
-/// SDBM sum expression. LHS is a term expression and RHS is a constant.
-class SDBMSumExpr : public SDBMDirectExpr {
-public:
- using ImplType = detail::SDBMBinaryExprStorage;
- using SDBMDirectExpr::SDBMDirectExpr;
-
- /// Obtain or create a sum expression unique'ed in the given context.
- static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs);
-
- static bool isClassFor(const SDBMExpr &expr) {
- SDBMExprKind kind = expr.getKind();
- return kind == SDBMExprKind::Add;
- }
-
- SDBMTermExpr getLHS() const;
- SDBMConstantExpr getRHS() const;
-};
-
-/// SDBM
diff erence expression. LHS is a direct expression, i.e. it may be a
-/// sum of a term and a constant. RHS is a term expression. Thus the
-/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as
-///
diff (sum(t1, C), t2)
-/// and it is possible to extract the constant factor without negating it.
-class SDBMDiffExpr : public SDBMVaryingExpr {
-public:
- using ImplType = detail::SDBMDiffExprStorage;
- using SDBMVaryingExpr::SDBMVaryingExpr;
-
- /// Obtain or create a
diff erence expression unique'ed in the given context.
- static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs);
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::Diff;
- }
-
- SDBMDirectExpr getLHS() const;
- SDBMTermExpr getRHS() const;
-};
-
-/// SDBM stripe expression "x # C" where "x" is a term expression, "C" is a
-/// constant expression and "#" is the stripe operator defined as:
-/// x # C = x - x mod C.
-class SDBMStripeExpr : public SDBMTermExpr {
-public:
- using ImplType = detail::SDBMBinaryExprStorage;
- using SDBMTermExpr::SDBMTermExpr;
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::Stripe;
- }
-
- static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor);
-
- SDBMDirectExpr getLHS() const;
- SDBMConstantExpr getStripeFactor() const;
-};
-
-/// SDBM "input" variable expression can be either a dimension identifier or
-/// a symbol identifier. When used to define SDBM functions, dimensions are
-/// interpreted as function arguments while symbols are treated as unknown but
-/// constant values, hence the name.
-class SDBMInputExpr : public SDBMTermExpr {
-public:
- using ImplType = detail::SDBMTermExprStorage;
- using SDBMTermExpr::SDBMTermExpr;
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::DimId ||
- expr.getKind() == SDBMExprKind::SymbolId;
- }
-
- unsigned getPosition() const;
-};
-
-/// SDBM dimension expression. Dimensions correspond to function arguments
-/// when defining functions using SDBM expressions.
-class SDBMDimExpr : public SDBMInputExpr {
-public:
- using ImplType = detail::SDBMTermExprStorage;
- using SDBMInputExpr::SDBMInputExpr;
-
- /// Obtain or create a dimension expression unique'ed in the given dialect
- /// (which belongs to a context).
- static SDBMDimExpr get(SDBMDialect *dialect, unsigned position);
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::DimId;
- }
-};
-
-/// SDBM symbol expression. Symbols correspond to symbolic constants when
-/// defining functions using SDBM expressions.
-class SDBMSymbolExpr : public SDBMInputExpr {
-public:
- using ImplType = detail::SDBMTermExprStorage;
- using SDBMInputExpr::SDBMInputExpr;
-
- /// Obtain or create a symbol expression unique'ed in the given dialect (which
- /// belongs to a context).
- static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position);
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::SymbolId;
- }
-};
-
-/// Negation of an SDBM variable expression. Equivalent to multiplying the
-/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
-class SDBMNegExpr : public SDBMVaryingExpr {
-public:
- using ImplType = detail::SDBMNegExprStorage;
- using SDBMVaryingExpr::SDBMVaryingExpr;
-
- /// Obtain or create a negation expression unique'ed in the given context.
- static SDBMNegExpr get(SDBMDirectExpr var);
-
- static bool isClassFor(const SDBMExpr &expr) {
- return expr.getKind() == SDBMExprKind::Neg;
- }
-
- SDBMDirectExpr getVar() const;
-};
-
-/// A visitor class for SDBM expressions. Calls the kind-specific function
-/// depending on the kind of expression it visits.
-template <typename Derived, typename Result = void> class SDBMVisitor {
-public:
- /// Visit the given SDBM expression, dispatching to kind-specific functions.
- Result visit(SDBMExpr expr) {
- auto *derived = static_cast<Derived *>(this);
- switch (expr.getKind()) {
- case SDBMExprKind::Add:
- case SDBMExprKind::Diff:
- case SDBMExprKind::DimId:
- case SDBMExprKind::SymbolId:
- case SDBMExprKind::Neg:
- case SDBMExprKind::Stripe:
- return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
- case SDBMExprKind::Constant:
- return derived->visitConstant(expr.cast<SDBMConstantExpr>());
- }
-
- llvm_unreachable("unsupported SDBM expression kind");
- }
-
- /// Traverse the SDBM expression tree calling `visit` on each node
- /// in depth-first preorder.
- void walkPreorder(SDBMExpr expr) { return walk</*isPreorder=*/true>(expr); }
-
- /// Traverse the SDBM expression tree calling `visit` on each node in
- /// depth-first postorder.
- void walkPostorder(SDBMExpr expr) { return walk</*isPreorder=*/false>(expr); }
-
-protected:
- /// Default visitors do nothing.
- void visitSum(SDBMSumExpr) {}
- void visitDiff(SDBMDiffExpr) {}
- void visitStripe(SDBMStripeExpr) {}
- void visitDim(SDBMDimExpr) {}
- void visitSymbol(SDBMSymbolExpr) {}
- void visitNeg(SDBMNegExpr) {}
- void visitConstant(SDBMConstantExpr) {}
-
- /// Default implementation of visitDirect dispatches to the dedicated for sums
- /// or delegates to visitTerm for the other expression kinds. Concrete
- /// visitors can overload it.
- Result visitDirect(SDBMDirectExpr expr) {
- auto *derived = static_cast<Derived *>(this);
- if (auto sum = expr.dyn_cast<SDBMSumExpr>())
- return derived->visitSum(sum);
- else
- return derived->visitTerm(expr.cast<SDBMTermExpr>());
- }
-
- /// Default implementation of visitTerm dispatches to the special functions
- /// for stripes and other variables. Concrete visitors can override it.
- Result visitTerm(SDBMTermExpr expr) {
- auto *derived = static_cast<Derived *>(this);
- if (expr.getKind() == SDBMExprKind::Stripe)
- return derived->visitStripe(expr.cast<SDBMStripeExpr>());
- else
- return derived->visitInput(expr.cast<SDBMInputExpr>());
- }
-
- /// Default implementation of visitInput dispatches to the special
- /// functions for dimensions or symbols. Concrete visitors can override it to
- /// visit all variables instead.
- Result visitInput(SDBMInputExpr expr) {
- auto *derived = static_cast<Derived *>(this);
- if (expr.getKind() == SDBMExprKind::DimId)
- return derived->visitDim(expr.cast<SDBMDimExpr>());
- else
- return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
- }
-
- /// Default implementation of visitVarying dispatches to the special
- /// functions for variables and negations thereof. Concrete visitors can
- /// override it to visit all variables and negations instead.
- Result visitVarying(SDBMVaryingExpr expr) {
- auto *derived = static_cast<Derived *>(this);
- if (auto var = expr.dyn_cast<SDBMDirectExpr>())
- return derived->visitDirect(var);
- else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
- return derived->visitNeg(neg);
- else if (auto
diff = expr.dyn_cast<SDBMDiffExpr>())
- return derived->visitDiff(
diff );
-
- llvm_unreachable("unhandled subtype of varying SDBM expression");
- }
-
- template <bool isPreorder> void walk(SDBMExpr expr) {
- if (isPreorder)
- visit(expr);
- if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
- walk<isPreorder>(sumExpr.getLHS());
- walk<isPreorder>(sumExpr.getRHS());
- } else if (auto
diff Expr = expr.dyn_cast<SDBMDiffExpr>()) {
- walk<isPreorder>(
diff Expr.getLHS());
- walk<isPreorder>(
diff Expr.getRHS());
- } else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
- walk<isPreorder>(stripeExpr.getLHS());
- walk<isPreorder>(stripeExpr.getStripeFactor());
- } else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
- walk<isPreorder>(negExpr.getVar());
- }
- if (!isPreorder)
- visit(expr);
- }
-};
-
-/// Overloaded arithmetic operators for SDBM expressions asserting that their
-/// arguments have the proper SDBM expression subtype. Perform canonicalization
-/// and constant folding on these expressions.
-namespace ops_assertions {
-
-/// Add two SDBM expressions. At least one of the expressions must be a
-/// constant or a negation, but both expressions cannot be negations
-/// simultaneously.
-SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs);
-inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
- return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs);
-}
-inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
- return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs;
-}
-
-/// Subtract an SDBM expression from another SDBM expression. Both expressions
-/// must not be
diff erence expressions.
-SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs);
-inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
- return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs);
-}
-inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
- return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs;
-}
-
-/// Construct a stripe expression from a positive expression and a positive
-/// constant stripe factor.
-SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
-inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
- return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor));
-}
-} // namespace ops_assertions
-
-} // end namespace mlir
-
-namespace llvm {
-// SDBMExpr hash just like pointers.
-template <> struct DenseMapInfo<mlir::SDBMExpr> {
- static mlir::SDBMExpr getEmptyKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
- return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static mlir::SDBMExpr getTombstoneKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
- return mlir::SDBMExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static unsigned getHashValue(mlir::SDBMExpr expr) {
- return expr.hash_value();
- }
- static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) {
- return lhs == rhs;
- }
-};
-
-// SDBMDirectExpr hash just like pointers.
-template <> struct DenseMapInfo<mlir::SDBMDirectExpr> {
- static mlir::SDBMDirectExpr getEmptyKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
- return mlir::SDBMDirectExpr(
- static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static mlir::SDBMDirectExpr getTombstoneKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
- return mlir::SDBMDirectExpr(
- static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static unsigned getHashValue(mlir::SDBMDirectExpr expr) {
- return expr.hash_value();
- }
- static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) {
- return lhs == rhs;
- }
-};
-
-// SDBMTermExpr hash just like pointers.
-template <> struct DenseMapInfo<mlir::SDBMTermExpr> {
- static mlir::SDBMTermExpr getEmptyKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
- return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static mlir::SDBMTermExpr getTombstoneKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
- return mlir::SDBMTermExpr(static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static unsigned getHashValue(mlir::SDBMTermExpr expr) {
- return expr.hash_value();
- }
- static bool isEqual(mlir::SDBMTermExpr lhs, mlir::SDBMTermExpr rhs) {
- return lhs == rhs;
- }
-};
-
-// SDBMConstantExpr hash just like pointers.
-template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
- static mlir::SDBMConstantExpr getEmptyKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
- return mlir::SDBMConstantExpr(
- static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static mlir::SDBMConstantExpr getTombstoneKey() {
- auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
- return mlir::SDBMConstantExpr(
- static_cast<mlir::SDBMExpr::ImplType *>(pointer));
- }
- static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
- return expr.hash_value();
- }
- static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
- return lhs == rhs;
- }
-};
-} // namespace llvm
-
-#endif // MLIR_DIALECT_SDBM_SDBMEXPR_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c52dae3fd1b5..5cf0429942ca 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -35,7 +35,6 @@
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -75,7 +74,6 @@ inline void registerAllDialects(DialectRegistry ®istry) {
vector::VectorDialect,
NVVM::NVVMDialect,
ROCDL::ROCDLDialect,
- SDBMDialect,
shape::ShapeDialect,
sparse_tensor::SparseTensorDialect,
tensor::TensorDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index de946beef0d9..8a6f08ab3b83 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -17,7 +17,6 @@ add_subdirectory(PDL)
add_subdirectory(PDLInterp)
add_subdirectory(Quant)
add_subdirectory(SCF)
-add_subdirectory(SDBM)
add_subdirectory(Shape)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
diff --git a/mlir/lib/Dialect/SDBM/CMakeLists.txt b/mlir/lib/Dialect/SDBM/CMakeLists.txt
deleted file mode 100644
index db2b9ac85472..000000000000
--- a/mlir/lib/Dialect/SDBM/CMakeLists.txt
+++ /dev/null
@@ -1,11 +0,0 @@
-add_mlir_dialect_library(MLIRSDBM
- SDBM.cpp
- SDBMDialect.cpp
- SDBMExpr.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SDBM
-
- LINK_LIBS PUBLIC
- MLIRIR
- )
diff --git a/mlir/lib/Dialect/SDBM/SDBM.cpp b/mlir/lib/Dialect/SDBM/SDBM.cpp
deleted file mode 100644
index df24e77bc4f2..000000000000
--- a/mlir/lib/Dialect/SDBM/SDBM.cpp
+++ /dev/null
@@ -1,551 +0,0 @@
-//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// A striped
diff erence-bound matrix (SDBM) is a set in Z^N (or R^N) defined
-// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SDBM/SDBM.h"
-#include "mlir/Dialect/SDBM/SDBMExpr.h"
-
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-
-// Helper function for SDBM construction that collects information necessary to
-// start building an SDBM in one sweep. In particular, it records the largest
-// position of a dimension in `dim`, that of a symbol in `symbol` as well as
-// collects all unique stripe expressions in `stripes`. Uses SetVector to
-// ensure these expressions always have the same order.
-static void collectSDBMBuildInfo(SDBMExpr expr, int &dim, int &symbol,
- llvm::SmallSetVector<SDBMExpr, 8> &stripes) {
- struct Visitor : public SDBMVisitor<Visitor> {
- void visitDim(SDBMDimExpr dimExpr) {
- int p = dimExpr.getPosition();
- if (p > maxDimPosition)
- maxDimPosition = p;
- }
- void visitSymbol(SDBMSymbolExpr symbExpr) {
- int p = symbExpr.getPosition();
- if (p > maxSymbPosition)
- maxSymbPosition = p;
- }
- void visitStripe(SDBMStripeExpr stripeExpr) { stripes.insert(stripeExpr); }
-
- Visitor(llvm::SmallSetVector<SDBMExpr, 8> &stripes) : stripes(stripes) {}
-
- int maxDimPosition = -1;
- int maxSymbPosition = -1;
- llvm::SmallSetVector<SDBMExpr, 8> &stripes;
- };
-
- Visitor visitor(stripes);
- visitor.walkPostorder(expr);
- dim = std::max(dim, visitor.maxDimPosition);
- symbol = std::max(symbol, visitor.maxSymbPosition);
-}
-
-namespace {
-// Utility class for SDBMBuilder. Represents a value that can be inserted in
-// the SDB matrix that corresponds to "v0 - v1 + C <= 0", where v0 and v1 is
-// any combination of the positive and negative positions. Since multiple
-// variables can be declared equal to the same stripe expression, the
-// constraints on this expression must be reflected to all these variables. For
-// example, if
-// d0 = s0 # 42
-// d1 = s0 # 42
-// d2 = s1 # 2
-// d3 = s1 # 2
-// the constraint
-// s0 # 42 - s1 # 2 <= C
-// should be reflected in the DB matrix as
-// d0 - d2 <= C
-// d1 - d2 <= C
-// d0 - d3 <= C
-// d1 - d3 <= C
-// since the DB matrix has no knowledge of the transitive equality between d0,
-// d1 and s0 # 42 as well as between d2, d3 and s1 # 2. This knowledge can be
-// obtained by computing a transitive closure, which is impossible until the
-// DBM is actually built.
-struct SDBMBuilderResult {
- // Positions in the matrix of the variables taken with the "+" sign in the
- //
diff erence expression, 0 if it is a constant rather than a variable.
- SmallVector<unsigned, 2> positivePos;
-
- // Positions in the matrix of the variables taken with the "-" sign in the
- //
diff erence expression, 0 if it is a constant rather than a variable.
- SmallVector<unsigned, 2> negativePos;
-
- // Constant value in the
diff erence expression.
- int64_t value = 0;
-};
-
-// Visitor for building an SDBM from SDBM expressions. After traversing an SDBM
-// expression, produces an update to the SDB matrix specifying the positions in
-// the matrix and the negated value that should be stored. Both the positive
-// and the negative positions may be lists of indices in cases where multiple
-// variables are equal to the same stripe expression. In such cases, the update
-// applies to the cross product of positions because elements involved in the
-// update are (transitively) equal and should have the same constraints, but we
-// may not have an explicit equality for them.
-struct SDBMBuilder : public SDBMVisitor<SDBMBuilder, SDBMBuilderResult> {
-public:
- // A
diff erence expression produces both the positive and the negative
- // coordinate in the matrix, recursively traversing the LHS and the RHS. The
- // value is the
diff erence between values obtained from LHS and RHS.
- SDBMBuilderResult visitDiff(SDBMDiffExpr
diff Expr) {
- auto lhs = visit(
diff Expr.getLHS());
- auto rhs = visit(
diff Expr.getRHS());
- assert(lhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
- "unexpected negative expression in a
diff erence expression");
- assert(rhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 &&
- "unexpected negative expression in a
diff erence expression");
-
- SDBMBuilderResult result;
- result.positivePos = lhs.positivePos;
- result.negativePos = rhs.positivePos;
- result.value = lhs.value - rhs.value;
- return result;
- }
-
- // An input expression is always taken with the "+" sign and therefore
- // produces a positive coordinate keeping the negative coordinate zero for an
- // eventual constant.
- SDBMBuilderResult visitInput(SDBMInputExpr expr) {
- SDBMBuilderResult r;
- r.positivePos.push_back(linearPosition(expr));
- r.negativePos.push_back(0);
- return r;
- }
-
- // A stripe expression is always equal to one or more variables, which may be
- // temporaries, and appears with a "+" sign in the SDBM expression tree. Take
- // the positions of the corresponding variables as positive coordinates.
- SDBMBuilderResult visitStripe(SDBMStripeExpr expr) {
- SDBMBuilderResult r;
- assert(pointExprToStripe.count(expr));
- r.positivePos = pointExprToStripe[expr];
- r.negativePos.push_back(0);
- return r;
- }
-
- // A constant expression has both coordinates at zero.
- SDBMBuilderResult visitConstant(SDBMConstantExpr expr) {
- SDBMBuilderResult r;
- r.positivePos.push_back(0);
- r.negativePos.push_back(0);
- r.value = expr.getValue();
- return r;
- }
-
- // A negation expression swaps the positive and the negative coordinates
- // and also negates the constant value.
- SDBMBuilderResult visitNeg(SDBMNegExpr expr) {
- SDBMBuilderResult result = visit(expr.getVar());
- std::swap(result.positivePos, result.negativePos);
- result.value = -result.value;
- return result;
- }
-
- // The RHS of a sum expression must be a constant and therefore must have both
- // positive and negative coordinates at zero. Take the sum of the values
- // between LHS and RHS and keep LHS coordinates.
- SDBMBuilderResult visitSum(SDBMSumExpr expr) {
- auto lhs = visit(expr.getLHS());
- auto rhs = visit(expr.getRHS());
- for (auto pos : rhs.negativePos) {
- (void)pos;
- assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
- }
- for (auto pos : rhs.positivePos) {
- (void)pos;
- assert(pos == 0 && "unexpected variable on the RHS of SDBM sum");
- }
-
- lhs.value += rhs.value;
- return lhs;
- }
-
- SDBMBuilder(DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe,
- function_ref<unsigned(SDBMInputExpr)> callback)
- : pointExprToStripe(pointExprToStripe), linearPosition(callback) {}
-
- DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe;
- function_ref<unsigned(SDBMInputExpr)> linearPosition;
-};
-} // namespace
-
-SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
- SDBM result;
-
- // TODO: consider detecting equalities in the list of inequalities.
- // This is potentially expensive and requires to
- // - create a list of negated inequalities (may allocate under lock);
- // - perform a pairwise comparison of direct and negated inequalities;
- // - copy the lists of equalities and inequalities, and move entries between
- // them;
- // only for the purpose of sparing a temporary variable in cases where an
- // implicit equality between a variable and a stripe expression is present in
- // the input.
-
- // Do the first sweep over (in)equalities to collect the information necessary
- // to allocate the SDB matrix (number of dimensions, symbol and temporary
- // variables required for stripe expressions).
- llvm::SmallSetVector<SDBMExpr, 8> stripes;
- int maxDim = -1;
- int maxSymbol = -1;
- for (auto expr : inequalities)
- collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
- for (auto expr : equalities)
- collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes);
- // Indexing of dimensions starts with 0, obtain the number of dimensions by
- // incrementing the maximal position of the dimension seen in expressions.
- result.numDims = maxDim + 1;
- result.numSymbols = maxSymbol + 1;
- result.numTemporaries = 0;
-
- // Helper function that returns the position of the variable represented by
- // an SDBM input expression.
- auto linearPosition = [result](SDBMInputExpr expr) {
- if (expr.isa<SDBMDimExpr>())
- return result.getDimPosition(expr.getPosition());
- return result.getSymbolPosition(expr.getPosition());
- };
-
- // Check if some stripe expressions are equal to another variable. In
- // particular, look for the equalities of the form
- // d0 - stripe-expression = 0, or
- // stripe-expression - d0 = 0.
- // There may be multiple variables that are equal to the same stripe
- // expression. Keep track of those in pointExprToStripe.
- // There may also be multiple stripe expressions equal to the same variable.
- // Introduce a temporary variable for each of those.
- DenseMap<SDBMExpr, SmallVector<unsigned, 2>> pointExprToStripe;
- unsigned numTemporaries = 0;
-
- auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe,
- linearPosition](SDBMInputExpr input,
- SDBMExpr expr) {
- unsigned position = linearPosition(input);
- if (result.stripeToPoint.count(position) &&
- result.stripeToPoint[position] != expr) {
- position = result.getNumVariables() + numTemporaries++;
- }
- pointExprToStripe[expr].push_back(position);
- result.stripeToPoint.insert(std::make_pair(position, expr));
- };
-
- for (auto eq : equalities) {
- auto
diff Expr = eq.dyn_cast<SDBMDiffExpr>();
- if (!
diff Expr)
- continue;
-
- auto lhs =
diff Expr.getLHS();
- auto rhs =
diff Expr.getRHS();
- auto lhsInput = lhs.dyn_cast<SDBMInputExpr>();
- auto rhsInput = rhs.dyn_cast<SDBMInputExpr>();
-
- if (lhsInput && stripes.count(rhs))
- updateStripePointMaps(lhsInput, rhs);
- if (rhsInput && stripes.count(lhs))
- updateStripePointMaps(rhsInput, lhs);
- }
-
- // Assign the remaining stripe expressions to temporary variables. These
- // expressions are the ones that could not be associated with an existing
- // variable in the previous step.
- for (auto expr : stripes) {
- if (pointExprToStripe.count(expr))
- continue;
- unsigned position = result.getNumVariables() + numTemporaries++;
- pointExprToStripe[expr].push_back(position);
- result.stripeToPoint.insert(std::make_pair(position, expr));
- }
-
- // Create the DBM matrix, initialized to infinity values for the least tight
- // possible bound (x - y <= infinity is always true).
- result.numTemporaries = numTemporaries;
- result.matrix.resize(result.getNumVariables() * result.getNumVariables(),
- IntInfty::infinity());
-
- SDBMBuilder builder(pointExprToStripe, linearPosition);
-
- // Only keep the tightest constraint. Since we transform everything into
- // less-than-or-equals-to inequalities, keep the smallest constant. For
- // example, if we have d0 - d1 <= 42 and d0 - d1 <= 2, we keep the latter.
- // Note that the input expressions are in the shape of d0 - d1 + -42 <= 0
- // so we negate the value before storing it.
- // In case where the positive and the negative positions are equal, the
- // corresponding expression has the form d0 - d0 + -42 <= 0. If the constant
- // value is positive, the set defined by SDBM is trivially empty. We store
- // this value anyway and continue processing to maintain the correspondence
- // between the matrix form and the list-of-SDBMExpr form.
- // TODO: we may want to reconsider this once we have canonicalization
- // or simplification in place
- auto updateMatrix = [](SDBM &sdbm, const SDBMBuilderResult &r) {
- for (auto positivePos : r.positivePos) {
- for (auto negativePos : r.negativePos) {
- auto &m = sdbm.at(negativePos, positivePos);
- m = m < -r.value ? m : -r.value;
- }
- }
- };
-
- // Do the second sweep on (in)equalities, updating the SDB matrix to reflect
- // the constraints.
- for (auto ineq : inequalities)
- updateMatrix(result, builder.visit(ineq));
-
- // An equality f(x) = 0 is represented as a pair of inequalities {f(x) >= 0;
- // f(x) <= 0} or, alternatively, {-f(x) <= 0 and f(x) <= 0}.
- for (auto eq : equalities) {
- updateMatrix(result, builder.visit(eq));
- updateMatrix(result, builder.visit(-eq));
- }
-
- // Add the inequalities induced by stripe equalities.
- // t = x # C => t <= x <= t + C - 1
- // which is equivalent to
- // {t - x <= 0;
- // x - t - (C - 1) <= 0}.
- for (const auto &pair : result.stripeToPoint) {
- auto stripe = pair.second.cast<SDBMStripeExpr>();
- SDBMBuilderResult update = builder.visit(stripe.getLHS());
- assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 &&
- "unexpected negated variable in stripe expression");
- assert(update.value == 0 &&
- "unexpected non-zero value in stripe expression");
- update.negativePos.clear();
- update.negativePos.push_back(pair.first);
- update.value = -(stripe.getStripeFactor().getValue() - 1);
- updateMatrix(result, update);
-
- std::swap(update.negativePos, update.positivePos);
- update.value = 0;
- updateMatrix(result, update);
- }
-
- return result;
-}
-
-// Given a row and a column position in the square DBM, insert one equality
-// or up to two inequalities that correspond the entries (col, row) and (row,
-// col) in the DBM. `rowExpr` and `colExpr` contain the expressions such that
-// colExpr - rowExpr <= V where V is the value at (row, col) in the DBM.
-// If one of the expressions is derived from another using a stripe operation,
-// check if the inequalities induced by the stripe operation subsume the
-// inequalities defined in the DBM and if so, elide these inequalities.
-void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
- SDBMTermExpr colExpr,
- SmallVectorImpl<SDBMExpr> &inequalities,
- SmallVectorImpl<SDBMExpr> &equalities) {
- using ops_assertions::operator+;
- using ops_assertions::operator-;
-
- auto
diff IJValue = at(col, row);
- auto
diff JIValue = at(row, col);
-
- // If symmetric entries are opposite, the corresponding expressions are equal.
- if (
diff IJValue.isFinite() &&
-
diff IJValue.getValue() == -
diff JIValue.getValue()) {
- equalities.push_back(rowExpr - colExpr -
diff IJValue.getValue());
- return;
- }
-
- // Given an inequality x0 - x1 <= A, check if x0 is a stripe variable derived
- // from x1: x0 = x1 # B. If so, it would imply the constraints
- // x0 <= x1 <= x0 + (B - 1) <=> x0 - x1 <= 0 and x1 - x0 <= (B - 1).
- // Therefore, if A >= 0, this inequality is subsumed by that implied
- // by the stripe equality and thus can be elided.
- // Similarly, check if x1 is a stripe variable derived from x0: x1 = x0 # C.
- // If so, it would imply the constraints x1 <= x0 <= x1 + (C - 1) <=>
- // <=> x1 - x0 <= 0 and x0 - x1 <= (C - 1). Therefore, if A >= (C - 1), this
- // inequality can be elided.
- //
- // Note: x0 and x1 may be a stripe expressions themselves, we rely on stripe
- // expressions being stored without temporaries on the RHS and being passed
- // into this function as is.
- auto canElide = [this](unsigned x0, unsigned x1, SDBMExpr x0Expr,
- SDBMExpr x1Expr, int64_t value) {
- if (stripeToPoint.count(x0)) {
- auto stripe = stripeToPoint[x0].cast<SDBMStripeExpr>();
- SDBMDirectExpr var = stripe.getLHS();
- if (x1Expr == var && value >= 0)
- return true;
- }
- if (stripeToPoint.count(x1)) {
- auto stripe = stripeToPoint[x1].cast<SDBMStripeExpr>();
- SDBMDirectExpr var = stripe.getLHS();
- if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1)
- return true;
- }
- return false;
- };
-
- // Check row - col.
- if (
diff IJValue.isFinite() &&
- !canElide(row, col, rowExpr, colExpr,
diff IJValue.getValue())) {
- inequalities.push_back(rowExpr - colExpr -
diff IJValue.getValue());
- }
- // Check col - row.
- if (
diff JIValue.isFinite() &&
- !canElide(col, row, colExpr, rowExpr,
diff JIValue.getValue())) {
- inequalities.push_back(colExpr - rowExpr -
diff JIValue.getValue());
- }
-}
-
-// The values on the main diagonal correspond to the upper bound on the
-//
diff erence between a variable and itself: d0 - d0 <= C, or alternatively
-// to -C <= 0. Only construct the inequalities when C is negative, which
-// are trivially false but necessary for the returned system of inequalities
-// to indicate that the set it defines is empty.
-void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr,
- SmallVectorImpl<SDBMExpr> &inequalities) {
- auto selfDifference = at(pos, pos);
- if (selfDifference.isFinite() && selfDifference < 0) {
- auto selfDifferenceValueExpr =
- SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue());
- inequalities.push_back(selfDifferenceValueExpr);
- }
-}
-
-void SDBM::getSDBMExpressions(SDBMDialect *dialect,
- SmallVectorImpl<SDBMExpr> &inequalities,
- SmallVectorImpl<SDBMExpr> &equalities) {
- using ops_assertions::operator-;
- using ops_assertions::operator+;
-
- // Helper function that creates an SDBMInputExpr given the linearized position
- // of variable in the DBM.
- auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr {
- if (matrixPos < numDims)
- return SDBMDimExpr::get(dialect, matrixPos);
- return SDBMSymbolExpr::get(dialect, matrixPos - numDims);
- };
-
- // The top-left value corresponds to inequality 0 <= C. If C is negative, the
- // set defined by SDBM is trivially empty and we add the constraint -C <= 0 to
- // the list of inequalities. Otherwise, the constraint is trivially true and
- // we ignore it.
- auto
diff erence = at(0, 0);
- if (
diff erence.isFinite() &&
diff erence < 0) {
- inequalities.push_back(
- SDBMConstantExpr::get(dialect, -
diff erence.getValue()));
- }
-
- // Traverse the segment of the matrix that involves non-temporary variables.
- unsigned numTrueVariables = numDims + numSymbols;
- for (unsigned i = 0; i < numTrueVariables; ++i) {
- // The first row and column represent numerical upper and lower bound on
- // each variable. Transform them into inequalities if they are finite.
- auto upperBound = at(0, 1 + i);
- auto lowerBound = at(1 + i, 0);
- auto inputExpr = getInput(i);
- if (upperBound.isFinite() &&
- upperBound.getValue() == -lowerBound.getValue()) {
- equalities.push_back(inputExpr - upperBound.getValue());
- } else if (upperBound.isFinite()) {
- inequalities.push_back(inputExpr - upperBound.getValue());
- } else if (lowerBound.isFinite()) {
- inequalities.push_back(-inputExpr - lowerBound.getValue());
- }
-
- // Introduce trivially false inequalities if required by diagonal elements.
- convertDBMDiagonalElement(1 + i, inputExpr, inequalities);
-
- // Introduce equalities or inequalities between non-temporary variables.
- for (unsigned j = 0; j < i; ++j) {
- convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities,
- equalities);
- }
- }
-
- // Add equalities for stripe expressions that define non-temporary
- // variables. Temporary variables will be substituted into their uses and
- // should not appear in the resulting equalities.
- for (const auto &stripePair : stripeToPoint) {
- unsigned position = stripePair.first;
- if (position < 1 + numTrueVariables) {
- equalities.push_back(getInput(position - 1) - stripePair.second);
- }
- }
-
- // Add equalities / inequalities involving temporaries by replacing the
- // temporaries with stripe expressions that define them.
- for (unsigned i = 1 + numTrueVariables, e = getNumVariables(); i < e; ++i) {
- // Mixed constraints involving one temporary (j) and one non-temporary (i)
- // variable.
- for (unsigned j = 0; j < numTrueVariables; ++j) {
- convertDBMElement(i, 1 + j, stripeToPoint[i].cast<SDBMStripeExpr>(),
- getInput(j), inequalities, equalities);
- }
-
- // Constraints involving only temporary variables.
- for (unsigned j = 1 + numTrueVariables; j < i; ++j) {
- convertDBMElement(i, j, stripeToPoint[i].cast<SDBMStripeExpr>(),
- stripeToPoint[j].cast<SDBMStripeExpr>(), inequalities,
- equalities);
- }
-
- // Introduce trivially false inequalities if required by diagonal elements.
- convertDBMDiagonalElement(i, stripeToPoint[i].cast<SDBMStripeExpr>(),
- inequalities);
- }
-}
-
-void SDBM::print(raw_ostream &os) {
- unsigned numVariables = getNumVariables();
-
- // Helper function that prints the name of the variable given its linearized
- // position in the DBM.
- auto getVarName = [this](unsigned matrixPos) -> std::string {
- if (matrixPos == 0)
- return "cst";
- matrixPos -= 1;
- if (matrixPos < numDims)
- return std::string(llvm::formatv("d{0}", matrixPos));
- matrixPos -= numDims;
- if (matrixPos < numSymbols)
- return std::string(llvm::formatv("s{0}", matrixPos));
- matrixPos -= numSymbols;
- return std::string(llvm::formatv("t{0}", matrixPos));
- };
-
- // Header row.
- os << " cst";
- for (unsigned i = 1; i < numVariables; ++i) {
- os << llvm::formatv(" {0,4}", getVarName(i));
- }
- os << '\n';
-
- // Data rows.
- for (unsigned i = 0; i < numVariables; ++i) {
- os << llvm::formatv("{0,-4}", getVarName(i));
- for (unsigned j = 0; j < numVariables; ++j) {
- IntInfty value = operator()(i, j);
- if (!value.isFinite())
- os << " inf";
- else
- os << llvm::formatv(" {0,4}", value.getValue());
- }
- os << '\n';
- }
-
- // Explanation of temporaries.
- for (const auto &pair : stripeToPoint) {
- os << getVarName(pair.first) << " = ";
- pair.second.print(os);
- os << '\n';
- }
-}
-
-void SDBM::dump() { print(llvm::errs()); }
diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
deleted file mode 100644
index 4e3e050b4a4f..000000000000
--- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-//===- SDBMDialect.cpp - MLIR SDBM Dialect --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SDBM/SDBMDialect.h"
-#include "SDBMExprDetail.h"
-
-using namespace mlir;
-
-SDBMDialect::SDBMDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
- uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
- uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
- uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
- uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
- uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
-}
-
-SDBMDialect::~SDBMDialect() = default;
diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
deleted file mode 100644
index 5adcbcc78d52..000000000000
--- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
+++ /dev/null
@@ -1,732 +0,0 @@
-//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// A striped
diff erence-bound matrix (SDBM) expression is a constant expression,
-// an identifier, a binary expression with constant RHS and +, stripe operators
-// or a
diff erence expression between two identifiers.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SDBM/SDBMExpr.h"
-#include "SDBMExprDetail.h"
-#include "mlir/Dialect/SDBM/SDBMDialect.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineExprVisitor.h"
-
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-
-namespace {
-/// A simple compositional matcher for AffineExpr
-///
-/// Example usage:
-///
-/// ```c++
-/// AffineExprMatcher x, C, m;
-/// AffineExprMatcher pattern1 = ((x % C) * m) + x;
-/// AffineExprMatcher pattern2 = x + ((x % C) * m);
-/// if (pattern1.match(expr) || pattern2.match(expr)) {
-/// ...
-/// }
-/// ```
-class AffineExprMatcherStorage;
-class AffineExprMatcher {
-public:
- AffineExprMatcher();
- AffineExprMatcher(const AffineExprMatcher &other);
-
- AffineExprMatcher operator+(AffineExprMatcher other) {
- return AffineExprMatcher(AffineExprKind::Add, *this, other);
- }
- AffineExprMatcher operator*(AffineExprMatcher other) {
- return AffineExprMatcher(AffineExprKind::Mul, *this, other);
- }
- AffineExprMatcher floorDiv(AffineExprMatcher other) {
- return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
- }
- AffineExprMatcher ceilDiv(AffineExprMatcher other) {
- return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
- }
- AffineExprMatcher operator%(AffineExprMatcher other) {
- return AffineExprMatcher(AffineExprKind::Mod, *this, other);
- }
-
- AffineExpr match(AffineExpr expr);
- AffineExpr matched();
- Optional<int> getMatchedConstantValue();
-
-private:
- AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
- AffineExprKind kind; // only used to match in binary op cases.
- // A shared_ptr allows multiple references to same matcher storage without
- // worrying about ownership or dealing with an arena. To be cleaned up if we
- // go with this.
- std::shared_ptr<AffineExprMatcherStorage> storage;
-};
-
-class AffineExprMatcherStorage {
-public:
- AffineExprMatcherStorage() {}
- AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
- : subExprs(other.subExprs.begin(), other.subExprs.end()),
- matched(other.matched) {}
- AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
- : subExprs(exprs.begin(), exprs.end()) {}
- AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
- : subExprs({a, b}) {}
- SmallVector<AffineExprMatcher, 0> subExprs;
- AffineExpr matched;
-};
-} // namespace
-
-AffineExprMatcher::AffineExprMatcher()
- : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
-
-AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
- : kind(other.kind), storage(other.storage) {}
-
-Optional<int> AffineExprMatcher::getMatchedConstantValue() {
- if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
- return cst.getValue();
- return None;
-}
-
-AffineExpr AffineExprMatcher::match(AffineExpr expr) {
- if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
- if (storage->matched)
- if (storage->matched != expr)
- return AffineExpr();
- storage->matched = expr;
- return storage->matched;
- }
- if (kind != expr.getKind()) {
- return AffineExpr();
- }
- if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
- if (!storage->subExprs.empty() &&
- !storage->subExprs[0].match(bin.getLHS())) {
- return AffineExpr();
- }
- if (!storage->subExprs.empty() &&
- !storage->subExprs[1].match(bin.getRHS())) {
- return AffineExpr();
- }
- if (storage->matched)
- if (storage->matched != expr)
- return AffineExpr();
- storage->matched = expr;
- return storage->matched;
- }
- llvm_unreachable("binary expected");
-}
-
-AffineExpr AffineExprMatcher::matched() { return storage->matched; }
-
-AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
- AffineExprMatcher b)
- : kind(k), storage(new AffineExprMatcherStorage(a, b)) {
- storage->subExprs.push_back(a);
- storage->subExprs.push_back(b);
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMExpr
-//===----------------------------------------------------------------------===//
-
-SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
-
-MLIRContext *SDBMExpr::getContext() const {
- return impl->dialect->getContext();
-}
-
-SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
-
-void SDBMExpr::print(raw_ostream &os) const {
- struct Printer : public SDBMVisitor<Printer> {
- Printer(raw_ostream &ostream) : prn(ostream) {}
-
- void visitSum(SDBMSumExpr expr) {
- visit(expr.getLHS());
- prn << " + ";
- visit(expr.getRHS());
- }
- void visitDiff(SDBMDiffExpr expr) {
- visit(expr.getLHS());
- prn << " - ";
- visit(expr.getRHS());
- }
- void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
- void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
- void visitStripe(SDBMStripeExpr expr) {
- SDBMDirectExpr lhs = expr.getLHS();
- bool isTerm = lhs.isa<SDBMTermExpr>();
- if (!isTerm)
- prn << '(';
- visit(lhs);
- if (!isTerm)
- prn << ')';
- prn << " # ";
- visitConstant(expr.getStripeFactor());
- }
- void visitNeg(SDBMNegExpr expr) {
- bool isSum = expr.getVar().isa<SDBMSumExpr>();
- prn << '-';
- if (isSum)
- prn << '(';
- visit(expr.getVar());
- if (isSum)
- prn << ')';
- }
- void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
-
- raw_ostream &prn;
- };
- Printer printer(os);
- printer.visit(*this);
-}
-
-void SDBMExpr::dump() const {
- print(llvm::errs());
- llvm::errs() << '\n';
-}
-
-namespace {
-// Helper class to perform negation of an SDBM expression.
-struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
- // Any term expression is wrapped into a negation expression.
- // -(x) = -x
- SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
- // A negation expression is unwrapped.
- // -(-x) = x
- SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
- // The value of the constant is negated.
- SDBMExpr visitConstant(SDBMConstantExpr expr) {
- return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
- }
-
- // Terms of a
diff erence are interchanged. Since only the LHS of a
diff
- // expression is allowed to be a sum with a constant, we need to recreate the
- // sum with the negated value:
- // -((x + C) - y) = (y - C) - x.
- SDBMExpr visitDiff(SDBMDiffExpr expr) {
- // If the LHS is just a term, we can do straightforward interchange.
- if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
- return SDBMDiffExpr::get(expr.getRHS(), term);
-
- auto sum = expr.getLHS().cast<SDBMSumExpr>();
- auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
- return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
- sum.getLHS());
- }
-};
-} // namespace
-
-SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
-
-//===----------------------------------------------------------------------===//
-// SDBMSumExpr
-//===----------------------------------------------------------------------===//
-
-SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
- assert(lhs && "expected SDBM variable expression");
- assert(rhs && "expected SDBM constant");
-
- // If LHS of a sum is another sum, fold the constant RHS parts.
- if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
- lhs = lhsSum.getLHS();
- rhs = SDBMConstantExpr::get(rhs.getDialect(),
- rhs.getValue() + lhsSum.getRHS().getValue());
- }
-
- StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
- return uniquer.get<detail::SDBMBinaryExprStorage>(
- /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
-}
-
-SDBMTermExpr SDBMSumExpr::getLHS() const {
- return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
-}
-
-SDBMConstantExpr SDBMSumExpr::getRHS() const {
- return static_cast<ImplType *>(impl)->rhs;
-}
-
-AffineExpr SDBMExpr::getAsAffineExpr() const {
- struct Converter : public SDBMVisitor<Converter, AffineExpr> {
- AffineExpr visitSum(SDBMSumExpr expr) {
- AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
- return lhs + rhs;
- }
-
- AffineExpr visitStripe(SDBMStripeExpr expr) {
- AffineExpr lhs = visit(expr.getLHS()),
- rhs = visit(expr.getStripeFactor());
- return lhs - (lhs % rhs);
- }
-
- AffineExpr visitDiff(SDBMDiffExpr expr) {
- AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
- return lhs - rhs;
- }
-
- AffineExpr visitDim(SDBMDimExpr expr) {
- return getAffineDimExpr(expr.getPosition(), expr.getContext());
- }
-
- AffineExpr visitSymbol(SDBMSymbolExpr expr) {
- return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
- }
-
- AffineExpr visitNeg(SDBMNegExpr expr) {
- return getAffineBinaryOpExpr(AffineExprKind::Mul,
- getAffineConstantExpr(-1, expr.getContext()),
- visit(expr.getVar()));
- }
-
- AffineExpr visitConstant(SDBMConstantExpr expr) {
- return getAffineConstantExpr(expr.getValue(), expr.getContext());
- }
- } converter;
- return converter.visit(*this);
-}
-
-// Given a direct expression `expr`, add the given constant to it and pass the
-// resulting expression to `builder` before returning its result. If the
-// expression is already a sum expression, update its constant and extract the
-// LHS if the constant becomes zero. Otherwise, construct a sum expression.
-template <typename Result>
-static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant,
- bool negated,
- function_ref<Result(SDBMDirectExpr)> builder) {
- SDBMDialect *dialect = expr.getDialect();
- if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
- if (negated)
- constant = sumExpr.getRHS().getValue() - constant;
- else
- constant += sumExpr.getRHS().getValue();
-
- if (constant != 0) {
- auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
- SDBMConstantExpr::get(dialect, constant));
- return builder(sum);
- } else {
- return builder(sumExpr.getLHS());
- }
- }
- if (constant != 0)
- return builder(SDBMSumExpr::get(
- expr.cast<SDBMTermExpr>(),
- SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
- return expr;
-}
-
-// Construct an expression lhs + constant while maintaining the canonical form
-// of the SDBM expressions, in particular sink the constant expression to the
-// nearest sum expression in the left subtree of the expression tree.
-static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
- if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
- return addConstantAndSink<SDBMExpr>(
- lhsDiff.getLHS(), constant, /*negated=*/false,
- [lhsDiff](SDBMDirectExpr e) {
- return SDBMDiffExpr::get(e, lhsDiff.getRHS());
- });
- if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
- return addConstantAndSink<SDBMExpr>(
- lhsNeg.getVar(), constant, /*negated=*/true,
- [](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
- if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
- return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
- [](SDBMDirectExpr e) { return e; });
- if (constant != 0)
- return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
- SDBMConstantExpr::get(lhs.getDialect(), constant));
- return lhs;
-}
-
-// Build a
diff erence expression given a direct expression and a negation
-// expression.
-static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
- // Fold (x + C) - (x + D) = C - D.
- if (lhs.getTerm() == rhs.getVar().getTerm())
- return SDBMConstantExpr::get(
- lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant());
-
- return SDBMDiffExpr::get(
- addConstantAndSink<SDBMDirectExpr>(lhs, -rhs.getVar().getConstant(),
- /*negated=*/false,
- [](SDBMDirectExpr e) { return e; }),
- rhs.getVar().getTerm());
-}
-
-// Try folding an expression (lhs + rhs) where at least one of the operands
-// contains a negated variable, i.e. is a negation or a
diff erence expression.
-static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
- // If exactly one of LHS, RHS is a negation expression, we can construct
- // a
diff erence expression, which is a special kind in SDBM.
- auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
- auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
- auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
- auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
-
- if (lhsDirect && rhsNeg)
- return buildDiffExpr(lhsDirect, rhsNeg);
- if (lhsNeg && rhsDirect)
- return buildDiffExpr(rhsDirect, lhsNeg);
-
- // If a subexpression appears in a
diff expression on the LHS(RHS) of a
- // sum expression where it also appears on the RHS(LHS) with the opposite
- // sign, we can simplify it away and obtain the SDBM form.
- auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
- auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
-
- // -(x + A) + ((x + B) - y) = -(y + (A - B))
- if (lhsNeg && rhsDiff &&
- lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) {
- int64_t constant =
- lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant();
- // RHS of the
diff is a term expression, its sum with a constant is a direct
- // expression.
- return SDBMNegExpr::get(
- addConstant(rhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
- }
-
- // (x + A) + ((y + B) - x) = (y + B) + A.
- if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS())
- return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant());
-
- // ((x + A) - y) + (-(x + B)) = -(y + (B - A)).
- if (lhsDiff && rhsNeg &&
- lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) {
- int64_t constant =
- rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant();
- // RHS of the
diff is a term expression, its sum with a constant is a direct
- // expression.
- return SDBMNegExpr::get(
- addConstant(lhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
- }
-
- // ((x + A) - y) + (y + B) = (x + A) + B.
- if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS())
- return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant());
-
- return {};
-}
-
-Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
- struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
- SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
- auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
- if (!lhs || !rhs)
- return {};
-
- // In a "add" AffineExpr, the constant always appears on the right. If
- // there were two constants, they would have been folded away.
- assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
-
- // If RHS is a constant, we can always extend the SDBM expression to
- // include it by sinking the constant into the nearest sum expression.
- if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
- int64_t constant = rhsConstant.getValue();
- auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
- assert(varying && "unexpected uncanonicalized sum of constants");
- return addConstant(varying, constant);
- }
-
- // Try building a
diff erence expression if one of the values is negated,
- // or check if a
diff erence on either hand side cancels out the outer term
- // so as to remain correct within SDBM. Return null otherwise.
- return foldSumDiff(lhs, rhs);
- }
-
- SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
- // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
- AffineExprMatcher x, C;
- AffineExprMatcher pattern = (x.floorDiv(C)) * C;
- if (pattern.match(expr)) {
- if (SDBMExpr converted = visit(x.matched())) {
- if (auto varConverted = converted.dyn_cast<SDBMTermExpr>())
- // TODO: return varConverted.stripe(C.getConstantValue());
- return SDBMStripeExpr::get(
- varConverted,
- SDBMConstantExpr::get(dialect,
- C.getMatchedConstantValue().getValue()));
- }
- }
-
- auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
- if (!lhs || !rhs)
- return {};
-
- // In a "mul" AffineExpr, the constant always appears on the right. If
- // there were two constants, they would have been folded away.
- assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
- auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
- if (!rhsConstant)
- return {};
-
- // The only supported "multiplication" expression is an SDBM is dimension
- // negation, that is a product of dimension and constant -1.
- if (rhsConstant.getValue() != -1)
- return {};
-
- if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
- return SDBMNegExpr::get(lhsVar);
- if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
- return SDBMNegator().visitDiff(lhsDiff);
-
- // Other multiplications are not allowed in SDBM.
- return {};
- }
-
- SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
- auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
- if (!lhs || !rhs)
- return {};
-
- // 'mod' can only be converted to SDBM if its LHS is a direct expression
- // and its RHS is a constant. Then it `x mod c = x - x stripe c`.
- auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
- auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
- if (!lhsVar || !rhsConstant)
- return {};
- return SDBMDiffExpr::get(lhsVar,
- SDBMStripeExpr::get(lhsVar, rhsConstant));
- }
-
- // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
- SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
- SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
-
- // Dimensions, symbols and constants are converted trivially.
- SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
- return SDBMConstantExpr::get(dialect, expr.getValue());
- }
- SDBMExpr visitDimExpr(AffineDimExpr expr) {
- return SDBMDimExpr::get(dialect, expr.getPosition());
- }
- SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
- return SDBMSymbolExpr::get(dialect, expr.getPosition());
- }
-
- SDBMDialect *dialect;
- } converter;
- converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
-
- if (auto result = converter.visit(affine))
- return result;
- return None;
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMDiffExpr
-//===----------------------------------------------------------------------===//
-
-SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
- assert(lhs && "expected SDBM dimension");
- assert(rhs && "expected SDBM dimension");
-
- StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
- return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
-}
-
-SDBMDirectExpr SDBMDiffExpr::getLHS() const {
- return static_cast<ImplType *>(impl)->lhs;
-}
-
-SDBMTermExpr SDBMDiffExpr::getRHS() const {
- return static_cast<ImplType *>(impl)->rhs;
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMDirectExpr
-//===----------------------------------------------------------------------===//
-
-SDBMTermExpr SDBMDirectExpr::getTerm() {
- if (auto sum = dyn_cast<SDBMSumExpr>())
- return sum.getLHS();
- return cast<SDBMTermExpr>();
-}
-
-int64_t SDBMDirectExpr::getConstant() {
- if (auto sum = dyn_cast<SDBMSumExpr>())
- return sum.getRHS().getValue();
- return 0;
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMStripeExpr
-//===----------------------------------------------------------------------===//
-
-SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
- SDBMConstantExpr stripeFactor) {
- assert(var && "expected SDBM variable expression");
- assert(stripeFactor && "expected non-null stripe factor");
- if (stripeFactor.getValue() <= 0)
- llvm::report_fatal_error("non-positive stripe factor");
-
- StorageUniquer &uniquer = var.getDialect()->getUniquer();
- return uniquer.get<detail::SDBMBinaryExprStorage>(
- /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
- stripeFactor);
-}
-
-SDBMDirectExpr SDBMStripeExpr::getLHS() const {
- if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
- return lhs.cast<SDBMDirectExpr>();
- return {};
-}
-
-SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
- return static_cast<ImplType *>(impl)->rhs;
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMInputExpr
-//===----------------------------------------------------------------------===//
-
-unsigned SDBMInputExpr::getPosition() const {
- return static_cast<ImplType *>(impl)->position;
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMDimExpr
-//===----------------------------------------------------------------------===//
-
-SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
- assert(dialect && "expected non-null dialect");
-
- auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
- storage->dialect = dialect;
- };
-
- StorageUniquer &uniquer = dialect->getUniquer();
- return uniquer.get<detail::SDBMTermExprStorage>(
- assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMSymbolExpr
-//===----------------------------------------------------------------------===//
-
-SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
- assert(dialect && "expected non-null dialect");
-
- auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) {
- storage->dialect = dialect;
- };
-
- StorageUniquer &uniquer = dialect->getUniquer();
- return uniquer.get<detail::SDBMTermExprStorage>(
- assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMConstantExpr
-//===----------------------------------------------------------------------===//
-
-SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
- assert(dialect && "expected non-null dialect");
-
- auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
- storage->dialect = dialect;
- };
-
- StorageUniquer &uniquer = dialect->getUniquer();
- return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
-}
-
-int64_t SDBMConstantExpr::getValue() const {
- return static_cast<ImplType *>(impl)->constant;
-}
-
-//===----------------------------------------------------------------------===//
-// SDBMNegExpr
-//===----------------------------------------------------------------------===//
-
-SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
- assert(var && "expected non-null SDBM direct expression");
-
- StorageUniquer &uniquer = var.getDialect()->getUniquer();
- return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
-}
-
-SDBMDirectExpr SDBMNegExpr::getVar() const {
- return static_cast<ImplType *>(impl)->expr;
-}
-
-SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) {
- if (auto folded = foldSumDiff(lhs, rhs))
- return folded;
- assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
- "a sum of negated expressions is a negation of a sum of variables and "
- "not a correct SDBM");
-
- // Fold (x - y) + (y - x) = 0.
- auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
- auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
- if (lhsDiff && rhsDiff) {
- if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
- lhsDiff.getRHS() == rhsDiff.getLHS())
- return SDBMConstantExpr::get(lhs.getDialect(), 0);
- }
-
- // If LHS is a constant and RHS is not, swap the order to get into a supported
- // sum case. From now on, RHS must be a constant.
- auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
- auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
- if (!rhsConstant && lhsConstant) {
- std::swap(lhs, rhs);
- std::swap(lhsConstant, rhsConstant);
- }
- assert(rhsConstant && "at least one operand must be a constant");
-
- // Constant-fold if LHS is also a constant.
- if (lhsConstant)
- return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
- rhsConstant.getValue());
- return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
-}
-
-SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) {
- // Fold x - x == 0.
- if (lhs == rhs)
- return SDBMConstantExpr::get(lhs.getDialect(), 0);
-
- // LHS and RHS may be constants.
- auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
- auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
-
- // Constant fold if both LHS and RHS are constants.
- if (lhsConstant && rhsConstant)
- return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
- rhsConstant.getValue());
-
- // Replace a
diff erence with a sum with a negated value if one of LHS and RHS
- // is a constant:
- // x - C == x + (-C);
- // C - x == -x + C.
- // This calls into operator+ for further simplification.
- if (rhsConstant)
- return lhs + (-rhsConstant);
- if (lhsConstant)
- return -rhs + lhsConstant;
-
- return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
-}
-
-SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) {
- auto constantFactor = factor.cast<SDBMConstantExpr>();
- assert(constantFactor.getValue() > 0 && "non-positive stripe");
-
- // Fold x # 1 = x.
- if (constantFactor.getValue() == 1)
- return expr;
-
- return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
-}
diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
deleted file mode 100644
index 8d91334c807e..000000000000
--- a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
+++ /dev/null
@@ -1,137 +0,0 @@
-//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This holds implementation details of SDBMExpr, in particular underlying
-// storage types.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_SDBMEXPRDETAIL_H
-#define MLIR_IR_SDBMEXPRDETAIL_H
-
-#include "mlir/Dialect/SDBM/SDBMExpr.h"
-#include "mlir/Support/StorageUniquer.h"
-
-namespace mlir {
-
-class SDBMDialect;
-
-namespace detail {
-
-// Base storage class for SDBMExpr.
-struct SDBMExprStorage : public StorageUniquer::BaseStorage {
- SDBMExprKind getKind() { return kind; }
-
- SDBMDialect *dialect;
- SDBMExprKind kind;
-};
-
-// Storage class for SDBM sum and stripe expressions.
-struct SDBMBinaryExprStorage : public SDBMExprStorage {
- using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
-
- bool operator==(const KeyTy &key) const {
- return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
- std::get<1>(key) == lhs && std::get<2>(key) == rhs;
- }
-
- static SDBMBinaryExprStorage *
- construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
- auto *result = allocator.allocate<SDBMBinaryExprStorage>();
- result->lhs = std::get<1>(key);
- result->rhs = std::get<2>(key);
- result->dialect = result->lhs.getDialect();
- result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
- return result;
- }
-
- SDBMDirectExpr lhs;
- SDBMConstantExpr rhs;
-};
-
-// Storage class for SDBM
diff erence expressions.
-struct SDBMDiffExprStorage : public SDBMExprStorage {
- using KeyTy = std::pair<SDBMDirectExpr, SDBMTermExpr>;
-
- bool operator==(const KeyTy &key) const {
- return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
- }
-
- static SDBMDiffExprStorage *
- construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
- auto *result = allocator.allocate<SDBMDiffExprStorage>();
- result->lhs = std::get<0>(key);
- result->rhs = std::get<1>(key);
- result->dialect = result->lhs.getDialect();
- result->kind = SDBMExprKind::Diff;
- return result;
- }
-
- SDBMDirectExpr lhs;
- SDBMTermExpr rhs;
-};
-
-// Storage class for SDBM constant expressions.
-struct SDBMConstantExprStorage : public SDBMExprStorage {
- using KeyTy = int64_t;
-
- bool operator==(const KeyTy &key) const { return constant == key; }
-
- static SDBMConstantExprStorage *
- construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
- auto *result = allocator.allocate<SDBMConstantExprStorage>();
- result->constant = key;
- result->kind = SDBMExprKind::Constant;
- return result;
- }
-
- int64_t constant;
-};
-
-// Storage class for SDBM dimension and symbol expressions.
-struct SDBMTermExprStorage : public SDBMExprStorage {
- using KeyTy = std::pair<unsigned, unsigned>;
-
- bool operator==(const KeyTy &key) const {
- return kind == static_cast<SDBMExprKind>(key.first) &&
- position == key.second;
- }
-
- static SDBMTermExprStorage *
- construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
- auto *result = allocator.allocate<SDBMTermExprStorage>();
- result->kind = static_cast<SDBMExprKind>(key.first);
- result->position = key.second;
- return result;
- }
-
- unsigned position;
-};
-
-// Storage class for SDBM negation expressions.
-struct SDBMNegExprStorage : public SDBMExprStorage {
- using KeyTy = SDBMDirectExpr;
-
- bool operator==(const KeyTy &key) const { return key == expr; }
-
- static SDBMNegExprStorage *
- construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
- auto *result = allocator.allocate<SDBMNegExprStorage>();
- result->expr = key;
- result->dialect = key.getDialect();
- result->kind = SDBMExprKind::Neg;
- return result;
- }
-
- SDBMDirectExpr expr;
-};
-
-} // end namespace detail
-} // end namespace mlir
-
-#endif // MLIR_IR_SDBMEXPRDETAIL_H
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 5ce620c6c2b6..416cfee7efad 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -1,5 +1,4 @@
add_subdirectory(CAPI)
-add_subdirectory(SDBM)
add_subdirectory(lib)
if(MLIR_ENABLE_BINDINGS_PYTHON)
@@ -75,7 +74,6 @@ set(MLIR_TEST_DEPENDS
mlir-lsp-server
mlir-opt
mlir-reduce
- mlir-sdbm-api-test
mlir-tblgen
mlir-translate
mlir_runner_utils
diff --git a/mlir/test/SDBM/CMakeLists.txt b/mlir/test/SDBM/CMakeLists.txt
deleted file mode 100644
index 633fae707c85..000000000000
--- a/mlir/test/SDBM/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-set(LLVM_LINK_COMPONENTS
- Core
- Support
- )
-
-add_llvm_executable(mlir-sdbm-api-test
- sdbm-api-test.cpp
-)
-
-llvm_update_compile_flags(mlir-sdbm-api-test)
-
-target_link_libraries(mlir-sdbm-api-test
- PRIVATE
- MLIRIR
- MLIRSDBM
- MLIRSupport
-)
-
-target_include_directories(mlir-sdbm-api-test PRIVATE ..)
diff --git a/mlir/test/SDBM/lit.local.cfg b/mlir/test/SDBM/lit.local.cfg
deleted file mode 100644
index 81261555b424..000000000000
--- a/mlir/test/SDBM/lit.local.cfg
+++ /dev/null
@@ -1 +0,0 @@
-config.suffixes.add('.cpp')
diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp
deleted file mode 100644
index 027c584c7409..000000000000
--- a/mlir/test/SDBM/sdbm-api-test.cpp
+++ /dev/null
@@ -1,201 +0,0 @@
-//===- sdbm-api-test.cpp - Tests for SDBM expression APIs -----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-// RUN: mlir-sdbm-api-test | FileCheck %s
-
-#include "mlir/Dialect/SDBM/SDBM.h"
-#include "mlir/Dialect/SDBM/SDBMDialect.h"
-#include "mlir/Dialect/SDBM/SDBMExpr.h"
-#include "mlir/IR/MLIRContext.h"
-
-#include "llvm/Support/raw_ostream.h"
-
-#include "APITest.h"
-
-using namespace mlir;
-
-
-static MLIRContext *ctx() {
- static thread_local MLIRContext context;
- static thread_local bool once =
- (context.getOrLoadDialect<SDBMDialect>(), true);
- (void)once;
- return &context;
-}
-
-static SDBMDialect *dialect() {
- static thread_local SDBMDialect *d = nullptr;
- if (!d) {
- d = ctx()->getOrLoadDialect<SDBMDialect>();
- }
- return d;
-}
-
-static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
-
-static SDBMExpr symb(unsigned pos) {
- return SDBMSymbolExpr::get(dialect(), pos);
-}
-
-namespace {
-
-using namespace mlir::ops_assertions;
-
-TEST_FUNC(SDBM_SingleConstraint) {
- // Build an SDBM defined by
- // d0 - 3 <= 0 <=> d0 <= 3.
- auto sdbm = SDBM::get(dim(0) - 3, llvm::None);
-
- // CHECK: cst d0
- // CHECK-NEXT: cst inf 3
- // CHECK-NEXT: d0 inf inf
- sdbm.print(llvm::outs());
-}
-
-TEST_FUNC(SDBM_Equality) {
- // Build an SDBM defined by
- //
- // d0 - d1 - 3 = 0
- // <=> {d0 - d1 - 3 <= 0 and d0 - d1 - 3 >= 0}
- // <=> {d0 - d1 <= 3 and d1 - d0 <= -3}.
- auto sdbm = SDBM::get(llvm::None, dim(0) - dim(1) - 3);
-
- // CHECK: cst d0 d1
- // CHECK-NEXT: cst inf inf inf
- // CHECK-NEXT: d0 inf inf -3
- // CHECK-NEXT: d1 inf 3 inf
- sdbm.print(llvm::outs());
-}
-
-TEST_FUNC(SDBM_TrivialSimplification) {
- // Build an SDBM defined by
- //
- // d0 - 3 <= 0 <=> d0 <= 3
- // d0 - 5 <= 0 <=> d0 <= 5
- //
- // which should get simplified on construction to only the former.
- auto sdbm = SDBM::get({dim(0) - 3, dim(0) - 5}, llvm::None);
-
- // CHECK: cst d0
- // CHECK-NEXT: cst inf 3
- // CHECK-NEXT: d0 inf inf
- sdbm.print(llvm::outs());
-}
-
-TEST_FUNC(SDBM_StripeInducedIneqs) {
- // Build an SDBM defined by d1 = d0 # 3, which induces the constraints
- //
- // d1 - d0 <= 0
- // d0 - d1 <= 3 - 1 = 2
- auto sdbm = SDBM::get(llvm::None, dim(1) - stripe(dim(0), 3));
-
- // CHECK: cst d0 d1
- // CHECK-NEXT: cst inf inf inf
- // CHECK-NEXT: d0 inf inf 0
- // CHECK-NEXT: d1 inf 2 0
- // CHECK-NEXT: d1 = d0 # 3
- sdbm.print(llvm::outs());
-}
-
-TEST_FUNC(SDBM_StripeTemporaries) {
- // Build an SDBM defined by d0 # 3 <= 0, which creates a temporary
- // t0 = d0 # 3 leading to a constraint t0 <= 0 and the stripe-induced
- // constraints
- //
- // t0 - d0 <= 0
- // d0 - t0 <= 3 - 1 = 2
- auto sdbm = SDBM::get(stripe(dim(0), 3), llvm::None);
-
- // CHECK: cst d0 t0
- // CHECK-NEXT: cst inf inf 0
- // CHECK-NEXT: d0 inf inf 0
- // CHECK-NEXT: t0 inf 2 inf
- // CHECK-NEXT: t0 = d0 # 3
- sdbm.print(llvm::outs());
-}
-
-TEST_FUNC(SDBM_ElideInducedInequalities) {
- // Build an SDBM defined by a single stripe equality d0 = s0 # 3 and make sure
- // the induced inequalities are not present after converting the SDBM back
- // into lists of expressions.
- auto sdbm = SDBM::get(llvm::None, {dim(0) - stripe(symb(0), 3)});
-
- SmallVector<SDBMExpr, 4> eqs, ineqs;
- sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
- // CHECK-EMPTY:
- for (auto ineq : ineqs)
- ineq.print(llvm::outs() << '\n');
- llvm::outs() << "\n";
-
- // CHECK: d0 - s0 # 3
- // CHECK-EMPTY:
- for (auto eq : eqs)
- eq.print(llvm::outs() << '\n');
- llvm::outs() << "\n\n";
-}
-
-TEST_FUNC(SDBM_StripeTightening) {
- // Build an SDBM defined by
- //
- // d0 = s0 # 3 # 5
- // s0 # 3 # 5 - d1 + 42 = 0
- // s0 # 3 - d0 <= 2
- //
- // where the last inequality is tighter than that induced by the first stripe
- // equality (s0 # 3 - d0 <= 5 - 1 = 4). Check that the conversion from SDBM
- // back to the lists of constraints conserves both the stripe equality and the
- // tighter inequality.
- auto s = stripe(stripe(symb(0), 3), 5);
- auto tight = stripe(symb(0), 3) - dim(0) - 2;
- auto sdbm = SDBM::get({tight}, {s - dim(0), s - dim(1) + 42});
-
- SmallVector<SDBMExpr, 4> eqs, ineqs;
- sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
- // CHECK: s0 # 3 + -2 - d0
- // CHECK-EMPTY:
- for (auto ineq : ineqs)
- ineq.print(llvm::outs() << '\n');
- llvm::outs() << "\n";
-
- // CHECK-DAG: d1 + -42 - d0
- // CHECK-DAG: d0 - s0 # 3 # 5
- for (auto eq : eqs)
- eq.print(llvm::outs() << '\n');
- llvm::outs() << "\n\n";
-}
-
-TEST_FUNC(SDBM_StripeTransitive) {
- // Build an SDBM defined by
- //
- // d0 = d1 # 3
- // d0 = d2 # 7
- //
- // where the same dimension is declared equal to two stripe expressions over
- //
diff erent variables. This is practically handled by introducing a
- // temporary variable for the second stripe expression and adding an equality
- // constraint between this variable and the original dimension variable.
- auto sdbm = SDBM::get(
- llvm::None, {stripe(dim(1), 3) - dim(0), stripe(dim(2), 7) - dim(0)});
-
- // CHECK: cst d0 d1 d2 t0
- // CHECK-NEXT: cst inf inf inf inf inf
- // CHECK-NEXT: d0 inf 0 2 inf 0
- // CHECK-NEXT: d1 inf 0 inf inf inf
- // CHECK-NEXT: d2 inf inf inf inf 0
- // CHECK-NEXT: t0 inf 0 inf 6 inf
- // CHECK-NEXT: t0 = d2 # 7
- // CHECK-NEXT: d0 = d1 # 3
- sdbm.print(llvm::outs());
-}
-
-} // end namespace
-
-int main() {
- RUN_TESTS();
- return 0;
-}
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 83048dd99603..dd38b8fec864 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -65,7 +65,6 @@
'mlir-linalg-ods-gen',
'mlir-linalg-ods-yaml-gen',
'mlir-reduce',
- 'mlir-sdbm-api-test',
]
# The following tools are optional
diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 95c476a84163..e42118d86b5d 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -21,7 +21,6 @@
// CHECK-NEXT: quant
// CHECK-NEXT: rocdl
// CHECK-NEXT: scf
-// CHECK-NEXT: sdbm
// CHECK-NEXT: shape
// CHECK-NEXT: sparse_tensor
// CHECK-NEXT: spv
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index a8e9212ee255..45558b6d3dce 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -11,5 +11,4 @@ add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(Rewrite)
-add_subdirectory(SDBM)
add_subdirectory(TableGen)
diff --git a/mlir/unittests/SDBM/CMakeLists.txt b/mlir/unittests/SDBM/CMakeLists.txt
deleted file mode 100644
index d86f9dda3802..000000000000
--- a/mlir/unittests/SDBM/CMakeLists.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-add_mlir_unittest(MLIRSDBMTests
- SDBMTest.cpp
-)
-target_link_libraries(MLIRSDBMTests
- PRIVATE
- MLIRSDBM
-)
diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp
deleted file mode 100644
index c907aed6258a..000000000000
--- a/mlir/unittests/SDBM/SDBMTest.cpp
+++ /dev/null
@@ -1,449 +0,0 @@
-//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SDBM/SDBM.h"
-#include "mlir/Dialect/SDBM/SDBMDialect.h"
-#include "mlir/Dialect/SDBM/SDBMExpr.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/MLIRContext.h"
-#include "gtest/gtest.h"
-
-#include "llvm/ADT/DenseSet.h"
-
-using namespace mlir;
-
-
-static MLIRContext *ctx() {
- static thread_local MLIRContext context;
- context.getOrLoadDialect<SDBMDialect>();
- return &context;
-}
-
-static SDBMDialect *dialect() {
- static thread_local SDBMDialect *d = nullptr;
- if (!d) {
- d = ctx()->getOrLoadDialect<SDBMDialect>();
- }
- return d;
-}
-
-static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
-
-static SDBMExpr symb(unsigned pos) {
- return SDBMSymbolExpr::get(dialect(), pos);
-}
-
-namespace {
-
-using namespace mlir::ops_assertions;
-
-TEST(SDBMOperators, Add) {
- auto expr = dim(0) + 42;
- auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
- ASSERT_TRUE(sumExpr);
- EXPECT_EQ(sumExpr.getLHS(), dim(0));
- EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
-}
-
-TEST(SDBMOperators, AddFolding) {
- auto constant = SDBMConstantExpr::get(dialect(), 2) + 42;
- auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
- ASSERT_TRUE(constantExpr);
- EXPECT_EQ(constantExpr.getValue(), 44);
-
- auto expr = (dim(0) + 10) + 32;
- auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
- ASSERT_TRUE(sumExpr);
- EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
-
- expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1));
- auto
diff Expr = expr.dyn_cast<SDBMDiffExpr>();
- ASSERT_TRUE(
diff Expr);
- EXPECT_EQ(
diff Expr.getLHS(), dim(0));
- EXPECT_EQ(
diff Expr.getRHS(), dim(1));
-
- auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
- EXPECT_EQ(inverted, expr);
-
- // Check that opposite values cancel each other, and that we elide the zero
- // constant.
- expr = dim(0) + 42;
- auto onlyDim = expr - 42;
- EXPECT_EQ(onlyDim, dim(0));
-
- // Check that we can sink a constant under a negation.
- expr = -(dim(0) + 2);
- auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
- ASSERT_TRUE(negatedSum);
- auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
- ASSERT_TRUE(sum);
- EXPECT_EQ(sum.getRHS().getValue(), -8);
-
- // Sum with zero is the same as the original expression.
- EXPECT_EQ(dim(0) + 0, dim(0));
-
- // Sum of opposite
diff erences is zero.
- auto
diff OfDiffs =
- ((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
- EXPECT_EQ(
diff OfDiffs.getValue(), 0);
-}
-
-TEST(SDBMOperators, AddNegativeTerms) {
- const int64_t A = 7;
- const int64_t B = -5;
- auto x = SDBMDimExpr::get(dialect(), 0);
- auto y = SDBMDimExpr::get(dialect(), 1);
-
- // Check the simplification patterns in addition where one of the variables is
- // cancelled out and the result remains an SDBM.
- EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B)));
- EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A);
- EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A)));
- EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B);
-}
-
-TEST(SDBMOperators, Diff) {
- auto expr = dim(0) - dim(1);
- auto
diff Expr = expr.dyn_cast<SDBMDiffExpr>();
- ASSERT_TRUE(
diff Expr);
- EXPECT_EQ(
diff Expr.getLHS(), dim(0));
- EXPECT_EQ(
diff Expr.getRHS(), dim(1));
-}
-
-TEST(SDBMOperators, DiffFolding) {
- auto constant = SDBMConstantExpr::get(dialect(), 10) - 3;
- auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
- ASSERT_TRUE(constantExpr);
- EXPECT_EQ(constantExpr.getValue(), 7);
-
- auto expr = dim(0) - 3;
- auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
- ASSERT_TRUE(sumExpr);
- EXPECT_EQ(sumExpr.getRHS().getValue(), -3);
-
- auto zero = dim(0) - dim(0);
- constantExpr = zero.dyn_cast<SDBMConstantExpr>();
- ASSERT_TRUE(constantExpr);
- EXPECT_EQ(constantExpr.getValue(), 0);
-
- // Check that the constant terms in
diff erence-of-sums are folded.
- // (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
- auto
diff OfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
- ASSERT_TRUE(
diff OfSums);
- auto lhs =
diff OfSums.getLHS().dyn_cast<SDBMSumExpr>();
- ASSERT_TRUE(lhs);
- EXPECT_EQ(lhs.getLHS(), dim(0));
- EXPECT_EQ(lhs.getRHS().getValue(), 2);
- EXPECT_EQ(
diff OfSums.getRHS(), dim(1));
-
- // Check that identical dimensions with opposite signs cancel each other.
- auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
- ASSERT_TRUE(cstOnly);
- EXPECT_EQ(cstOnly.getValue(), 42);
-
- // Check that identical terms in sum of
diff s cancel out.
- auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
- EXPECT_EQ(dimOnly, -dim(1));
- dimOnly = (dim(0) - dim(1)) + (-dim(0));
- EXPECT_EQ(dimOnly, -dim(1));
- dimOnly = (dim(0) - dim(1)) + dim(1);
- EXPECT_EQ(dimOnly, dim(0));
- dimOnly = dim(0) + (dim(1) - dim(0));
- EXPECT_EQ(dimOnly, dim(1));
-
- // Top-level zero constant is fine.
- cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
- ASSERT_TRUE(cstOnly);
- EXPECT_EQ(cstOnly.getValue(), 0);
-}
-
-TEST(SDBMOperators, Negate) {
- auto sum = dim(0) + 3;
- auto negated = (-sum).dyn_cast<SDBMNegExpr>();
- ASSERT_TRUE(negated);
- EXPECT_EQ(negated.getVar(), sum);
-}
-
-TEST(SDBMOperators, Stripe) {
- auto expr = stripe(dim(0), 3);
- auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
- ASSERT_TRUE(stripeExpr);
- EXPECT_EQ(stripeExpr.getLHS(), dim(0));
- EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
-}
-
-TEST(SDBM, RoundTripEqs) {
- // Build an SDBM defined by
- //
- // d0 = s0 # 3 # 5
- // s0 # 3 # 5 - d1 + 42 = 0
- //
- // and perform a double round-trip between the "list of equalities" and SDBM
- // representation. After the first round-trip, the equalities may be
- //
diff erent due to simplification or equivalent substitutions (e.g., the
- // second equality may become d0 - d1 + 42 = 0). However, there should not
- // be any further simplification after the second round-trip,
-
- // Build the SDBM from a pair of equalities and extract back the lists of
- // inequalities and equalities. Check that all equalities are properly
- // detected and none of them decayed into inequalities.
- auto s = stripe(stripe(symb(0), 3), 5);
- auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42});
- SmallVector<SDBMExpr, 4> eqs, ineqs;
- sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
- ASSERT_TRUE(ineqs.empty());
-
- // Do the second round-trip.
- auto sdbm2 = SDBM::get(llvm::None, eqs);
- SmallVector<SDBMExpr, 4> eqs2, ineqs2;
- sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2);
- ASSERT_EQ(eqs.size(), eqs2.size());
-
- // Check that the sets of equalities are equal, their order is not relevant.
- llvm::DenseSet<SDBMExpr> eqSet, eq2Set;
- eqSet.insert(eqs.begin(), eqs.end());
- eq2Set.insert(eqs2.begin(), eqs2.end());
- EXPECT_EQ(eqSet, eq2Set);
-}
-
-TEST(SDBMExpr, Constant) {
- // We can create constants and query them.
- auto expr = SDBMConstantExpr::get(dialect(), 42);
- EXPECT_EQ(expr.getValue(), 42);
-
- // Two separately created constants with identical values are trivially equal.
- auto expr2 = SDBMConstantExpr::get(dialect(), 42);
- EXPECT_EQ(expr, expr2);
-
- // Hierarchy is okay.
- auto generic = static_cast<SDBMExpr>(expr);
- EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
-}
-
-TEST(SDBMExpr, Dim) {
- // We can create dimension expressions and query them.
- auto expr = SDBMDimExpr::get(dialect(), 0);
- EXPECT_EQ(expr.getPosition(), 0u);
-
- // Two separately created dimensions with the same position are trivially
- // equal.
- auto expr2 = SDBMDimExpr::get(dialect(), 0);
- EXPECT_EQ(expr, expr2);
-
- // Hierarchy is okay.
- auto generic = static_cast<SDBMExpr>(expr);
- EXPECT_TRUE(generic.isa<SDBMDimExpr>());
- EXPECT_TRUE(generic.isa<SDBMInputExpr>());
- EXPECT_TRUE(generic.isa<SDBMTermExpr>());
- EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
- EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
-
- // Dimensions are not Symbols.
- auto symbol = SDBMSymbolExpr::get(dialect(), 0);
- EXPECT_NE(expr, symbol);
- EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
-}
-
-TEST(SDBMExpr, Symbol) {
- // We can create symbol expressions and query them.
- auto expr = SDBMSymbolExpr::get(dialect(), 0);
- EXPECT_EQ(expr.getPosition(), 0u);
-
- // Two separately created symbols with the same position are trivially equal.
- auto expr2 = SDBMSymbolExpr::get(dialect(), 0);
- EXPECT_EQ(expr, expr2);
-
- // Hierarchy is okay.
- auto generic = static_cast<SDBMExpr>(expr);
- EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
- EXPECT_TRUE(generic.isa<SDBMInputExpr>());
- EXPECT_TRUE(generic.isa<SDBMTermExpr>());
- EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
- EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
-
- // Dimensions are not Symbols.
- auto symbol = SDBMDimExpr::get(dialect(), 0);
- EXPECT_NE(expr, symbol);
- EXPECT_FALSE(expr.isa<SDBMDimExpr>());
-}
-
-TEST(SDBMExpr, Stripe) {
- auto cst2 = SDBMConstantExpr::get(dialect(), 2);
- auto cst0 = SDBMConstantExpr::get(dialect(), 0);
- auto var = SDBMSymbolExpr::get(dialect(), 0);
-
- // We can create stripe expressions and query them.
- auto expr = SDBMStripeExpr::get(var, cst2);
- EXPECT_EQ(expr.getLHS(), var);
- EXPECT_EQ(expr.getStripeFactor(), cst2);
-
- // Two separately created stripe expressions with the same LHS and RHS are
- // trivially equal.
- auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2);
- EXPECT_EQ(expr, expr2);
-
- // Stripes can be nested.
- SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4));
-
- // Non-positive stripe factors are not allowed.
- EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
-
- // Stripes can have sums on the LHS.
- SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2);
-
- // Hierarchy is okay.
- auto generic = static_cast<SDBMExpr>(expr);
- EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
- EXPECT_TRUE(generic.isa<SDBMTermExpr>());
- EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
- EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
-}
-
-TEST(SDBMExpr, Neg) {
- auto cst2 = SDBMConstantExpr::get(dialect(), 2);
- auto var = SDBMSymbolExpr::get(dialect(), 0);
- auto stripe = SDBMStripeExpr::get(var, cst2);
-
- // We can create negation expressions and query them.
- auto expr = SDBMNegExpr::get(var);
- EXPECT_EQ(expr.getVar(), var);
- auto expr2 = SDBMNegExpr::get(stripe);
- EXPECT_EQ(expr2.getVar(), stripe);
-
- // Neg expressions are trivially comparable.
- EXPECT_EQ(expr, SDBMNegExpr::get(var));
-
- // Hierarchy is okay.
- auto generic = static_cast<SDBMExpr>(expr);
- EXPECT_TRUE(generic.isa<SDBMNegExpr>());
- EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
-}
-
-TEST(SDBMExpr, Sum) {
- auto cst2 = SDBMConstantExpr::get(dialect(), 2);
- auto var = SDBMSymbolExpr::get(dialect(), 0);
- auto stripe = SDBMStripeExpr::get(var, cst2);
-
- // We can create sum expressions and query them.
- auto expr = SDBMSumExpr::get(var, cst2);
- EXPECT_EQ(expr.getLHS(), var);
- EXPECT_EQ(expr.getRHS(), cst2);
- auto expr2 = SDBMSumExpr::get(stripe, cst2);
- EXPECT_EQ(expr2.getLHS(), stripe);
- EXPECT_EQ(expr2.getRHS(), cst2);
-
- // Sum expressions are trivially comparable.
- EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
-
- // Hierarchy is okay.
- auto generic = static_cast<SDBMExpr>(expr);
- EXPECT_TRUE(generic.isa<SDBMSumExpr>());
- EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
- EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
-}
-
-TEST(SDBMExpr, Diff) {
- auto cst2 = SDBMConstantExpr::get(dialect(), 2);
- auto var = SDBMSymbolExpr::get(dialect(), 0);
- auto stripe = SDBMStripeExpr::get(var, cst2);
-
- // We can create sum expressions and query them.
- auto expr = SDBMDiffExpr::get(var, stripe);
- EXPECT_EQ(expr.getLHS(), var);
- EXPECT_EQ(expr.getRHS(), stripe);
- auto expr2 = SDBMDiffExpr::get(stripe, var);
- EXPECT_EQ(expr2.getLHS(), stripe);
- EXPECT_EQ(expr2.getRHS(), var);
-
- // Sum expressions are trivially comparable.
- EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
-
- // Hierarchy is okay.
- auto generic = static_cast<SDBMExpr>(expr);
- EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
- EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
-}
-
-TEST(SDBMExpr, AffineRoundTrip) {
- // Build an expression (s0 - s0 # 2)
- auto cst2 = SDBMConstantExpr::get(dialect(), 2);
- auto var = SDBMSymbolExpr::get(dialect(), 0);
- auto stripe = SDBMStripeExpr::get(var, cst2);
- auto expr = SDBMDiffExpr::get(var, stripe);
-
- // Check that it can be converted to AffineExpr and back, i.e. stripe
- // detection works correctly.
- Optional<SDBMExpr> roundtripped =
- SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr());
- ASSERT_TRUE(roundtripped.hasValue());
- EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(expr));
-
- // Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
- // detection supports nested expressions.
- auto cst5 = SDBMConstantExpr::get(dialect(), 5);
- auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
- roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
- ASSERT_TRUE(roundtripped.hasValue());
- EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
-
- // Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e.
- // stripe detection supports sum expressions.
- auto inner = SDBMSumExpr::get(var, cst2);
- auto stripeSum = SDBMStripeExpr::get(inner, cst5);
- roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr());
- ASSERT_TRUE(roundtripped.hasValue());
- EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(stripeSum));
-
- // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
- // deeper expression tree.
- auto sum = SDBMSumExpr::get(outerStripe, cst2);
- auto
diff = SDBMDiffExpr::get(sum, stripe);
- roundtripped = SDBMExpr::tryConvertAffineExpr(
diff .getAsAffineExpr());
- ASSERT_TRUE(roundtripped.hasValue());
- EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(
diff ));
-
- // Check a nested stripe-sum combination.
- auto cst7 = SDBMConstantExpr::get(dialect(), 7);
- auto nestedStripe =
- SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7);
-
diff = SDBMDiffExpr::get(nestedStripe, stripe);
- roundtripped = SDBMExpr::tryConvertAffineExpr(
diff .getAsAffineExpr());
- ASSERT_TRUE(roundtripped.hasValue());
- EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(
diff ));
-}
-
-TEST(SDBMExpr, MatchStripeMulPattern) {
- // Make sure conversion from AffineExpr recognizes multiplicative stripe
- // pattern (x floordiv B) * B == x # B.
- auto cst = getAffineConstantExpr(42, ctx());
- auto dim = getAffineDimExpr(0, ctx());
- auto floor = dim.floorDiv(cst);
- auto mul = cst * floor;
- Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
- ASSERT_TRUE(converted.hasValue());
- EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
-}
-
-TEST(SDBMExpr, NonSDBM) {
- auto d0 = getAffineDimExpr(0, ctx());
- auto d1 = getAffineDimExpr(1, ctx());
- auto sum = d0 + d1;
- auto c2 = getAffineConstantExpr(2, ctx());
- auto prod = d0 * c2;
- auto ceildiv = d1.ceilDiv(c2);
-
- // The following are not valid SDBM expressions:
- // - a sum of two variables
- EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue());
- // - a variable with coefficient other than 1 or -1
- EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue());
- // - a ceildiv expression
- EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue());
-}
-
-} // end namespace
More information about the Mlir-commits
mailing list