[Mlir-commits] [mlir] d48777e - [mlir][polynomial] remove incorrect canonicalization rule (#110318)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 28 08:38:00 PDT 2024
Author: Hongren Zheng
Date: 2024-09-28T08:37:57-07:00
New Revision: d48777ece50c39df553ed779d0771bc9ef6747cf
URL: https://github.com/llvm/llvm-project/commit/d48777ece50c39df553ed779d0771bc9ef6747cf
DIFF: https://github.com/llvm/llvm-project/commit/d48777ece50c39df553ed779d0771bc9ef6747cf.diff
LOG: [mlir][polynomial] remove incorrect canonicalization rule (#110318)
arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect
It should be rewritten as modular arithmetic instead of arith
Revert https://github.com/llvm/llvm-project/pull/93132
Addresses https://github.com/google/heir/issues/749
Cc @j2kun
Added:
Modified:
mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
mlir/test/Dialect/Polynomial/canonicalization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 93ea6e4e43698d..28c45e6846380c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -11,12 +11,9 @@
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Polynomial/IR/Polynomial.td"
-include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"
-defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
-
def Equal : Constraint<CPred<"$0 == $1">>;
// Get a -1 integer attribute of the same type as the polynomial SSA value's
@@ -44,40 +41,4 @@ def NTTAfterINTT : Pat<
[(Equal $r1, $r2)]
>;
-// NTTs are expensive, and addition in coefficient or NTT domain should be
-// equivalently expensive, so reducing the number of NTTs is optimal.
-// ntt(a) + ntt(b) -> ntt(a + b)
-def NTTOfAdd : Pat<
- (Arith_AddIOp
- (Polynomial_NTTOp $p1, $r1),
- (Polynomial_NTTOp $p2, $r2),
- $overflow),
- (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
- [(Equal $r1, $r2)]
->;
-// intt(a) + intt(b) -> intt(a + b)
-def INTTOfAdd : Pat<
- (Polynomial_AddOp
- (Polynomial_INTTOp $t1, $r1),
- (Polynomial_INTTOp $t2, $r2)),
- (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
- [(Equal $r1, $r2)]
->;
-// repeated for sub
-def NTTOfSub : Pat<
- (Arith_SubIOp
- (Polynomial_NTTOp $p1, $r1),
- (Polynomial_NTTOp $p2, $r2),
- $overflow),
- (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
- [(Equal $r1, $r2)]
->;
-def INTTOfSub : Pat<
- (Polynomial_SubOp
- (Polynomial_INTTOp $t1, $r1),
- (Polynomial_INTTOp $t2, $r2)),
- (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
- [(Equal $r1, $r2)]
->;
-
#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 2ba13bb7dab569..460ef17167e801 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -289,10 +289,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
+ results.add<NTTAfterINTT>(context);
}
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
+ results.add<INTTAfterNTT>(context);
}
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index b79938627e4154..c0ee514daab645 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -45,73 +45,3 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
return %0 : !sub_ty
}
-// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
-// CHECK: polynomial.add
-// CHECK-NOT: polynomial.ntt
-// CHECK-NOT: polynomial.intt
-func.func @test_canonicalize_fold_add_through_ntt(
- %poly0 : !ntt_poly_ty,
- %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
- %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
- %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
- %a_plus_b = arith.addi %0, %1 : !tensor_ty
- %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
- return %out : !ntt_poly_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_add_through_intt
-// CHECK: arith.addi
-// CHECK-NOT: polynomial.intt
-// CHECK-NOT: polynomial.iintt
-func.func @test_canonicalize_fold_add_through_intt(
- %tensor0 : !tensor_ty,
- %tensor1 : !tensor_ty) -> !tensor_ty {
- %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
- %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
- %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
- %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
- return %out : !tensor_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
-// CHECK: polynomial.mul_scalar
-// CHECK: polynomial.add
-// CHECK-NOT: polynomial.ntt
-// CHECK-NOT: polynomial.intt
-func.func @test_canonicalize_fold_sub_through_ntt(
- %poly0 : !ntt_poly_ty,
- %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
- %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
- %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
- %a_plus_b = arith.subi %0, %1 : !tensor_ty
- %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
- return %out : !ntt_poly_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
-// CHECK: arith.subi
-// CHECK-NOT: polynomial.intt
-// CHECK-NOT: polynomial.iintt
-func.func @test_canonicalize_fold_sub_through_intt(
- %tensor0 : !tensor_ty,
- %tensor1 : !tensor_ty) -> !tensor_ty {
- %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
- %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
- %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
- %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
- return %out : !tensor_ty
-}
-
-
-// CHECK-LABEL: test_canonicalize_do_not_fold_
diff erent_roots
-// CHECK: arith.addi
-func.func @test_canonicalize_do_not_fold_
diff erent_roots(
- %poly0 : !ntt_poly_ty,
- %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
- %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
- %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
- %a_plus_b = arith.addi %0, %1 : !tensor_ty
- %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
- return %out : !ntt_poly_ty
-}
-
More information about the Mlir-commits
mailing list