[Mlir-commits] [mlir] [mlir][polynomial] fix polynomial.constant syntax in docstrings (PR #92818)
Jeremy Kun
llvmlistbot at llvm.org
Mon May 20 17:01:51 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/92818
>From 1203b90c4ba7bfa79ab2fefe81ae7f05e96bde1c Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 20 May 2024 13:48:08 -0700
Subject: [PATCH 1/3] [mlir][polynomial] fix polynomial.constant syntax in
docstrings
---
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 20 +++++++++----------
.../Polynomial/IR/PolynomialDialect.td | 6 +++---
2 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 3ef899d3376b1..e03d2ec81e9c8 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
// add two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
// subtract two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
// multiply two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
%2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
// multiply two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
%1 = arith.constant 3 : i32
%2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
```
@@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
```mlir
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
%1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
```
}];
@@ -286,10 +286,10 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
```mlir
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
#float_ring = #polynomial.ring<coefficientType=f32>
- %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
+ %0 = polynomial.constant {value=#polynomial.float_polynomial<0.5 + 1.3e06 x**2>} : !polynomial.polynomial<#float_ring>
```
}];
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
index b0573b3715f78..73783815781cf 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
@@ -33,18 +33,18 @@ def Polynomial_Dialect : Dialect {
```mlir
// A constant polynomial in a ring with i32 coefficients and no polynomial modulus
#ring = #polynomial.ring<coefficientType=i32>
- %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+ %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
#modulus = #polynomial.int_polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
- %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+ %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, with a polynomial
// modulus of (x^1024 + 1) and a coefficient modulus of 17.
#modulus = #polynomial.int_polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
- %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+ %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
```
}];
>From 29b42317e7b3800ffe4a3b4b0f89f499ce4e51e7 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 20 May 2024 15:53:26 -0700
Subject: [PATCH 2/3] Revert "[mlir][polynomial] fix polynomial.constant syntax
in docstrings"
This reverts commit 1203b90c4ba7bfa79ab2fefe81ae7f05e96bde1c.
---
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 20 +++++++++----------
.../Polynomial/IR/PolynomialDialect.td | 6 +++---
2 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index e03d2ec81e9c8..3ef899d3376b1 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
// add two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
- %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
// subtract two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
- %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
// multiply two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
- %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
// multiply two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1 = arith.constant 3 : i32
%2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
```
@@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
```mlir
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
```
}];
@@ -286,10 +286,10 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
```mlir
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
#float_ring = #polynomial.ring<coefficientType=f32>
- %0 = polynomial.constant {value=#polynomial.float_polynomial<0.5 + 1.3e06 x**2>} : !polynomial.polynomial<#float_ring>
+ %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
```
}];
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
index 73783815781cf..b0573b3715f78 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
@@ -33,18 +33,18 @@ def Polynomial_Dialect : Dialect {
```mlir
// A constant polynomial in a ring with i32 coefficients and no polynomial modulus
#ring = #polynomial.ring<coefficientType=i32>
- %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
+ %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
#modulus = #polynomial.int_polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
- %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
+ %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, with a polynomial
// modulus of (x^1024 + 1) and a coefficient modulus of 17.
#modulus = #polynomial.int_polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
- %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
+ %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
```
}];
>From f6276fe2d81a883676cfc24de152c0f16afdc16d Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 20 May 2024 17:01:42 -0700
Subject: [PATCH 3/3] add typed variants for polynomial.constant op
---
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 13 ++---
.../Polynomial/IR/PolynomialAttributes.td | 54 +++++++++++++++++--
.../Dialect/Polynomial/IR/PolynomialOps.cpp | 15 ++++++
mlir/test/Dialect/Polynomial/ops.mlir | 8 +--
4 files changed, 77 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 3ef899d3376b1..85a9dd6b935d2 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -272,13 +272,14 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
let hasVerifier = 1;
}
-def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
- Polynomial_FloatPolynomialAttr,
- Polynomial_IntPolynomialAttr
+def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
+ Polynomial_TypedFloatPolynomialAttr,
+ Polynomial_TypedIntPolynomialAttr
]>;
// Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
+def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
+ [Pure, InferTypeOpAdaptor]> {
let summary = "Define a constant polynomial via an attribute.";
let description = [{
Example:
@@ -292,9 +293,9 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
%0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
```
}];
- let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
+ let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
let results = (outs Polynomial_PolynomialType:$output);
- let assemblyFormat = "attr-dict `:` type($output)";
+ let assemblyFormat = "attr-dict $value";
}
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index e5dbfa7fa21ee..1ea07e21e0076 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -18,7 +18,7 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
}
def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
- let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
+ let summary = "an attribute containing a single-variable polynomial with integer coefficients";
let description = [{
A polynomial attribute represents a single-variable polynomial with integer
coefficients, which is used to define the modulus of a `RingAttr`, as well
@@ -41,7 +41,7 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
}
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
- let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
+ let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients";
let description = [{
A polynomial attribute represents a single-variable polynomial with double
precision floating point coefficients.
@@ -62,8 +62,56 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
let hasCustomAssemblyFormat = 1;
}
+def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
+ "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
+ let summary = "a typed int_polynomial";
+ let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
+ let assemblyFormat = "$value `:` $type";
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "Type":$type,
+ "const IntPolynomial &":$value), [{
+ return $_get(
+ type.getContext(),
+ type,
+ IntPolynomialAttr::get(type.getContext(), value));
+ }]>,
+ AttrBuilderWithInferredContext<(ins "Type":$type,
+ "const Attribute &":$value), [{
+ return $_get(type.getContext(), type, ::llvm::cast<IntPolynomialAttr>(value));
+ }]>
+ ];
+ let extraClassDeclaration = [{
+ // used for constFoldBinaryOp
+ using ValueType = ::mlir::Attribute;
+ }];
+}
+
+def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
+ "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
+ let summary = "a typed float_polynomial";
+ let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
+ let assemblyFormat = "$value `:` $type";
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "Type":$type,
+ "const FloatPolynomial &":$value), [{
+ return $_get(
+ type.getContext(),
+ type,
+ FloatPolynomialAttr::get(type.getContext(), value));
+ }]>,
+ AttrBuilderWithInferredContext<(ins "Type":$type,
+ "const Attribute &":$value), [{
+ return $_get(type.getContext(), type, ::llvm::cast<FloatPolynomialAttr>(value));
+ }]>
+ ];
+ let extraClassDeclaration = [{
+ // used for constFoldBinaryOp
+ using ValueType = ::mlir::Attribute;
+ }];
+}
+
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
- let summary = "An attribute specifying a polynomial ring.";
+ let summary = "an attribute specifying a polynomial ring";
let description = [{
A ring describes the domain in which polynomial arithmetic occurs. The ring
attribute in `polynomial` represents the more specific case of polynomials
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..4c2fed6bab312 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -186,6 +186,21 @@ LogicalResult INTTOp::verify() {
return verifyNTTOp(this->getOperation(), ring, tensorType);
}
+LogicalResult ConstantOp::inferReturnTypes(
+ MLIRContext *context, std::optional<mlir::Location> location,
+ ConstantOp::Adaptor adaptor,
+ llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
+ Attribute operand = adaptor.getValue();
+ if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
+ inferredReturnTypes.push_back(intPoly.getType());
+ } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
+ inferredReturnTypes.push_back(floatPoly.getType());
+ } else {
+ return failure();
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd canonicalization patterns
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ff709960c50e9..695b1acf18bd7 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -74,15 +74,15 @@ module {
func.func @test_monic_monomial_mul() {
%five = arith.constant 5 : index
- %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
+ %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
%1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
return
}
func.func @test_constant() {
- %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
- %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
- %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
+ %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
+ %1 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
+ %2 = polynomial.constant #polynomial.float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
return
}
More information about the Mlir-commits
mailing list