[Mlir-commits] [mlir] d48777e - [mlir][polynomial] remove incorrect canonicalization rule (#110318)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 28 08:38:00 PDT 2024


Author: Hongren Zheng
Date: 2024-09-28T08:37:57-07:00
New Revision: d48777ece50c39df553ed779d0771bc9ef6747cf

URL: https://github.com/llvm/llvm-project/commit/d48777ece50c39df553ed779d0771bc9ef6747cf
DIFF: https://github.com/llvm/llvm-project/commit/d48777ece50c39df553ed779d0771bc9ef6747cf.diff

LOG: [mlir][polynomial] remove incorrect canonicalization rule (#110318)

arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert https://github.com/llvm/llvm-project/pull/93132
Addresses https://github.com/google/heir/issues/749

Cc @j2kun

Added: 
    

Modified: 
    mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
    mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
    mlir/test/Dialect/Polynomial/canonicalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 93ea6e4e43698d..28c45e6846380c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -11,12 +11,9 @@
 
 include "mlir/Dialect/Arith/IR/ArithOps.td"
 include "mlir/Dialect/Polynomial/IR/Polynomial.td"
-include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
-defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
-
 def Equal : Constraint<CPred<"$0 == $1">>;
 
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
@@ -44,40 +41,4 @@ def NTTAfterINTT : Pat<
   [(Equal $r1, $r2)]
 >;
 
-// NTTs are expensive, and addition in coefficient or NTT domain should be
-// equivalently expensive, so reducing the number of NTTs is optimal.
-// ntt(a) + ntt(b) -> ntt(a + b)
-def NTTOfAdd : Pat<
-  (Arith_AddIOp
-    (Polynomial_NTTOp $p1, $r1),
-    (Polynomial_NTTOp $p2, $r2),
-    $overflow),
-  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
-  [(Equal $r1, $r2)]
->;
-// intt(a) + intt(b) -> intt(a + b)
-def INTTOfAdd : Pat<
-  (Polynomial_AddOp
-    (Polynomial_INTTOp $t1, $r1),
-    (Polynomial_INTTOp $t2, $r2)),
-  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
-  [(Equal $r1, $r2)]
->;
-// repeated for sub
-def NTTOfSub : Pat<
-  (Arith_SubIOp
-    (Polynomial_NTTOp $p1, $r1),
-    (Polynomial_NTTOp $p2, $r2),
-    $overflow),
-  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
-  [(Equal $r1, $r2)]
->;
-def INTTOfSub : Pat<
-  (Polynomial_SubOp
-    (Polynomial_INTTOp $t1, $r1),
-    (Polynomial_INTTOp $t2, $r2)),
-  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
-  [(Equal $r1, $r2)]
->;
-
 #endif  // POLYNOMIAL_CANONICALIZATION

diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 2ba13bb7dab569..460ef17167e801 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -289,10 +289,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
+  results.add<NTTAfterINTT>(context);
 }
 
 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
+  results.add<INTTAfterNTT>(context);
 }

diff  --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index b79938627e4154..c0ee514daab645 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -45,73 +45,3 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
   return %0 : !sub_ty
 }
 
-// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
-// CHECK: polynomial.add
-// CHECK-NOT: polynomial.ntt
-// CHECK-NOT: polynomial.intt
-func.func @test_canonicalize_fold_add_through_ntt(
-    %poly0 : !ntt_poly_ty,
-    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %a_plus_b = arith.addi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
-  return %out : !ntt_poly_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_add_through_intt
-// CHECK: arith.addi
-// CHECK-NOT: polynomial.intt
-// CHECK-NOT: polynomial.iintt
-func.func @test_canonicalize_fold_add_through_intt(
-    %tensor0 : !tensor_ty,
-    %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
-  return %out : !tensor_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
-// CHECK: polynomial.mul_scalar
-// CHECK: polynomial.add
-// CHECK-NOT: polynomial.ntt
-// CHECK-NOT: polynomial.intt
-func.func @test_canonicalize_fold_sub_through_ntt(
-    %poly0 : !ntt_poly_ty,
-    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %a_plus_b = arith.subi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
-  return %out : !ntt_poly_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
-// CHECK: arith.subi
-// CHECK-NOT: polynomial.intt
-// CHECK-NOT: polynomial.iintt
-func.func @test_canonicalize_fold_sub_through_intt(
-    %tensor0 : !tensor_ty,
-    %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
-  return %out : !tensor_ty
-}
-
-
-// CHECK-LABEL: test_canonicalize_do_not_fold_
diff erent_roots
-// CHECK: arith.addi
-func.func @test_canonicalize_do_not_fold_
diff erent_roots(
-    %poly0 : !ntt_poly_ty,
-    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
-  %a_plus_b = arith.addi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
-  return %out : !ntt_poly_ty
-}
-


        


More information about the Mlir-commits mailing list