[Mlir-commits] [mlir] mlir/Presburger: contribute a free-standing parser (PR #94916)
Ramkumar Ramachandra
llvmlistbot at llvm.org
Thu Jun 27 04:41:03 PDT 2024
https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/94916
>From 6a5471faa35df627c9845e8bd0ddbc50ec679777 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Sun, 16 Jun 2024 14:36:44 +0100
Subject: [PATCH 1/5] Presburger/test: increase coverage of parser
In preparation to write a free-standing parser for Presburger, improve
the test coverage of the existing parser.
---
.../Analysis/Presburger/ParserTest.cpp | 58 +++++++++++++++++++
1 file changed, 58 insertions(+)
diff --git a/mlir/unittests/Analysis/Presburger/ParserTest.cpp b/mlir/unittests/Analysis/Presburger/ParserTest.cpp
index 4c9f54f97d246..06b728cd1a8fa 100644
--- a/mlir/unittests/Analysis/Presburger/ParserTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/ParserTest.cpp
@@ -45,6 +45,18 @@ static bool parseAndCompare(StringRef str, const IntegerPolyhedron &ex) {
}
TEST(ParseFACTest, ParseAndCompareTest) {
+ // constant-fold addition
+ EXPECT_TRUE(parseAndCompare("() : (4 + 3 >= 0)",
+ makeFACFromConstraints(0, 0, {}, {})));
+
+ // constant-fold addition + multiplication
+ EXPECT_TRUE(parseAndCompare("()[a] : (4 * 3 == 10 + 2)",
+ makeFACFromConstraints(0, 1, {}, {})));
+
+ // constant-fold ceildiv + floordiv
+ EXPECT_TRUE(parseAndCompare("(x) : (11 ceildiv 3 == 13 floordiv 3)",
+ makeFACFromConstraints(1, 0, {}, {})));
+
// simple ineq
EXPECT_TRUE(parseAndCompare("(x)[] : (x >= 0)",
makeFACFromConstraints(1, 0, {{1, 0}})));
@@ -57,6 +69,11 @@ TEST(ParseFACTest, ParseAndCompareTest) {
EXPECT_TRUE(parseAndCompare("(x)[] : (7 * x >= 0, -7 * x + 5 >= 0)",
makeFACFromConstraints(1, 0, {{7, 0}, {-7, 5}})));
+ // multiplication distribution
+ EXPECT_TRUE(
+ parseAndCompare("(x) : (2 * x >= 2, (-7 + x * 9) * 5 >= 0)",
+ makeFACFromConstraints(1, 0, {{2, -2}, {45, -35}})));
+
// multiple dimensions
EXPECT_TRUE(parseAndCompare("(x,y,z)[] : (x + y - z >= 0)",
makeFACFromConstraints(3, 0, {{1, 1, -1, 0}})));
@@ -70,20 +87,61 @@ TEST(ParseFACTest, ParseAndCompareTest) {
EXPECT_TRUE(parseAndCompare("()[a] : (2 * a - 4 == 0)",
makeFACFromConstraints(0, 1, {}, {{2, -4}})));
+ // no linear terms
+ EXPECT_TRUE(parseAndCompare(
+ "(x, y) : (26 * (x floordiv 6) == y floordiv 3)",
+ makeFACFromConstraints(2, 0, {}, {{0, 0, 26, -1, 0}},
+ {{{1, 0, 0}, 6}, {{0, 1, 0, 0}, 3}})));
+
// simple floordiv
EXPECT_TRUE(parseAndCompare(
"(x, y) : (y - 3 * ((x + y - 13) floordiv 3) - 42 == 0)",
makeFACFromConstraints(2, 0, {}, {{0, 1, -3, -42}}, {{{1, 1, -13}, 3}})));
+ // simple ceildiv
+ EXPECT_TRUE(parseAndCompare(
+ "(x, y) : (y - 3 * ((x + y - 13) ceildiv 3) - 42 == 0)",
+ makeFACFromConstraints(2, 0, {}, {{0, 1, -3, -42}}, {{{1, 1, -11}, 3}})));
+
+ // simple mod
+ EXPECT_TRUE(parseAndCompare(
+ "(x, y) : (y - 3 * ((x + y - 13) mod 3) - 42 == 0)",
+ makeFACFromConstraints(2, 0, {}, {{-3, -2, 9, -3}}, {{{1, 1, -13}, 3}})));
+
// multiple floordiv
EXPECT_TRUE(parseAndCompare(
"(x, y) : (y - x floordiv 3 - y floordiv 2 == 0)",
makeFACFromConstraints(2, 0, {}, {{0, 1, -1, -1, 0}},
{{{1, 0, 0}, 3}, {{0, 1, 0, 0}, 2}})));
+ // multiple ceildiv
+ EXPECT_TRUE(parseAndCompare(
+ "(x, y) : (y - x ceildiv 3 - y ceildiv 2 == 0)",
+ makeFACFromConstraints(2, 0, {}, {{0, 1, -1, -1, 0}},
+ {{{1, 0, 2}, 3}, {{0, 1, 0, 1}, 2}})));
+
+ // multiple mod
+ EXPECT_TRUE(parseAndCompare(
+ "(x, y) : (y - x mod 3 - y mod 2 == 0)",
+ makeFACFromConstraints(2, 0, {}, {{-1, 0, 3, 2, 0}},
+ {{{1, 0, 0}, 3}, {{0, 1, 0, 0}, 2}})));
+
// nested floordiv
EXPECT_TRUE(parseAndCompare(
"(x, y) : (y - (x + y floordiv 2) floordiv 3 == 0)",
makeFACFromConstraints(2, 0, {}, {{0, 1, 0, -1, 0}},
{{{0, 1, 0}, 2}, {{1, 0, 1, 0}, 3}})));
+
+ // nested mod
+ EXPECT_TRUE(parseAndCompare(
+ "(x, y) : (y - (x + y mod 2) mod 3 == 0)",
+ makeFACFromConstraints(2, 0, {}, {{-1, 0, 2, 3, 0}},
+ {{{0, 1, 0}, 2}, {{1, 1, -2, 0}, 3}})));
+
+ // nested floordiv + ceildiv + mod
+ EXPECT_TRUE(parseAndCompare(
+ "(x, y) : ((2 * x + 3 * (y floordiv 2) + x mod 7 + 1) ceildiv 3 == 42)",
+ makeFACFromConstraints(
+ 2, 0, {}, {{0, 0, 0, 0, 1, -42}},
+ {{{0, 1, 0}, 2}, {{1, 0, 0, 0}, 7}, {{3, 0, 3, -7, 3}, 3}})));
}
>From 4455468aaf8f9ba4cb78703296245e58042d65c2 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 6 Jun 2024 18:40:28 +0100
Subject: [PATCH 2/5] mlir/Presburger: contribute a free-standing parser
The Presburger library is already quite independent of MLIR, with the
exception of the MLIR Support library. There is, however, one major
exception: the test suite for the library depends on the core MLIR
parser. To free it of this dependency, extract the parts of the core
MLIR parser that are applicable to the Presburger test suite, author
custom parsing data structures, and adapt the new parser to parse into
these structures.
This patch is part of a project to move the Presburger library into
LLVM.
---
.../mlir}/Analysis/Presburger/Parser.h | 53 +-
mlir/lib/Analysis/Presburger/CMakeLists.txt | 2 +
.../Analysis/Presburger/Parser/CMakeLists.txt | 6 +
.../Analysis/Presburger/Parser/Flattener.cpp | 408 +++++++++
.../Analysis/Presburger/Parser/Flattener.h | 244 ++++++
mlir/lib/Analysis/Presburger/Parser/Lexer.cpp | 161 ++++
mlir/lib/Analysis/Presburger/Parser/Lexer.h | 58 ++
.../Presburger/Parser/ParseStructs.cpp | 268 ++++++
.../Analysis/Presburger/Parser/ParseStructs.h | 293 +++++++
.../Analysis/Presburger/Parser/ParserImpl.cpp | 822 ++++++++++++++++++
.../Analysis/Presburger/Parser/ParserImpl.h | 237 +++++
.../Analysis/Presburger/Parser/ParserState.h | 39 +
mlir/lib/Analysis/Presburger/Parser/Token.cpp | 64 ++
mlir/lib/Analysis/Presburger/Parser/Token.h | 91 ++
.../Analysis/Presburger/Parser/TokenKinds.def | 73 ++
.../Analysis/Presburger/BarvinokTest.cpp | 4 +-
.../Analysis/Presburger/CMakeLists.txt | 4 +-
.../Analysis/Presburger/FractionTest.cpp | 1 -
.../Presburger/GeneratingFunctionTest.cpp | 2 +-
.../Presburger/IntegerPolyhedronTest.cpp | 2 +-
.../Presburger/IntegerRelationTest.cpp | 2 +-
.../Analysis/Presburger/MatrixTest.cpp | 3 +-
.../Analysis/Presburger/PWMAFunctionTest.cpp | 5 +-
.../Analysis/Presburger/ParserTest.cpp | 2 +-
.../Presburger/PresburgerRelationTest.cpp | 2 +-
.../Analysis/Presburger/PresburgerSetTest.cpp | 3 +-
.../Presburger/QuasiPolynomialTest.cpp | 4 +-
.../Analysis/Presburger/SimplexTest.cpp | 5 +-
mlir/unittests/Analysis/Presburger/Utils.h | 5 -
29 files changed, 2795 insertions(+), 68 deletions(-)
rename mlir/{unittests => include/mlir}/Analysis/Presburger/Parser.h (65%)
create mode 100644 mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt
create mode 100644 mlir/lib/Analysis/Presburger/Parser/Flattener.cpp
create mode 100644 mlir/lib/Analysis/Presburger/Parser/Flattener.h
create mode 100644 mlir/lib/Analysis/Presburger/Parser/Lexer.cpp
create mode 100644 mlir/lib/Analysis/Presburger/Parser/Lexer.h
create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp
create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParserImpl.h
create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParserState.h
create mode 100644 mlir/lib/Analysis/Presburger/Parser/Token.cpp
create mode 100644 mlir/lib/Analysis/Presburger/Parser/Token.h
create mode 100644 mlir/lib/Analysis/Presburger/Parser/TokenKinds.def
diff --git a/mlir/unittests/Analysis/Presburger/Parser.h b/mlir/include/mlir/Analysis/Presburger/Parser.h
similarity index 65%
rename from mlir/unittests/Analysis/Presburger/Parser.h
rename to mlir/include/mlir/Analysis/Presburger/Parser.h
index 75842fb054e2b..dcc99de4e99fd 100644
--- a/mlir/unittests/Analysis/Presburger/Parser.h
+++ b/mlir/include/mlir/Analysis/Presburger/Parser.h
@@ -11,32 +11,24 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H
-#define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H
+#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_H
+#define MLIR_ANALYSIS_PRESBURGER_PARSER_H
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
-#include "mlir/AsmParser/AsmParser.h"
-#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/IntegerSet.h"
-
-namespace mlir {
-namespace presburger {
-
-/// Parses an IntegerPolyhedron from a StringRef. It is expected that the string
-/// represents a valid IntegerSet.
-inline IntegerPolyhedron parseIntegerPolyhedron(StringRef str) {
- MLIRContext context(MLIRContext::Threading::DISABLED);
- return affine::FlatAffineValueConstraints(parseIntegerSet(str, &context));
-}
+
+namespace mlir::presburger {
+/// Parses an IntegerPolyhedron from a StringRef.
+IntegerPolyhedron parseIntegerPolyhedron(StringRef str);
+
+/// Parses a MultiAffineFunction from a StringRef.
+MultiAffineFunction parseMultiAffineFunction(StringRef str);
/// Parse a list of StringRefs to IntegerRelation and combine them into a
-/// PresburgerSet by using the union operation. It is expected that the strings
-/// are all valid IntegerSet representation and that all of them have compatible
-/// spaces.
+/// PresburgerSet by using the union operation. It is expected that the
+/// strings are all valid IntegerSet representation and that all of them have
+/// compatible spaces.
inline PresburgerSet parsePresburgerSet(ArrayRef<StringRef> strs) {
assert(!strs.empty() && "strs should not be empty");
@@ -47,25 +39,10 @@ inline PresburgerSet parsePresburgerSet(ArrayRef<StringRef> strs) {
return result;
}
-inline MultiAffineFunction parseMultiAffineFunction(StringRef str) {
- MLIRContext context(MLIRContext::Threading::DISABLED);
-
- // TODO: Add default constructor for MultiAffineFunction.
- MultiAffineFunction multiAff(PresburgerSpace::getRelationSpace(),
- IntMatrix(0, 1));
- if (getMultiAffineFunctionFromMap(parseAffineMap(str, &context), multiAff)
- .failed())
- llvm_unreachable(
- "Failed to parse MultiAffineFunction because of semi-affinity");
- return multiAff;
-}
-
inline PWMAFunction
parsePWMAF(ArrayRef<std::pair<StringRef, StringRef>> pieces) {
assert(!pieces.empty() && "At least one piece should be present.");
- MLIRContext context(MLIRContext::Threading::DISABLED);
-
IntegerPolyhedron initDomain = parseIntegerPolyhedron(pieces[0].first);
MultiAffineFunction initMultiAff = parseMultiAffineFunction(pieces[0].second);
@@ -100,8 +77,6 @@ parsePresburgerRelationFromPresburgerSet(ArrayRef<StringRef> strs,
result.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain, 0);
return result;
}
+} // namespace mlir::presburger
-} // namespace presburger
-} // namespace mlir
-
-#endif // MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H
+#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_H
diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt
index 1d30dd38ccd1b..aadd16305ac0b 100644
--- a/mlir/lib/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Parser)
+
add_mlir_library(MLIRPresburger
Barvinok.cpp
IntegerRelation.cpp
diff --git a/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt b/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt
new file mode 100644
index 0000000000000..f708a5c8db949
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_library(MLIRPresburgerParser
+ Flattener.cpp
+ Lexer.cpp
+ ParserImpl.cpp
+ ParseStructs.cpp
+ Token.cpp)
diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp
new file mode 100644
index 0000000000000..4ebcf6c676672
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp
@@ -0,0 +1,408 @@
+//===- Flattener.cpp - Presburger ParseStruct Flattener ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Flattener class for flattening the parse tree
+// produced by the parser for the Presburger library.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Flattener.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace presburger;
+using llvm::SmallVector;
+
+AffineExpr AffineExprFlattener::getAffineExprFromFlatForm(
+ ArrayRef<int64_t> flatExprs, unsigned numDims, unsigned numSymbols) {
+ assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
+ "unexpected number of local expressions");
+
+ // Dimensions and symbols.
+ AffineExpr expr = std::make_unique<AffineConstantExpr>(0);
+ for (unsigned j = 0; j < getLocalVarStartIndex(); ++j) {
+ if (flatExprs[j] == 0)
+ continue;
+ if (j < numDims)
+ expr =
+ std::move(expr) + std::make_unique<AffineDimExpr>(j) * flatExprs[j];
+ else
+ expr = std::move(expr) +
+ std::make_unique<AffineSymbolExpr>(j - numDims) * flatExprs[j];
+ }
+
+ // Local identifiers.
+ for (unsigned j = getLocalVarStartIndex(); j < flatExprs.size() - 1; ++j) {
+ if (flatExprs[j] == 0)
+ continue;
+ // It is safe to move out of the localExprs vector, since no expr is used
+ // more than once.
+ AffineExpr term =
+ std::move(localExprs[j - getLocalVarStartIndex()]) * flatExprs[j];
+ expr = std::move(expr) + std::move(term);
+ }
+
+ // Constant term.
+ int64_t constTerm = flatExprs[flatExprs.size() - 1];
+ if (constTerm != 0)
+ return std::move(expr) + constTerm;
+ return expr;
+}
+
+// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
+//
+// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
+// introduce a local variable p (= expr * symbolic_expr), and the affine
+// expression expr * symbolic_expr is added to `localExprs`.
+LogicalResult AffineExprFlattener::visitMulExpr(const AffineBinOpExpr &expr) {
+ assert(operandExprStack.size() >= 2);
+ SmallVector<int64_t, 8> rhs = operandExprStack.back();
+ operandExprStack.pop_back();
+ SmallVector<int64_t, 8> &lhs = operandExprStack.back();
+
+ // Flatten semi-affine multiplication expressions by introducing a local
+ // variable in place of the product; the affine expression
+ // corresponding to the quantifier is added to `localExprs`.
+ if (!isa<AffineConstantExpr>(expr.getRHS())) {
+ AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols);
+ AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols);
+ addLocalVariableSemiAffine(std::move(a) * std::move(b), lhs, lhs.size());
+ return success();
+ }
+
+ // Get the RHS constant.
+ int64_t rhsConst = rhs[getConstantIndex()];
+ for (int64_t &lhsElt : lhs)
+ lhsElt *= rhsConst;
+
+ return success();
+}
+
+LogicalResult AffineExprFlattener::visitAddExpr(const AffineBinOpExpr &expr) {
+ assert(operandExprStack.size() >= 2);
+ const auto &rhs = operandExprStack.back();
+ auto &lhs = operandExprStack[operandExprStack.size() - 2];
+ assert(lhs.size() == rhs.size());
+ // Update the LHS in place.
+ for (unsigned i = 0; i < rhs.size(); ++i)
+ lhs[i] += rhs[i];
+ // Pop off the RHS.
+ operandExprStack.pop_back();
+ return success();
+}
+
+//
+// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
+//
+// A mod expression "expr mod c" is thus flattened by introducing a new local
+// variable q (= expr floordiv c), such that expr mod c is replaced with
+// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
+//
+// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
+// introduce a local variable m (= expr mod symbolic_expr), and the affine
+// expression expr mod symbolic_expr is added to `localExprs`.
+LogicalResult AffineExprFlattener::visitModExpr(const AffineBinOpExpr &expr) {
+ assert(operandExprStack.size() >= 2);
+
+ SmallVector<int64_t, 8> rhs = operandExprStack.back();
+ operandExprStack.pop_back();
+ SmallVector<int64_t, 8> &lhs = operandExprStack.back();
+
+ // Flatten semi affine modulo expressions by introducing a local
+ // variable in place of the modulo value, and the affine expression
+ // corresponding to the quantifier is added to `localExprs`.
+ if (!isa<AffineConstantExpr>(expr.getRHS())) {
+ AffineExpr dividendExpr =
+ getAffineExprFromFlatForm(lhs, numDims, numSymbols);
+ AffineExpr divisorExpr =
+ getAffineExprFromFlatForm(rhs, numDims, numSymbols);
+ AffineExpr modExpr = std::move(dividendExpr) % std::move(divisorExpr);
+ addLocalVariableSemiAffine(std::move(modExpr), lhs, lhs.size());
+ return success();
+ }
+
+ int64_t rhsConst = rhs[getConstantIndex()];
+ if (rhsConst <= 0)
+ return failure();
+
+ // Check if the LHS expression is a multiple of modulo factor.
+ unsigned i;
+ for (i = 0; i < lhs.size(); ++i)
+ if (lhs[i] % rhsConst != 0)
+ break;
+ // If yes, modulo expression here simplifies to zero.
+ if (i == lhs.size()) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ return success();
+ }
+
+ // Add a local variable for the quotient, i.e., expr % c is replaced by
+ // (expr - q * c) where q = expr floordiv c. Do this while canceling out
+ // the GCD of expr and c.
+ SmallVector<int64_t, 8> floorDividend(lhs);
+ uint64_t gcd = rhsConst;
+ for (int64_t lhsElt : lhs)
+ gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
+ // Simplify the numerator and the denominator.
+ if (gcd != 1) {
+ for (int64_t &floorDividendElt : floorDividend)
+ floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
+ }
+ int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
+
+ // Construct the AffineExpr form of the floordiv to store in localExprs.
+
+ AffineExpr dividendExpr =
+ getAffineExprFromFlatForm(floorDividend, numDims, numSymbols);
+ AffineExpr divisorExpr = std::make_unique<AffineConstantExpr>(floorDivisor);
+ AffineExpr floorDivExpr =
+ floorDiv(std::move(dividendExpr), std::move(divisorExpr));
+ int loc;
+ if ((loc = findLocalId(floorDivExpr)) == -1) {
+ addLocalFloorDivId(floorDividend, floorDivisor, std::move(floorDivExpr));
+ // Set result at top of stack to "lhs - rhsConst * q".
+ lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
+ } else {
+ // Reuse the existing local id.
+ lhs[getLocalVarStartIndex() + loc] = -rhsConst;
+ }
+ return success();
+}
+
+LogicalResult
+AffineExprFlattener::visitCeilDivExpr(const AffineBinOpExpr &expr) {
+ return visitDivExpr(expr, /*isCeil=*/true);
+}
+LogicalResult
+AffineExprFlattener::visitFloorDivExpr(const AffineBinOpExpr &expr) {
+ return visitDivExpr(expr, /*isCeil=*/false);
+}
+
+LogicalResult AffineExprFlattener::visitDimExpr(const AffineDimExpr &expr) {
+ operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+ auto &eq = operandExprStack.back();
+ assert(expr.getPosition() < numDims && "Inconsistent number of dims");
+ eq[getDimStartIndex() + expr.getPosition()] = 1;
+ return success();
+}
+
+LogicalResult
+AffineExprFlattener::visitSymbolExpr(const AffineSymbolExpr &expr) {
+ operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+ auto &eq = operandExprStack.back();
+ assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
+ eq[getSymbolStartIndex() + expr.getPosition()] = 1;
+ return success();
+}
+
+LogicalResult
+AffineExprFlattener::visitConstantExpr(const AffineConstantExpr &expr) {
+ operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+ auto &eq = operandExprStack.back();
+ eq[getConstantIndex()] = expr.getValue();
+ return success();
+}
+
+void AffineExprFlattener::addLocalVariableSemiAffine(
+ AffineExpr &&expr, SmallVectorImpl<int64_t> &result,
+ unsigned long resultSize) {
+ assert(result.size() == resultSize && "result vector size mismatch");
+ int loc;
+ if ((loc = findLocalId(expr)) == -1)
+ addLocalIdSemiAffine(std::move(expr));
+ std::fill(result.begin(), result.end(), 0);
+ if (loc == -1)
+ result[getLocalVarStartIndex() + numLocals - 1] = 1;
+ else
+ result[getLocalVarStartIndex() + loc] = 1;
+}
+
+// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
+// A floordiv is thus flattened by introducing a new local variable q, and
+// replacing that expression with 'q' while adding the constraints
+// c * q <= expr <= c * q + c - 1 to localVarCst (done by
+// IntegerRelation::addLocalFloorDiv).
+//
+// A ceildiv is similarly flattened:
+// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
+//
+// In case of semi affine division expressions, t = expr floordiv symbolic_expr
+// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
+// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
+// `localExprs`.
+LogicalResult AffineExprFlattener::visitDivExpr(const AffineBinOpExpr &expr,
+ bool isCeil) {
+ assert(operandExprStack.size() >= 2);
+
+ SmallVector<int64_t, 8> rhs = operandExprStack.back();
+ operandExprStack.pop_back();
+ SmallVector<int64_t, 8> &lhs = operandExprStack.back();
+
+ // Flatten semi affine division expressions by introducing a local
+ // variable in place of the quotient, and the affine expression corresponding
+ // to the quantifier is added to `localExprs`.
+ if (!isa<AffineConstantExpr>(expr.getRHS())) {
+ AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols);
+ AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols);
+ AffineExpr divExpr = isCeil ? ceilDiv(std::move(a), std::move(b))
+ : floorDiv(std::move(a), std::move(b));
+ addLocalVariableSemiAffine(std::move(divExpr), lhs, lhs.size());
+ return success();
+ }
+
+ // This is a pure affine expr; the RHS is a positive constant.
+ int64_t rhsConst = rhs[getConstantIndex()];
+ if (rhsConst <= 0)
+ return failure();
+
+ // Simplify the floordiv, ceildiv if possible by canceling out the greatest
+ // common divisors of the numerator and denominator.
+ uint64_t gcd = std::abs(rhsConst);
+ for (int64_t lhsElt : lhs)
+ gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
+ // Simplify the numerator and the denominator.
+ if (gcd != 1) {
+ for (int64_t &lhsElt : lhs)
+ lhsElt = lhsElt / static_cast<int64_t>(gcd);
+ }
+ int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
+ // If the divisor becomes 1, the updated LHS is the result. (The
+ // divisor can't be negative since rhsConst is positive).
+ if (divisor == 1)
+ return success();
+
+ // If the divisor cannot be simplified to one, we will have to retain
+ // the ceil/floor expr (simplified up until here). Add an existential
+ // quantifier to express its result, i.e., expr1 div expr2 is replaced
+ // by a new identifier, q.
+ AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols);
+ AffineExpr b = std::make_unique<AffineConstantExpr>(divisor);
+
+ int loc;
+ AffineExpr divExpr = isCeil ? ceilDiv(std::move(a), std::move(b))
+ : floorDiv(std::move(a), std::move(b));
+ if ((loc = findLocalId(divExpr)) == -1) {
+ if (!isCeil) {
+ SmallVector<int64_t, 8> dividend(lhs);
+ addLocalFloorDivId(dividend, divisor, std::move(divExpr));
+ } else {
+ // lhs ceildiv c <=> (lhs + c - 1) floordiv c
+ SmallVector<int64_t, 8> dividend(lhs);
+ dividend.back() += divisor - 1;
+ addLocalFloorDivId(dividend, divisor, std::move(divExpr));
+ }
+ }
+ // Set the expression on stack to the local var introduced to capture the
+ // result of the division (floor or ceil).
+ std::fill(lhs.begin(), lhs.end(), 0);
+ if (loc == -1)
+ lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
+ else
+ lhs[getLocalVarStartIndex() + loc] = 1;
+ return success();
+}
+
+void AffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
+ int64_t divisor,
+ AffineExpr &&localExpr) {
+ assert(divisor > 0 && "positive constant divisor expected");
+ for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
+ subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
+ localExprs.emplace_back(std::move(localExpr));
+ ++numLocals;
+ // Update localVarCst.
+ localVarCst.addLocalFloorDiv(dividend, divisor);
+}
+
+void AffineExprFlattener::addLocalIdSemiAffine(AffineExpr &&localExpr) {
+ for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
+ subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
+ localExprs.emplace_back(std::move(localExpr));
+ ++numLocals;
+}
+
+int AffineExprFlattener::findLocalId(const AffineExpr &localExpr) {
+ auto *it = llvm::find(localExprs, localExpr);
+ if (it == localExprs.end())
+ return -1;
+ return it - localExprs.begin();
+}
+
+AffineExprFlattener::AffineExprFlattener(unsigned numDims, unsigned numSymbols)
+ : numDims(numDims), numSymbols(numSymbols), numLocals(0),
+ localVarCst(PresburgerSpace::getSetSpace(numDims, numSymbols)) {
+ operandExprStack.reserve(8);
+}
+
+// Flattens the expressions in map. Returns failure if 'expr' was unable to be
+// flattened. For example two specific cases:
+// 1. semi-affine expressions not handled yet.
+// 2. has poison expression (i.e., division by zero).
+static LogicalResult
+getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
+ unsigned numSymbols,
+ std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
+ IntegerPolyhedron &localVarCst) {
+ if (exprs.empty()) {
+ localVarCst = IntegerPolyhedron(
+ 0, 0, numDims + numSymbols + 1,
+ presburger::PresburgerSpace::getSetSpace(numDims, numSymbols, 0));
+ return success();
+ }
+
+ AffineExprFlattener flattener(numDims, numSymbols);
+ // Use the same flattener to simplify each expression successively. This way
+ // local variables / expressions are shared.
+ for (const AffineExpr &expr : exprs) {
+ if (!expr->isPureAffine())
+ return failure();
+ // has poison expression
+ LogicalResult flattenResult = flattener.walkPostOrder(*expr);
+ if (failed(flattenResult))
+ return failure();
+ }
+
+ assert(flattener.operandExprStack.size() == exprs.size());
+ flattenedExprs.clear();
+ flattenedExprs.assign(flattener.operandExprStack.begin(),
+ flattener.operandExprStack.end());
+
+ localVarCst.clearAndCopyFrom(flattener.localVarCst);
+
+ return success();
+}
+
+namespace mlir::presburger {
+LogicalResult
+getFlattenedAffineExprs(const AffineMap &map,
+ std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
+ IntegerPolyhedron &cst) {
+ if (map.getNumExprs() == 0) {
+ cst = IntegerPolyhedron(0, 0, map.getNumDims() + map.getNumSymbols() + 1,
+ presburger::PresburgerSpace::getSetSpace(
+ map.getNumDims(), map.getNumSymbols(), 0));
+ return success();
+ }
+ return ::getFlattenedAffineExprs(map.getExprs(), map.getNumDims(),
+ map.getNumSymbols(), flattenedExprs, cst);
+}
+
+LogicalResult
+getFlattenedAffineExprs(const IntegerSet &set,
+ std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
+ IntegerPolyhedron &cst) {
+ if (set.getNumConstraints() == 0) {
+ cst = IntegerPolyhedron(0, 0, set.getNumDims() + set.getNumSymbols() + 1,
+ presburger::PresburgerSpace::getSetSpace(
+ set.getNumDims(), set.getNumSymbols(), 0));
+ return success();
+ }
+ return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
+ set.getNumSymbols(), flattenedExprs, cst);
+}
+} // namespace mlir::presburger
diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.h b/mlir/lib/Analysis/Presburger/Parser/Flattener.h
new file mode 100644
index 0000000000000..abbdc002c5545
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.h
@@ -0,0 +1,244 @@
+//===- Flattener.h - Presburger ParseStruct Flattener -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H
+#define MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H
+
+#include "ParseStructs.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+
+namespace mlir::presburger {
+// This class is used to flatten a pure affine expression (AffineExpr,
+// which is in a tree form) into a sum of products (w.r.t constants) when
+// possible, and in that process simplifying the expression. For a modulo,
+// floordiv, or a ceildiv expression, an additional identifier, called a local
+// identifier, is introduced to rewrite the expression as a sum of product
+// affine expression. Each local identifier is always and by construction a
+// floordiv of a pure add/mul affine function of dimensional, symbolic, and
+// other local identifiers, in a non-mutually recursive way. Hence, every local
+// identifier can ultimately always be recovered as an affine function of
+// dimensional and symbolic identifiers (involving floordiv's); note however
+// that by AffineExpr construction, some floordiv combinations are converted to
+// mod's. The result of the flattening is a flattened expression and a set of
+// constraints involving just the local variables.
+//
+// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
+// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
+//
+// The simplification performed includes the accumulation of contributions for
+// each dimensional and symbolic identifier together, the simplification of
+// floordiv/ceildiv/mod expressions and other simplifications that in turn
+// happen as a result. A simplification that this flattening naturally performs
+// is of simplifying the numerator and denominator of floordiv/ceildiv, and
+// folding a modulo expression to a zero, if possible. Three examples are below:
+//
+// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
+// (d0 - d0 mod 4 + 4) mod 4 simplified to 0
+// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
+//
+// The way the flattening works for the second example is as follows: d0 % 4 is
+// replaced by d0 - 4*q with q being introduced: the expression then simplifies
+// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
+// zero. Note that an affine expression may not always be expressible purely as
+// a sum of products involving just the original dimensional and symbolic
+// identifiers due to the presence of modulo/floordiv/ceildiv expressions that
+// may not be eliminated after simplification; in such cases, the final
+// expression can be reconstructed by replacing the local identifiers with their
+// corresponding explicit form stored in 'localExprs' (note that each of the
+// explicit forms itself would have been simplified).
+//
+// The expression walk method here performs a linear time post order walk that
+// performs the above simplifications through visit methods, with partial
+// results being stored in 'operandExprStack'. When a parent expr is visited,
+// the flattened expressions corresponding to its two operands would already be
+// on the stack - the parent expression looks at the two flattened expressions
+// and combines the two. It pops off the operand expressions and pushes the
+// combined result (although this is done in-place on its LHS operand expr).
+// When the walk is completed, the flattened form of the top-level expression
+// would be left on the stack.
+//
+// A flattener can be repeatedly used for multiple affine expressions that bind
+// to the same operands, for example, for all result expressions of an
+// AffineMap or AffineValueMap. In such cases, using it for multiple expressions
+// is more efficient than creating a new flattener for each expression since
+// common identical div and mod expressions appearing across different
+// expressions are mapped to the same local identifier (same column position in
+// 'localVarCst').
+class AffineExprFlattener {
+public:
+ // Flattend expression layout: [dims, symbols, locals, constant]
+ // Stack that holds the LHS and RHS operands while visiting a binary op expr.
+ // In future, consider adding a prepass to determine how big the SmallVector's
+ // will be, and linearize this to std::vector<int64_t> to prevent
+ // SmallVector moves on re-allocation.
+ std::vector<SmallVector<int64_t, 8>> operandExprStack;
+
+ unsigned numDims;
+ unsigned numSymbols;
+
+ // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
+ unsigned numLocals;
+
+ // Constraints connecting newly introduced local variables (for mod's and
+ // div's) to existing (dimensional and symbolic) ones. These are always
+ // inequalities.
+ IntegerPolyhedron localVarCst;
+
+ // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
+ // which new identifiers were introduced; if the latter do not get canceled
+ // out, these expressions can be readily used to reconstruct the AffineExpr
+ // (tree) form. Note that these expressions themselves would have been
+ // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
+ // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
+ // ceildiv 2 would be the local expression stored for q.
+ SmallVector<AffineExpr, 4> localExprs;
+
+ AffineExprFlattener(unsigned numDims, unsigned numSymbols);
+
+ virtual ~AffineExprFlattener() = default;
+
+ // Visitor methods.
+ LogicalResult visitMulExpr(const AffineBinOpExpr &expr);
+ LogicalResult visitAddExpr(const AffineBinOpExpr &expr);
+ LogicalResult visitDimExpr(const AffineDimExpr &expr);
+ LogicalResult visitSymbolExpr(const AffineSymbolExpr &expr);
+ LogicalResult visitConstantExpr(const AffineConstantExpr &expr);
+ LogicalResult visitCeilDivExpr(const AffineBinOpExpr &expr);
+ LogicalResult visitFloorDivExpr(const AffineBinOpExpr &expr);
+
+ //
+ // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
+ //
+ // A mod expression "expr mod c" is thus flattened by introducing a new local
+ // variable q (= expr floordiv c), such that expr mod c is replaced with
+ // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
+ LogicalResult visitModExpr(const AffineBinOpExpr &expr);
+
+ // Function to walk an AffineExpr (in post order).
+ LogicalResult walkPostOrder(const AffineExprImpl &expr) {
+ switch (expr.getKind()) {
+ case AffineExprKind::Add: {
+ const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
+ return visitAddExpr(binOpExpr);
+ }
+ case AffineExprKind::Mul: {
+ const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
+ return visitMulExpr(binOpExpr);
+ }
+ case AffineExprKind::Mod: {
+ const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
+ return visitModExpr(binOpExpr);
+ }
+ case AffineExprKind::FloorDiv: {
+ const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
+ return visitFloorDivExpr(binOpExpr);
+ }
+ case AffineExprKind::CeilDiv: {
+ const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
+ if (failed(walkOperandsPostOrder(binOpExpr)))
+ return failure();
+ return visitCeilDivExpr(binOpExpr);
+ }
+ case AffineExprKind::Constant:
+ return visitConstantExpr(cast<AffineConstantExpr>(expr));
+ case AffineExprKind::DimId:
+ return visitDimExpr(cast<AffineDimExpr>(expr));
+ case AffineExprKind::SymbolId:
+ return visitSymbolExpr(cast<AffineSymbolExpr>(expr));
+ }
+ llvm_unreachable("Unknown AffineExpr");
+ }
+
+private:
+ // Walk the operands - each operand is itself walked in post order.
+ LogicalResult walkOperandsPostOrder(const AffineBinOpExpr &expr) {
+ if (failed(walkPostOrder(*expr.getLHS())))
+ return failure();
+ if (failed(walkPostOrder(*expr.getRHS())))
+ return failure();
+ return success();
+ }
+
+ /// Constructs an affine expression from a flat ArrayRef. If there are local
+ /// identifiers (neither dimensional nor symbolic) that appear in the sum of
+ /// products expression, `localExprs` is expected to have the AffineExpr
+ /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to
+ /// be in the format [dims, symbols, locals, constant term].
+ AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
+ unsigned numDims, unsigned numSymbols);
+
+ // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
+ // The local identifier added is always a floordiv of a pure add/mul affine
+ // function of other identifiers, coefficients of which are specified in
+ // dividend and with respect to a positive constant divisor. localExpr is the
+ // simplified tree expression (AffineExpr) corresponding to the quantifier.
+ void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
+ AffineExpr &&localExpr);
+
+ /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
+ /// expr) when the rhs is a symbolic expression. The local identifier added
+ /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
+ /// function of other identifiers, coefficients of which are specified in the
+ /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
+ /// symbolic rhs expression. `localExpr` is the simplified tree expression
+ /// (AffineExpr) corresponding to the quantifier.
+ void addLocalIdSemiAffine(AffineExpr &&localExpr);
+
+ /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
+ /// representing the affine expression corresponding to the quantifier
+ /// introduced as the local variable corresponding to `expr`. If the
+ /// quantifier is already present, we put the coefficient in the proper index
+ /// of `result`, otherwise we add a new local variable and put the coefficient
+ /// there.
+ void addLocalVariableSemiAffine(AffineExpr &&expr,
+ SmallVectorImpl<int64_t> &result,
+ unsigned long resultSize);
+
+ // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
+ // A floordiv is thus flattened by introducing a new local variable q, and
+ // replacing that expression with 'q' while adding the constraints
+ // c * q <= expr <= c * q + c - 1 to localVarCst (done by
+ // IntegerRelation::addLocalFloorDiv).
+ //
+ // A ceildiv is similarly flattened:
+ // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
+ LogicalResult visitDivExpr(const AffineBinOpExpr &expr, bool isCeil);
+
+ int findLocalId(const AffineExpr &localExpr);
+
+ inline unsigned getNumCols() const {
+ return numDims + numSymbols + numLocals + 1;
+ }
+ inline unsigned getConstantIndex() const { return getNumCols() - 1; }
+ inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
+ inline unsigned getSymbolStartIndex() const { return numDims; }
+ inline unsigned getDimStartIndex() const { return 0; }
+};
+
+// Flattener for AffineMap.
+LogicalResult
+getFlattenedAffineExprs(const AffineMap &map,
+ std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
+ IntegerPolyhedron &cst);
+
+// Flattener for IntegerSet.
+LogicalResult
+getFlattenedAffineExprs(const IntegerSet &set,
+ std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
+ IntegerPolyhedron &cst);
+} // namespace mlir::presburger
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H
diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp
new file mode 100644
index 0000000000000..f89742bb0451b
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp
@@ -0,0 +1,161 @@
+//===- Lexer.cpp - Presburger Lexer Implementation ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lexer for the Presburger textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Lexer.h"
+#include "Token.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir::presburger;
+
+Lexer::Lexer(const llvm::SourceMgr &sourceMgr) : sourceMgr(sourceMgr) {
+ auto bufferID = sourceMgr.getMainFileID();
+ curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
+ curPtr = curBuffer.begin();
+}
+
+/// emitError - Emit an error message and return an Token::error token.
+Token Lexer::emitError(const char *loc, const llvm::Twine &message) {
+ sourceMgr.PrintMessage(SMLoc::getFromPointer(loc), llvm::SourceMgr::DK_Error,
+ message);
+ return formToken(Token::error, loc);
+}
+
+Token Lexer::lexToken() {
+ while (true) {
+ const char *tokStart = curPtr;
+
+ // Lex the next token.
+ switch (*curPtr++) {
+ default:
+ // Handle bare identifiers.
+ if (isalpha(curPtr[-1]))
+ return lexBareIdentifierOrKeyword(tokStart);
+
+ // Unknown character, emit an error.
+ return emitError(tokStart, "unexpected character");
+
+ case ' ':
+ case '\t':
+ case '\n':
+ case '\r':
+ // Handle whitespace.
+ continue;
+
+ case '_':
+ // Handle bare identifiers.
+ return lexBareIdentifierOrKeyword(tokStart);
+
+ case 0:
+ // This may either be a nul character in the source file or may be the EOF
+ // marker that llvm::MemoryBuffer guarantees will be there.
+ if (curPtr - 1 == curBuffer.end())
+ return formToken(Token::eof, tokStart);
+ continue;
+
+ case ':':
+ return formToken(Token::colon, tokStart);
+ case ',':
+ return formToken(Token::comma, tokStart);
+ case '(':
+ return formToken(Token::l_paren, tokStart);
+ case ')':
+ return formToken(Token::r_paren, tokStart);
+ case '{':
+ return formToken(Token::l_brace, tokStart);
+ case '}':
+ return formToken(Token::r_brace, tokStart);
+ case '[':
+ return formToken(Token::l_square, tokStart);
+ case ']':
+ return formToken(Token::r_square, tokStart);
+ case '<':
+ return formToken(Token::less, tokStart);
+ case '>':
+ return formToken(Token::greater, tokStart);
+ case '=':
+ return formToken(Token::equal, tokStart);
+ case '+':
+ return formToken(Token::plus, tokStart);
+ case '*':
+ return formToken(Token::star, tokStart);
+ case '-':
+ if (*curPtr == '>') {
+ ++curPtr;
+ return formToken(Token::arrow, tokStart);
+ }
+ return formToken(Token::minus, tokStart);
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ return lexNumber(tokStart);
+ }
+ }
+}
+
+/// Lex a bare identifier or keyword that starts with a letter.
+///
+/// bare-id ::= (letter|[_]) (letter|digit|[_$.])*
+/// integer-type ::= `[su]?i[1-9][0-9]*`
+///
+Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
+ // Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
+ while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
+ *curPtr == '$' || *curPtr == '.')
+ ++curPtr;
+
+ // Check to see if this identifier is a keyword.
+ StringRef spelling(tokStart, curPtr - tokStart);
+
+ auto isAllDigit = [](StringRef str) {
+ return llvm::all_of(str, llvm::isDigit);
+ };
+
+ // Check for i123, si456, ui789.
+ if ((spelling.size() > 1 && tokStart[0] == 'i' &&
+ isAllDigit(spelling.drop_front())) ||
+ ((spelling.size() > 2 && tokStart[1] == 'i' &&
+ (tokStart[0] == 's' || tokStart[0] == 'u')) &&
+ isAllDigit(spelling.drop_front(2))))
+ return Token(Token::inttype, spelling);
+
+ Token::Kind kind = llvm::StringSwitch<Token::Kind>(spelling)
+#define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
+#include "TokenKinds.def"
+ .Default(Token::bare_identifier);
+
+ return Token(kind, spelling);
+}
+
+/// Lex a number literal.
+///
+/// integer-literal ::= digit+ | `0x` hex_digit+
+/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
+///
+Token Lexer::lexNumber(const char *tokStart) {
+ assert(isdigit(curPtr[-1]));
+
+ // Handle the normal decimal case.
+ while (isdigit(*curPtr))
+ ++curPtr;
+
+ return formToken(Token::integer, tokStart);
+}
diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.h b/mlir/lib/Analysis/Presburger/Parser/Lexer.h
new file mode 100644
index 0000000000000..081ded3a7fa7c
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.h
@@ -0,0 +1,58 @@
+//===- Lexer.h - Presburger Lexer Interface ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Presburger Lexer class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H
+#define MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H
+
+#include "Token.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir::presburger {
+/// This class breaks up the current file into a token stream.
+class Lexer {
+public:
+ explicit Lexer(const llvm::SourceMgr &sourceMgr);
+
+ Token lexToken();
+
+ /// Change the position of the lexer cursor. The next token we lex will start
+ /// at the designated point in the input.
+ void resetPointer(const char *newPointer) { curPtr = newPointer; }
+
+ /// Returns the start of the buffer.
+ const char *getBufferBegin() { return curBuffer.data(); }
+
+private:
+ // Helpers.
+ Token formToken(Token::Kind kind, const char *tokStart) {
+ return Token(kind, StringRef(tokStart, curPtr - tokStart));
+ }
+
+ Token emitError(const char *loc, const llvm::Twine &message);
+
+ // Lexer implementation methods.
+ Token lexAtIdentifier(const char *tokStart);
+ Token lexBareIdentifierOrKeyword(const char *tokStart);
+ Token lexNumber(const char *tokStart);
+
+ const llvm::SourceMgr &sourceMgr;
+
+ StringRef curBuffer;
+ const char *curPtr;
+
+ Lexer(const Lexer &) = delete;
+ void operator=(const Lexer &) = delete;
+};
+} // namespace mlir::presburger
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp
new file mode 100644
index 0000000000000..681bb23db001d
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp
@@ -0,0 +1,268 @@
+//===- ParseStructs.cpp - Presburger Parse Structrures ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ParseStructs class that the parser for the
+// Presburger library parses into.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ParseStructs.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir::presburger;
+using llvm::cast;
+using llvm::dbgs;
+using llvm::isa;
+
+bool AffineExprImpl::isPureAffine() const {
+ switch (getKind()) {
+ case AffineExprKind::SymbolId:
+ case AffineExprKind::DimId:
+ case AffineExprKind::Constant:
+ return true;
+ case AffineExprKind::Add: {
+ const auto &op = cast<AffineBinOpExpr>(*this);
+ return op.getLHS()->isPureAffine() && op.getRHS()->isPureAffine();
+ }
+ case AffineExprKind::Mul: {
+ const auto &op = cast<AffineBinOpExpr>(*this);
+ return op.getLHS()->isPureAffine() && op.getRHS()->isPureAffine() &&
+ (isa<AffineConstantExpr>(op.getLHS()) ||
+ isa<AffineConstantExpr>(op.getRHS()));
+ }
+ case AffineExprKind::FloorDiv:
+ case AffineExprKind::CeilDiv:
+ case AffineExprKind::Mod: {
+ const auto &op = cast<AffineBinOpExpr>(*this);
+ return op.getLHS()->isPureAffine() && isa<AffineConstantExpr>(op.getRHS());
+ }
+ }
+ llvm_unreachable("Unknown AffineExpr");
+}
+
+bool AffineExprImpl::isSymbolicOrConstant() const {
+ switch (getKind()) {
+ case AffineExprKind::Constant:
+ case AffineExprKind::SymbolId:
+ return true;
+ case AffineExprKind::DimId:
+ return false;
+ case AffineExprKind::Add:
+ case AffineExprKind::Mul:
+ case AffineExprKind::FloorDiv:
+ case AffineExprKind::CeilDiv:
+ case AffineExprKind::Mod: {
+ const auto &expr = cast<AffineBinOpExpr>(*this);
+ return expr.getLHS()->isSymbolicOrConstant() &&
+ expr.getRHS()->isSymbolicOrConstant();
+ }
+ }
+ llvm_unreachable("Unknown AffineExpr");
+}
+
+// Simplify the mul to the extent required by usage and the flattener.
+static AffineExpr simplifyMul(AffineExpr &&lhs, AffineExpr &&rhs) {
+ if (isa<AffineConstantExpr>(*lhs) && isa<AffineConstantExpr>(*rhs)) {
+ auto lhsConst = cast<AffineConstantExpr>(*lhs);
+ auto rhsConst = cast<AffineConstantExpr>(*rhs);
+ return std::make_unique<AffineConstantExpr>(lhsConst.getValue() *
+ rhsConst.getValue());
+ }
+
+ if (!lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())
+ return nullptr;
+
+ // Canonicalize the mul expression so that the constant/symbolic term is the
+ // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
+ // constant. (Note that a constant is trivially symbolic).
+ if (!rhs->isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
+ // At least one of them has to be symbolic.
+ return std::move(rhs) * std::move(lhs);
+ }
+
+ // At this point, if there was a constant, it would be on the right.
+
+ // Multiplication with a one is a noop, return the other input.
+ if (isa<AffineConstantExpr>(*rhs)) {
+ auto rhsConst = cast<AffineConstantExpr>(*rhs);
+ if (rhsConst.getValue() == 1)
+ return lhs;
+ // Multiplication with zero.
+ if (rhsConst.getValue() == 0)
+ return std::make_unique<AffineConstantExpr>(rhsConst);
+ }
+
+ return nullptr;
+}
+
+namespace mlir::presburger {
+AffineExpr operator*(AffineExpr &&s, AffineExpr &&o) {
+ if (AffineExpr simpl = simplifyMul(std::move(s), std::move(o)))
+ return simpl;
+ return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
+ AffineExprKind::Mul);
+}
+} // namespace mlir::presburger
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+enum class BindingStrength {
+ Weak, // + and -
+ Strong, // All other binary operators.
+};
+
+static void printAffineExpr(const AffineExprImpl &expr,
+ BindingStrength enclosingTightness) {
+ const char *binopSpelling = nullptr;
+ switch (expr.getKind()) {
+ case AffineExprKind::SymbolId: {
+ unsigned pos = cast<AffineSymbolExpr>(expr).getPosition();
+ dbgs() << 's' << pos;
+ return;
+ }
+ case AffineExprKind::DimId: {
+ unsigned pos = cast<AffineDimExpr>(expr).getPosition();
+ dbgs() << 'd' << pos;
+ return;
+ }
+ case AffineExprKind::Constant:
+ dbgs() << cast<AffineConstantExpr>(expr).getValue();
+ return;
+ case AffineExprKind::Add:
+ binopSpelling = " + ";
+ break;
+ case AffineExprKind::Mul:
+ binopSpelling = " * ";
+ break;
+ case AffineExprKind::FloorDiv:
+ binopSpelling = " floordiv ";
+ break;
+ case AffineExprKind::CeilDiv:
+ binopSpelling = " ceildiv ";
+ break;
+ case AffineExprKind::Mod:
+ binopSpelling = " mod ";
+ break;
+ }
+
+ const auto &binOp = cast<AffineBinOpExpr>(expr);
+ const AffineExprImpl &lhsExpr = *binOp.getLHS();
+ const AffineExprImpl &rhsExpr = *binOp.getRHS();
+
+ // Handle tightly binding binary operators.
+ if (binOp.getKind() != AffineExprKind::Add) {
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << '(';
+
+ // Pretty print multiplication with -1.
+ if (isa<AffineConstantExpr>(rhsExpr)) {
+ const auto &rhsConst = cast<AffineConstantExpr>(rhsExpr);
+ if (binOp.getKind() == AffineExprKind::Mul && rhsConst.getValue() == -1) {
+ dbgs() << "-";
+ printAffineExpr(lhsExpr, BindingStrength::Strong);
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << ')';
+ return;
+ }
+ }
+ printAffineExpr(lhsExpr, BindingStrength::Strong);
+
+ dbgs() << binopSpelling;
+ printAffineExpr(rhsExpr, BindingStrength::Strong);
+
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << ')';
+ return;
+ }
+
+ // Print out special "pretty" forms for add.
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << '(';
+
+ // Pretty print addition to a product that has a negative operand as a
+ // subtraction.
+ if (isa<AffineBinOpExpr>(rhsExpr)) {
+ const auto &rhs = cast<AffineBinOpExpr>(rhsExpr);
+ if (rhs.getKind() == AffineExprKind::Mul) {
+ const AffineExprImpl &rrhsExpr = *rhs.getRHS();
+ if (isa<AffineConstantExpr>(rrhsExpr)) {
+ const auto &rrhs = cast<AffineConstantExpr>(rrhsExpr);
+ if (rrhs.getValue() == -1) {
+ printAffineExpr(lhsExpr, BindingStrength::Weak);
+ dbgs() << " - ";
+ if (rhs.getLHS()->getKind() == AffineExprKind::Add) {
+ printAffineExpr(*rhs.getLHS(), BindingStrength::Strong);
+ } else {
+ printAffineExpr(*rhs.getLHS(), BindingStrength::Weak);
+ }
+
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << ')';
+ return;
+ }
+
+ if (rrhs.getValue() < -1) {
+ printAffineExpr(lhsExpr, BindingStrength::Weak);
+ dbgs() << " - ";
+ printAffineExpr(*rhs.getLHS(), BindingStrength::Strong);
+ dbgs() << " * " << -rrhs.getValue();
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << ')';
+ return;
+ }
+ }
+ }
+ }
+
+ // Pretty print addition to a negative number as a subtraction.
+ if (isa<AffineConstantExpr>(rhsExpr)) {
+ const auto &rhsConst = cast<AffineConstantExpr>(rhsExpr);
+ if (rhsConst.getValue() < 0) {
+ printAffineExpr(lhsExpr, BindingStrength::Weak);
+ dbgs() << " - " << -rhsConst.getValue();
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << ')';
+ return;
+ }
+ }
+
+ printAffineExpr(lhsExpr, BindingStrength::Weak);
+
+ dbgs() << " + ";
+ printAffineExpr(rhsExpr, BindingStrength::Weak);
+
+ if (enclosingTightness == BindingStrength::Strong)
+ dbgs() << ')';
+}
+
+LLVM_DUMP_METHOD void AffineExprImpl::dump() const {
+ printAffineExpr(*this, BindingStrength::Weak);
+ dbgs() << '\n';
+}
+
+LLVM_DUMP_METHOD void AffineMap::dump() const {
+ dbgs() << "NumDims = " << numDims << '\n';
+ dbgs() << "NumSymbols = " << numSymbols << '\n';
+ dbgs() << "Expressions:\n";
+ for (const AffineExpr &e : getExprs())
+ e->dump();
+}
+
+LLVM_DUMP_METHOD void IntegerSet::dump() const {
+ dbgs() << "NumDims = " << numDims << '\n';
+ dbgs() << "NumSymbols = " << numSymbols << '\n';
+ dbgs() << "Constraints:\n";
+ for (const AffineExpr &c : getConstraints())
+ c->dump();
+ dbgs() << "EqFlags:\n";
+ for (bool e : getEqFlags())
+ dbgs() << e << '\n';
+}
+#endif
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
new file mode 100644
index 0000000000000..8ec7fe2ff840a
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
@@ -0,0 +1,293 @@
+//===- ParseStructs.h - Presburger Parse Structrures ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H
+#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include <cassert>
+#include <cstdint>
+#include <memory>
+
+namespace mlir::presburger {
+using llvm::ArrayRef;
+using llvm::SmallVector;
+using llvm::SmallVectorImpl;
+
+enum class AffineExprKind {
+ Add,
+ /// RHS of mul is always a constant or a symbolic expression.
+ Mul,
+ /// RHS of mod is always a constant or a symbolic expression with a positive
+ /// value.
+ Mod,
+ /// RHS of floordiv is always a constant or a symbolic expression.
+ FloorDiv,
+ /// RHS of ceildiv is always a constant or a symbolic expression.
+ CeilDiv,
+ /// This is a marker for the last affine binary op. The range of binary
+ /// op's is expected to be this element and earlier.
+ LAST_BINOP = CeilDiv,
+ /// Constant integer.
+ Constant,
+ /// Dimensional identifier.
+ DimId,
+ /// Symbolic identifier.
+ SymbolId,
+};
+
+struct AffineExprImpl {
+ explicit AffineExprImpl(AffineExprKind kind) : kind(kind) {}
+
+ // Delete all copy/move operators.
+ AffineExprImpl(const AffineExprImpl &o) = delete;
+ AffineExprImpl &operator=(const AffineExprImpl &o) = delete;
+ AffineExprImpl(AffineExprImpl &&o) = delete;
+ AffineExprImpl &operator=(AffineExprImpl &&o) = delete;
+
+ AffineExprKind getKind() const { return kind; }
+
+ /// Returns true if this expression is made out of only symbols and
+ /// constants, i.e., it does not involve dimensional identifiers.
+ bool isSymbolicOrConstant() const;
+
+ /// Returns true if this is a pure affine expression, i.e., multiplication,
+ /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
+ bool isPureAffine() const;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+
+ AffineExprKind kind;
+};
+
+// AffineExpr is a unique ptr, since there is a cycle is AffineBinaryOp.
+using AffineExpr = std::unique_ptr<AffineExprImpl>;
+
+struct AffineBinOpExpr : public AffineExprImpl {
+ AffineBinOpExpr(AffineExpr &&lhs, AffineExpr &&rhs, AffineExprKind kind)
+ : AffineExprImpl(kind), lhs(std::move(lhs)), rhs(std::move(rhs)) {}
+
+ // Delete all copy/move operators.
+ AffineBinOpExpr(const AffineBinOpExpr &o) = delete;
+ AffineBinOpExpr &operator=(const AffineBinOpExpr &o) = delete;
+ AffineBinOpExpr(AffineBinOpExpr &&o) = delete;
+ AffineBinOpExpr &operator=(AffineBinOpExpr &&o) = delete;
+
+ const AffineExpr &getLHS() const { return lhs; }
+ const AffineExpr &getRHS() const { return rhs; }
+ static bool classof(const AffineExprImpl *a) {
+ return a->getKind() <= AffineExprKind::LAST_BINOP;
+ }
+
+ AffineExpr lhs;
+ AffineExpr rhs;
+};
+
+/// A dimensional or symbolic identifier appearing in an affine expression.
+struct AffineDimExpr : public AffineExprImpl {
+ AffineDimExpr(unsigned position)
+ : AffineExprImpl(AffineExprKind::DimId), position(position) {}
+
+ // Enable copy/move constructors; trivial.
+ AffineDimExpr(const AffineDimExpr &o)
+ : AffineExprImpl(AffineExprKind::DimId), position(o.position) {}
+ AffineDimExpr(AffineDimExpr &&o)
+ : AffineExprImpl(AffineExprKind::DimId), position(o.position) {}
+ AffineDimExpr &operator=(const AffineDimExpr &o) = delete;
+ AffineDimExpr &operator=(AffineDimExpr &&o) = delete;
+
+ unsigned getPosition() const { return position; }
+ static bool classof(const AffineExprImpl *a) {
+ return a->getKind() == AffineExprKind::DimId;
+ }
+ bool operator==(const AffineDimExpr &o) const {
+ return position == o.position;
+ }
+
+ /// Position of this identifier in the argument list.
+ unsigned position;
+};
+
+/// A symbolic identifier appearing in an affine expression.
+struct AffineSymbolExpr : public AffineExprImpl {
+ AffineSymbolExpr(unsigned position)
+ : AffineExprImpl(AffineExprKind::SymbolId), position(position) {}
+
+ // Enable copy/move constructors; trivial.
+ AffineSymbolExpr(const AffineSymbolExpr &o)
+ : AffineExprImpl(AffineExprKind::SymbolId), position(o.position) {}
+ AffineSymbolExpr(AffineSymbolExpr &&o)
+ : AffineExprImpl(AffineExprKind::SymbolId), position(o.position) {}
+ AffineSymbolExpr &operator=(const AffineSymbolExpr &o) = delete;
+ AffineSymbolExpr &operator=(AffineSymbolExpr &&o) = delete;
+
+ unsigned getPosition() const { return position; }
+ static bool classof(const AffineExprImpl *a) {
+ return a->getKind() == AffineExprKind::SymbolId;
+ }
+ bool operator==(const AffineSymbolExpr &o) const {
+ return position == o.position;
+ }
+
+ /// Position of this identifier in the argument list.
+ unsigned position;
+};
+
+/// An integer constant appearing in affine expression.
+struct AffineConstantExpr : public AffineExprImpl {
+ AffineConstantExpr(int64_t constant)
+ : AffineExprImpl(AffineExprKind::Constant), constant(constant) {}
+
+ // Enable copy/move constructors; trivial.
+ AffineConstantExpr(const AffineConstantExpr &o)
+ : AffineExprImpl(AffineExprKind::Constant), constant(o.constant) {}
+ AffineConstantExpr(AffineConstantExpr &&o)
+ : AffineExprImpl(AffineExprKind::Constant), constant(o.constant) {}
+ AffineConstantExpr &operator=(const AffineConstantExpr &o) = delete;
+ AffineConstantExpr &operator=(AffineConstantExpr &&o) = delete;
+
+ int64_t getValue() const { return constant; }
+ static bool classof(const AffineExprImpl *a) {
+ return a->getKind() == AffineExprKind::Constant;
+ }
+ bool operator==(const AffineConstantExpr &o) const {
+ return constant == o.constant;
+ }
+
+ // The constant.
+ int64_t constant;
+};
+
+struct AffineMap {
+ unsigned numDims;
+ unsigned numSymbols;
+
+ // The affine expressions in the map.
+ SmallVector<AffineExpr, 4> exprs;
+
+ AffineMap(unsigned numDims, unsigned numSymbols,
+ SmallVectorImpl<AffineExpr> &&exprs)
+ : numDims(numDims), numSymbols(numSymbols), exprs(std::move(exprs)) {}
+
+ // Non-copyable; only movable.
+ AffineMap(const AffineMap &) = delete;
+ AffineMap operator=(const AffineMap &) = delete;
+ AffineMap(AffineMap &&o)
+ : numDims(o.numDims), numSymbols(o.numSymbols),
+ exprs(std::move(o.exprs)) {}
+ AffineMap &operator=(AffineMap &&o) = delete;
+
+ unsigned getNumDims() const { return numDims; }
+ unsigned getNumSymbols() const { return numSymbols; }
+ unsigned getNumInputs() const { return numDims + numSymbols; }
+ unsigned getNumExprs() const { return exprs.size(); }
+ ArrayRef<AffineExpr> getExprs() const { return exprs; }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+struct IntegerSet {
+ unsigned numDims;
+ unsigned numSymbols;
+
+ /// Array of affine constraints: a constraint is either an equality
+ /// (affine_expr == 0) or an inequality (affine_expr >= 0).
+ SmallVector<AffineExpr, 4> constraints;
+
+ // Bits to check whether a constraint is an equality or an inequality.
+ SmallVector<bool, 4> eqFlags;
+
+ IntegerSet(unsigned numDims, unsigned numSymbols,
+ SmallVectorImpl<AffineExpr> &&constraints,
+ SmallVectorImpl<bool> &&eqFlags)
+ : numDims(numDims), numSymbols(numSymbols),
+ constraints(std::move(constraints)), eqFlags(std::move(eqFlags)) {
+ assert(constraints.size() == eqFlags.size());
+ }
+
+ // Non-copyable; only movable.
+ IntegerSet(const IntegerSet &o) = delete;
+ IntegerSet &operator=(const IntegerSet &o) = delete;
+ IntegerSet(IntegerSet &&o)
+ : numDims(o.numDims), numSymbols(o.numSymbols),
+ constraints(std::move(o.constraints)), eqFlags(std::move(o.eqFlags)) {}
+ IntegerSet &operator=(IntegerSet &&o) = delete;
+
+ IntegerSet(unsigned dimCount, unsigned symbolCount, AffineExpr &&constraint,
+ bool eqFlag)
+ : numDims(dimCount), numSymbols(symbolCount) {
+ constraints.emplace_back(std::move(constraint));
+ eqFlags.emplace_back(eqFlag);
+ }
+
+ unsigned getNumDims() const { return numDims; }
+ unsigned getNumSymbols() const { return numSymbols; }
+ unsigned getNumInputs() const { return numDims + numSymbols; }
+ ArrayRef<AffineExpr> getConstraints() const { return constraints; }
+ unsigned getNumConstraints() const { return constraints.size(); }
+ ArrayRef<bool> getEqFlags() const { return eqFlags; }
+ bool isEq(unsigned idx) const { return eqFlags[idx]; };
+
+ unsigned getNumEqualities() const {
+ unsigned numEqualities = 0;
+ for (unsigned i = 0, e = getNumConstraints(); i < e; i++)
+ if (isEq(i))
+ ++numEqualities;
+ return numEqualities;
+ }
+
+ unsigned getNumInequalities() const {
+ return getNumConstraints() - getNumEqualities();
+ }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+// Convenience operators.
+AffineExpr operator*(AffineExpr &&s, AffineExpr &&o);
+inline AffineExpr operator*(AffineExpr &&s, int64_t o) {
+ return std::move(s) * std::make_unique<AffineConstantExpr>(o);
+}
+inline AffineExpr operator*(int64_t s, AffineExpr &&o) {
+ return std::move(o) * s;
+}
+inline AffineExpr operator+(AffineExpr &&s, AffineExpr &&o) {
+ return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
+ AffineExprKind::Add);
+}
+inline AffineExpr operator+(AffineExpr &&s, int64_t o) {
+ return std::move(s) + std::make_unique<AffineConstantExpr>(o);
+}
+inline AffineExpr operator+(int64_t s, AffineExpr &&o) {
+ return std::move(o) + s;
+}
+inline AffineExpr operator-(AffineExpr &&s, AffineExpr &&o) {
+ return std::move(s) + std::move(o) * -1;
+}
+inline AffineExpr operator%(AffineExpr &&s, AffineExpr &&o) {
+ return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
+ AffineExprKind::Mod);
+}
+inline AffineExpr ceilDiv(AffineExpr &&s, AffineExpr &&o) {
+ return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
+ AffineExprKind::CeilDiv);
+}
+inline AffineExpr floorDiv(AffineExpr &&s, AffineExpr &&o) {
+ return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
+ AffineExprKind::FloorDiv);
+}
+} // namespace mlir::presburger
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
new file mode 100644
index 0000000000000..43e383cba1a0a
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
@@ -0,0 +1,822 @@
+//===- ParserImpl.cpp - Presburger Parser Implementation --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ParserImpl class for the Presburger textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ParserImpl.h"
+#include "Flattener.h"
+#include "ParseStructs.h"
+#include "ParserState.h"
+#include "mlir/Analysis/Presburger/Parser.h"
+
+using namespace mlir;
+using namespace presburger;
+using llvm::MemoryBuffer;
+using llvm::SmallVector;
+using llvm::SourceMgr;
+
+//===----------------------------------------------------------------------===//
+// Parser core
+//===----------------------------------------------------------------------===//
+/// Consume the specified token if present and return success. On failure,
+/// output a diagnostic and return failure.
+ParseResult ParserImpl::parseToken(Token::Kind expectedToken,
+ const Twine &message) {
+ if (consumeIf(expectedToken))
+ return success();
+ return emitWrongTokenError(message);
+}
+
+/// Parse a list of comma-separated items with an optional delimiter. If a
+/// delimiter is provided, then an empty list is allowed. If not, then at
+/// least one element will be parsed.
+ParseResult
+ParserImpl::parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElementFn,
+ StringRef contextMessage) {
+ switch (delimiter) {
+ case Delimiter::None:
+ break;
+ case Delimiter::OptionalParen:
+ if (getToken().isNot(Token::l_paren))
+ return success();
+ [[fallthrough]];
+ case Delimiter::Paren:
+ if (parseToken(Token::l_paren, "expected '('" + contextMessage))
+ return failure();
+ // Check for empty list.
+ if (consumeIf(Token::r_paren))
+ return success();
+ break;
+ case Delimiter::OptionalLessGreater:
+ // Check for absent list.
+ if (getToken().isNot(Token::less))
+ return success();
+ [[fallthrough]];
+ case Delimiter::LessGreater:
+ if (parseToken(Token::less, "expected '<'" + contextMessage))
+ return success();
+ // Check for empty list.
+ if (consumeIf(Token::greater))
+ return success();
+ break;
+ case Delimiter::OptionalSquare:
+ if (getToken().isNot(Token::l_square))
+ return success();
+ [[fallthrough]];
+ case Delimiter::Square:
+ if (parseToken(Token::l_square, "expected '['" + contextMessage))
+ return failure();
+ // Check for empty list.
+ if (consumeIf(Token::r_square))
+ return success();
+ break;
+ case Delimiter::OptionalBraces:
+ if (getToken().isNot(Token::l_brace))
+ return success();
+ [[fallthrough]];
+ case Delimiter::Braces:
+ if (parseToken(Token::l_brace, "expected '{'" + contextMessage))
+ return failure();
+ // Check for empty list.
+ if (consumeIf(Token::r_brace))
+ return success();
+ break;
+ }
+
+ // Non-empty case starts with an element.
+ if (parseElementFn())
+ return failure();
+
+ // Otherwise we have a list of comma separated elements.
+ while (consumeIf(Token::comma)) {
+ if (parseElementFn())
+ return failure();
+ }
+
+ switch (delimiter) {
+ case Delimiter::None:
+ return success();
+ case Delimiter::OptionalParen:
+ case Delimiter::Paren:
+ return parseToken(Token::r_paren, "expected ')'" + contextMessage);
+ case Delimiter::OptionalLessGreater:
+ case Delimiter::LessGreater:
+ return parseToken(Token::greater, "expected '>'" + contextMessage);
+ case Delimiter::OptionalSquare:
+ case Delimiter::Square:
+ return parseToken(Token::r_square, "expected ']'" + contextMessage);
+ case Delimiter::OptionalBraces:
+ case Delimiter::Braces:
+ return parseToken(Token::r_brace, "expected '}'" + contextMessage);
+ }
+ llvm_unreachable("Unknown delimiter");
+}
+
+//===----------------------------------------------------------------------===//
+// Parse error emitters
+//===----------------------------------------------------------------------===//
+ParseResult ParserImpl::emitError(SMLoc loc, const Twine &message) {
+ // If we hit a parse error in response to a lexer error, then the lexer
+ // already reported the error.
+ if (!getToken().is(Token::error))
+ state.sourceMgr.PrintMessage(loc, SourceMgr::DK_Error, message);
+ return failure();
+}
+
+ParseResult ParserImpl::emitError(const Twine &message) {
+ SMLoc loc = state.curToken.getLoc();
+ if (state.curToken.isNot(Token::eof))
+ return emitError(loc, message);
+
+ // If the error is to be emitted at EOF, move it back one character.
+ return emitError(SMLoc::getFromPointer(loc.getPointer() - 1), message);
+}
+
+/// Emit an error about a "wrong token". If the current token is at the
+/// start of a source line, this will apply heuristics to back up and report
+/// the error at the end of the previous line, which is where the expected
+/// token is supposed to be.
+ParseResult ParserImpl::emitWrongTokenError(const Twine &message) {
+ SMLoc loc = state.curToken.getLoc();
+
+ // If the error is to be emitted at EOF, move it back one character.
+ if (state.curToken.is(Token::eof))
+ loc = SMLoc::getFromPointer(loc.getPointer() - 1);
+
+ // This is the location we were originally asked to report the error at.
+ SMLoc originalLoc = loc;
+
+ // Determine if the token is at the start of the current line.
+ const char *bufferStart = state.lex.getBufferBegin();
+ const char *curPtr = loc.getPointer();
+
+ // Use this StringRef to keep track of what we are going to back up through,
+ // it provides nicer string search functions etc.
+ StringRef startOfBuffer(bufferStart, curPtr - bufferStart);
+
+ // Back up over entirely blank lines.
+ while (true) {
+ // Back up until we see a \n, but don't look past the buffer start.
+ startOfBuffer = startOfBuffer.rtrim(" \t");
+
+ // For tokens with no preceding source line, just emit at the original
+ // location.
+ if (startOfBuffer.empty())
+ return emitError(originalLoc, message);
+
+ // If we found something that isn't the end of line, then we're done.
+ if (startOfBuffer.back() != '\n' && startOfBuffer.back() != '\r')
+ return emitError(SMLoc::getFromPointer(startOfBuffer.end()), message);
+
+ // Drop the \n so we emit the diagnostic at the end of the line.
+ startOfBuffer = startOfBuffer.drop_back();
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Affine Expression Parser
+//===----------------------------------------------------------------------===//
+static bool isIdentifier(const Token &token) {
+ // We include only `inttype` and `bare_identifier` here since they are the
+ // only non-keyword tokens that can be used to represent an identifier.
+ return token.isAny(Token::bare_identifier, Token::inttype) ||
+ token.isKeyword();
+}
+
+/// Parse a bare id that may appear in an affine expression.
+///
+/// affine-expr ::= bare-id
+AffineExpr ParserImpl::parseBareIdExpr() {
+ if (!isIdentifier(getToken())) {
+ std::ignore = emitWrongTokenError("expected bare identifier");
+ return nullptr;
+ }
+
+ StringRef sRef = getTokenSpelling();
+ for (const auto &entry : dimsAndSymbols) {
+ if (entry.first == sRef) {
+ consumeToken();
+ // Since every DimExpr or SymbolExpr is used more than once, construct a
+ // fresh unique_ptr every time we encounter it in the dimsAndSymbols list.
+ if (std::holds_alternative<AffineDimExpr>(entry.second))
+ return std::make_unique<AffineDimExpr>(
+ std::get<AffineDimExpr>(entry.second));
+ return std::make_unique<AffineSymbolExpr>(
+ std::get<AffineSymbolExpr>(entry.second));
+ }
+ }
+
+ std::ignore = emitWrongTokenError("use of undeclared identifier");
+ return nullptr;
+}
+
+/// Parse an affine expression inside parentheses.
+///
+/// affine-expr ::= `(` affine-expr `)`
+AffineExpr ParserImpl::parseParentheticalExpr() {
+ if (parseToken(Token::l_paren, "expected '('"))
+ return nullptr;
+ if (getToken().is(Token::r_paren)) {
+ std::ignore = emitError("no expression inside parentheses");
+ return nullptr;
+ }
+
+ AffineExpr expr = parseAffineExpr();
+ if (!expr || parseToken(Token::r_paren, "expected ')'"))
+ return nullptr;
+
+ return expr;
+}
+
+/// Parse the negation expression.
+///
+/// affine-expr ::= `-` affine-expr
+AffineExpr ParserImpl::parseNegateExpression(const AffineExpr &lhs) {
+ if (parseToken(Token::minus, "expected '-'"))
+ return nullptr;
+
+ AffineExpr operand = parseAffineOperandExpr(lhs);
+ // Since negation has the highest precedence of all ops (including high
+ // precedence ops) but lower than parentheses, we are only going to use
+ // parseAffineOperandExpr instead of parseAffineExpr here.
+ if (!operand) {
+ // Extra error message although parseAffineOperandExpr would have
+ // complained. Leads to a better diagnostic.
+ std::ignore = emitError("missing operand of negation");
+ return nullptr;
+ }
+ return -1 * std::move(operand);
+}
+
+/// Parse a positive integral constant appearing in an affine expression.
+///
+/// affine-expr ::= integer-literal
+AffineExpr ParserImpl::parseIntegerExpr() {
+ std::optional<uint64_t> val = getToken().getUInt64IntegerValue();
+ if (!val.has_value() || (int64_t)*val < 0) {
+ std::ignore = emitError("constant too large for index");
+ return nullptr;
+ }
+
+ consumeToken(Token::integer);
+ return std::make_unique<AffineConstantExpr>((int64_t)*val);
+}
+
+/// Parses an expression that can be a valid operand of an affine expression.
+/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
+/// operator, the rhs of which is being parsed. This is used to determine
+/// whether an error should be emitted for a missing right operand.
+// Eg: for an expression without parentheses (like i + j + k + l), each
+// of the four identifiers is an operand. For i + j*k + l, j*k is not an
+// operand expression, it's an op expression and will be parsed via
+// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
+// -l are valid operands that will be parsed by this function.
+AffineExpr ParserImpl::parseAffineOperandExpr(const AffineExpr &lhs) {
+ switch (getToken().getKind()) {
+ case Token::integer:
+ return parseIntegerExpr();
+ case Token::l_paren:
+ return parseParentheticalExpr();
+ case Token::minus:
+ return parseNegateExpression(lhs);
+ case Token::kw_ceildiv:
+ case Token::kw_floordiv:
+ case Token::kw_mod:
+ // Try to treat these tokens as identifiers.
+ return parseBareIdExpr();
+ case Token::plus:
+ case Token::star:
+ if (lhs)
+ std::ignore = emitError("missing right operand of binary operator");
+ else
+ std::ignore = emitError("missing left operand of binary operator");
+ return nullptr;
+ default:
+ // If nothing matches, we try to treat this token as an identifier.
+ if (isIdentifier(getToken()))
+ return parseBareIdExpr();
+
+ if (lhs)
+ std::ignore = emitError("missing right operand of binary operator");
+ else
+ std::ignore = emitError("expected affine expression");
+ return nullptr;
+ }
+}
+
+/// Create an affine binary high precedence op expression (mul's, div's, mod).
+/// opLoc is the location of the op token to be used to report errors
+/// for non-conforming expressions.
+AffineExpr ParserImpl::getAffineBinaryOpExpr(AffineHighPrecOp op,
+ AffineExpr &&lhs, AffineExpr &&rhs,
+ SMLoc opLoc) {
+ switch (op) {
+ case Mul:
+ if (!lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()) {
+ std::ignore = emitError(
+ opLoc, "non-affine expression: at least one of the multiply "
+ "operands has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return std::move(lhs) * std::move(rhs);
+ case FloorDiv:
+ if (!rhs->isSymbolicOrConstant()) {
+ std::ignore =
+ emitError(opLoc, "non-affine expression: right operand of floordiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return floorDiv(std::move(lhs), std::move(rhs));
+ case CeilDiv:
+ if (!rhs->isSymbolicOrConstant()) {
+ std::ignore =
+ emitError(opLoc, "non-affine expression: right operand of ceildiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return ceilDiv(std::move(lhs), std::move(rhs));
+ case Mod:
+ if (!rhs->isSymbolicOrConstant()) {
+ std::ignore =
+ emitError(opLoc, "non-affine expression: right operand of mod "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return std::move(lhs) % std::move(rhs);
+ case HNoOp:
+ llvm_unreachable("can't create affine expression for null high prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineHighPrecOp");
+}
+
+/// Create an affine binary low precedence op expression (add, sub).
+AffineExpr ParserImpl::getAffineBinaryOpExpr(AffineLowPrecOp op,
+ AffineExpr &&lhs,
+ AffineExpr &&rhs) {
+ switch (op) {
+ case AffineLowPrecOp::Add:
+ return std::move(lhs) + std::move(rhs);
+ case AffineLowPrecOp::Sub:
+ return std::move(lhs) - std::move(rhs);
+ case AffineLowPrecOp::LNoOp:
+ llvm_unreachable("can't create affine expression for null low prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineLowPrecOp");
+}
+
+/// Consume this token if it is a lower precedence affine op (there are only
+/// two precedence levels).
+AffineLowPrecOp ParserImpl::consumeIfLowPrecOp() {
+ switch (getToken().getKind()) {
+ case Token::plus:
+ consumeToken(Token::plus);
+ return AffineLowPrecOp::Add;
+ case Token::minus:
+ consumeToken(Token::minus);
+ return AffineLowPrecOp::Sub;
+ default:
+ return AffineLowPrecOp::LNoOp;
+ }
+}
+
+/// Consume this token if it is a higher precedence affine op (there are only
+/// two precedence levels)
+AffineHighPrecOp ParserImpl::consumeIfHighPrecOp() {
+ switch (getToken().getKind()) {
+ case Token::star:
+ consumeToken(Token::star);
+ return Mul;
+ case Token::kw_floordiv:
+ consumeToken(Token::kw_floordiv);
+ return FloorDiv;
+ case Token::kw_ceildiv:
+ consumeToken(Token::kw_ceildiv);
+ return CeilDiv;
+ case Token::kw_mod:
+ consumeToken(Token::kw_mod);
+ return Mod;
+ default:
+ return HNoOp;
+ }
+}
+
+/// Parse a high precedence op expression list: mul, div, and mod are high
+/// precedence binary ops, i.e., parse a
+/// expr_1 op_1 expr_2 op_2 ... expr_n
+/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
+/// All affine binary ops are left associative.
+/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
+/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
+/// null. llhsOpLoc is the location of the llhsOp token that will be used to
+/// report an error for non-conforming expressions.
+AffineExpr ParserImpl::parseAffineHighPrecOpExpr(AffineExpr &&llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc) {
+ AffineExpr lhs = parseAffineOperandExpr(llhs);
+ if (!lhs)
+ return nullptr;
+
+ // Found an LHS. Parse the remaining expression.
+ SMLoc opLoc = getToken().getLoc();
+ if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
+ if (llhs) {
+ AffineExpr expr =
+ getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs), opLoc);
+ if (!expr)
+ return nullptr;
+ return parseAffineHighPrecOpExpr(std::move(expr), op, opLoc);
+ }
+ // No LLHS, get RHS
+ return parseAffineHighPrecOpExpr(std::move(lhs), op, opLoc);
+ }
+
+ // This is the last operand in this expression.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs),
+ llhsOpLoc);
+
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse affine expressions that are bare-id's, integer constants,
+/// parenthetical affine expressions, and affine op expressions that are a
+/// composition of those.
+///
+/// All binary op's associate from left to right.
+///
+/// {add, sub} have lower precedence than {mul, div, and mod}.
+///
+/// Add, sub'are themselves at the same precedence level. Mul, floordiv,
+/// ceildiv, and mod are at the same higher precedence level. Negation has
+/// higher precedence than any binary op.
+///
+/// llhs: the affine expression appearing on the left of the one being parsed.
+/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
+/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
+/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
+/// associativity.
+///
+/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
+/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
+/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
+AffineExpr ParserImpl::parseAffineLowPrecOpExpr(AffineExpr &&llhs,
+ AffineLowPrecOp llhsOp) {
+ AffineExpr lhs = parseAffineOperandExpr(llhs);
+ if (!lhs)
+ return nullptr;
+
+ // Found an LHS. Deal with the ops.
+ if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
+ if (llhs) {
+ AffineExpr sum =
+ getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs));
+ return parseAffineLowPrecOpExpr(std::move(sum), lOp);
+ }
+ // No LLHS, get RHS and form the expression.
+ return parseAffineLowPrecOpExpr(std::move(lhs), lOp);
+ }
+ SMLoc opLoc = getToken().getLoc();
+ if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
+ // We have a higher precedence op here. Get the rhs operand for the llhs
+ // through parseAffineHighPrecOpExpr.
+ AffineExpr highRes = parseAffineHighPrecOpExpr(std::move(lhs), hOp, opLoc);
+ if (!highRes)
+ return nullptr;
+
+ // If llhs is null, the product forms the first operand of the yet to be
+ // found expression. If non-null, the op to associate with llhs is llhsOp.
+ AffineExpr expr = llhs ? getAffineBinaryOpExpr(llhsOp, std::move(llhs),
+ std::move(highRes))
+ : std::move(highRes);
+
+ // Recurse for subsequent low prec op's after the affine high prec op
+ // expression.
+ if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
+ return parseAffineLowPrecOpExpr(std::move(expr), nextOp);
+ return expr;
+ }
+ // Last operand in the expression list.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs));
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse an affine expression.
+/// affine-expr ::= `(` affine-expr `)`
+/// | `-` affine-expr
+/// | affine-expr `+` affine-expr
+/// | affine-expr `-` affine-expr
+/// | affine-expr `*` affine-expr
+/// | affine-expr `floordiv` affine-expr
+/// | affine-expr `ceildiv` affine-expr
+/// | affine-expr `mod` affine-expr
+/// | bare-id
+/// | integer-literal
+///
+/// Additional conditions are checked depending on the production. For eg.,
+/// one of the operands for `*` has to be either constant/symbolic; the second
+/// operand for floordiv, ceildiv, and mod has to be a positive integer.
+AffineExpr ParserImpl::parseAffineExpr() {
+ return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
+}
+
+/// Parse a dim or symbol from the lists appearing before the actual
+/// expressions of the affine map. Update our state to store the
+/// dimensional/symbolic identifier.
+ParseResult ParserImpl::parseIdentifierDefinition(
+ std::variant<AffineDimExpr, AffineSymbolExpr> idExpr) {
+ if (!isIdentifier(getToken()))
+ return emitWrongTokenError("expected bare identifier");
+
+ StringRef name = getTokenSpelling();
+ for (const auto &entry : dimsAndSymbols) {
+ if (entry.first == name)
+ return emitError("redefinition of identifier '" + name + "'");
+ }
+ consumeToken();
+
+ dimsAndSymbols.emplace_back(name, idExpr);
+ return success();
+}
+
+/// Parse the list of dimensional identifiers to an affine map.
+ParseResult ParserImpl::parseDimIdList(unsigned &numDims) {
+ auto parseElt = [&]() -> ParseResult {
+ return parseIdentifierDefinition(AffineDimExpr(numDims++));
+ };
+ return parseCommaSeparatedList(Delimiter::Paren, parseElt,
+ " in dimensional identifier list");
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult ParserImpl::parseSymbolIdList(unsigned &numSymbols) {
+ auto parseElt = [&]() -> ParseResult {
+ return parseIdentifierDefinition(AffineSymbolExpr(numSymbols++));
+ };
+ return parseCommaSeparatedList(Delimiter::Square, parseElt,
+ " in symbol list");
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult ParserImpl::parseDimAndOptionalSymbolIdList(unsigned &numDims,
+ unsigned &numSymbols) {
+ if (parseDimIdList(numDims)) {
+ return failure();
+ }
+ if (!getToken().is(Token::l_square)) {
+ numSymbols = 0;
+ return success();
+ }
+ return parseSymbolIdList(numSymbols);
+}
+
+/// Parse the range and sizes affine map definition inline.
+///
+/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
+///
+/// multi-dim-affine-expr ::= `(` `)`
+/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
+std::optional<AffineMap> ParserImpl::parseAffineMapRange(unsigned numDims,
+ unsigned numSymbols) {
+ SmallVector<AffineExpr, 4> exprs;
+ auto parseElt = [&]() -> ParseResult {
+ AffineExpr elt = parseAffineExpr();
+ ParseResult res = elt ? success() : failure();
+ exprs.emplace_back(std::move(elt));
+ return res;
+ };
+
+ // Parse a multi-dimensional affine expression (a comma-separated list of
+ // 1-d affine expressions). Grammar:
+ // multi-dim-affine-expr ::= `(` `)`
+ // | `(` affine-expr (`,` affine-expr)* `)`
+ if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
+ " in affine map range"))
+ return std::nullopt;
+
+ // Parsed a valid affine map.
+ return AffineMap(numDims, numSymbols, std::move(exprs));
+}
+
+/// Parse an affine constraint.
+/// affine-constraint ::= affine-expr `>=` `affine-expr`
+/// | affine-expr `<=` `affine-expr`
+/// | affine-expr `==` `affine-expr`
+///
+/// The constraint is normalized to
+/// affine-constraint ::= affine-expr `>=` `0`
+/// | affine-expr `==` `0`
+/// before returning.
+///
+/// isEq is set to true if the parsed constraint is an equality, false if it
+/// is an inequality (greater than or equal).
+///
+AffineExpr ParserImpl::parseAffineConstraint(bool *isEq) {
+ AffineExpr lhsExpr = parseAffineExpr();
+ if (!lhsExpr)
+ return nullptr;
+
+ // affine-constraint ::= `affine-expr` `>=` `affine-expr`
+ if (consumeIf(Token::greater) && consumeIf(Token::equal)) {
+ AffineExpr rhsExpr = parseAffineExpr();
+ if (!rhsExpr)
+ return nullptr;
+ *isEq = false;
+ return std::move(lhsExpr) - std::move(rhsExpr);
+ }
+
+ // affine-constraint ::= `affine-expr` `<=` `affine-expr`
+ if (consumeIf(Token::less) && consumeIf(Token::equal)) {
+ AffineExpr rhsExpr = parseAffineExpr();
+ if (!rhsExpr)
+ return nullptr;
+ *isEq = false;
+ return std::move(rhsExpr) - std::move(lhsExpr);
+ }
+
+ // affine-constraint ::= `affine-expr` `==` `affine-expr`
+ if (consumeIf(Token::equal) && consumeIf(Token::equal)) {
+ AffineExpr rhsExpr = parseAffineExpr();
+ if (!rhsExpr)
+ return nullptr;
+ *isEq = true;
+ return std::move(lhsExpr) - std::move(rhsExpr);
+ }
+
+ std::ignore =
+ emitError("expected '== affine-expr' or '>= affine-expr' at end of "
+ "affine constraint");
+ return nullptr;
+}
+
+/// Parse the constraints that are part of an integer set definition.
+/// integer-set-inline
+/// ::= dim-and-symbol-id-lists `:`
+/// '(' affine-constraint-conjunction? ')'
+/// affine-constraint-conjunction ::= affine-constraint (`,`
+/// affine-constraint)*
+///
+std::optional<IntegerSet>
+ParserImpl::parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols) {
+ SmallVector<AffineExpr, 4> constraints;
+ SmallVector<bool, 4> isEqs;
+ auto parseElt = [&]() -> ParseResult {
+ bool isEq;
+ AffineExpr elt = parseAffineConstraint(&isEq);
+ ParseResult res = elt ? success() : failure();
+ if (elt) {
+ constraints.emplace_back(std::move(elt));
+ isEqs.push_back(isEq);
+ }
+ return res;
+ };
+
+ // Parse a list of affine constraints (comma-separated).
+ if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
+ " in integer set constraint list"))
+ return std::nullopt;
+
+ // If no constraints were parsed, then treat this as a degenerate 'true' case.
+ if (constraints.empty()) {
+ /* 0 == 0 */
+ return IntegerSet(numDims, numSymbols,
+ std::make_unique<AffineConstantExpr>(0), true);
+ }
+
+ // Parsed a valid integer set.
+ return IntegerSet(numDims, numSymbols, std::move(constraints),
+ std::move(isEqs));
+}
+
+std::variant<AffineMap, IntegerSet, std::nullopt_t>
+ParserImpl::parseAffineMapOrIntegerSet() {
+ unsigned numDims = 0, numSymbols = 0;
+
+ // List of dimensional and optional symbol identifiers.
+ if (parseDimAndOptionalSymbolIdList(numDims, numSymbols))
+ return std::nullopt;
+
+ if (consumeIf(Token::arrow)) {
+ if (std::optional<AffineMap> v = parseAffineMapRange(numDims, numSymbols))
+ return std::move(*v);
+ return std::nullopt;
+ }
+
+ if (parseToken(Token::colon, "expected '->' or ':'"))
+ return std::nullopt;
+
+ if (std::optional<IntegerSet> v =
+ parseIntegerSetConstraints(numDims, numSymbols))
+ return std::move(*v);
+ return std::nullopt;
+}
+
+static MultiAffineFunction getMultiAffineFunctionFromMap(const AffineMap &map) {
+ IntegerPolyhedron cst(presburger::PresburgerSpace::getSetSpace(0, 0, 0));
+ std::vector<SmallVector<int64_t, 8>> flattenedExprs;
+
+ // Flatten expressions and add them to the constraint system.
+ LogicalResult result = getFlattenedAffineExprs(map, flattenedExprs, cst);
+ assert(result.succeeded() && "Unable to get flattened affine exprs");
+
+ DivisionRepr divs = cst.getLocalReprs();
+ assert(divs.hasAllReprs() &&
+ "AffineMap cannot produce divs without local representation");
+
+ // TODO: We shouldn't have to do this conversion.
+ Matrix<DynamicAPInt> mat(map.getNumExprs(),
+ map.getNumInputs() + divs.getNumDivs() + 1);
+ for (unsigned i = 0; i < flattenedExprs.size(); ++i)
+ for (unsigned j = 0; j < flattenedExprs[i].size(); ++j)
+ mat(i, j) = flattenedExprs[i][j];
+
+ return MultiAffineFunction(
+ PresburgerSpace::getRelationSpace(map.getNumDims(), map.getNumExprs(),
+ map.getNumSymbols(), divs.getNumDivs()),
+ mat, divs);
+}
+
+static IntegerPolyhedron getPolyhedronFromSet(const IntegerSet &set) {
+ IntegerPolyhedron cst(presburger::PresburgerSpace::getSetSpace(0, 0, 0));
+ std::vector<SmallVector<int64_t, 8>> flattenedExprs;
+
+ // Flatten expressions and add them to the constraint system.
+ LogicalResult result = getFlattenedAffineExprs(set, flattenedExprs, cst);
+ assert(result.succeeded() && "Unable to get flattened affine exprs");
+ assert(flattenedExprs.size() == set.getNumConstraints());
+
+ unsigned numInequalities = set.getNumInequalities();
+ unsigned numEqualities = set.getNumEqualities();
+ unsigned numDims = set.getNumDims();
+ unsigned numSymbols = set.getNumSymbols();
+ unsigned numReservedCols = numDims + numSymbols + 1;
+ IntegerPolyhedron poly(
+ numInequalities, numEqualities, numReservedCols,
+ presburger::PresburgerSpace::getSetSpace(numDims, numSymbols, 0));
+ assert(numReservedCols >= poly.getSpace().getNumVars() + 1);
+
+ poly.insertVar(VarKind::Local, poly.getNumVarKind(VarKind::Local),
+ /*num=*/cst.getNumLocalVars());
+
+ for (unsigned i = 0; i < flattenedExprs.size(); ++i) {
+ const auto &flatExpr = flattenedExprs[i];
+ assert(flatExpr.size() == poly.getSpace().getNumVars() + 1);
+ if (set.eqFlags[i])
+ poly.addEquality(flatExpr);
+ else
+ poly.addInequality(flatExpr);
+ }
+ // Add the other constraints involving local vars from flattening.
+ poly.append(cst);
+
+ return poly;
+}
+
+static std::variant<AffineMap, IntegerSet, std::nullopt_t>
+parseAffineMapOrIntegerSet(StringRef str) {
+ SourceMgr sourceMgr;
+ auto memBuffer = MemoryBuffer::getMemBuffer(str, "<mlir_parser_buffer>",
+ /*RequiresNullTerminator=*/false);
+ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+ ParserState state(sourceMgr);
+ ParserImpl parser(state);
+ return parser.parseAffineMapOrIntegerSet();
+}
+
+static AffineMap parseAffineMap(StringRef str) {
+ std::variant<AffineMap, IntegerSet, std::nullopt_t> v =
+ parseAffineMapOrIntegerSet(str);
+ if (std::holds_alternative<AffineMap>(v))
+ return std::move(std::get<AffineMap>(v));
+ llvm_unreachable("expected string to represent AffineMap");
+}
+
+static IntegerSet parseIntegerSet(StringRef str) {
+ std::variant<AffineMap, IntegerSet, std::nullopt_t> v =
+ parseAffineMapOrIntegerSet(str);
+ if (std::holds_alternative<IntegerSet>(v))
+ return std::move(std::get<IntegerSet>(v));
+ llvm_unreachable("expected string to represent IntegerSet");
+}
+
+namespace mlir::presburger {
+IntegerPolyhedron parseIntegerPolyhedron(StringRef str) {
+ return getPolyhedronFromSet(parseIntegerSet(str));
+}
+
+MultiAffineFunction parseMultiAffineFunction(StringRef str) {
+ return getMultiAffineFunctionFromMap(parseAffineMap(str));
+}
+} // namespace mlir::presburger
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h
new file mode 100644
index 0000000000000..c168d09826d51
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h
@@ -0,0 +1,237 @@
+//===- ParserImpl.h - Presburger Parser Implementation ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H
+#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H
+
+#include "ParseStructs.h"
+#include "ParserState.h"
+#include "mlir/Support/LogicalResult.h"
+#include <optional>
+#include <variant>
+
+namespace mlir::presburger {
+template <typename T>
+using function_ref = llvm::function_ref<T>;
+
+/// These are the supported delimiters around operand lists and region
+/// argument lists, used by parseOperandList.
+enum class Delimiter {
+ /// Zero or more operands with no delimiters.
+ None,
+ /// Parens surrounding zero or more operands.
+ Paren,
+ /// Square brackets surrounding zero or more operands.
+ Square,
+ /// <> brackets surrounding zero or more operands.
+ LessGreater,
+ /// {} brackets surrounding zero or more operands.
+ Braces,
+ /// Parens supporting zero or more operands, or nothing.
+ OptionalParen,
+ /// Square brackets supporting zero or more ops, or nothing.
+ OptionalSquare,
+ /// <> brackets supporting zero or more ops, or nothing.
+ OptionalLessGreater,
+ /// {} brackets surrounding zero or more operands, or nothing.
+ OptionalBraces,
+};
+
+/// Lower precedence ops (all at the same precedence level). LNoOp is false in
+/// the boolean sense.
+enum AffineLowPrecOp {
+ /// Null value.
+ LNoOp,
+ Add,
+ Sub
+};
+
+/// Higher precedence ops - all at the same precedence level. HNoOp is false
+/// in the boolean sense.
+enum AffineHighPrecOp {
+ /// Null value.
+ HNoOp,
+ Mul,
+ FloorDiv,
+ CeilDiv,
+ Mod
+};
+
+//===----------------------------------------------------------------------===//
+// Parser
+//===----------------------------------------------------------------------===//
+
+/// This class implement support for parsing global entities like attributes and
+/// types. It is intended to be subclassed by specialized subparsers that
+/// include state.
+class ParserImpl {
+public:
+ ParserImpl(ParserState &state) : state(state) {}
+
+ // Helper methods to get stuff from the parser-global state.
+ ParserState &getState() const { return state; }
+
+ /// Parse a comma-separated list of elements up until the specified end token.
+ ParseResult
+ parseCommaSeparatedListUntil(Token::Kind rightToken,
+ function_ref<ParseResult()> parseElement,
+ bool allowEmptyList = true);
+
+ /// Parse a list of comma-separated items with an optional delimiter. If a
+ /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// least one element will be parsed.
+ ParseResult
+ parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElementFn,
+ StringRef contextMessage = StringRef());
+
+ /// Parse a comma separated list of elements that must have at least one entry
+ /// in it.
+ ParseResult
+ parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
+ return parseCommaSeparatedList(Delimiter::None, parseElementFn);
+ }
+
+ // We have two forms of parsing methods - those that return a non-null
+ // pointer on success, and those that return a ParseResult to indicate whether
+ // they returned a failure. The second class fills in by-reference arguments
+ // as the results of their action.
+
+ //===--------------------------------------------------------------------===//
+ // Error Handling
+ //===--------------------------------------------------------------------===//
+
+ /// Emit an error and return failure.
+ ParseResult emitError(const Twine &message = {});
+ ParseResult emitError(SMLoc loc, const Twine &message = {});
+
+ /// Emit an error about a "wrong token". If the current token is at the
+ /// start of a source line, this will apply heuristics to back up and report
+ /// the error at the end of the previous line, which is where the expected
+ /// token is supposed to be.
+ ParseResult emitWrongTokenError(const Twine &message = {});
+
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Return the current token the parser is inspecting.
+ const Token &getToken() const { return state.curToken; }
+ StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
+
+ /// Return the last parsed token.
+ const Token &getLastToken() const { return state.lastToken; }
+
+ /// If the current token has the specified kind, consume it and return true.
+ /// If not, return false.
+ bool consumeIf(Token::Kind kind) {
+ if (state.curToken.isNot(kind))
+ return false;
+ consumeToken(kind);
+ return true;
+ }
+
+ /// Advance the current lexer onto the next token.
+ void consumeToken() {
+ assert(state.curToken.isNot(Token::eof, Token::error) &&
+ "shouldn't advance past EOF or errors");
+ state.lastToken = state.curToken;
+ state.curToken = state.lex.lexToken();
+ }
+
+ /// Advance the current lexer onto the next token, asserting what the expected
+ /// current token is. This is preferred to the above method because it leads
+ /// to more self-documenting code with better checking.
+ void consumeToken(Token::Kind kind) {
+ assert(state.curToken.is(kind) && "consumed an unexpected token");
+ consumeToken();
+ }
+
+ /// Reset the parser to the given lexer position.
+ void resetToken(const char *tokPos) {
+ state.lex.resetPointer(tokPos);
+ state.lastToken = state.curToken;
+ state.curToken = state.lex.lexToken();
+ }
+
+ /// Consume the specified token if present and return success. On failure,
+ /// output a diagnostic and return failure.
+ ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
+
+ /// Parse an optional integer value from the stream.
+ std::optional<ParseResult> parseOptionalInteger(APInt &result);
+
+ /// Returns true if the current token corresponds to a keyword.
+ bool isCurrentTokenAKeyword() const {
+ return getToken().isAny(Token::bare_identifier, Token::inttype) ||
+ getToken().isKeyword();
+ }
+
+ /// Parse a keyword, if present, into 'keyword'.
+ ParseResult parseOptionalKeyword(StringRef *keyword);
+
+ //===--------------------------------------------------------------------===//
+ // Affine Parsing
+ //===--------------------------------------------------------------------===//
+
+ ParseResult
+ parseAffineExprReference(ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet,
+ AffineExpr &expr);
+ ParseResult
+ parseAffineExprInline(ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet,
+ AffineExpr &expr);
+ std::optional<AffineMap> parseAffineMapRange(unsigned numDims,
+ unsigned numSymbols);
+ std::optional<IntegerSet> parseIntegerSetConstraints(unsigned numDims,
+ unsigned numSymbols);
+ std::variant<AffineMap, IntegerSet, std::nullopt_t>
+ parseAffineMapOrIntegerSet();
+
+private:
+ // Binary affine op parsing.
+ AffineLowPrecOp consumeIfLowPrecOp();
+ AffineHighPrecOp consumeIfHighPrecOp();
+
+ // Identifier lists for polyhedral structures.
+ ParseResult parseDimIdList(unsigned &numDims);
+ ParseResult parseSymbolIdList(unsigned &numSymbols);
+ ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
+ unsigned &numSymbols);
+ ParseResult parseIdentifierDefinition(
+ std::variant<AffineDimExpr, AffineSymbolExpr> idExpr);
+
+ AffineExpr parseAffineExpr();
+ AffineExpr parseParentheticalExpr();
+ AffineExpr parseNegateExpression(const AffineExpr &lhs);
+ AffineExpr parseIntegerExpr();
+ AffineExpr parseBareIdExpr();
+
+ AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr &&lhs,
+ AffineExpr &&rhs, SMLoc opLoc);
+ AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr &&lhs,
+ AffineExpr &&rhs);
+ AffineExpr parseAffineOperandExpr(const AffineExpr &lhs);
+ AffineExpr parseAffineLowPrecOpExpr(AffineExpr &&llhs,
+ AffineLowPrecOp llhsOp);
+ AffineExpr parseAffineHighPrecOpExpr(AffineExpr &&llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc);
+ AffineExpr parseAffineConstraint(bool *isEq);
+
+private:
+ ParserState &state;
+ function_ref<ParseResult(bool)> parseElement;
+ unsigned numDimOperands = 0;
+ unsigned numSymbolOperands = 0;
+ SmallVector<
+ std::pair<StringRef, std::variant<AffineDimExpr, AffineSymbolExpr>>, 4>
+ dimsAndSymbols;
+};
+} // namespace mlir::presburger
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserState.h b/mlir/lib/Analysis/Presburger/Parser/ParserState.h
new file mode 100644
index 0000000000000..2fad3aa46fbb8
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/ParserState.h
@@ -0,0 +1,39 @@
+//===- ParserState.h - MLIR Presburger ParserState --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERSTATE_H
+#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERSTATE_H
+
+#include "Lexer.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir::presburger {
+/// This class refers to all of the state maintained globally by the parser,
+/// such as the current lexer position etc.
+struct ParserState {
+ ParserState(const llvm::SourceMgr &sourceMgr)
+ : sourceMgr(sourceMgr), lex(sourceMgr), curToken(lex.lexToken()),
+ lastToken(Token::error, "") {}
+ ParserState(const ParserState &) = delete;
+ void operator=(const ParserState &) = delete;
+
+ // The source manager for the parser.
+ const llvm::SourceMgr &sourceMgr;
+
+ /// The lexer for the source file we're parsing.
+ Lexer lex;
+
+ /// This is the next token that hasn't been consumed yet.
+ Token curToken;
+
+ /// This is the last token that has been consumed.
+ Token lastToken;
+};
+} // namespace mlir::presburger
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERSTATE_H
diff --git a/mlir/lib/Analysis/Presburger/Parser/Token.cpp b/mlir/lib/Analysis/Presburger/Parser/Token.cpp
new file mode 100644
index 0000000000000..d675e0662c88a
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/Token.cpp
@@ -0,0 +1,64 @@
+//===- Token.cpp - Presburger Token Implementation --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Token class for the Presburger textual form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Token.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/SMLoc.h"
+#include <optional>
+
+using namespace mlir::presburger;
+
+SMLoc Token::getLoc() const { return SMLoc::getFromPointer(spelling.data()); }
+
+SMLoc Token::getEndLoc() const {
+ return SMLoc::getFromPointer(spelling.data() + spelling.size());
+}
+
+SMRange Token::getLocRange() const { return SMRange(getLoc(), getEndLoc()); }
+
+/// For an integer token, return its value as a uint64_t. If it doesn't fit,
+/// return std::nullopt.
+std::optional<uint64_t> Token::getUInt64IntegerValue(StringRef spelling) {
+ uint64_t result = 0;
+ if (spelling.getAsInteger(10, result))
+ return std::nullopt;
+ return result;
+}
+
+/// Given a punctuation or keyword token kind, return the spelling of the
+/// token as a string. Warning: This will abort on markers, identifiers and
+/// literal tokens since they have no fixed spelling.
+StringRef Token::getTokenSpelling(Kind kind) {
+ switch (kind) {
+ default:
+ llvm_unreachable("This token kind has no fixed spelling");
+#define TOK_PUNCTUATION(NAME, SPELLING) \
+ case NAME: \
+ return SPELLING;
+#define TOK_KEYWORD(SPELLING) \
+ case kw_##SPELLING: \
+ return #SPELLING;
+#include "TokenKinds.def"
+ }
+}
+
+/// Return true if this is one of the keyword token kinds (e.g. kw_if).
+bool Token::isKeyword() const {
+ switch (kind) {
+ default:
+ return false;
+#define TOK_KEYWORD(SPELLING) \
+ case kw_##SPELLING: \
+ return true;
+#include "TokenKinds.def"
+ }
+}
diff --git a/mlir/lib/Analysis/Presburger/Parser/Token.h b/mlir/lib/Analysis/Presburger/Parser/Token.h
new file mode 100644
index 0000000000000..5c84c1ba820ee
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/Token.h
@@ -0,0 +1,91 @@
+//===- Token.h - Presburger Token Interface ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H
+#define MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SMLoc.h"
+#include <optional>
+
+namespace mlir::presburger {
+using llvm::SMLoc;
+using llvm::SMRange;
+using llvm::StringRef;
+
+class Token {
+public:
+ enum Kind {
+#define TOK_MARKER(NAME) NAME,
+#define TOK_IDENTIFIER(NAME) NAME,
+#define TOK_LITERAL(NAME) NAME,
+#define TOK_PUNCTUATION(NAME, SPELLING) NAME,
+#define TOK_KEYWORD(SPELLING) kw_##SPELLING,
+#include "TokenKinds.def"
+ };
+
+ Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
+
+ // Return the bytes that make up this token.
+ StringRef getSpelling() const { return spelling; }
+
+ // Token classification.
+ Kind getKind() const { return kind; }
+ bool is(Kind k) const { return kind == k; }
+
+ bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); }
+
+ /// Return true if this token is one of the specified kinds.
+ template <typename... T>
+ bool isAny(Kind k1, Kind k2, Kind k3, T... others) const {
+ if (is(k1))
+ return true;
+ return isAny(k2, k3, others...);
+ }
+
+ bool isNot(Kind k) const { return kind != k; }
+
+ /// Return true if this token isn't one of the specified kinds.
+ template <typename... T>
+ bool isNot(Kind k1, Kind k2, T... others) const {
+ return !isAny(k1, k2, others...);
+ }
+
+ /// Return true if this is one of the keyword token kinds (e.g. kw_if).
+ bool isKeyword() const;
+
+ // Helpers to decode specific sorts of tokens.
+
+ /// For an integer token, return its value as an uint64_t. If it doesn't fit,
+ /// return std::nullopt.
+ static std::optional<uint64_t> getUInt64IntegerValue(StringRef spelling);
+ std::optional<uint64_t> getUInt64IntegerValue() const {
+ return getUInt64IntegerValue(getSpelling());
+ }
+
+ // Location processing.
+ SMLoc getLoc() const;
+ SMLoc getEndLoc() const;
+ SMRange getLocRange() const;
+
+ /// Given a punctuation or keyword token kind, return the spelling of the
+ /// token as a string. Warning: This will abort on markers, identifiers and
+ /// literal tokens since they have no fixed spelling.
+ static StringRef getTokenSpelling(Kind kind);
+
+private:
+ /// Discriminator that indicates the sort of token this is.
+ Kind kind;
+
+ /// A reference to the entire token contents; this is always a pointer into
+ /// a memory buffer owned by the source manager.
+ StringRef spelling;
+};
+} // namespace mlir::presburger
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H
diff --git a/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def
new file mode 100644
index 0000000000000..e7010dfe11954
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def
@@ -0,0 +1,73 @@
+//===- TokenKinds.def - Presburger Parser Token Description -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file is intended to be #include'd multiple times to extract information
+// about tokens for various clients in the lexer.
+//
+//===----------------------------------------------------------------------===//
+
+#if !defined(TOK_MARKER) && !defined(TOK_IDENTIFIER) && \
+ !defined(TOK_LITERAL) && !defined(TOK_PUNCTUATION) && \
+ !defined(TOK_KEYWORD)
+#error Must define one of the TOK_ macros.
+#endif
+
+#ifndef TOK_MARKER
+#define TOK_MARKER(X)
+#endif
+#ifndef TOK_IDENTIFIER
+#define TOK_IDENTIFIER(NAME)
+#endif
+#ifndef TOK_LITERAL
+#define TOK_LITERAL(NAME)
+#endif
+#ifndef TOK_PUNCTUATION
+#define TOK_PUNCTUATION(NAME, SPELLING)
+#endif
+#ifndef TOK_KEYWORD
+#define TOK_KEYWORD(SPELLING)
+#endif
+
+// Markers
+TOK_MARKER(eof)
+TOK_MARKER(error)
+
+// Identifiers.
+TOK_IDENTIFIER(bare_identifier) // foo
+
+// Literals
+TOK_LITERAL(integer) // 42
+TOK_LITERAL(inttype) // i4, si8, ui16
+
+// Punctuation.
+TOK_PUNCTUATION(arrow, "->")
+TOK_PUNCTUATION(colon, ":")
+TOK_PUNCTUATION(comma, ",")
+TOK_PUNCTUATION(equal, "=")
+TOK_PUNCTUATION(greater, ">")
+TOK_PUNCTUATION(l_brace, "{")
+TOK_PUNCTUATION(l_paren, "(")
+TOK_PUNCTUATION(l_square, "[")
+TOK_PUNCTUATION(less, "<")
+TOK_PUNCTUATION(minus, "-")
+TOK_PUNCTUATION(plus, "+")
+TOK_PUNCTUATION(r_brace, "}")
+TOK_PUNCTUATION(r_paren, ")")
+TOK_PUNCTUATION(r_square, "]")
+TOK_PUNCTUATION(star, "*")
+
+// Keywords. These turn "foo" into Token::kw_foo enums.
+TOK_KEYWORD(ceildiv)
+TOK_KEYWORD(floordiv)
+TOK_KEYWORD(mod)
+
+#undef TOK_MARKER
+#undef TOK_IDENTIFIER
+#undef TOK_LITERAL
+#undef TOK_PUNCTUATION
+#undef TOK_KEYWORD
diff --git a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
index 5e279b542fdf9..135697ec2d6a5 100644
--- a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
@@ -1,6 +1,6 @@
#include "mlir/Analysis/Presburger/Barvinok.h"
-#include "./Utils.h"
-#include "Parser.h"
+#include "Utils.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
index b69f514711337..d65e32917a35d 100644
--- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
@@ -6,7 +6,6 @@ add_mlir_unittest(MLIRPresburgerTests
IntegerRelationTest.cpp
LinearTransformTest.cpp
MatrixTest.cpp
- Parser.h
ParserTest.cpp
PresburgerSetTest.cpp
PresburgerRelationTest.cpp
@@ -19,6 +18,5 @@ add_mlir_unittest(MLIRPresburgerTests
target_link_libraries(MLIRPresburgerTests
PRIVATE MLIRPresburger
- MLIRAffineAnalysis
- MLIRParser
+ MLIRPresburgerParser
)
diff --git a/mlir/unittests/Analysis/Presburger/FractionTest.cpp b/mlir/unittests/Analysis/Presburger/FractionTest.cpp
index c9fad953dacd5..92bdd57d6cf82 100644
--- a/mlir/unittests/Analysis/Presburger/FractionTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/FractionTest.cpp
@@ -1,5 +1,4 @@
#include "mlir/Analysis/Presburger/Fraction.h"
-#include "./Utils.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp
index 3fc68cddaad00..759f4d6cc5194 100644
--- a/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/GeneratingFunction.h"
-#include "./Utils.h"
+#include "Utils.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index f64bb240b4ee4..bc0f7a0426d08 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "Parser.h"
#include "Utils.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include <gmock/gmock.h>
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..a6e02f3982936 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/IntegerRelation.h"
-#include "Parser.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "mlir/Analysis/Presburger/Simplex.h"
diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index cb8df8b346011..2da49d7793b72 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -7,8 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/Matrix.h"
-#include "./Utils.h"
+#include "Utils.h"
#include "mlir/Analysis/Presburger/Fraction.h"
+#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
index ee2931e78185c..22557414b8cb0 100644
--- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
@@ -10,11 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#include "Parser.h"
-
#include "mlir/Analysis/Presburger/PWMAFunction.h"
-#include "mlir/Analysis/Presburger/PresburgerRelation.h"
-#include "mlir/IR/MLIRContext.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/ParserTest.cpp b/mlir/unittests/Analysis/Presburger/ParserTest.cpp
index 06b728cd1a8fa..81026cf4ee0a5 100644
--- a/mlir/unittests/Analysis/Presburger/ParserTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/ParserTest.cpp
@@ -13,7 +13,7 @@
//
//===----------------------------------------------------------------------===//
-#include "Parser.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
index ad71bb32a0688..b8e21dcf76200 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
-#include "Parser.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include <gmock/gmock.h>
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
index 8e31a8bb2030b..ec265a71b38cd 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
@@ -14,10 +14,9 @@
//
//===----------------------------------------------------------------------===//
-#include "Parser.h"
#include "Utils.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
-#include "mlir/IR/MLIRContext.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp
index a84f0234067ab..bc3d788aeb4c3 100644
--- a/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/QuasiPolynomial.h"
-#include "./Utils.h"
+#include "Utils.h"
#include "mlir/Analysis/Presburger/Fraction.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -137,4 +137,4 @@ TEST(QuasiPolynomialTest, simplify) {
{{{Fraction(1, 1), Fraction(3, 4), Fraction(5, 3)},
{Fraction(2, 1), Fraction(0, 1), Fraction(0, 1)}},
{{Fraction(1, 1), Fraction(4, 5), Fraction(6, 5)}}}));
-}
\ No newline at end of file
+}
diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
index 63d0243808555..59fe453eb6a05 100644
--- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
@@ -6,11 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "Parser.h"
-#include "Utils.h"
-
#include "mlir/Analysis/Presburger/Simplex.h"
-#include "mlir/IR/MLIRContext.h"
+#include "mlir/Analysis/Presburger/Parser.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index ef4429b5c6bc8..279dd66a8aee0 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -14,13 +14,8 @@
#define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_UTILS_H
#include "mlir/Analysis/Presburger/GeneratingFunction.h"
-#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/Matrix.h"
-#include "mlir/Analysis/Presburger/PWMAFunction.h"
-#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/QuasiPolynomial.h"
-#include "mlir/Analysis/Presburger/Simplex.h"
-#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"
#include <gtest/gtest.h>
>From 581fad781c8a86b245108b75a2f1f9c2cf4073bb Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Mon, 17 Jun 2024 13:08:16 +0100
Subject: [PATCH 3/5] Write a fresh parser
---
.../include/mlir/Analysis/Presburger/Matrix.h | 3 +
mlir/lib/Analysis/Presburger/Matrix.cpp | 8 +
.../Analysis/Presburger/Parser/Flattener.cpp | 428 ++-----------
.../Analysis/Presburger/Parser/Flattener.h | 251 ++------
.../Presburger/Parser/ParseStructs.cpp | 402 ++++++------
.../Analysis/Presburger/Parser/ParseStructs.h | 506 ++++++++-------
.../Analysis/Presburger/Parser/ParserImpl.cpp | 597 ++++++------------
.../Analysis/Presburger/Parser/ParserImpl.h | 185 ++----
.../Analysis/Presburger/Parser/ParserState.h | 39 --
9 files changed, 848 insertions(+), 1571 deletions(-)
delete mode 100644 mlir/lib/Analysis/Presburger/Parser/ParserState.h
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index e232ecd5e1509..aca4af4d43543 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -102,6 +102,9 @@ class Matrix {
/// Set the specified row to `elems`.
void setRow(unsigned row, ArrayRef<T> elems);
+ /// Add the specified row to `elems`.
+ void addToRow(unsigned row, ArrayRef<T> elems);
+
/// Insert columns having positions pos, pos + 1, ... pos + count - 1.
/// Columns that were at positions 0 to pos - 1 will stay where they are;
/// columns that were at positions pos to nColumns - 1 will be pushed to the
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 134b805648d9f..f70288063a8ba 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -145,6 +145,14 @@ void Matrix<T>::setRow(unsigned row, ArrayRef<T> elems) {
at(row, i) = elems[i];
}
+template <typename T>
+void Matrix<T>::addToRow(unsigned row, ArrayRef<T> elems) {
+ assert(elems.size() == getNumColumns() &&
+ "elems size must match row length!");
+ for (unsigned i = 0, e = getNumColumns(); i < e; ++i)
+ at(row, i) += elems[i];
+}
+
template <typename T>
void Matrix<T>::insertColumn(unsigned pos) {
insertColumns(pos, 1);
diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp
index 4ebcf6c676672..90f645a9f20bd 100644
--- a/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp
+++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp
@@ -12,397 +12,77 @@
//===----------------------------------------------------------------------===//
#include "Flattener.h"
-#include "llvm/ADT/SmallVector.h"
+#include "ParseStructs.h"
using namespace mlir;
using namespace presburger;
-using llvm::SmallVector;
-AffineExpr AffineExprFlattener::getAffineExprFromFlatForm(
- ArrayRef<int64_t> flatExprs, unsigned numDims, unsigned numSymbols) {
- assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
- "unexpected number of local expressions");
-
- // Dimensions and symbols.
- AffineExpr expr = std::make_unique<AffineConstantExpr>(0);
- for (unsigned j = 0; j < getLocalVarStartIndex(); ++j) {
- if (flatExprs[j] == 0)
- continue;
- if (j < numDims)
- expr =
- std::move(expr) + std::make_unique<AffineDimExpr>(j) * flatExprs[j];
- else
- expr = std::move(expr) +
- std::make_unique<AffineSymbolExpr>(j - numDims) * flatExprs[j];
- }
-
- // Local identifiers.
- for (unsigned j = getLocalVarStartIndex(); j < flatExprs.size() - 1; ++j) {
- if (flatExprs[j] == 0)
- continue;
- // It is safe to move out of the localExprs vector, since no expr is used
- // more than once.
- AffineExpr term =
- std::move(localExprs[j - getLocalVarStartIndex()]) * flatExprs[j];
- expr = std::move(expr) + std::move(term);
- }
-
- // Constant term.
- int64_t constTerm = flatExprs[flatExprs.size() - 1];
- if (constTerm != 0)
- return std::move(expr) + constTerm;
- return expr;
-}
-
-// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
-//
-// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
-// introduce a local variable p (= expr * symbolic_expr), and the affine
-// expression expr * symbolic_expr is added to `localExprs`.
-LogicalResult AffineExprFlattener::visitMulExpr(const AffineBinOpExpr &expr) {
- assert(operandExprStack.size() >= 2);
- SmallVector<int64_t, 8> rhs = operandExprStack.back();
- operandExprStack.pop_back();
- SmallVector<int64_t, 8> &lhs = operandExprStack.back();
-
- // Flatten semi-affine multiplication expressions by introducing a local
- // variable in place of the product; the affine expression
- // corresponding to the quantifier is added to `localExprs`.
- if (!isa<AffineConstantExpr>(expr.getRHS())) {
- AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols);
- AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols);
- addLocalVariableSemiAffine(std::move(a) * std::move(b), lhs, lhs.size());
- return success();
+void Flattener::visitDiv(const PureAffineExprImpl &div) {
+ int64_t divisor = div.getDivisor();
+
+ // First construct the linear part of the divisor.
+ auto dividend = div.collectLinearTerms().getPadded(
+ info.getLocalVarStartIdx() + localExprs.size() + 1);
+
+ // Next, insert the non-linear coefficients.
+ for (const auto &[hash, adjustedMulFactor, adjustedLinearTerm] :
+ div.getNonLinearCoeffs()) {
+
+ // adjustedMulFactor will be mulFactor * -divisor in case of mod, and
+ // mulFactor in case of floordiv.
+ dividend[lookupLocal(hash)] = adjustedMulFactor;
+
+ // The expansion is either a modExpansion which we previously stored, or the
+ // adjustedLinearTerm, which is correct in the case when we're encountering
+ // the innermost mod for the first time.
+ CoefficientVector expansion =
+ lookupModExpansion(hash)
+ .value_or(adjustedLinearTerm)
+ .getPadded(info.getLocalVarStartIdx() + localExprs.size() + 1);
+ dividend += expansion;
+
+ // If this is a mod, insert the new computed expansion, which is the
+ // dividend * mulFactor.
+ if (div.isMod())
+ localModExpansion.insert({div.hash(), dividend * div.getMulFactor()});
}
- // Get the RHS constant.
- int64_t rhsConst = rhs[getConstantIndex()];
- for (int64_t &lhsElt : lhs)
- lhsElt *= rhsConst;
-
- return success();
-}
-
-LogicalResult AffineExprFlattener::visitAddExpr(const AffineBinOpExpr &expr) {
- assert(operandExprStack.size() >= 2);
- const auto &rhs = operandExprStack.back();
- auto &lhs = operandExprStack[operandExprStack.size() - 2];
- assert(lhs.size() == rhs.size());
- // Update the LHS in place.
- for (unsigned i = 0; i < rhs.size(); ++i)
- lhs[i] += rhs[i];
- // Pop off the RHS.
- operandExprStack.pop_back();
- return success();
+ cst.addLocalFloorDiv(dividend, divisor);
+ localExprs.insert(div.hash());
}
-//
-// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
-//
-// A mod expression "expr mod c" is thus flattened by introducing a new local
-// variable q (= expr floordiv c), such that expr mod c is replaced with
-// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
-//
-// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
-// introduce a local variable m (= expr mod symbolic_expr), and the affine
-// expression expr mod symbolic_expr is added to `localExprs`.
-LogicalResult AffineExprFlattener::visitModExpr(const AffineBinOpExpr &expr) {
- assert(operandExprStack.size() >= 2);
-
- SmallVector<int64_t, 8> rhs = operandExprStack.back();
- operandExprStack.pop_back();
- SmallVector<int64_t, 8> &lhs = operandExprStack.back();
+void Flattener::flatten(unsigned row, PureAffineExprImpl &div) {
+ // Visit divs inner to outer.
+ for (auto &nestedDiv : div.getNestedDivTerms())
+ flatten(row, *nestedDiv);
- // Flatten semi affine modulo expressions by introducing a local
- // variable in place of the modulo value, and the affine expression
- // corresponding to the quantifier is added to `localExprs`.
- if (!isa<AffineConstantExpr>(expr.getRHS())) {
- AffineExpr dividendExpr =
- getAffineExprFromFlatForm(lhs, numDims, numSymbols);
- AffineExpr divisorExpr =
- getAffineExprFromFlatForm(rhs, numDims, numSymbols);
- AffineExpr modExpr = std::move(dividendExpr) % std::move(divisorExpr);
- addLocalVariableSemiAffine(std::move(modExpr), lhs, lhs.size());
- return success();
+ if (div.hasDivisor()) {
+ visitDiv(div);
+ return;
}
- int64_t rhsConst = rhs[getConstantIndex()];
- if (rhsConst <= 0)
- return failure();
+ // Hit multiple times every time we have a linear sub-expression, but the
+ // row is overwritten to consider only the outermost div, which is hit
+ // last.
- // Check if the LHS expression is a multiple of modulo factor.
- unsigned i;
- for (i = 0; i < lhs.size(); ++i)
- if (lhs[i] % rhsConst != 0)
- break;
- // If yes, modulo expression here simplifies to zero.
- if (i == lhs.size()) {
- std::fill(lhs.begin(), lhs.end(), 0);
- return success();
- }
-
- // Add a local variable for the quotient, i.e., expr % c is replaced by
- // (expr - q * c) where q = expr floordiv c. Do this while canceling out
- // the GCD of expr and c.
- SmallVector<int64_t, 8> floorDividend(lhs);
- uint64_t gcd = rhsConst;
- for (int64_t lhsElt : lhs)
- gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
- // Simplify the numerator and the denominator.
- if (gcd != 1) {
- for (int64_t &floorDividendElt : floorDividend)
- floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
- }
- int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
-
- // Construct the AffineExpr form of the floordiv to store in localExprs.
+ // Set the linear part of the row.
+ setRow(row, div.getLinearDividend());
- AffineExpr dividendExpr =
- getAffineExprFromFlatForm(floorDividend, numDims, numSymbols);
- AffineExpr divisorExpr = std::make_unique<AffineConstantExpr>(floorDivisor);
- AffineExpr floorDivExpr =
- floorDiv(std::move(dividendExpr), std::move(divisorExpr));
- int loc;
- if ((loc = findLocalId(floorDivExpr)) == -1) {
- addLocalFloorDivId(floorDividend, floorDivisor, std::move(floorDivExpr));
- // Set result at top of stack to "lhs - rhsConst * q".
- lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
- } else {
- // Reuse the existing local id.
- lhs[getLocalVarStartIndex() + loc] = -rhsConst;
+ // Set the non-linear coefficients.
+ for (const auto &[hash, adjustedMulFactor, adjustedLinearTerm] :
+ div.getNonLinearCoeffs()) {
+ flatMatrix(row, lookupLocal(hash)) = adjustedMulFactor;
+ CoefficientVector expansion =
+ lookupModExpansion(hash).value_or(adjustedLinearTerm);
+ addToRow(row, expansion);
}
- return success();
-}
-
-LogicalResult
-AffineExprFlattener::visitCeilDivExpr(const AffineBinOpExpr &expr) {
- return visitDivExpr(expr, /*isCeil=*/true);
-}
-LogicalResult
-AffineExprFlattener::visitFloorDivExpr(const AffineBinOpExpr &expr) {
- return visitDivExpr(expr, /*isCeil=*/false);
-}
-
-LogicalResult AffineExprFlattener::visitDimExpr(const AffineDimExpr &expr) {
- operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
- auto &eq = operandExprStack.back();
- assert(expr.getPosition() < numDims && "Inconsistent number of dims");
- eq[getDimStartIndex() + expr.getPosition()] = 1;
- return success();
-}
-
-LogicalResult
-AffineExprFlattener::visitSymbolExpr(const AffineSymbolExpr &expr) {
- operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
- auto &eq = operandExprStack.back();
- assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
- eq[getSymbolStartIndex() + expr.getPosition()] = 1;
- return success();
-}
-
-LogicalResult
-AffineExprFlattener::visitConstantExpr(const AffineConstantExpr &expr) {
- operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
- auto &eq = operandExprStack.back();
- eq[getConstantIndex()] = expr.getValue();
- return success();
-}
-
-void AffineExprFlattener::addLocalVariableSemiAffine(
- AffineExpr &&expr, SmallVectorImpl<int64_t> &result,
- unsigned long resultSize) {
- assert(result.size() == resultSize && "result vector size mismatch");
- int loc;
- if ((loc = findLocalId(expr)) == -1)
- addLocalIdSemiAffine(std::move(expr));
- std::fill(result.begin(), result.end(), 0);
- if (loc == -1)
- result[getLocalVarStartIndex() + numLocals - 1] = 1;
- else
- result[getLocalVarStartIndex() + loc] = 1;
}
-// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
-// A floordiv is thus flattened by introducing a new local variable q, and
-// replacing that expression with 'q' while adding the constraints
-// c * q <= expr <= c * q + c - 1 to localVarCst (done by
-// IntegerRelation::addLocalFloorDiv).
-//
-// A ceildiv is similarly flattened:
-// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
-//
-// In case of semi affine division expressions, t = expr floordiv symbolic_expr
-// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
-// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
-// `localExprs`.
-LogicalResult AffineExprFlattener::visitDivExpr(const AffineBinOpExpr &expr,
- bool isCeil) {
- assert(operandExprStack.size() >= 2);
-
- SmallVector<int64_t, 8> rhs = operandExprStack.back();
- operandExprStack.pop_back();
- SmallVector<int64_t, 8> &lhs = operandExprStack.back();
-
- // Flatten semi affine division expressions by introducing a local
- // variable in place of the quotient, and the affine expression corresponding
- // to the quantifier is added to `localExprs`.
- if (!isa<AffineConstantExpr>(expr.getRHS())) {
- AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols);
- AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols);
- AffineExpr divExpr = isCeil ? ceilDiv(std::move(a), std::move(b))
- : floorDiv(std::move(a), std::move(b));
- addLocalVariableSemiAffine(std::move(divExpr), lhs, lhs.size());
- return success();
- }
-
- // This is a pure affine expr; the RHS is a positive constant.
- int64_t rhsConst = rhs[getConstantIndex()];
- if (rhsConst <= 0)
- return failure();
-
- // Simplify the floordiv, ceildiv if possible by canceling out the greatest
- // common divisors of the numerator and denominator.
- uint64_t gcd = std::abs(rhsConst);
- for (int64_t lhsElt : lhs)
- gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
- // Simplify the numerator and the denominator.
- if (gcd != 1) {
- for (int64_t &lhsElt : lhs)
- lhsElt = lhsElt / static_cast<int64_t>(gcd);
- }
- int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
- // If the divisor becomes 1, the updated LHS is the result. (The
- // divisor can't be negative since rhsConst is positive).
- if (divisor == 1)
- return success();
-
- // If the divisor cannot be simplified to one, we will have to retain
- // the ceil/floor expr (simplified up until here). Add an existential
- // quantifier to express its result, i.e., expr1 div expr2 is replaced
- // by a new identifier, q.
- AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols);
- AffineExpr b = std::make_unique<AffineConstantExpr>(divisor);
-
- int loc;
- AffineExpr divExpr = isCeil ? ceilDiv(std::move(a), std::move(b))
- : floorDiv(std::move(a), std::move(b));
- if ((loc = findLocalId(divExpr)) == -1) {
- if (!isCeil) {
- SmallVector<int64_t, 8> dividend(lhs);
- addLocalFloorDivId(dividend, divisor, std::move(divExpr));
- } else {
- // lhs ceildiv c <=> (lhs + c - 1) floordiv c
- SmallVector<int64_t, 8> dividend(lhs);
- dividend.back() += divisor - 1;
- addLocalFloorDivId(dividend, divisor, std::move(divExpr));
- }
- }
- // Set the expression on stack to the local var introduced to capture the
- // result of the division (floor or ceil).
- std::fill(lhs.begin(), lhs.end(), 0);
- if (loc == -1)
- lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
- else
- lhs[getLocalVarStartIndex() + loc] = 1;
- return success();
-}
-
-void AffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
- int64_t divisor,
- AffineExpr &&localExpr) {
- assert(divisor > 0 && "positive constant divisor expected");
- for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
- subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
- localExprs.emplace_back(std::move(localExpr));
- ++numLocals;
- // Update localVarCst.
- localVarCst.addLocalFloorDiv(dividend, divisor);
-}
-
-void AffineExprFlattener::addLocalIdSemiAffine(AffineExpr &&localExpr) {
- for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
- subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
- localExprs.emplace_back(std::move(localExpr));
- ++numLocals;
-}
-
-int AffineExprFlattener::findLocalId(const AffineExpr &localExpr) {
- auto *it = llvm::find(localExprs, localExpr);
- if (it == localExprs.end())
- return -1;
- return it - localExprs.begin();
-}
-
-AffineExprFlattener::AffineExprFlattener(unsigned numDims, unsigned numSymbols)
- : numDims(numDims), numSymbols(numSymbols), numLocals(0),
- localVarCst(PresburgerSpace::getSetSpace(numDims, numSymbols)) {
- operandExprStack.reserve(8);
-}
-
-// Flattens the expressions in map. Returns failure if 'expr' was unable to be
-// flattened. For example two specific cases:
-// 1. semi-affine expressions not handled yet.
-// 2. has poison expression (i.e., division by zero).
-static LogicalResult
-getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
- unsigned numSymbols,
- std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
- IntegerPolyhedron &localVarCst) {
- if (exprs.empty()) {
- localVarCst = IntegerPolyhedron(
- 0, 0, numDims + numSymbols + 1,
- presburger::PresburgerSpace::getSetSpace(numDims, numSymbols, 0));
- return success();
- }
-
- AffineExprFlattener flattener(numDims, numSymbols);
+std::pair<IntMatrix, IntegerPolyhedron> Flattener::flatten() {
// Use the same flattener to simplify each expression successively. This way
// local variables / expressions are shared.
- for (const AffineExpr &expr : exprs) {
- if (!expr->isPureAffine())
- return failure();
- // has poison expression
- LogicalResult flattenResult = flattener.walkPostOrder(*expr);
- if (failed(flattenResult))
- return failure();
- }
-
- assert(flattener.operandExprStack.size() == exprs.size());
- flattenedExprs.clear();
- flattenedExprs.assign(flattener.operandExprStack.begin(),
- flattener.operandExprStack.end());
+ for (const auto &[row, expr] : enumerate(exprs))
+ flatten(row, *expr);
- localVarCst.clearAndCopyFrom(flattener.localVarCst);
-
- return success();
-}
-
-namespace mlir::presburger {
-LogicalResult
-getFlattenedAffineExprs(const AffineMap &map,
- std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
- IntegerPolyhedron &cst) {
- if (map.getNumExprs() == 0) {
- cst = IntegerPolyhedron(0, 0, map.getNumDims() + map.getNumSymbols() + 1,
- presburger::PresburgerSpace::getSetSpace(
- map.getNumDims(), map.getNumSymbols(), 0));
- return success();
- }
- return ::getFlattenedAffineExprs(map.getExprs(), map.getNumDims(),
- map.getNumSymbols(), flattenedExprs, cst);
-}
-
-LogicalResult
-getFlattenedAffineExprs(const IntegerSet &set,
- std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
- IntegerPolyhedron &cst) {
- if (set.getNumConstraints() == 0) {
- cst = IntegerPolyhedron(0, 0, set.getNumDims() + set.getNumSymbols() + 1,
- presburger::PresburgerSpace::getSetSpace(
- set.getNumDims(), set.getNumSymbols(), 0));
- return success();
- }
- return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
- set.getNumSymbols(), flattenedExprs, cst);
+ return {flatMatrix, cst};
}
-} // namespace mlir::presburger
diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.h b/mlir/lib/Analysis/Presburger/Parser/Flattener.h
index abbdc002c5545..4a5d60faddd53 100644
--- a/mlir/lib/Analysis/Presburger/Parser/Flattener.h
+++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.h
@@ -11,234 +11,57 @@
#include "ParseStructs.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
namespace mlir::presburger {
-// This class is used to flatten a pure affine expression (AffineExpr,
-// which is in a tree form) into a sum of products (w.r.t constants) when
-// possible, and in that process simplifying the expression. For a modulo,
-// floordiv, or a ceildiv expression, an additional identifier, called a local
-// identifier, is introduced to rewrite the expression as a sum of product
-// affine expression. Each local identifier is always and by construction a
-// floordiv of a pure add/mul affine function of dimensional, symbolic, and
-// other local identifiers, in a non-mutually recursive way. Hence, every local
-// identifier can ultimately always be recovered as an affine function of
-// dimensional and symbolic identifiers (involving floordiv's); note however
-// that by AffineExpr construction, some floordiv combinations are converted to
-// mod's. The result of the flattening is a flattened expression and a set of
-// constraints involving just the local variables.
-//
-// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
-// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
-//
-// The simplification performed includes the accumulation of contributions for
-// each dimensional and symbolic identifier together, the simplification of
-// floordiv/ceildiv/mod expressions and other simplifications that in turn
-// happen as a result. A simplification that this flattening naturally performs
-// is of simplifying the numerator and denominator of floordiv/ceildiv, and
-// folding a modulo expression to a zero, if possible. Three examples are below:
-//
-// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
-// (d0 - d0 mod 4 + 4) mod 4 simplified to 0
-// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
-//
-// The way the flattening works for the second example is as follows: d0 % 4 is
-// replaced by d0 - 4*q with q being introduced: the expression then simplifies
-// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
-// zero. Note that an affine expression may not always be expressible purely as
-// a sum of products involving just the original dimensional and symbolic
-// identifiers due to the presence of modulo/floordiv/ceildiv expressions that
-// may not be eliminated after simplification; in such cases, the final
-// expression can be reconstructed by replacing the local identifiers with their
-// corresponding explicit form stored in 'localExprs' (note that each of the
-// explicit forms itself would have been simplified).
-//
-// The expression walk method here performs a linear time post order walk that
-// performs the above simplifications through visit methods, with partial
-// results being stored in 'operandExprStack'. When a parent expr is visited,
-// the flattened expressions corresponding to its two operands would already be
-// on the stack - the parent expression looks at the two flattened expressions
-// and combines the two. It pops off the operand expressions and pushes the
-// combined result (although this is done in-place on its LHS operand expr).
-// When the walk is completed, the flattened form of the top-level expression
-// would be left on the stack.
-//
-// A flattener can be repeatedly used for multiple affine expressions that bind
-// to the same operands, for example, for all result expressions of an
-// AffineMap or AffineValueMap. In such cases, using it for multiple expressions
-// is more efficient than creating a new flattener for each expression since
-// common identical div and mod expressions appearing across different
-// expressions are mapped to the same local identifier (same column position in
-// 'localVarCst').
-class AffineExprFlattener {
-public:
- // Flattend expression layout: [dims, symbols, locals, constant]
- // Stack that holds the LHS and RHS operands while visiting a binary op expr.
- // In future, consider adding a prepass to determine how big the SmallVector's
- // will be, and linearize this to std::vector<int64_t> to prevent
- // SmallVector moves on re-allocation.
- std::vector<SmallVector<int64_t, 8>> operandExprStack;
-
- unsigned numDims;
- unsigned numSymbols;
-
- // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
- unsigned numLocals;
+using llvm::SmallDenseMap;
+using llvm::SmallSetVector;
- // Constraints connecting newly introduced local variables (for mod's and
- // div's) to existing (dimensional and symbolic) ones. These are always
- // inequalities.
- IntegerPolyhedron localVarCst;
+class Flattener : public FinalParseResult {
+public:
+ // The final flattened result is stored here.
+ IntMatrix flatMatrix;
- // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
- // which new identifiers were introduced; if the latter do not get canceled
- // out, these expressions can be readily used to reconstruct the AffineExpr
- // (tree) form. Note that these expressions themselves would have been
- // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
- // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
- // ceildiv 2 would be the local expression stored for q.
- SmallVector<AffineExpr, 4> localExprs;
+ // We maintain a set of divs that we have seen while flattening. The size of
+ // this set we at most info.numDivs, hitting info.numDivs at the end of the
+ // flattening, if that expression contains all the possible divs.
+ SmallSetVector<size_t, 4> localExprs;
- AffineExprFlattener(unsigned numDims, unsigned numSymbols);
+ // We maintain a mapping between local mods and their expansions. The vector
+ // is the dividend.
+ SmallDenseMap<size_t, CoefficientVector, 2> localModExpansion;
- virtual ~AffineExprFlattener() = default;
+ Flattener(FinalParseResult &&parseResult)
+ : FinalParseResult(std::move(parseResult)),
+ flatMatrix(info.numExprs, info.getNumCols()) {}
- // Visitor methods.
- LogicalResult visitMulExpr(const AffineBinOpExpr &expr);
- LogicalResult visitAddExpr(const AffineBinOpExpr &expr);
- LogicalResult visitDimExpr(const AffineDimExpr &expr);
- LogicalResult visitSymbolExpr(const AffineSymbolExpr &expr);
- LogicalResult visitConstantExpr(const AffineConstantExpr &expr);
- LogicalResult visitCeilDivExpr(const AffineBinOpExpr &expr);
- LogicalResult visitFloorDivExpr(const AffineBinOpExpr &expr);
+ std::pair<IntMatrix, IntegerPolyhedron> flatten();
- //
- // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
- //
- // A mod expression "expr mod c" is thus flattened by introducing a new local
- // variable q (= expr floordiv c), such that expr mod c is replaced with
- // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
- LogicalResult visitModExpr(const AffineBinOpExpr &expr);
+private:
+ void flatten(unsigned row, PureAffineExprImpl &div);
+ void visitDiv(const PureAffineExprImpl &div);
- // Function to walk an AffineExpr (in post order).
- LogicalResult walkPostOrder(const AffineExprImpl &expr) {
- switch (expr.getKind()) {
- case AffineExprKind::Add: {
- const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
- if (failed(walkOperandsPostOrder(binOpExpr)))
- return failure();
- return visitAddExpr(binOpExpr);
- }
- case AffineExprKind::Mul: {
- const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
- if (failed(walkOperandsPostOrder(binOpExpr)))
- return failure();
- return visitMulExpr(binOpExpr);
- }
- case AffineExprKind::Mod: {
- const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
- if (failed(walkOperandsPostOrder(binOpExpr)))
- return failure();
- return visitModExpr(binOpExpr);
- }
- case AffineExprKind::FloorDiv: {
- const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
- if (failed(walkOperandsPostOrder(binOpExpr)))
- return failure();
- return visitFloorDivExpr(binOpExpr);
- }
- case AffineExprKind::CeilDiv: {
- const auto &binOpExpr = cast<AffineBinOpExpr>(expr);
- if (failed(walkOperandsPostOrder(binOpExpr)))
- return failure();
- return visitCeilDivExpr(binOpExpr);
- }
- case AffineExprKind::Constant:
- return visitConstantExpr(cast<AffineConstantExpr>(expr));
- case AffineExprKind::DimId:
- return visitDimExpr(cast<AffineDimExpr>(expr));
- case AffineExprKind::SymbolId:
- return visitSymbolExpr(cast<AffineSymbolExpr>(expr));
- }
- llvm_unreachable("Unknown AffineExpr");
+ void addToRow(unsigned row, const CoefficientVector &l) {
+ flatMatrix.addToRow(row,
+ getDynamicAPIntVec(l.getPadded(info.getNumCols())));
}
-
-private:
- // Walk the operands - each operand is itself walked in post order.
- LogicalResult walkOperandsPostOrder(const AffineBinOpExpr &expr) {
- if (failed(walkPostOrder(*expr.getLHS())))
- return failure();
- if (failed(walkPostOrder(*expr.getRHS())))
- return failure();
- return success();
+ void setRow(unsigned row, const CoefficientVector &l) {
+ flatMatrix.setRow(row, getDynamicAPIntVec(l.getPadded(info.getNumCols())));
}
- /// Constructs an affine expression from a flat ArrayRef. If there are local
- /// identifiers (neither dimensional nor symbolic) that appear in the sum of
- /// products expression, `localExprs` is expected to have the AffineExpr
- /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to
- /// be in the format [dims, symbols, locals, constant term].
- AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
- unsigned numDims, unsigned numSymbols);
-
- // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
- // The local identifier added is always a floordiv of a pure add/mul affine
- // function of other identifiers, coefficients of which are specified in
- // dividend and with respect to a positive constant divisor. localExpr is the
- // simplified tree expression (AffineExpr) corresponding to the quantifier.
- void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
- AffineExpr &&localExpr);
-
- /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
- /// expr) when the rhs is a symbolic expression. The local identifier added
- /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
- /// function of other identifiers, coefficients of which are specified in the
- /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
- /// symbolic rhs expression. `localExpr` is the simplified tree expression
- /// (AffineExpr) corresponding to the quantifier.
- void addLocalIdSemiAffine(AffineExpr &&localExpr);
-
- /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
- /// representing the affine expression corresponding to the quantifier
- /// introduced as the local variable corresponding to `expr`. If the
- /// quantifier is already present, we put the coefficient in the proper index
- /// of `result`, otherwise we add a new local variable and put the coefficient
- /// there.
- void addLocalVariableSemiAffine(AffineExpr &&expr,
- SmallVectorImpl<int64_t> &result,
- unsigned long resultSize);
-
- // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
- // A floordiv is thus flattened by introducing a new local variable q, and
- // replacing that expression with 'q' while adding the constraints
- // c * q <= expr <= c * q + c - 1 to localVarCst (done by
- // IntegerRelation::addLocalFloorDiv).
- //
- // A ceildiv is similarly flattened:
- // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
- LogicalResult visitDivExpr(const AffineBinOpExpr &expr, bool isCeil);
-
- int findLocalId(const AffineExpr &localExpr);
-
- inline unsigned getNumCols() const {
- return numDims + numSymbols + numLocals + 1;
+ unsigned lookupLocal(size_t hash) {
+ const auto *it = find(localExprs, hash);
+ assert(it != localExprs.end() &&
+ "Local expression not found; walking from inner to outer?");
+ return info.getLocalVarStartIdx() + it - localExprs.begin();
+ }
+ std::optional<CoefficientVector> lookupModExpansion(size_t hash) {
+ return localModExpansion.contains(hash)
+ ? std::make_optional(localModExpansion.at(hash))
+ : std::nullopt;
}
- inline unsigned getConstantIndex() const { return getNumCols() - 1; }
- inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
- inline unsigned getSymbolStartIndex() const { return numDims; }
- inline unsigned getDimStartIndex() const { return 0; }
};
-
-// Flattener for AffineMap.
-LogicalResult
-getFlattenedAffineExprs(const AffineMap &map,
- std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
- IntegerPolyhedron &cst);
-
-// Flattener for IntegerSet.
-LogicalResult
-getFlattenedAffineExprs(const IntegerSet &set,
- std::vector<SmallVector<int64_t, 8>> &flattenedExprs,
- IntegerPolyhedron &cst);
} // namespace mlir::presburger
#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp
index 681bb23db001d..b937ae27a8fea 100644
--- a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp
+++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp
@@ -13,104 +13,151 @@
#include "ParseStructs.h"
#include "llvm/ADT/Twine.h"
-#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir::presburger;
-using llvm::cast;
using llvm::dbgs;
-using llvm::isa;
-
-bool AffineExprImpl::isPureAffine() const {
- switch (getKind()) {
- case AffineExprKind::SymbolId:
- case AffineExprKind::DimId:
- case AffineExprKind::Constant:
- return true;
- case AffineExprKind::Add: {
- const auto &op = cast<AffineBinOpExpr>(*this);
- return op.getLHS()->isPureAffine() && op.getRHS()->isPureAffine();
- }
- case AffineExprKind::Mul: {
- const auto &op = cast<AffineBinOpExpr>(*this);
- return op.getLHS()->isPureAffine() && op.getRHS()->isPureAffine() &&
- (isa<AffineConstantExpr>(op.getLHS()) ||
- isa<AffineConstantExpr>(op.getRHS()));
+using llvm::divideFloorSigned;
+using llvm::mod;
+
+CoefficientVector PureAffineExprImpl::collectLinearTerms() const {
+ CoefficientVector nestedLinear = std::accumulate(
+ nestedDivTerms.begin(), nestedDivTerms.end(), CoefficientVector(info),
+ [](const CoefficientVector &acc, const PureAffineExpr &div) {
+ return acc + div->getLinearDividend();
+ });
+ return nestedLinear += linearDividend;
+}
+
+SmallVector<std::tuple<size_t, int64_t, CoefficientVector>, 8>
+PureAffineExprImpl::getNonLinearCoeffs() const {
+ SmallVector<std::tuple<size_t, int64_t, CoefficientVector>, 8> ret;
+ // dividend `floordiv` divisor <=> q; adjustedMulFactor = 1,
+ // adjustedLinearTerm is empty.
+ //
+ // dividend `mod` divisor <=> dividend - divisor*q; adjustedMulFactor =
+ // -divisor, adjustedLinearTerm is the linear part of the dividend.
+ //
+ // where q is a floordiv id added by the flattener.
+ auto adjustedMulFactor = [](const PureAffineExprImpl &div) {
+ return div.mulFactor * (div.kind == DivKind::Mod ? -div.divisor : 1);
+ };
+ auto adjustedLinearTerm = [](PureAffineExprImpl &div) {
+ return div.kind == DivKind::Mod ? div.linearDividend *= div.mulFactor
+ : CoefficientVector(div.info);
+ };
+ if (hasDivisor())
+ for (const auto &toplevel : nestedDivTerms)
+ for (const auto &div : toplevel->getNestedDivTerms())
+ ret.emplace_back(div->hash(), adjustedMulFactor(*div),
+ adjustedLinearTerm(*div));
+ else
+ for (const auto &div : nestedDivTerms)
+ ret.emplace_back(div->hash(), adjustedMulFactor(*div),
+ adjustedLinearTerm(*div));
+ return ret;
+}
+
+unsigned PureAffineExprImpl::countNestedDivs() const {
+ return hasDivisor() +
+ std::accumulate(getNestedDivTerms().begin(), getNestedDivTerms().end(),
+ 0, [](unsigned acc, const PureAffineExpr &div) {
+ return acc + div->countNestedDivs();
+ });
+}
+
+PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&lhs,
+ PureAffineExpr &&rhs) {
+ if (lhs->isLinear() && rhs->isLinear())
+ return std::make_unique<PureAffineExprImpl>(lhs->getLinearDividend() +
+ rhs->getLinearDividend());
+ if (lhs->isLinear())
+ return std::make_unique<PureAffineExprImpl>(
+ std::move(rhs->addLinearTerm(lhs->getLinearDividend())));
+ if (rhs->isLinear())
+ return std::make_unique<PureAffineExprImpl>(
+ std::move(lhs->addLinearTerm(rhs->getLinearDividend())));
+
+ if (!(lhs->hasDivisor() ^ rhs->hasDivisor())) {
+ auto ret = PureAffineExprImpl(lhs->info);
+ ret.addDivTerm(std::move(*lhs));
+ ret.addDivTerm(std::move(*rhs));
+ return std::make_unique<PureAffineExprImpl>(std::move(ret));
}
- case AffineExprKind::FloorDiv:
- case AffineExprKind::CeilDiv:
- case AffineExprKind::Mod: {
- const auto &op = cast<AffineBinOpExpr>(*this);
- return op.getLHS()->isPureAffine() && isa<AffineConstantExpr>(op.getRHS());
+ if (lhs->hasDivisor()) {
+ rhs->addDivTerm(std::move(*lhs));
+ return rhs;
}
+ if (rhs->hasDivisor()) {
+ lhs->addDivTerm(std::move(*rhs));
+ return lhs;
}
- llvm_unreachable("Unknown AffineExpr");
+ llvm_unreachable("Malformed AffineExpr");
}
-bool AffineExprImpl::isSymbolicOrConstant() const {
- switch (getKind()) {
- case AffineExprKind::Constant:
- case AffineExprKind::SymbolId:
- return true;
- case AffineExprKind::DimId:
- return false;
- case AffineExprKind::Add:
- case AffineExprKind::Mul:
- case AffineExprKind::FloorDiv:
- case AffineExprKind::CeilDiv:
- case AffineExprKind::Mod: {
- const auto &expr = cast<AffineBinOpExpr>(*this);
- return expr.getLHS()->isSymbolicOrConstant() &&
- expr.getRHS()->isSymbolicOrConstant();
- }
- }
- llvm_unreachable("Unknown AffineExpr");
+PureAffineExpr mlir::presburger::operator*(PureAffineExpr &&expr, int64_t c) {
+ if (expr->isLinear())
+ return std::make_unique<PureAffineExprImpl>(expr->getLinearDividend() * c);
+ return std::make_unique<PureAffineExprImpl>(std::move(expr->mulConstant(c)));
}
-// Simplify the mul to the extent required by usage and the flattener.
-static AffineExpr simplifyMul(AffineExpr &&lhs, AffineExpr &&rhs) {
- if (isa<AffineConstantExpr>(*lhs) && isa<AffineConstantExpr>(*rhs)) {
- auto lhsConst = cast<AffineConstantExpr>(*lhs);
- auto rhsConst = cast<AffineConstantExpr>(*rhs);
- return std::make_unique<AffineConstantExpr>(lhsConst.getValue() *
- rhsConst.getValue());
- }
+PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&expr, int64_t c) {
+ return std::move(expr) + std::make_unique<PureAffineExprImpl>(expr->info, c);
+}
- if (!lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())
- return nullptr;
+PureAffineExpr mlir::presburger::div(PureAffineExpr &÷nd, int64_t divisor,
+ DivKind kind) {
+ assert(divisor > 0 && "floorDiv or mod with a negative divisor");
- // Canonicalize the mul expression so that the constant/symbolic term is the
- // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
- // constant. (Note that a constant is trivially symbolic).
- if (!rhs->isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
- // At least one of them has to be symbolic.
- return std::move(rhs) * std::move(lhs);
+ // Constant fold.
+ if (dividend->isConstant()) {
+ int64_t c = kind == DivKind::FloorDiv
+ ? divideFloorSigned(dividend->getConstant(), divisor)
+ : mod(dividend->getConstant(), divisor);
+ return std::make_unique<PureAffineExprImpl>(dividend->info, c);
}
- // At this point, if there was a constant, it would be on the right.
+ // Factor out mul, using gcd internally.
+ uint64_t exprMultiple;
+ if (dividend->isLinear()) {
+ exprMultiple = dividend->getLinearDividend().factorMulFromLinearTerm();
+ } else {
+ // Canonicalize the div.
+ uint64_t constMultiple =
+ dividend->getLinearDividend().factorMulFromLinearTerm();
+ dividend->divLinearDividend(static_cast<int64_t>(constMultiple));
+ dividend->mulFactor *= constMultiple;
+ exprMultiple = dividend->mulFactor;
+ }
- // Multiplication with a one is a noop, return the other input.
- if (isa<AffineConstantExpr>(*rhs)) {
- auto rhsConst = cast<AffineConstantExpr>(*rhs);
- if (rhsConst.getValue() == 1)
- return lhs;
- // Multiplication with zero.
- if (rhsConst.getValue() == 0)
- return std::make_unique<AffineConstantExpr>(rhsConst);
+ // Perform gcd with divisor.
+ uint64_t gcd = std::gcd(std::abs(divisor), exprMultiple);
+
+ // Divide according to the type.
+ if (gcd > 1) {
+ if (dividend->isLinear())
+ dividend->linearDividend /= static_cast<int64_t>(gcd);
+ else
+ dividend->mulFactor /= (static_cast<int64_t>(gcd));
+
+ divisor /= static_cast<int64_t>(gcd);
}
- return nullptr;
-}
+ if (dividend->isLinear())
+ return std::make_unique<PureAffineExprImpl>(dividend->getLinearDividend(),
+ divisor, kind);
-namespace mlir::presburger {
-AffineExpr operator*(AffineExpr &&s, AffineExpr &&o) {
- if (AffineExpr simpl = simplifyMul(std::move(s), std::move(o)))
- return simpl;
- return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
- AffineExprKind::Mul);
+ // x floordiv 1 <=> x, x % 1 <=> 0
+ return divisor == 1
+ ? (dividend->isMod()
+ ? std::make_unique<PureAffineExprImpl>(dividend->info)
+ : std::move(dividend))
+ : std::make_unique<PureAffineExprImpl>(std::move(*dividend),
+ divisor, kind);
+ ;
}
-} // namespace mlir::presburger
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
enum class BindingStrength {
@@ -118,151 +165,104 @@ enum class BindingStrength {
Strong, // All other binary operators.
};
-static void printAffineExpr(const AffineExprImpl &expr,
- BindingStrength enclosingTightness) {
- const char *binopSpelling = nullptr;
- switch (expr.getKind()) {
- case AffineExprKind::SymbolId: {
- unsigned pos = cast<AffineSymbolExpr>(expr).getPosition();
- dbgs() << 's' << pos;
+static void printCoefficient(int64_t c, bool &isExprBegin,
+ const ParseInfo &info, int idx = -1) {
+ bool isConstant = idx != -1 && info.isConstantIdx(idx);
+ bool isDimOrSymbol = idx != -1 && !info.isConstantIdx(idx);
+ if (!c)
return;
+ if (!isExprBegin)
+ dbgs() << " ";
+ if (c < 0) {
+ dbgs() << "- ";
+ if (c != -1 || isConstant)
+ dbgs() << std::abs(c);
+ } else {
+ if (!isExprBegin)
+ dbgs() << "+ ";
+ if (c != 1 || isConstant) {
+ dbgs() << c;
+ isExprBegin = false;
+ }
}
- case AffineExprKind::DimId: {
- unsigned pos = cast<AffineDimExpr>(expr).getPosition();
- dbgs() << 'd' << pos;
- return;
+ if (isDimOrSymbol) {
+ if (std::abs(c) != 1)
+ dbgs() << " * ";
+ dbgs() << (info.isDimIdx(idx) ? 'd' : 's') << idx;
+ isExprBegin = false;
}
- case AffineExprKind::Constant:
- dbgs() << cast<AffineConstantExpr>(expr).getValue();
- return;
- case AffineExprKind::Add:
- binopSpelling = " + ";
- break;
- case AffineExprKind::Mul:
- binopSpelling = " * ";
- break;
- case AffineExprKind::FloorDiv:
- binopSpelling = " floordiv ";
- break;
- case AffineExprKind::CeilDiv:
- binopSpelling = " ceildiv ";
- break;
- case AffineExprKind::Mod:
- binopSpelling = " mod ";
- break;
+}
+
+static bool printCoefficientVec(
+ const CoefficientVector &linear, bool isExprBegin = true,
+ BindingStrength enclosingTightness = BindingStrength::Weak) {
+ if (enclosingTightness == BindingStrength::Strong &&
+ linear.hasMultipleCoefficients()) {
+ dbgs() << '(';
+ isExprBegin = true;
}
+ for (auto [idx, c] : enumerate(linear.getCoefficients()))
+ printCoefficient(c, isExprBegin, linear.info, idx);
- const auto &binOp = cast<AffineBinOpExpr>(expr);
- const AffineExprImpl &lhsExpr = *binOp.getLHS();
- const AffineExprImpl &rhsExpr = *binOp.getRHS();
-
- // Handle tightly binding binary operators.
- if (binOp.getKind() != AffineExprKind::Add) {
- if (enclosingTightness == BindingStrength::Strong)
- dbgs() << '(';
-
- // Pretty print multiplication with -1.
- if (isa<AffineConstantExpr>(rhsExpr)) {
- const auto &rhsConst = cast<AffineConstantExpr>(rhsExpr);
- if (binOp.getKind() == AffineExprKind::Mul && rhsConst.getValue() == -1) {
- dbgs() << "-";
- printAffineExpr(lhsExpr, BindingStrength::Strong);
- if (enclosingTightness == BindingStrength::Strong)
- dbgs() << ')';
- return;
- }
- }
- printAffineExpr(lhsExpr, BindingStrength::Strong);
+ if (enclosingTightness == BindingStrength::Strong &&
+ linear.hasMultipleCoefficients())
+ dbgs() << ')';
+ return isExprBegin;
+}
- dbgs() << binopSpelling;
- printAffineExpr(rhsExpr, BindingStrength::Strong);
+static bool
+printAffineExpr(const PureAffineExprImpl &expr, bool isExprBegin = true,
+ BindingStrength enclosingTightness = BindingStrength::Weak) {
+ if (expr.isLinear())
+ return printCoefficientVec(expr.getLinearDividend(), isExprBegin,
+ enclosingTightness);
- if (enclosingTightness == BindingStrength::Strong)
- dbgs() << ')';
- return;
- }
+ const auto &div = expr;
+ const auto &linearDividend = div.getLinearDividend();
+ const auto &divisor = div.getDivisor();
+ const auto &mulFactor = div.getMulFactor();
+ const auto &nestedDivs = div.getNestedDivTerms();
- // Print out special "pretty" forms for add.
- if (enclosingTightness == BindingStrength::Strong)
- dbgs() << '(';
+ printCoefficient(mulFactor, isExprBegin, expr.info);
+ if (std::abs(mulFactor) != 1)
+ dbgs() << " * ";
- // Pretty print addition to a product that has a negative operand as a
- // subtraction.
- if (isa<AffineBinOpExpr>(rhsExpr)) {
- const auto &rhs = cast<AffineBinOpExpr>(rhsExpr);
- if (rhs.getKind() == AffineExprKind::Mul) {
- const AffineExprImpl &rrhsExpr = *rhs.getRHS();
- if (isa<AffineConstantExpr>(rrhsExpr)) {
- const auto &rrhs = cast<AffineConstantExpr>(rrhsExpr);
- if (rrhs.getValue() == -1) {
- printAffineExpr(lhsExpr, BindingStrength::Weak);
- dbgs() << " - ";
- if (rhs.getLHS()->getKind() == AffineExprKind::Add) {
- printAffineExpr(*rhs.getLHS(), BindingStrength::Strong);
- } else {
- printAffineExpr(*rhs.getLHS(), BindingStrength::Weak);
- }
-
- if (enclosingTightness == BindingStrength::Strong)
- dbgs() << ')';
- return;
- }
-
- if (rrhs.getValue() < -1) {
- printAffineExpr(lhsExpr, BindingStrength::Weak);
- dbgs() << " - ";
- printAffineExpr(*rhs.getLHS(), BindingStrength::Strong);
- dbgs() << " * " << -rrhs.getValue();
- if (enclosingTightness == BindingStrength::Strong)
- dbgs() << ')';
- return;
- }
- }
- }
+ if (div.hasDivisor() || enclosingTightness == BindingStrength::Strong) {
+ dbgs() << '(';
+ isExprBegin = true;
}
- // Pretty print addition to a negative number as a subtraction.
- if (isa<AffineConstantExpr>(rhsExpr)) {
- const auto &rhsConst = cast<AffineConstantExpr>(rhsExpr);
- if (rhsConst.getValue() < 0) {
- printAffineExpr(lhsExpr, BindingStrength::Weak);
- dbgs() << " - " << -rhsConst.getValue();
- if (enclosingTightness == BindingStrength::Strong)
- dbgs() << ')';
- return;
- }
- }
+ isExprBegin = printCoefficientVec(linearDividend, isExprBegin);
- printAffineExpr(lhsExpr, BindingStrength::Weak);
+ for (const auto &div : nestedDivs)
+ isExprBegin = printAffineExpr(*div, isExprBegin, BindingStrength::Strong);
- dbgs() << " + ";
- printAffineExpr(rhsExpr, BindingStrength::Weak);
+ if (div.hasDivisor())
+ dbgs() << (expr.isMod() ? " % " : " floordiv ") << divisor;
- if (enclosingTightness == BindingStrength::Strong)
+ if (div.hasDivisor() || enclosingTightness == BindingStrength::Strong)
dbgs() << ')';
+
+ return isExprBegin;
}
-LLVM_DUMP_METHOD void AffineExprImpl::dump() const {
- printAffineExpr(*this, BindingStrength::Weak);
+LLVM_DUMP_METHOD void CoefficientVector::dump() const {
+ printCoefficientVec(*this);
dbgs() << '\n';
}
-LLVM_DUMP_METHOD void AffineMap::dump() const {
- dbgs() << "NumDims = " << numDims << '\n';
- dbgs() << "NumSymbols = " << numSymbols << '\n';
- dbgs() << "Expressions:\n";
- for (const AffineExpr &e : getExprs())
- e->dump();
+LLVM_DUMP_METHOD void PureAffineExprImpl::dump() const {
+ printAffineExpr(*this);
+ dbgs() << '\n';
}
-LLVM_DUMP_METHOD void IntegerSet::dump() const {
- dbgs() << "NumDims = " << numDims << '\n';
- dbgs() << "NumSymbols = " << numSymbols << '\n';
- dbgs() << "Constraints:\n";
- for (const AffineExpr &c : getConstraints())
- c->dump();
- dbgs() << "EqFlags:\n";
- for (bool e : getEqFlags())
- dbgs() << e << '\n';
+LLVM_DUMP_METHOD void FinalParseResult::dump() const {
+ dbgs() << "Exprs:\n";
+ for (const auto &expr : exprs)
+ expr->dump();
+ dbgs() << "EqFlags: ";
+ for (bool eqF : eqFlags)
+ dbgs() << eqF << ' ';
+ dbgs() << '\n';
}
#endif
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
index 8ec7fe2ff840a..d0ad899af5974 100644
--- a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
+++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
@@ -9,245 +9,299 @@
#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H
#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include <cassert>
#include <cstdint>
-#include <memory>
namespace mlir::presburger {
using llvm::ArrayRef;
using llvm::SmallVector;
using llvm::SmallVectorImpl;
-enum class AffineExprKind {
- Add,
- /// RHS of mul is always a constant or a symbolic expression.
- Mul,
- /// RHS of mod is always a constant or a symbolic expression with a positive
- /// value.
- Mod,
- /// RHS of floordiv is always a constant or a symbolic expression.
- FloorDiv,
- /// RHS of ceildiv is always a constant or a symbolic expression.
- CeilDiv,
- /// This is a marker for the last affine binary op. The range of binary
- /// op's is expected to be this element and earlier.
- LAST_BINOP = CeilDiv,
- /// Constant integer.
- Constant,
- /// Dimensional identifier.
- DimId,
- /// Symbolic identifier.
- SymbolId,
-};
-
-struct AffineExprImpl {
- explicit AffineExprImpl(AffineExprKind kind) : kind(kind) {}
-
- // Delete all copy/move operators.
- AffineExprImpl(const AffineExprImpl &o) = delete;
- AffineExprImpl &operator=(const AffineExprImpl &o) = delete;
- AffineExprImpl(AffineExprImpl &&o) = delete;
- AffineExprImpl &operator=(AffineExprImpl &&o) = delete;
-
- AffineExprKind getKind() const { return kind; }
-
- /// Returns true if this expression is made out of only symbols and
- /// constants, i.e., it does not involve dimensional identifiers.
- bool isSymbolicOrConstant() const;
-
- /// Returns true if this is a pure affine expression, i.e., multiplication,
- /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
- bool isPureAffine() const;
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
- LLVM_DUMP_METHOD void dump() const;
-#endif
-
- AffineExprKind kind;
-};
-
-// AffineExpr is a unique ptr, since there is a cycle is AffineBinaryOp.
-using AffineExpr = std::unique_ptr<AffineExprImpl>;
-
-struct AffineBinOpExpr : public AffineExprImpl {
- AffineBinOpExpr(AffineExpr &&lhs, AffineExpr &&rhs, AffineExprKind kind)
- : AffineExprImpl(kind), lhs(std::move(lhs)), rhs(std::move(rhs)) {}
-
- // Delete all copy/move operators.
- AffineBinOpExpr(const AffineBinOpExpr &o) = delete;
- AffineBinOpExpr &operator=(const AffineBinOpExpr &o) = delete;
- AffineBinOpExpr(AffineBinOpExpr &&o) = delete;
- AffineBinOpExpr &operator=(AffineBinOpExpr &&o) = delete;
-
- const AffineExpr &getLHS() const { return lhs; }
- const AffineExpr &getRHS() const { return rhs; }
- static bool classof(const AffineExprImpl *a) {
- return a->getKind() <= AffineExprKind::LAST_BINOP;
+/// This structure is central to the parser and flattener, and holds the number
+/// of dimensions, symbols, locals, and the constant term.
+struct ParseInfo {
+ unsigned numDims = 0;
+ unsigned numSymbols = 0;
+ unsigned numExprs = 0;
+ unsigned numDivs = 0;
+
+ constexpr unsigned getDimStartIdx() const { return 0; }
+ constexpr unsigned getSymbolStartIdx() const { return numDims; }
+ constexpr unsigned getLocalVarStartIdx() const {
+ return numDims + numSymbols;
}
+ constexpr unsigned getNumCols() const {
+ return numDims + numSymbols + numDivs + 1;
+ }
+ constexpr unsigned getConstantIdx() const { return getNumCols() - 1; }
- AffineExpr lhs;
- AffineExpr rhs;
-};
-
-/// A dimensional or symbolic identifier appearing in an affine expression.
-struct AffineDimExpr : public AffineExprImpl {
- AffineDimExpr(unsigned position)
- : AffineExprImpl(AffineExprKind::DimId), position(position) {}
-
- // Enable copy/move constructors; trivial.
- AffineDimExpr(const AffineDimExpr &o)
- : AffineExprImpl(AffineExprKind::DimId), position(o.position) {}
- AffineDimExpr(AffineDimExpr &&o)
- : AffineExprImpl(AffineExprKind::DimId), position(o.position) {}
- AffineDimExpr &operator=(const AffineDimExpr &o) = delete;
- AffineDimExpr &operator=(AffineDimExpr &&o) = delete;
-
- unsigned getPosition() const { return position; }
- static bool classof(const AffineExprImpl *a) {
- return a->getKind() == AffineExprKind::DimId;
+ constexpr bool isDimIdx(unsigned i) const { return i < getSymbolStartIdx(); }
+ constexpr bool isSymbolIdx(unsigned i) const {
+ return i >= getSymbolStartIdx() && i < getLocalVarStartIdx();
}
- bool operator==(const AffineDimExpr &o) const {
- return position == o.position;
+ constexpr bool isLocalVarIdx(unsigned i) const {
+ return i >= getLocalVarStartIdx() && i < getConstantIdx();
+ }
+ constexpr bool isConstantIdx(unsigned i) const {
+ return i == getConstantIdx();
}
-
- /// Position of this identifier in the argument list.
- unsigned position;
};
-/// A symbolic identifier appearing in an affine expression.
-struct AffineSymbolExpr : public AffineExprImpl {
- AffineSymbolExpr(unsigned position)
- : AffineExprImpl(AffineExprKind::SymbolId), position(position) {}
+/// Helper for storing coefficients in canonical form: dims followed by symbols,
+/// followed by locals, and finally the constant term.
+///
+/// (x, y)[a, b]: y * 91 + x + 3 * a + 7
+/// coefficients: [1, 91, 3, 0, 7]
+struct CoefficientVector {
+ ParseInfo info;
+ SmallVector<int64_t, 8> coefficients;
+
+ CoefficientVector(const ParseInfo &info, int64_t c = 0) : info(info) {
+ coefficients.resize(info.getNumCols());
+ coefficients[info.getConstantIdx()] = c;
+ }
- // Enable copy/move constructors; trivial.
- AffineSymbolExpr(const AffineSymbolExpr &o)
- : AffineExprImpl(AffineExprKind::SymbolId), position(o.position) {}
- AffineSymbolExpr(AffineSymbolExpr &&o)
- : AffineExprImpl(AffineExprKind::SymbolId), position(o.position) {}
- AffineSymbolExpr &operator=(const AffineSymbolExpr &o) = delete;
- AffineSymbolExpr &operator=(AffineSymbolExpr &&o) = delete;
+ // Copyable and movable
+ CoefficientVector(const CoefficientVector &o) = default;
+ CoefficientVector &operator=(const CoefficientVector &o) = default;
+ CoefficientVector(CoefficientVector &&o)
+ : info(o.info), coefficients(std::move(o.coefficients)) {
+ o.coefficients.clear();
+ }
- unsigned getPosition() const { return position; }
- static bool classof(const AffineExprImpl *a) {
- return a->getKind() == AffineExprKind::SymbolId;
+ ArrayRef<int64_t> getCoefficients() const { return coefficients; }
+ constexpr int64_t getConstant() const {
+ return coefficients[info.getConstantIdx()];
}
- bool operator==(const AffineSymbolExpr &o) const {
- return position == o.position;
+ constexpr size_t size() const { return coefficients.size(); }
+ operator ArrayRef<int64_t>() const { return coefficients; }
+ void resize(size_t size) { coefficients.resize(size); }
+ operator bool() const {
+ return any_of(coefficients, [](int64_t c) { return c; });
}
-
- /// Position of this identifier in the argument list.
- unsigned position;
-};
-
-/// An integer constant appearing in affine expression.
-struct AffineConstantExpr : public AffineExprImpl {
- AffineConstantExpr(int64_t constant)
- : AffineExprImpl(AffineExprKind::Constant), constant(constant) {}
-
- // Enable copy/move constructors; trivial.
- AffineConstantExpr(const AffineConstantExpr &o)
- : AffineExprImpl(AffineExprKind::Constant), constant(o.constant) {}
- AffineConstantExpr(AffineConstantExpr &&o)
- : AffineExprImpl(AffineExprKind::Constant), constant(o.constant) {}
- AffineConstantExpr &operator=(const AffineConstantExpr &o) = delete;
- AffineConstantExpr &operator=(AffineConstantExpr &&o) = delete;
-
- int64_t getValue() const { return constant; }
- static bool classof(const AffineExprImpl *a) {
- return a->getKind() == AffineExprKind::Constant;
+ int64_t &operator[](unsigned i) {
+ assert(i < coefficients.size());
+ return coefficients[i];
}
- bool operator==(const AffineConstantExpr &o) const {
- return constant == o.constant;
+ int64_t &back() { return coefficients.back(); }
+ int64_t back() const { return coefficients.back(); }
+ void clear() {
+ for_each(coefficients, [](auto &coeff) { coeff = 0; });
}
- // The constant.
- int64_t constant;
-};
-
-struct AffineMap {
- unsigned numDims;
- unsigned numSymbols;
-
- // The affine expressions in the map.
- SmallVector<AffineExpr, 4> exprs;
-
- AffineMap(unsigned numDims, unsigned numSymbols,
- SmallVectorImpl<AffineExpr> &&exprs)
- : numDims(numDims), numSymbols(numSymbols), exprs(std::move(exprs)) {}
+ CoefficientVector &operator+=(const CoefficientVector &l) {
+ coefficients.resize(l.size());
+ for (auto [idx, c] : enumerate(l.getCoefficients()))
+ coefficients[idx] += c;
+ return *this;
+ }
+ CoefficientVector &operator*=(int64_t c) {
+ for_each(coefficients, [c](auto &coeff) { coeff *= c; });
+ return *this;
+ }
+ CoefficientVector &operator/=(int64_t c) {
+ assert(c && "Division by zero");
+ for_each(coefficients, [c](auto &coeff) { coeff /= c; });
+ return *this;
+ }
- // Non-copyable; only movable.
- AffineMap(const AffineMap &) = delete;
- AffineMap operator=(const AffineMap &) = delete;
- AffineMap(AffineMap &&o)
- : numDims(o.numDims), numSymbols(o.numSymbols),
- exprs(std::move(o.exprs)) {}
- AffineMap &operator=(AffineMap &&o) = delete;
+ CoefficientVector operator+(const CoefficientVector &l) const {
+ CoefficientVector ret(*this);
+ return ret += l;
+ }
+ CoefficientVector operator*(int64_t c) const {
+ CoefficientVector ret(*this);
+ return ret *= c;
+ }
+ CoefficientVector operator/(int64_t c) const {
+ CoefficientVector ret(*this);
+ return ret /= c;
+ }
- unsigned getNumDims() const { return numDims; }
- unsigned getNumSymbols() const { return numSymbols; }
- unsigned getNumInputs() const { return numDims + numSymbols; }
- unsigned getNumExprs() const { return exprs.size(); }
- ArrayRef<AffineExpr> getExprs() const { return exprs; }
+ bool isConstant() const {
+ return all_of(drop_end(coefficients), [](int64_t c) { return !c; });
+ }
+ CoefficientVector getPadded(size_t newSize) const {
+ assert(newSize >= size() &&
+ "Padding size should be greater than expr size");
+ CoefficientVector ret(info);
+ ret.resize(newSize);
+
+ // Start constructing the result by taking the dims and symbols of the
+ // coefficients.
+ for (const auto &[col, coeff] : enumerate(drop_end(coefficients)))
+ ret[col] = coeff;
+
+ // Put the constant at the end.
+ ret.back() = back();
+ return ret;
+ }
+ uint64_t factorMulFromLinearTerm() const {
+ uint64_t gcd = 1;
+ for (int64_t val : coefficients)
+ gcd = std::gcd(gcd, std::abs(val));
+ return gcd;
+ }
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ constexpr bool hasMultipleCoefficients() const {
+ return count_if(coefficients, [](auto &coeff) { return coeff; }) > 1;
+ }
LLVM_DUMP_METHOD void dump() const;
#endif
};
-struct IntegerSet {
- unsigned numDims;
- unsigned numSymbols;
-
- /// Array of affine constraints: a constraint is either an equality
- /// (affine_expr == 0) or an inequality (affine_expr >= 0).
- SmallVector<AffineExpr, 4> constraints;
+enum class DimOrSymbolKind {
+ DimId,
+ Symbol,
+};
- // Bits to check whether a constraint is an equality or an inequality.
- SmallVector<bool, 4> eqFlags;
+using DimOrSymbolExpr = std::pair<DimOrSymbolKind, unsigned>;
+enum class DivKind { FloorDiv, Mod };
+
+/// Represents a pure Affine expression. Linear expressions are represented with
+/// divisor = 1, and no nestedDivTerms.
+///
+/// 3 - a * (3 + (3 - x div 3) div 4 + y div 7) div 4
+/// ^ linearDivident = 3, mulFactor = 1, divisor = 1
+/// ^ nest: 1, mulFactor: -a
+/// ^ nest: 1, linearDividend
+/// ^ nest: 2, linearDividend
+/// ^ nest: 3
+/// ^ nest: 2
+/// ^ nest: 2
+/// nest: 1, divisor ^
+///
+/// Where div = floordiv|mod; ceildiv is pre-reduced
+struct PureAffineExprImpl {
+ ParseInfo info;
+ DivKind kind = DivKind::FloorDiv;
+ using PureAffineExpr = std::unique_ptr<PureAffineExprImpl>;
+
+ int64_t mulFactor = 1;
+ CoefficientVector linearDividend;
+ int64_t divisor = 1;
+ SmallVector<PureAffineExpr, 4> nestedDivTerms;
+
+ PureAffineExprImpl(const ParseInfo &info, int64_t c = 0)
+ : info(info), linearDividend(info, c) {}
+ PureAffineExprImpl(const ParseInfo &info, DimOrSymbolExpr idExpr)
+ : PureAffineExprImpl(info) {
+ auto [kind, pos] = idExpr;
+ unsigned startIdx = kind == DimOrSymbolKind::Symbol
+ ? info.getSymbolStartIdx()
+ : info.getDimStartIdx();
+ linearDividend.coefficients[startIdx + pos] = 1;
+ }
+ PureAffineExprImpl(const CoefficientVector &linearDividend,
+ int64_t divisor = 1, DivKind kind = DivKind::FloorDiv)
+ : info(linearDividend.info), kind(kind), linearDividend(linearDividend),
+ divisor(divisor) {}
+ PureAffineExprImpl(PureAffineExprImpl &&div, int64_t divisor, DivKind kind)
+ : info(div.info), kind(kind), linearDividend(div.info), divisor(divisor) {
+ addDivTerm(std::move(div));
+ }
- IntegerSet(unsigned numDims, unsigned numSymbols,
- SmallVectorImpl<AffineExpr> &&constraints,
- SmallVectorImpl<bool> &&eqFlags)
- : numDims(numDims), numSymbols(numSymbols),
- constraints(std::move(constraints)), eqFlags(std::move(eqFlags)) {
- assert(constraints.size() == eqFlags.size());
+ // Non-copyable, only movable
+ PureAffineExprImpl(const PureAffineExprImpl &) = delete;
+ PureAffineExprImpl(PureAffineExprImpl &&o)
+ : info(o.info), kind(o.kind), mulFactor(o.mulFactor),
+ linearDividend(std::move(o.linearDividend)), divisor(o.divisor),
+ nestedDivTerms(std::move(o.nestedDivTerms)) {
+ o.nestedDivTerms.clear();
+ o.divisor = o.mulFactor = 1;
}
- // Non-copyable; only movable.
- IntegerSet(const IntegerSet &o) = delete;
- IntegerSet &operator=(const IntegerSet &o) = delete;
- IntegerSet(IntegerSet &&o)
- : numDims(o.numDims), numSymbols(o.numSymbols),
- constraints(std::move(o.constraints)), eqFlags(std::move(o.eqFlags)) {}
- IntegerSet &operator=(IntegerSet &&o) = delete;
+ const CoefficientVector &getLinearDividend() const { return linearDividend; }
+ CoefficientVector collectLinearTerms() const;
+ SmallVector<std::tuple<size_t, int64_t, CoefficientVector>, 8>
+ getNonLinearCoeffs() const;
- IntegerSet(unsigned dimCount, unsigned symbolCount, AffineExpr &&constraint,
- bool eqFlag)
- : numDims(dimCount), numSymbols(symbolCount) {
- constraints.emplace_back(std::move(constraint));
- eqFlags.emplace_back(eqFlag);
+ constexpr bool isMod() const { return kind == DivKind::Mod; }
+ constexpr bool hasDivisor() const { return divisor != 1; }
+ constexpr bool isLinear() const {
+ return nestedDivTerms.empty() && !hasDivisor();
+ }
+ constexpr int64_t getDivisor() const { return divisor; }
+ constexpr int64_t getMulFactor() const { return mulFactor; }
+ constexpr bool isConstant() const {
+ return nestedDivTerms.empty() && getLinearDividend().isConstant();
+ }
+ constexpr int64_t getConstant() const {
+ return getLinearDividend().getConstant();
}
+ constexpr size_t hash() const {
+ return std::hash<const PureAffineExprImpl *>{}(this);
+ }
+ ArrayRef<PureAffineExpr> getNestedDivTerms() const { return nestedDivTerms; }
- unsigned getNumDims() const { return numDims; }
- unsigned getNumSymbols() const { return numSymbols; }
- unsigned getNumInputs() const { return numDims + numSymbols; }
- ArrayRef<AffineExpr> getConstraints() const { return constraints; }
- unsigned getNumConstraints() const { return constraints.size(); }
- ArrayRef<bool> getEqFlags() const { return eqFlags; }
- bool isEq(unsigned idx) const { return eqFlags[idx]; };
+ PureAffineExprImpl &mulConstant(int64_t c) {
+ // Canonicalize mulFactors in div terms without divisors.
+ mulFactor *= c;
+ distributeMulFactor();
+ return *this;
+ }
+ PureAffineExprImpl &addLinearTerm(const CoefficientVector &l) {
+ if (hasDivisor())
+ nestedDivTerms.emplace_back(
+ std::make_unique<PureAffineExprImpl>(std::move(*this)));
+ linearDividend += l;
+ return *this;
+ }
+ PureAffineExprImpl &addDivTerm(PureAffineExprImpl &&d) {
+ nestedDivTerms.emplace_back(
+ std::make_unique<PureAffineExprImpl>(std::move(d)));
+ return *this;
+ }
+ PureAffineExprImpl &divLinearDividend(int64_t c) {
+ linearDividend /= c;
+ return *this;
+ }
- unsigned getNumEqualities() const {
- unsigned numEqualities = 0;
- for (unsigned i = 0, e = getNumConstraints(); i < e; i++)
- if (isEq(i))
- ++numEqualities;
- return numEqualities;
+ void distributeMulFactor() {
+ // Canonicalize the -1 mulFactor for divs without divisor.
+ if (mulFactor != -1 || hasDivisor())
+ return;
+ linearDividend *= -1;
+ for_each(nestedDivTerms,
+ [](const PureAffineExpr &div) { div->mulFactor *= -1; });
+ mulFactor *= -1;
}
+ unsigned countNestedDivs() const;
- unsigned getNumInequalities() const {
- return getNumConstraints() - getNumEqualities();
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+using PureAffineExpr = std::unique_ptr<PureAffineExprImpl>;
+
+/// This structure holds the final parse result, and is constructed
+/// non-trivially to compute the total number of divs, which is then used to
+/// compute the total number of columns, and construct the IntMatrix in the
+/// Flattener.
+struct FinalParseResult {
+ ParseInfo info;
+ SmallVector<PureAffineExpr, 8> exprs;
+ SmallVector<bool, 8> eqFlags;
+ IntegerPolyhedron cst;
+
+ FinalParseResult(const ParseInfo &parseInfo,
+ SmallVectorImpl<PureAffineExpr> &&exprStack,
+ ArrayRef<bool> eqFlagStack)
+ : info(parseInfo), exprs(std::move(exprStack)), eqFlags(eqFlagStack),
+ cst(0, 0, info.numDims + info.numSymbols + 1,
+ PresburgerSpace::getSetSpace(info.numDims, info.numSymbols, 0)) {
+ auto &i = this->info;
+ i.numExprs = exprs.size();
+ i.numDivs = std::accumulate(exprs.begin(), exprs.end(), 0,
+ [](unsigned acc, const PureAffineExpr &expr) {
+ return acc + expr->countNestedDivs();
+ });
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -255,38 +309,44 @@ struct IntegerSet {
#endif
};
-// Convenience operators.
-AffineExpr operator*(AffineExpr &&s, AffineExpr &&o);
-inline AffineExpr operator*(AffineExpr &&s, int64_t o) {
- return std::move(s) * std::make_unique<AffineConstantExpr>(o);
+// The main interface to the parser is a bunch of operators, which is used to
+// successively build the final AffineExpr.
+PureAffineExpr operator+(PureAffineExpr &&lhs, PureAffineExpr &&rhs);
+PureAffineExpr operator*(PureAffineExpr &&expr, int64_t c);
+PureAffineExpr operator+(PureAffineExpr &&expr, int64_t c);
+PureAffineExpr div(PureAffineExpr &÷nd, int64_t divisor, DivKind kind);
+inline PureAffineExpr floordiv(PureAffineExpr &&expr, int64_t c) {
+ return div(std::move(expr), c, DivKind::FloorDiv);
}
-inline AffineExpr operator*(int64_t s, AffineExpr &&o) {
- return std::move(o) * s;
+inline PureAffineExpr operator%(PureAffineExpr &&expr, int64_t c) {
+ return div(std::move(expr), c, DivKind::Mod);
}
-inline AffineExpr operator+(AffineExpr &&s, AffineExpr &&o) {
- return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
- AffineExprKind::Add);
+inline PureAffineExpr operator*(int64_t c, PureAffineExpr &&expr) {
+ return std::move(expr) * c;
}
-inline AffineExpr operator+(AffineExpr &&s, int64_t o) {
- return std::move(s) + std::make_unique<AffineConstantExpr>(o);
+inline PureAffineExpr operator-(PureAffineExpr &&lhs, PureAffineExpr &&rhs) {
+ return std::move(lhs) + std::move(rhs) * -1;
}
-inline AffineExpr operator+(int64_t s, AffineExpr &&o) {
- return std::move(o) + s;
+inline PureAffineExpr operator+(int64_t c, PureAffineExpr &&expr) {
+ return std::move(expr) + c;
}
-inline AffineExpr operator-(AffineExpr &&s, AffineExpr &&o) {
- return std::move(s) + std::move(o) * -1;
+inline PureAffineExpr operator-(PureAffineExpr &&expr, int64_t c) {
+ return std::move(expr) + (-c);
}
-inline AffineExpr operator%(AffineExpr &&s, AffineExpr &&o) {
- return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
- AffineExprKind::Mod);
+inline PureAffineExpr operator-(int64_t c, PureAffineExpr &&expr) {
+ return -1 * std::move(expr) + c;
}
-inline AffineExpr ceilDiv(AffineExpr &&s, AffineExpr &&o) {
- return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
- AffineExprKind::CeilDiv);
+inline PureAffineExpr ceildiv(PureAffineExpr &&expr, int64_t c) {
+ // expr ceildiv c <=> (expr + c - 1) floordiv c
+ return floordiv(std::move(expr) + (c - 1), c);
}
-inline AffineExpr floorDiv(AffineExpr &&s, AffineExpr &&o) {
- return std::make_unique<AffineBinOpExpr>(std::move(s), std::move(o),
- AffineExprKind::FloorDiv);
+
+// Our final canonical expression, the outermost div, should have a divisor
+// of 1.
+inline PureAffineExpr canonicalize(PureAffineExpr &&expr) {
+ if (expr->hasDivisor())
+ expr->addLinearTerm(CoefficientVector(expr->info));
+ return expr;
}
} // namespace mlir::presburger
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
index 43e383cba1a0a..776c228eba6d9 100644
--- a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
+++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
@@ -13,97 +13,82 @@
#include "ParserImpl.h"
#include "Flattener.h"
#include "ParseStructs.h"
-#include "ParserState.h"
#include "mlir/Analysis/Presburger/Parser.h"
using namespace mlir;
using namespace presburger;
using llvm::MemoryBuffer;
-using llvm::SmallVector;
using llvm::SourceMgr;
-//===----------------------------------------------------------------------===//
-// Parser core
-//===----------------------------------------------------------------------===//
-/// Consume the specified token if present and return success. On failure,
-/// output a diagnostic and return failure.
-ParseResult ParserImpl::parseToken(Token::Kind expectedToken,
- const Twine &message) {
+static bool isIdentifier(const Token &token) {
+ return token.isAny(Token::bare_identifier, Token::inttype) ||
+ token.isKeyword();
+}
+
+bool ParserImpl::parseToken(Token::Kind expectedToken, const Twine &message) {
if (consumeIf(expectedToken))
- return success();
- return emitWrongTokenError(message);
+ return true;
+ return emitError(message);
}
-/// Parse a list of comma-separated items with an optional delimiter. If a
-/// delimiter is provided, then an empty list is allowed. If not, then at
-/// least one element will be parsed.
-ParseResult
-ParserImpl::parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage) {
+bool ParserImpl::parseCommaSepeatedList(Delimiter delimiter,
+ function_ref<bool()> parseElementFn,
+ StringRef contextMessage) {
switch (delimiter) {
case Delimiter::None:
break;
case Delimiter::OptionalParen:
if (getToken().isNot(Token::l_paren))
- return success();
+ return true;
[[fallthrough]];
case Delimiter::Paren:
- if (parseToken(Token::l_paren, "expected '('" + contextMessage))
- return failure();
- // Check for empty list.
+ if (!parseToken(Token::l_paren, "expected '('" + contextMessage))
+ return false;
if (consumeIf(Token::r_paren))
- return success();
+ return true;
break;
case Delimiter::OptionalLessGreater:
- // Check for absent list.
if (getToken().isNot(Token::less))
- return success();
+ return true;
[[fallthrough]];
case Delimiter::LessGreater:
- if (parseToken(Token::less, "expected '<'" + contextMessage))
- return success();
- // Check for empty list.
+ if (!parseToken(Token::less, "expected '<'" + contextMessage))
+ return true;
if (consumeIf(Token::greater))
- return success();
+ return true;
break;
case Delimiter::OptionalSquare:
if (getToken().isNot(Token::l_square))
- return success();
+ return true;
[[fallthrough]];
case Delimiter::Square:
- if (parseToken(Token::l_square, "expected '['" + contextMessage))
- return failure();
- // Check for empty list.
+ if (!parseToken(Token::l_square, "expected '['" + contextMessage))
+ return false;
if (consumeIf(Token::r_square))
- return success();
+ return true;
break;
case Delimiter::OptionalBraces:
if (getToken().isNot(Token::l_brace))
- return success();
+ return true;
[[fallthrough]];
case Delimiter::Braces:
- if (parseToken(Token::l_brace, "expected '{'" + contextMessage))
- return failure();
- // Check for empty list.
+ if (!parseToken(Token::l_brace, "expected '{'" + contextMessage))
+ return false;
if (consumeIf(Token::r_brace))
- return success();
+ return true;
break;
}
- // Non-empty case starts with an element.
- if (parseElementFn())
- return failure();
+ if (!parseElementFn())
+ return false;
- // Otherwise we have a list of comma separated elements.
- while (consumeIf(Token::comma)) {
- if (parseElementFn())
- return failure();
- }
+ while (consumeIf(Token::comma))
+ if (!parseElementFn())
+ return false;
switch (delimiter) {
case Delimiter::None:
- return success();
+ return true;
case Delimiter::OptionalParen:
case Delimiter::Paren:
return parseToken(Token::r_paren, "expected ')'" + contextMessage);
@@ -120,18 +105,15 @@ ParserImpl::parseCommaSeparatedList(Delimiter delimiter,
llvm_unreachable("Unknown delimiter");
}
-//===----------------------------------------------------------------------===//
-// Parse error emitters
-//===----------------------------------------------------------------------===//
-ParseResult ParserImpl::emitError(SMLoc loc, const Twine &message) {
+bool ParserImpl::emitError(SMLoc loc, const Twine &message) {
// If we hit a parse error in response to a lexer error, then the lexer
// already reported the error.
- if (!getToken().is(Token::error))
- state.sourceMgr.PrintMessage(loc, SourceMgr::DK_Error, message);
- return failure();
+ if (getToken().isNot(Token::error))
+ sourceMgr.PrintMessage(loc, SourceMgr::DK_Error, message);
+ return false;
}
-ParseResult ParserImpl::emitError(const Twine &message) {
+bool ParserImpl::emitError(const Twine &message) {
SMLoc loc = state.curToken.getLoc();
if (state.curToken.isNot(Token::eof))
return emitError(loc, message);
@@ -140,97 +122,36 @@ ParseResult ParserImpl::emitError(const Twine &message) {
return emitError(SMLoc::getFromPointer(loc.getPointer() - 1), message);
}
-/// Emit an error about a "wrong token". If the current token is at the
-/// start of a source line, this will apply heuristics to back up and report
-/// the error at the end of the previous line, which is where the expected
-/// token is supposed to be.
-ParseResult ParserImpl::emitWrongTokenError(const Twine &message) {
- SMLoc loc = state.curToken.getLoc();
-
- // If the error is to be emitted at EOF, move it back one character.
- if (state.curToken.is(Token::eof))
- loc = SMLoc::getFromPointer(loc.getPointer() - 1);
-
- // This is the location we were originally asked to report the error at.
- SMLoc originalLoc = loc;
-
- // Determine if the token is at the start of the current line.
- const char *bufferStart = state.lex.getBufferBegin();
- const char *curPtr = loc.getPointer();
-
- // Use this StringRef to keep track of what we are going to back up through,
- // it provides nicer string search functions etc.
- StringRef startOfBuffer(bufferStart, curPtr - bufferStart);
-
- // Back up over entirely blank lines.
- while (true) {
- // Back up until we see a \n, but don't look past the buffer start.
- startOfBuffer = startOfBuffer.rtrim(" \t");
-
- // For tokens with no preceding source line, just emit at the original
- // location.
- if (startOfBuffer.empty())
- return emitError(originalLoc, message);
-
- // If we found something that isn't the end of line, then we're done.
- if (startOfBuffer.back() != '\n' && startOfBuffer.back() != '\r')
- return emitError(SMLoc::getFromPointer(startOfBuffer.end()), message);
-
- // Drop the \n so we emit the diagnostic at the end of the line.
- startOfBuffer = startOfBuffer.drop_back();
- }
-}
-
-//===----------------------------------------------------------------------===//
-// Affine Expression Parser
-//===----------------------------------------------------------------------===//
-static bool isIdentifier(const Token &token) {
- // We include only `inttype` and `bare_identifier` here since they are the
- // only non-keyword tokens that can be used to represent an identifier.
- return token.isAny(Token::bare_identifier, Token::inttype) ||
- token.isKeyword();
-}
-
/// Parse a bare id that may appear in an affine expression.
///
/// affine-expr ::= bare-id
-AffineExpr ParserImpl::parseBareIdExpr() {
- if (!isIdentifier(getToken())) {
- std::ignore = emitWrongTokenError("expected bare identifier");
- return nullptr;
- }
+PureAffineExpr ParserImpl::parseBareIdExpr() {
+ if (!isIdentifier(getToken()))
+ return emitError("expected bare identifier"), nullptr;
StringRef sRef = getTokenSpelling();
for (const auto &entry : dimsAndSymbols) {
if (entry.first == sRef) {
consumeToken();
- // Since every DimExpr or SymbolExpr is used more than once, construct a
- // fresh unique_ptr every time we encounter it in the dimsAndSymbols list.
- if (std::holds_alternative<AffineDimExpr>(entry.second))
- return std::make_unique<AffineDimExpr>(
- std::get<AffineDimExpr>(entry.second));
- return std::make_unique<AffineSymbolExpr>(
- std::get<AffineSymbolExpr>(entry.second));
+ return std::make_unique<PureAffineExprImpl>(info, entry.second);
}
}
- std::ignore = emitWrongTokenError("use of undeclared identifier");
- return nullptr;
+ return emitError("use of undeclared identifier"), nullptr;
}
/// Parse an affine expression inside parentheses.
///
/// affine-expr ::= `(` affine-expr `)`
-AffineExpr ParserImpl::parseParentheticalExpr() {
- if (parseToken(Token::l_paren, "expected '('"))
+PureAffineExpr ParserImpl::parseParentheticalExpr() {
+ if (!parseToken(Token::l_paren, "expected '('"))
return nullptr;
if (getToken().is(Token::r_paren)) {
- std::ignore = emitError("no expression inside parentheses");
- return nullptr;
+ return emitError("no expression inside parentheses"), nullptr;
}
- AffineExpr expr = parseAffineExpr();
- if (!expr || parseToken(Token::r_paren, "expected ')'"))
+ PureAffineExpr expr = parseAffineExpr();
+ if (!expr || !parseToken(Token::r_paren, "expected ')'"))
return nullptr;
return expr;
@@ -239,19 +160,18 @@ AffineExpr ParserImpl::parseParentheticalExpr() {
/// Parse the negation expression.
///
/// affine-expr ::= `-` affine-expr
-AffineExpr ParserImpl::parseNegateExpression(const AffineExpr &lhs) {
- if (parseToken(Token::minus, "expected '-'"))
+PureAffineExpr ParserImpl::parseNegateExpression(const PureAffineExpr &lhs) {
+ if (!parseToken(Token::minus, "expected '-'"))
return nullptr;
- AffineExpr operand = parseAffineOperandExpr(lhs);
+ PureAffineExpr operand = parseAffineOperandExpr(lhs);
// Since negation has the highest precedence of all ops (including high
// precedence ops) but lower than parentheses, we are only going to use
// parseAffineOperandExpr instead of parseAffineExpr here.
if (!operand) {
// Extra error message although parseAffineOperandExpr would have
// complained. Leads to a better diagnostic.
- std::ignore = emitError("missing operand of negation");
- return nullptr;
+ return emitError("missing operand of negation"), nullptr;
}
return -1 * std::move(operand);
}
@@ -259,15 +179,16 @@ AffineExpr ParserImpl::parseNegateExpression(const AffineExpr &lhs) {
/// Parse a positive integral constant appearing in an affine expression.
///
/// affine-expr ::= integer-literal
-AffineExpr ParserImpl::parseIntegerExpr() {
+PureAffineExpr ParserImpl::parseIntegerExpr() {
std::optional<uint64_t> val = getToken().getUInt64IntegerValue();
- if (!val.has_value() || (int64_t)*val < 0) {
- std::ignore = emitError("constant too large for index");
- return nullptr;
- }
+ if (!val)
+ return emitError("failed to parse constant"), nullptr;
+ int64_t ret = static_cast<int64_t>(*val);
+ if (ret < 0)
+ return emitError("constant too large"), nullptr;
consumeToken(Token::integer);
- return std::make_unique<AffineConstantExpr>((int64_t)*val);
+ return std::make_unique<PureAffineExprImpl>(info, ret);
}
/// Parses an expression that can be a valid operand of an affine expression.
@@ -279,7 +200,7 @@ AffineExpr ParserImpl::parseIntegerExpr() {
// operand expression, it's an op expression and will be parsed via
// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
// -l are valid operands that will be parsed by this function.
-AffineExpr ParserImpl::parseAffineOperandExpr(const AffineExpr &lhs) {
+PureAffineExpr ParserImpl::parseAffineOperandExpr(const PureAffineExpr &lhs) {
switch (getToken().getKind()) {
case Token::integer:
return parseIntegerExpr();
@@ -290,67 +211,63 @@ AffineExpr ParserImpl::parseAffineOperandExpr(const AffineExpr &lhs) {
case Token::kw_ceildiv:
case Token::kw_floordiv:
case Token::kw_mod:
- // Try to treat these tokens as identifiers.
return parseBareIdExpr();
case Token::plus:
case Token::star:
if (lhs)
- std::ignore = emitError("missing right operand of binary operator");
+ emitError("missing right operand of binary operator");
else
- std::ignore = emitError("missing left operand of binary operator");
+ emitError("missing left operand of binary operator");
return nullptr;
default:
- // If nothing matches, we try to treat this token as an identifier.
if (isIdentifier(getToken()))
return parseBareIdExpr();
if (lhs)
- std::ignore = emitError("missing right operand of binary operator");
+ emitError("missing right operand of binary operator");
else
- std::ignore = emitError("expected affine expression");
+ emitError("expected affine expression");
return nullptr;
}
}
-/// Create an affine binary high precedence op expression (mul's, div's, mod).
-/// opLoc is the location of the op token to be used to report errors
-/// for non-conforming expressions.
-AffineExpr ParserImpl::getAffineBinaryOpExpr(AffineHighPrecOp op,
- AffineExpr &&lhs, AffineExpr &&rhs,
- SMLoc opLoc) {
+PureAffineExpr ParserImpl::getAffineBinaryOpExpr(AffineHighPrecOp op,
+ PureAffineExpr &&lhs,
+ PureAffineExpr &&rhs,
+ SMLoc opLoc) {
switch (op) {
case Mul:
- if (!lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant()) {
- std::ignore = emitError(
- opLoc, "non-affine expression: at least one of the multiply "
- "operands has to be either a constant or symbolic");
- return nullptr;
+ if (!lhs->isConstant() && !rhs->isConstant()) {
+ return emitError(opLoc,
+ "non-affine expression: at least one of the multiply "
+ "operands has to be a constant"),
+ nullptr;
}
- return std::move(lhs) * std::move(rhs);
+ if (rhs->isConstant())
+ return std::move(lhs) * rhs->getConstant();
+ return std::move(rhs) * lhs->getConstant();
case FloorDiv:
- if (!rhs->isSymbolicOrConstant()) {
- std::ignore =
- emitError(opLoc, "non-affine expression: right operand of floordiv "
- "has to be either a constant or symbolic");
- return nullptr;
+ if (!rhs->isConstant()) {
+ return emitError(opLoc,
+ "non-affine expression: right operand of floordiv "
+ "has to be a constant"),
+ nullptr;
}
- return floorDiv(std::move(lhs), std::move(rhs));
+ return floordiv(std::move(lhs), rhs->getConstant());
case CeilDiv:
- if (!rhs->isSymbolicOrConstant()) {
- std::ignore =
- emitError(opLoc, "non-affine expression: right operand of ceildiv "
- "has to be either a constant or symbolic");
- return nullptr;
+ if (!rhs->isConstant()) {
+ return emitError(opLoc, "non-affine expression: right operand of ceildiv "
+ "has to be a constant"),
+ nullptr;
}
- return ceilDiv(std::move(lhs), std::move(rhs));
+ return ceildiv(std::move(lhs), rhs->getConstant());
case Mod:
- if (!rhs->isSymbolicOrConstant()) {
- std::ignore =
- emitError(opLoc, "non-affine expression: right operand of mod "
- "has to be either a constant or symbolic");
- return nullptr;
+ if (!rhs->isConstant()) {
+ return emitError(opLoc, "non-affine expression: right operand of mod "
+ "has to be a constant"),
+ nullptr;
}
- return std::move(lhs) % std::move(rhs);
+ return std::move(lhs) % rhs->getConstant();
case HNoOp:
llvm_unreachable("can't create affine expression for null high prec op");
return nullptr;
@@ -358,10 +275,9 @@ AffineExpr ParserImpl::getAffineBinaryOpExpr(AffineHighPrecOp op,
llvm_unreachable("Unknown AffineHighPrecOp");
}
-/// Create an affine binary low precedence op expression (add, sub).
-AffineExpr ParserImpl::getAffineBinaryOpExpr(AffineLowPrecOp op,
- AffineExpr &&lhs,
- AffineExpr &&rhs) {
+PureAffineExpr ParserImpl::getAffineBinaryOpExpr(AffineLowPrecOp op,
+ PureAffineExpr &&lhs,
+ PureAffineExpr &&rhs) {
switch (op) {
case AffineLowPrecOp::Add:
return std::move(lhs) + std::move(rhs);
@@ -374,8 +290,6 @@ AffineExpr ParserImpl::getAffineBinaryOpExpr(AffineLowPrecOp op,
llvm_unreachable("Unknown AffineLowPrecOp");
}
-/// Consume this token if it is a lower precedence affine op (there are only
-/// two precedence levels).
AffineLowPrecOp ParserImpl::consumeIfLowPrecOp() {
switch (getToken().getKind()) {
case Token::plus:
@@ -419,10 +333,10 @@ AffineHighPrecOp ParserImpl::consumeIfHighPrecOp() {
/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
/// null. llhsOpLoc is the location of the llhsOp token that will be used to
/// report an error for non-conforming expressions.
-AffineExpr ParserImpl::parseAffineHighPrecOpExpr(AffineExpr &&llhs,
- AffineHighPrecOp llhsOp,
- SMLoc llhsOpLoc) {
- AffineExpr lhs = parseAffineOperandExpr(llhs);
+PureAffineExpr ParserImpl::parseAffineHighPrecOpExpr(PureAffineExpr &&llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc) {
+ PureAffineExpr lhs = parseAffineOperandExpr(llhs);
if (!lhs)
return nullptr;
@@ -430,7 +344,7 @@ AffineExpr ParserImpl::parseAffineHighPrecOpExpr(AffineExpr &&llhs,
SMLoc opLoc = getToken().getLoc();
if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
if (llhs) {
- AffineExpr expr =
+ PureAffineExpr expr =
getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs), opLoc);
if (!expr)
return nullptr;
@@ -445,7 +359,6 @@ AffineExpr ParserImpl::parseAffineHighPrecOpExpr(AffineExpr &&llhs,
return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs),
llhsOpLoc);
- // No llhs, 'lhs' itself is the expression.
return lhs;
}
@@ -470,35 +383,36 @@ AffineExpr ParserImpl::parseAffineHighPrecOpExpr(AffineExpr &&llhs,
/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
-AffineExpr ParserImpl::parseAffineLowPrecOpExpr(AffineExpr &&llhs,
- AffineLowPrecOp llhsOp) {
- AffineExpr lhs = parseAffineOperandExpr(llhs);
+PureAffineExpr ParserImpl::parseAffineLowPrecOpExpr(PureAffineExpr &&llhs,
+ AffineLowPrecOp llhsOp) {
+ PureAffineExpr lhs = parseAffineOperandExpr(llhs);
if (!lhs)
return nullptr;
// Found an LHS. Deal with the ops.
if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
if (llhs) {
- AffineExpr sum =
+ PureAffineExpr sum =
getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs));
return parseAffineLowPrecOpExpr(std::move(sum), lOp);
}
- // No LLHS, get RHS and form the expression.
+
return parseAffineLowPrecOpExpr(std::move(lhs), lOp);
}
SMLoc opLoc = getToken().getLoc();
if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
// We have a higher precedence op here. Get the rhs operand for the llhs
// through parseAffineHighPrecOpExpr.
- AffineExpr highRes = parseAffineHighPrecOpExpr(std::move(lhs), hOp, opLoc);
+ PureAffineExpr highRes =
+ parseAffineHighPrecOpExpr(std::move(lhs), hOp, opLoc);
if (!highRes)
return nullptr;
// If llhs is null, the product forms the first operand of the yet to be
// found expression. If non-null, the op to associate with llhs is llhsOp.
- AffineExpr expr = llhs ? getAffineBinaryOpExpr(llhsOp, std::move(llhs),
- std::move(highRes))
- : std::move(highRes);
+ PureAffineExpr expr = llhs ? getAffineBinaryOpExpr(llhsOp, std::move(llhs),
+ std::move(highRes))
+ : std::move(highRes);
// Recurse for subsequent low prec op's after the affine high prec op
// expression.
@@ -509,7 +423,7 @@ AffineExpr ParserImpl::parseAffineLowPrecOpExpr(AffineExpr &&llhs,
// Last operand in the expression list.
if (llhs)
return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs));
- // No llhs, 'lhs' itself is the expression.
+
return lhs;
}
@@ -528,17 +442,16 @@ AffineExpr ParserImpl::parseAffineLowPrecOpExpr(AffineExpr &&llhs,
/// Additional conditions are checked depending on the production. For eg.,
/// one of the operands for `*` has to be either constant/symbolic; the second
/// operand for floordiv, ceildiv, and mod has to be a positive integer.
-AffineExpr ParserImpl::parseAffineExpr() {
+PureAffineExpr ParserImpl::parseAffineExpr() {
return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
}
/// Parse a dim or symbol from the lists appearing before the actual
/// expressions of the affine map. Update our state to store the
/// dimensional/symbolic identifier.
-ParseResult ParserImpl::parseIdentifierDefinition(
- std::variant<AffineDimExpr, AffineSymbolExpr> idExpr) {
+bool ParserImpl::parseIdentifierDefinition(DimOrSymbolExpr idExpr) {
if (!isIdentifier(getToken()))
- return emitWrongTokenError("expected bare identifier");
+ return emitError("expected bare identifier");
StringRef name = getTokenSpelling();
for (const auto &entry : dimsAndSymbols) {
@@ -548,66 +461,36 @@ ParseResult ParserImpl::parseIdentifierDefinition(
consumeToken();
dimsAndSymbols.emplace_back(name, idExpr);
- return success();
+ return true;
}
/// Parse the list of dimensional identifiers to an affine map.
-ParseResult ParserImpl::parseDimIdList(unsigned &numDims) {
- auto parseElt = [&]() -> ParseResult {
- return parseIdentifierDefinition(AffineDimExpr(numDims++));
+bool ParserImpl::parseDimIdList() {
+ auto parseElt = [&]() -> bool {
+ return parseIdentifierDefinition({DimOrSymbolKind::DimId, info.numDims++});
};
- return parseCommaSeparatedList(Delimiter::Paren, parseElt,
- " in dimensional identifier list");
+ return parseCommaSepeatedList(Delimiter::Paren, parseElt,
+ " in dimensional identifier list");
}
/// Parse the list of symbolic identifiers to an affine map.
-ParseResult ParserImpl::parseSymbolIdList(unsigned &numSymbols) {
- auto parseElt = [&]() -> ParseResult {
- return parseIdentifierDefinition(AffineSymbolExpr(numSymbols++));
+bool ParserImpl::parseSymbolIdList() {
+ auto parseElt = [&]() -> bool {
+ return parseIdentifierDefinition(
+ {DimOrSymbolKind::Symbol, info.numSymbols++});
};
- return parseCommaSeparatedList(Delimiter::Square, parseElt,
- " in symbol list");
+ return parseCommaSepeatedList(Delimiter::Square, parseElt, " in symbol list");
}
/// Parse the list of symbolic identifiers to an affine map.
-ParseResult ParserImpl::parseDimAndOptionalSymbolIdList(unsigned &numDims,
- unsigned &numSymbols) {
- if (parseDimIdList(numDims)) {
- return failure();
+bool ParserImpl::parseDimAndOptionalSymbolIdList() {
+ if (!parseDimIdList())
+ return false;
+ if (getToken().isNot(Token::l_square)) {
+ info.numSymbols = 0;
+ return true;
}
- if (!getToken().is(Token::l_square)) {
- numSymbols = 0;
- return success();
- }
- return parseSymbolIdList(numSymbols);
-}
-
-/// Parse the range and sizes affine map definition inline.
-///
-/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
-///
-/// multi-dim-affine-expr ::= `(` `)`
-/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
-std::optional<AffineMap> ParserImpl::parseAffineMapRange(unsigned numDims,
- unsigned numSymbols) {
- SmallVector<AffineExpr, 4> exprs;
- auto parseElt = [&]() -> ParseResult {
- AffineExpr elt = parseAffineExpr();
- ParseResult res = elt ? success() : failure();
- exprs.emplace_back(std::move(elt));
- return res;
- };
-
- // Parse a multi-dimensional affine expression (a comma-separated list of
- // 1-d affine expressions). Grammar:
- // multi-dim-affine-expr ::= `(` `)`
- // | `(` affine-expr (`,` affine-expr)* `)`
- if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
- " in affine map range"))
- return std::nullopt;
-
- // Parsed a valid affine map.
- return AffineMap(numDims, numSymbols, std::move(exprs));
+ return parseSymbolIdList();
}
/// Parse an affine constraint.
@@ -622,201 +505,137 @@ std::optional<AffineMap> ParserImpl::parseAffineMapRange(unsigned numDims,
///
/// isEq is set to true if the parsed constraint is an equality, false if it
/// is an inequality (greater than or equal).
-///
-AffineExpr ParserImpl::parseAffineConstraint(bool *isEq) {
- AffineExpr lhsExpr = parseAffineExpr();
+PureAffineExpr ParserImpl::parseAffineConstraint(bool *isEq) {
+ PureAffineExpr lhsExpr = parseAffineExpr();
if (!lhsExpr)
return nullptr;
- // affine-constraint ::= `affine-expr` `>=` `affine-expr`
if (consumeIf(Token::greater) && consumeIf(Token::equal)) {
- AffineExpr rhsExpr = parseAffineExpr();
+ PureAffineExpr rhsExpr = parseAffineExpr();
if (!rhsExpr)
return nullptr;
*isEq = false;
return std::move(lhsExpr) - std::move(rhsExpr);
}
- // affine-constraint ::= `affine-expr` `<=` `affine-expr`
if (consumeIf(Token::less) && consumeIf(Token::equal)) {
- AffineExpr rhsExpr = parseAffineExpr();
+ PureAffineExpr rhsExpr = parseAffineExpr();
if (!rhsExpr)
return nullptr;
*isEq = false;
return std::move(rhsExpr) - std::move(lhsExpr);
}
- // affine-constraint ::= `affine-expr` `==` `affine-expr`
if (consumeIf(Token::equal) && consumeIf(Token::equal)) {
- AffineExpr rhsExpr = parseAffineExpr();
+ PureAffineExpr rhsExpr = parseAffineExpr();
if (!rhsExpr)
return nullptr;
*isEq = true;
return std::move(lhsExpr) - std::move(rhsExpr);
}
- std::ignore =
- emitError("expected '== affine-expr' or '>= affine-expr' at end of "
- "affine constraint");
- return nullptr;
+ return emitError("expected '==', '<=' or '>='"), nullptr;
}
-/// Parse the constraints that are part of an integer set definition.
-/// integer-set-inline
-/// ::= dim-and-symbol-id-lists `:`
-/// '(' affine-constraint-conjunction? ')'
-/// affine-constraint-conjunction ::= affine-constraint (`,`
-/// affine-constraint)*
-///
-std::optional<IntegerSet>
-ParserImpl::parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols) {
- SmallVector<AffineExpr, 4> constraints;
- SmallVector<bool, 4> isEqs;
- auto parseElt = [&]() -> ParseResult {
- bool isEq;
- AffineExpr elt = parseAffineConstraint(&isEq);
- ParseResult res = elt ? success() : failure();
+FinalParseResult ParserImpl::parseAffineMapOrIntegerSet() {
+ SmallVector<PureAffineExpr, 8> exprs;
+ SmallVector<bool, 8> eqFlags;
+
+ if (!parseDimAndOptionalSymbolIdList())
+ llvm_unreachable("expected dim and symbol list");
+
+ auto parseExpr = [&]() -> bool {
+ PureAffineExpr elt = canonicalize(parseAffineExpr());
if (elt) {
- constraints.emplace_back(std::move(elt));
- isEqs.push_back(isEq);
+ exprs.emplace_back(std::move(elt));
+ return true;
}
- return res;
+ return false;
};
- // Parse a list of affine constraints (comma-separated).
- if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
- " in integer set constraint list"))
- return std::nullopt;
-
- // If no constraints were parsed, then treat this as a degenerate 'true' case.
- if (constraints.empty()) {
- /* 0 == 0 */
- return IntegerSet(numDims, numSymbols,
- std::make_unique<AffineConstantExpr>(0), true);
- }
-
- // Parsed a valid integer set.
- return IntegerSet(numDims, numSymbols, std::move(constraints),
- std::move(isEqs));
-}
-
-std::variant<AffineMap, IntegerSet, std::nullopt_t>
-ParserImpl::parseAffineMapOrIntegerSet() {
- unsigned numDims = 0, numSymbols = 0;
-
- // List of dimensional and optional symbol identifiers.
- if (parseDimAndOptionalSymbolIdList(numDims, numSymbols))
- return std::nullopt;
+ auto parseConstraint = [&]() -> bool {
+ bool isEq;
+ PureAffineExpr elt = canonicalize(parseAffineConstraint(&isEq));
+ if (elt) {
+ exprs.emplace_back(std::move(elt));
+ eqFlags.push_back(isEq);
+ return true;
+ }
+ return false;
+ };
if (consumeIf(Token::arrow)) {
- if (std::optional<AffineMap> v = parseAffineMapRange(numDims, numSymbols))
- return std::move(*v);
- return std::nullopt;
+ /// Parse the range and sizes affine map definition inline.
+ ///
+ /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
+ ///
+ /// multi-dim-affine-expr ::= `(` `)`
+ /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
+ if (!parseCommaSepeatedList(Delimiter::Paren, parseExpr,
+ " in affine map range"))
+ llvm_unreachable("expected affine map range");
+
+ return {info, std::move(exprs), eqFlags};
}
- if (parseToken(Token::colon, "expected '->' or ':'"))
- return std::nullopt;
+ if (!parseToken(Token::colon, "expected '->' or ':'"))
+ llvm_unreachable("Unexpected token");
+
+ /// Parse the constraints that are part of an integer set definition.
+ /// integer-set-inline
+ /// ::= dim-and-symbol-id-lists `:`
+ /// '(' affine-constraint-conjunction? ')'
+ /// affine-constraint-conjunction ::= affine-constraint (`,`
+ /// affine-constraint)*
+ ///
+ if (!parseCommaSepeatedList(Delimiter::Paren, parseConstraint,
+ " in integer set constraint list"))
+ llvm_unreachable("expected integer set");
- if (std::optional<IntegerSet> v =
- parseIntegerSetConstraints(numDims, numSymbols))
- return std::move(*v);
- return std::nullopt;
+ return {info, std::move(exprs), eqFlags};
}
-static MultiAffineFunction getMultiAffineFunctionFromMap(const AffineMap &map) {
- IntegerPolyhedron cst(presburger::PresburgerSpace::getSetSpace(0, 0, 0));
- std::vector<SmallVector<int64_t, 8>> flattenedExprs;
-
- // Flatten expressions and add them to the constraint system.
- LogicalResult result = getFlattenedAffineExprs(map, flattenedExprs, cst);
- assert(result.succeeded() && "Unable to get flattened affine exprs");
+static MultiAffineFunction getMAF(FinalParseResult &&parseResult) {
+ auto info = parseResult.info;
+ auto [flatMatrix, cst] = Flattener(std::move(parseResult)).flatten();
DivisionRepr divs = cst.getLocalReprs();
assert(divs.hasAllReprs() &&
"AffineMap cannot produce divs without local representation");
- // TODO: We shouldn't have to do this conversion.
- Matrix<DynamicAPInt> mat(map.getNumExprs(),
- map.getNumInputs() + divs.getNumDivs() + 1);
- for (unsigned i = 0; i < flattenedExprs.size(); ++i)
- for (unsigned j = 0; j < flattenedExprs[i].size(); ++j)
- mat(i, j) = flattenedExprs[i][j];
-
return MultiAffineFunction(
- PresburgerSpace::getRelationSpace(map.getNumDims(), map.getNumExprs(),
- map.getNumSymbols(), divs.getNumDivs()),
- mat, divs);
+ PresburgerSpace::getRelationSpace(info.numDims, info.numExprs,
+ info.numSymbols, divs.getNumDivs()),
+ flatMatrix, divs);
}
-static IntegerPolyhedron getPolyhedronFromSet(const IntegerSet &set) {
- IntegerPolyhedron cst(presburger::PresburgerSpace::getSetSpace(0, 0, 0));
- std::vector<SmallVector<int64_t, 8>> flattenedExprs;
-
- // Flatten expressions and add them to the constraint system.
- LogicalResult result = getFlattenedAffineExprs(set, flattenedExprs, cst);
- assert(result.succeeded() && "Unable to get flattened affine exprs");
- assert(flattenedExprs.size() == set.getNumConstraints());
-
- unsigned numInequalities = set.getNumInequalities();
- unsigned numEqualities = set.getNumEqualities();
- unsigned numDims = set.getNumDims();
- unsigned numSymbols = set.getNumSymbols();
- unsigned numReservedCols = numDims + numSymbols + 1;
- IntegerPolyhedron poly(
- numInequalities, numEqualities, numReservedCols,
- presburger::PresburgerSpace::getSetSpace(numDims, numSymbols, 0));
- assert(numReservedCols >= poly.getSpace().getNumVars() + 1);
-
- poly.insertVar(VarKind::Local, poly.getNumVarKind(VarKind::Local),
- /*num=*/cst.getNumLocalVars());
-
- for (unsigned i = 0; i < flattenedExprs.size(); ++i) {
- const auto &flatExpr = flattenedExprs[i];
- assert(flatExpr.size() == poly.getSpace().getNumVars() + 1);
- if (set.eqFlags[i])
- poly.addEquality(flatExpr);
+static IntegerPolyhedron getPoly(FinalParseResult &&parseResult) {
+ auto eqFlags = parseResult.eqFlags;
+ auto [flatMatrix, cst] = Flattener(std::move(parseResult)).flatten();
+
+ for (unsigned i = 0; i < flatMatrix.getNumRows(); ++i) {
+ if (eqFlags[i])
+ cst.addEquality(flatMatrix.getRow(i));
else
- poly.addInequality(flatExpr);
+ cst.addInequality(flatMatrix.getRow(i));
}
- // Add the other constraints involving local vars from flattening.
- poly.append(cst);
- return poly;
+ return cst;
}
-static std::variant<AffineMap, IntegerSet, std::nullopt_t>
-parseAffineMapOrIntegerSet(StringRef str) {
+static FinalParseResult parseAffineMapOrIntegerSet(StringRef str) {
SourceMgr sourceMgr;
- auto memBuffer = MemoryBuffer::getMemBuffer(str, "<mlir_parser_buffer>",
- /*RequiresNullTerminator=*/false);
+ auto memBuffer =
+ MemoryBuffer::getMemBuffer(str, "<mlir_parser_buffer>", false);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
- ParserState state(sourceMgr);
- ParserImpl parser(state);
+ ParserImpl parser(sourceMgr);
return parser.parseAffineMapOrIntegerSet();
}
-static AffineMap parseAffineMap(StringRef str) {
- std::variant<AffineMap, IntegerSet, std::nullopt_t> v =
- parseAffineMapOrIntegerSet(str);
- if (std::holds_alternative<AffineMap>(v))
- return std::move(std::get<AffineMap>(v));
- llvm_unreachable("expected string to represent AffineMap");
-}
-
-static IntegerSet parseIntegerSet(StringRef str) {
- std::variant<AffineMap, IntegerSet, std::nullopt_t> v =
- parseAffineMapOrIntegerSet(str);
- if (std::holds_alternative<IntegerSet>(v))
- return std::move(std::get<IntegerSet>(v));
- llvm_unreachable("expected string to represent IntegerSet");
-}
-
-namespace mlir::presburger {
-IntegerPolyhedron parseIntegerPolyhedron(StringRef str) {
- return getPolyhedronFromSet(parseIntegerSet(str));
+IntegerPolyhedron mlir::presburger::parseIntegerPolyhedron(StringRef str) {
+ return getPoly(parseAffineMapOrIntegerSet(str));
}
-MultiAffineFunction parseMultiAffineFunction(StringRef str) {
- return getMultiAffineFunctionFromMap(parseAffineMap(str));
+MultiAffineFunction mlir::presburger::parseMultiAffineFunction(StringRef str) {
+ return getMAF(parseAffineMapOrIntegerSet(str));
}
-} // namespace mlir::presburger
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h
index c168d09826d51..7fef4ded21673 100644
--- a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h
+++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h
@@ -9,15 +9,25 @@
#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H
#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H
+#include "Lexer.h"
#include "ParseStructs.h"
-#include "ParserState.h"
-#include "mlir/Support/LogicalResult.h"
-#include <optional>
-#include <variant>
namespace mlir::presburger {
template <typename T>
using function_ref = llvm::function_ref<T>;
+using llvm::SourceMgr;
+using llvm::Twine;
+
+/// This class refers to the lexing-related state for the parser.
+struct ParserState {
+ ParserState(const llvm::SourceMgr &sourceMgr)
+ : lex(sourceMgr), curToken(lex.lexToken()), lastToken(Token::error, "") {}
+ ParserState(const ParserState &) = delete;
+
+ Lexer lex;
+ Token curToken;
+ Token lastToken;
+};
/// These are the supported delimiters around operand lists and region
/// argument lists, used by parseOperandList.
@@ -45,8 +55,7 @@ enum class Delimiter {
/// Lower precedence ops (all at the same precedence level). LNoOp is false in
/// the boolean sense.
enum AffineLowPrecOp {
- /// Null value.
- LNoOp,
+ LNoOp, // Null value.
Add,
Sub
};
@@ -54,183 +63,97 @@ enum AffineLowPrecOp {
/// Higher precedence ops - all at the same precedence level. HNoOp is false
/// in the boolean sense.
enum AffineHighPrecOp {
- /// Null value.
- HNoOp,
+ HNoOp, // Null value.
Mul,
FloorDiv,
CeilDiv,
Mod
};
-//===----------------------------------------------------------------------===//
-// Parser
-//===----------------------------------------------------------------------===//
-
-/// This class implement support for parsing global entities like attributes and
-/// types. It is intended to be subclassed by specialized subparsers that
-/// include state.
class ParserImpl {
public:
- ParserImpl(ParserState &state) : state(state) {}
+ ParserImpl(const SourceMgr &sourceMgr)
+ : sourceMgr(sourceMgr), state(sourceMgr) {}
- // Helper methods to get stuff from the parser-global state.
- ParserState &getState() const { return state; }
-
- /// Parse a comma-separated list of elements up until the specified end token.
- ParseResult
- parseCommaSeparatedListUntil(Token::Kind rightToken,
- function_ref<ParseResult()> parseElement,
- bool allowEmptyList = true);
+ FinalParseResult parseAffineMapOrIntegerSet();
+private:
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
- ParseResult
- parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage = StringRef());
-
- /// Parse a comma separated list of elements that must have at least one entry
- /// in it.
- ParseResult
- parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
- return parseCommaSeparatedList(Delimiter::None, parseElementFn);
- }
-
- // We have two forms of parsing methods - those that return a non-null
- // pointer on success, and those that return a ParseResult to indicate whether
- // they returned a failure. The second class fills in by-reference arguments
- // as the results of their action.
+ bool parseCommaSepeatedList(Delimiter delimiter,
+ function_ref<bool()> parseElementFn,
+ StringRef contextMessage);
//===--------------------------------------------------------------------===//
// Error Handling
//===--------------------------------------------------------------------===//
-
- /// Emit an error and return failure.
- ParseResult emitError(const Twine &message = {});
- ParseResult emitError(SMLoc loc, const Twine &message = {});
-
- /// Emit an error about a "wrong token". If the current token is at the
- /// start of a source line, this will apply heuristics to back up and report
- /// the error at the end of the previous line, which is where the expected
- /// token is supposed to be.
- ParseResult emitWrongTokenError(const Twine &message = {});
+ bool emitError(const Twine &message);
+ bool emitError(SMLoc loc, const Twine &message);
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
-
- /// Return the current token the parser is inspecting.
const Token &getToken() const { return state.curToken; }
StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
-
- /// Return the last parsed token.
const Token &getLastToken() const { return state.lastToken; }
-
- /// If the current token has the specified kind, consume it and return true.
- /// If not, return false.
bool consumeIf(Token::Kind kind) {
if (state.curToken.isNot(kind))
return false;
consumeToken(kind);
return true;
}
-
- /// Advance the current lexer onto the next token.
void consumeToken() {
assert(state.curToken.isNot(Token::eof, Token::error) &&
"shouldn't advance past EOF or errors");
state.lastToken = state.curToken;
state.curToken = state.lex.lexToken();
}
-
- /// Advance the current lexer onto the next token, asserting what the expected
- /// current token is. This is preferred to the above method because it leads
- /// to more self-documenting code with better checking.
void consumeToken(Token::Kind kind) {
assert(state.curToken.is(kind) && "consumed an unexpected token");
consumeToken();
}
-
- /// Reset the parser to the given lexer position.
void resetToken(const char *tokPos) {
state.lex.resetPointer(tokPos);
state.lastToken = state.curToken;
state.curToken = state.lex.lexToken();
}
-
- /// Consume the specified token if present and return success. On failure,
- /// output a diagnostic and return failure.
- ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
-
- /// Parse an optional integer value from the stream.
- std::optional<ParseResult> parseOptionalInteger(APInt &result);
-
- /// Returns true if the current token corresponds to a keyword.
- bool isCurrentTokenAKeyword() const {
- return getToken().isAny(Token::bare_identifier, Token::inttype) ||
- getToken().isKeyword();
- }
-
- /// Parse a keyword, if present, into 'keyword'.
- ParseResult parseOptionalKeyword(StringRef *keyword);
+ bool parseToken(Token::Kind expectedToken, const Twine &message);
//===--------------------------------------------------------------------===//
// Affine Parsing
//===--------------------------------------------------------------------===//
-
- ParseResult
- parseAffineExprReference(ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet,
- AffineExpr &expr);
- ParseResult
- parseAffineExprInline(ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet,
- AffineExpr &expr);
- std::optional<AffineMap> parseAffineMapRange(unsigned numDims,
- unsigned numSymbols);
- std::optional<IntegerSet> parseIntegerSetConstraints(unsigned numDims,
- unsigned numSymbols);
- std::variant<AffineMap, IntegerSet, std::nullopt_t>
- parseAffineMapOrIntegerSet();
-
-private:
- // Binary affine op parsing.
AffineLowPrecOp consumeIfLowPrecOp();
AffineHighPrecOp consumeIfHighPrecOp();
- // Identifier lists for polyhedral structures.
- ParseResult parseDimIdList(unsigned &numDims);
- ParseResult parseSymbolIdList(unsigned &numSymbols);
- ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
- unsigned &numSymbols);
- ParseResult parseIdentifierDefinition(
- std::variant<AffineDimExpr, AffineSymbolExpr> idExpr);
-
- AffineExpr parseAffineExpr();
- AffineExpr parseParentheticalExpr();
- AffineExpr parseNegateExpression(const AffineExpr &lhs);
- AffineExpr parseIntegerExpr();
- AffineExpr parseBareIdExpr();
-
- AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr &&lhs,
- AffineExpr &&rhs, SMLoc opLoc);
- AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr &&lhs,
- AffineExpr &&rhs);
- AffineExpr parseAffineOperandExpr(const AffineExpr &lhs);
- AffineExpr parseAffineLowPrecOpExpr(AffineExpr &&llhs,
- AffineLowPrecOp llhsOp);
- AffineExpr parseAffineHighPrecOpExpr(AffineExpr &&llhs,
- AffineHighPrecOp llhsOp,
- SMLoc llhsOpLoc);
- AffineExpr parseAffineConstraint(bool *isEq);
-
-private:
- ParserState &state;
- function_ref<ParseResult(bool)> parseElement;
- unsigned numDimOperands = 0;
- unsigned numSymbolOperands = 0;
- SmallVector<
- std::pair<StringRef, std::variant<AffineDimExpr, AffineSymbolExpr>>, 4>
- dimsAndSymbols;
+ bool parseDimIdList();
+ bool parseSymbolIdList();
+ bool parseDimAndOptionalSymbolIdList();
+ bool parseIdentifierDefinition(DimOrSymbolExpr idExpr);
+
+ PureAffineExpr parseAffineExpr();
+ PureAffineExpr parseParentheticalExpr();
+ PureAffineExpr parseNegateExpression(const PureAffineExpr &lhs);
+ PureAffineExpr parseIntegerExpr();
+ PureAffineExpr parseBareIdExpr();
+
+ PureAffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op,
+ PureAffineExpr &&lhs,
+ PureAffineExpr &&rhs, SMLoc opLoc);
+ PureAffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, PureAffineExpr &&lhs,
+ PureAffineExpr &&rhs);
+ PureAffineExpr parseAffineOperandExpr(const PureAffineExpr &lhs);
+ PureAffineExpr parseAffineLowPrecOpExpr(PureAffineExpr &&llhs,
+ AffineLowPrecOp llhsOp);
+ PureAffineExpr parseAffineHighPrecOpExpr(PureAffineExpr &&llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc);
+ PureAffineExpr parseAffineConstraint(bool *isEq);
+
+ const SourceMgr &sourceMgr;
+ ParserState state;
+ ParseInfo info;
+ SmallVector<std::pair<StringRef, DimOrSymbolExpr>, 4> dimsAndSymbols;
};
} // namespace mlir::presburger
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserState.h b/mlir/lib/Analysis/Presburger/Parser/ParserState.h
deleted file mode 100644
index 2fad3aa46fbb8..0000000000000
--- a/mlir/lib/Analysis/Presburger/Parser/ParserState.h
+++ /dev/null
@@ -1,39 +0,0 @@
-//===- ParserState.h - MLIR Presburger ParserState --------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERSTATE_H
-#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERSTATE_H
-
-#include "Lexer.h"
-#include "llvm/Support/SourceMgr.h"
-
-namespace mlir::presburger {
-/// This class refers to all of the state maintained globally by the parser,
-/// such as the current lexer position etc.
-struct ParserState {
- ParserState(const llvm::SourceMgr &sourceMgr)
- : sourceMgr(sourceMgr), lex(sourceMgr), curToken(lex.lexToken()),
- lastToken(Token::error, "") {}
- ParserState(const ParserState &) = delete;
- void operator=(const ParserState &) = delete;
-
- // The source manager for the parser.
- const llvm::SourceMgr &sourceMgr;
-
- /// The lexer for the source file we're parsing.
- Lexer lex;
-
- /// This is the next token that hasn't been consumed yet.
- Token curToken;
-
- /// This is the last token that has been consumed.
- Token lastToken;
-};
-} // namespace mlir::presburger
-
-#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERSTATE_H
>From eed6786895d56b2f3cffe97c6211cba9f5988960 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Sun, 23 Jun 2024 21:56:49 +0100
Subject: [PATCH 4/5] ParseStructs: fix build on Windows
---
mlir/lib/Analysis/Presburger/Parser/ParseStructs.h | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
index d0ad899af5974..ecf8428acd656 100644
--- a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
+++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h
@@ -74,10 +74,10 @@ struct CoefficientVector {
}
ArrayRef<int64_t> getCoefficients() const { return coefficients; }
- constexpr int64_t getConstant() const {
+ int64_t getConstant() const {
return coefficients[info.getConstantIdx()];
}
- constexpr size_t size() const { return coefficients.size(); }
+ size_t size() const { return coefficients.size(); }
operator ArrayRef<int64_t>() const { return coefficients; }
void resize(size_t size) { coefficients.resize(size); }
operator bool() const {
@@ -148,7 +148,7 @@ struct CoefficientVector {
return gcd;
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
- constexpr bool hasMultipleCoefficients() const {
+ bool hasMultipleCoefficients() const {
return count_if(coefficients, [](auto &coeff) { return coeff; }) > 1;
}
LLVM_DUMP_METHOD void dump() const;
@@ -223,18 +223,18 @@ struct PureAffineExprImpl {
constexpr bool isMod() const { return kind == DivKind::Mod; }
constexpr bool hasDivisor() const { return divisor != 1; }
- constexpr bool isLinear() const {
+ bool isLinear() const {
return nestedDivTerms.empty() && !hasDivisor();
}
constexpr int64_t getDivisor() const { return divisor; }
constexpr int64_t getMulFactor() const { return mulFactor; }
- constexpr bool isConstant() const {
+ bool isConstant() const {
return nestedDivTerms.empty() && getLinearDividend().isConstant();
}
- constexpr int64_t getConstant() const {
+ int64_t getConstant() const {
return getLinearDividend().getConstant();
}
- constexpr size_t hash() const {
+ size_t hash() const {
return std::hash<const PureAffineExprImpl *>{}(this);
}
ArrayRef<PureAffineExpr> getNestedDivTerms() const { return nestedDivTerms; }
>From 52d54ac5329ea999941b7b2f97e059b7706607f3 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 27 Jun 2024 12:39:26 +0100
Subject: [PATCH 5/5] Presburger/Parser: address review
---
.../Analysis/Presburger/Parser/Flattener.h | 2 +-
mlir/lib/Analysis/Presburger/Parser/Lexer.cpp | 33 +++----------------
mlir/lib/Analysis/Presburger/Parser/Lexer.h | 2 --
.../Analysis/Presburger/Parser/ParserImpl.cpp | 3 +-
.../Analysis/Presburger/Parser/TokenKinds.def | 1 -
5 files changed, 6 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.h b/mlir/lib/Analysis/Presburger/Parser/Flattener.h
index 4a5d60faddd53..d3043ed67f56d 100644
--- a/mlir/lib/Analysis/Presburger/Parser/Flattener.h
+++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.h
@@ -24,7 +24,7 @@ class Flattener : public FinalParseResult {
IntMatrix flatMatrix;
// We maintain a set of divs that we have seen while flattening. The size of
- // this set we at most info.numDivs, hitting info.numDivs at the end of the
+ // this set is at most info.numDivs, hitting info.numDivs at the end of the
// flattening, if that expression contains all the possible divs.
SmallSetVector<size_t, 4> localExprs;
diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp
index f89742bb0451b..8c037539094a6 100644
--- a/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp
+++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp
@@ -12,10 +12,7 @@
#include "Lexer.h"
#include "Token.h"
-#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/SourceMgr.h"
using namespace mlir::presburger;
@@ -25,7 +22,6 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr) : sourceMgr(sourceMgr) {
curPtr = curBuffer.begin();
}
-/// emitError - Emit an error message and return an Token::error token.
Token Lexer::emitError(const char *loc, const llvm::Twine &message) {
sourceMgr.PrintMessage(SMLoc::getFromPointer(loc), llvm::SourceMgr::DK_Error,
message);
@@ -36,14 +32,11 @@ Token Lexer::lexToken() {
while (true) {
const char *tokStart = curPtr;
- // Lex the next token.
switch (*curPtr++) {
default:
// Handle bare identifiers.
if (isalpha(curPtr[-1]))
return lexBareIdentifierOrKeyword(tokStart);
-
- // Unknown character, emit an error.
return emitError(tokStart, "unexpected character");
case ' ':
@@ -53,16 +46,12 @@ Token Lexer::lexToken() {
// Handle whitespace.
continue;
- case '_':
- // Handle bare identifiers.
- return lexBareIdentifierOrKeyword(tokStart);
-
case 0:
- // This may either be a nul character in the source file or may be the EOF
- // marker that llvm::MemoryBuffer guarantees will be there.
+ // This may be the EOF marker that llvm::MemoryBuffer guarantees will be
+ // there.
if (curPtr - 1 == curBuffer.end())
return formToken(Token::eof, tokStart);
- continue;
+ return emitError(tokStart, "unexpected character");
case ':':
return formToken(Token::colon, tokStart);
@@ -114,7 +103,6 @@ Token Lexer::lexToken() {
/// Lex a bare identifier or keyword that starts with a letter.
///
/// bare-id ::= (letter|[_]) (letter|digit|[_$.])*
-/// integer-type ::= `[su]?i[1-9][0-9]*`
///
Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
@@ -125,18 +113,6 @@ Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
// Check to see if this identifier is a keyword.
StringRef spelling(tokStart, curPtr - tokStart);
- auto isAllDigit = [](StringRef str) {
- return llvm::all_of(str, llvm::isDigit);
- };
-
- // Check for i123, si456, ui789.
- if ((spelling.size() > 1 && tokStart[0] == 'i' &&
- isAllDigit(spelling.drop_front())) ||
- ((spelling.size() > 2 && tokStart[1] == 'i' &&
- (tokStart[0] == 's' || tokStart[0] == 'u')) &&
- isAllDigit(spelling.drop_front(2))))
- return Token(Token::inttype, spelling);
-
Token::Kind kind = llvm::StringSwitch<Token::Kind>(spelling)
#define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
#include "TokenKinds.def"
@@ -147,8 +123,7 @@ Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
/// Lex a number literal.
///
-/// integer-literal ::= digit+ | `0x` hex_digit+
-/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
+/// integer-literal ::= digit+
///
Token Lexer::lexNumber(const char *tokStart) {
assert(isdigit(curPtr[-1]));
diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.h b/mlir/lib/Analysis/Presburger/Parser/Lexer.h
index 081ded3a7fa7c..1dc458a358a94 100644
--- a/mlir/lib/Analysis/Presburger/Parser/Lexer.h
+++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.h
@@ -33,7 +33,6 @@ class Lexer {
const char *getBufferBegin() { return curBuffer.data(); }
private:
- // Helpers.
Token formToken(Token::Kind kind, const char *tokStart) {
return Token(kind, StringRef(tokStart, curPtr - tokStart));
}
@@ -41,7 +40,6 @@ class Lexer {
Token emitError(const char *loc, const llvm::Twine &message);
// Lexer implementation methods.
- Token lexAtIdentifier(const char *tokStart);
Token lexBareIdentifierOrKeyword(const char *tokStart);
Token lexNumber(const char *tokStart);
diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
index 776c228eba6d9..20b2b6d08638e 100644
--- a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
+++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp
@@ -21,8 +21,7 @@ using llvm::MemoryBuffer;
using llvm::SourceMgr;
static bool isIdentifier(const Token &token) {
- return token.isAny(Token::bare_identifier, Token::inttype) ||
- token.isKeyword();
+ return token.is(Token::bare_identifier) || token.isKeyword();
}
bool ParserImpl::parseToken(Token::Kind expectedToken, const Twine &message) {
diff --git a/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def
index e7010dfe11954..3285e77a40975 100644
--- a/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def
+++ b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def
@@ -42,7 +42,6 @@ TOK_IDENTIFIER(bare_identifier) // foo
// Literals
TOK_LITERAL(integer) // 42
-TOK_LITERAL(inttype) // i4, si8, ui16
// Punctuation.
TOK_PUNCTUATION(arrow, "->")
More information about the Mlir-commits
mailing list