[Mlir-commits] [mlir] Add a polynomial dialect shell, attributes, and types (PR #72081)

Jeremy Kun llvmlistbot at llvm.org
Wed Nov 15 12:02:45 PST 2023


================
@@ -0,0 +1,217 @@
+//===- PolynomialAttributes.cpp - Polynomial dialect attributes --*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+namespace polynomial {
+
+void PolynomialAttr::print(AsmPrinter &p) const {
+  p << '<';
+  p << getPolynomial();
+  p << '>';
+}
+
+/// Try to parse a monomial. If successful, populate the fields of the outparam
+/// `monomial` with the results, and the `variable` outparam with the parsed
+/// variable name.
+ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
+                          llvm::StringRef *variable, bool *isConstantTerm) {
+  APInt parsedCoeff(apintBitWidth, 1);
+  auto result = parser.parseOptionalInteger(parsedCoeff);
+  if (result.has_value()) {
+    if (failed(*result)) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "Invalid integer coefficient.");
+      return failure();
+    }
+  }
+
+  // Variable name
+  result = parser.parseOptionalKeyword(variable);
+  if (!result.has_value() || failed(*result)) {
+    // we allow "failed" because it triggers when the next token is a +,
+    // which is allowed when the input is the constant term.
+    monomial.coefficient = parsedCoeff;
+    monomial.exponent = APInt(apintBitWidth, 0);
+    *isConstantTerm = true;
+    return success();
+  }
+
+  // Parse exponentiation symbol as **
+  // We can't use caret because it's reserved for basic block identifiers
+  // If no star is present, it's treated as a polynomial with exponent 1
+  if (failed(parser.parseOptionalStar())) {
+    monomial.coefficient = parsedCoeff;
+    monomial.exponent = APInt(apintBitWidth, 1);
+    return success();
+  }
+
+  // If there's one * there must be two
+  if (failed(parser.parseStar())) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "Exponents must be specified as a double-asterisk `**`.");
+    return failure();
+  }
+
+  // If there's a **, then the integer exponent is required.
+  APInt parsedExponent(apintBitWidth, 0);
+  if (failed(parser.parseInteger(parsedExponent))) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "Found invalid integer exponent.");
+    return failure();
+  }
+
+  monomial.coefficient = parsedCoeff;
+  monomial.exponent = parsedExponent;
+  return success();
+}
+
+mlir::Attribute mlir::polynomial::PolynomialAttr::parse(AsmParser &parser,
+                                                        Type type) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  std::vector<Monomial> monomials;
+  llvm::SmallSet<std::string, 2> variables;
+  llvm::DenseSet<APInt> exponents;
+
+  while (true) {
+    Monomial parsedMonomial;
+    llvm::StringRef parsedVariableRef;
+    bool isConstantTerm = false;
+    if (failed(parseMonomial(parser, parsedMonomial, &parsedVariableRef,
+                             &isConstantTerm))) {
+      return {};
+    }
+
+    if (!isConstantTerm) {
+      std::string parsedVariable = parsedVariableRef.str();
+      variables.insert(parsedVariable);
+    }
+    monomials.push_back(parsedMonomial);
+
+    if (exponents.count(parsedMonomial.exponent) > 0) {
+      llvm::SmallString<512> coeff_string;
+      parsedMonomial.exponent.toStringSigned(coeff_string);
+      parser.emitError(parser.getCurrentLocation(),
+                       "At most one monomial may have exponent " +
+                           coeff_string + ", but found multiple.");
+      return {};
+    }
+    exponents.insert(parsedMonomial.exponent);
+
+    // Parse optional +. If a + is absent, require > and break, otherwise forbid
+    // > and continue with the next monomial.
+    // ParseOptional{Plus, Greater} does not return an OptionalParseResult, so
+    // failed means that the token was not found.
+    if (failed(parser.parseOptionalPlus())) {
+      if (succeeded(parser.parseGreater())) {
+        break;
+      } else {
+        parser.emitError(
+            parser.getCurrentLocation(),
+            "Expected + and more monomials, or > to end polynomial attribute.");
+        return {};
+      }
+    } else if (succeeded(parser.parseOptionalGreater())) {
+      parser.emitError(
+          parser.getCurrentLocation(),
+          "Expected another monomial after +, but found > ending attribute.");
+      return {};
+    }
+  }
+
+  // insert necessary constant ops to give as input to extract_slice.
+  if (variables.size() > 1) {
+    std::string vars = llvm::join(variables.begin(), variables.end(), ", ");
+    parser.emitError(
+        parser.getCurrentLocation(),
+        "Polynomials must have one indeterminate, but there were multiple: " +
+            vars);
+  }
+
+  Polynomial poly = Polynomial::fromMonomials(monomials, parser.getContext());
+  return PolynomialAttr::get(poly);
+}
+
+void RingAttr::print(AsmPrinter &p) const {
+  p << "#polynomial.ring<ctype=" << getCoefficientType()
+    << ", cmod=" << getCoefficientModulus()
+    << ", ideal=" << getPolynomialModulus() << '>';
+}
+
+mlir::Attribute mlir::polynomial::RingAttr::parse(AsmParser &parser,
+                                                  Type type) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  if (failed(parser.parseKeyword("ctype")))
+    return {};
+
+  if (failed(parser.parseEqual()))
+    return {};
+
+  TypeAttr typeAttr;
+  if (failed(parser.parseAttribute<TypeAttr>(typeAttr)))
+    return {};
+
+  if (failed(parser.parseComma()))
+    return {};
+
+  std::optional<IntegerAttr> cmodAttr = std::nullopt;
+  if (succeeded(parser.parseKeyword("cmod"))) {
+    if (failed(parser.parseEqual()))
+      return {};
+
+    IntegerType iType = llvm::dyn_cast<IntegerType>(typeAttr.getValue());
----------------
j2kun wrote:

That import is already there, but it looks like the static version of dyn_cast was replaced by a member function. Fixed in 6c153e66cd0b0f13b7f5b2566f59ec98eff917cb

https://github.com/llvm/llvm-project/pull/72081


More information about the Mlir-commits mailing list