[Mlir-commits] [mlir] [mlir][polynomial] remove incorrect canonicalization rule (PR #110318)

Hongren Zheng llvmlistbot at llvm.org
Fri Sep 27 12:17:04 PDT 2024


https://github.com/ZenithalHourlyRate created https://github.com/llvm/llvm-project/pull/110318

arith.add for tensor does not mod coefficientModulus, and it may overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert https://github.com/llvm/llvm-project/pull/93132
Addresses https://github.com/google/heir/issues/749

Cc @j2kun

>From 4976cf8c86ecee924b8e9664015e8c07da4676d7 Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Fri, 27 Sep 2024 18:42:17 +0000
Subject: [PATCH] [mlir][polynomial] remove incorrect canonicalization rule

arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert https://github.com/llvm/llvm-project/pull/93132
Addresses https://github.com/google/heir/issues/749
---
 .../IR/PolynomialCanonicalization.td          | 39 -----------
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   |  4 +-
 .../Dialect/Polynomial/canonicalization.mlir  | 70 -------------------
 3 files changed, 2 insertions(+), 111 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 93ea6e4e43698d..28c45e6846380c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -11,12 +11,9 @@
 
 include "mlir/Dialect/Arith/IR/ArithOps.td"
 include "mlir/Dialect/Polynomial/IR/Polynomial.td"
-include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
-defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
-
 def Equal : Constraint<CPred<"$0 == $1">>;
 
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
@@ -44,40 +41,4 @@ def NTTAfterINTT : Pat<
   [(Equal $r1, $r2)]
 >;
 
-// NTTs are expensive, and addition in coefficient or NTT domain should be
-// equivalently expensive, so reducing the number of NTTs is optimal.
-// ntt(a) + ntt(b) -> ntt(a + b)
-def NTTOfAdd : Pat<
-  (Arith_AddIOp
-    (Polynomial_NTTOp $p1, $r1),
-    (Polynomial_NTTOp $p2, $r2),
-    $overflow),
-  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
-  [(Equal $r1, $r2)]
->;
-// intt(a) + intt(b) -> intt(a + b)
-def INTTOfAdd : Pat<
-  (Polynomial_AddOp
-    (Polynomial_INTTOp $t1, $r1),
-    (Polynomial_INTTOp $t2, $r2)),
-  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
-  [(Equal $r1, $r2)]
->;
-// repeated for sub
-def NTTOfSub : Pat<
-  (Arith_SubIOp
-    (Polynomial_NTTOp $p1, $r1),
-    (Polynomial_NTTOp $p2, $r2),
-    $overflow),
-  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
-  [(Equal $r1, $r2)]
->;
-def INTTOfSub : Pat<
-  (Polynomial_SubOp
-    (Polynomial_INTTOp $t1, $r1),
-    (Polynomial_INTTOp $t2, $r2)),
-  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
-  [(Equal $r1, $r2)]
->;
-
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 2ba13bb7dab569..460ef17167e801 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -289,10 +289,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
+  results.add<NTTAfterINTT>(context);
 }
 
 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
+  results.add<INTTAfterNTT>(context);
 }
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index b79938627e4154..c0ee514daab645 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -45,73 +45,3 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
   return %0 : !sub_ty
 }
 
-// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
-// CHECK: polynomial.add
-// CHECK-NOT: polynomial.ntt
-// CHECK-NOT: polynomial.intt
-func.func @test_canonicalize_fold_add_through_ntt(
-    %poly0 : !ntt_poly_ty,
-    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %a_plus_b = arith.addi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
-  return %out : !ntt_poly_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_add_through_intt
-// CHECK: arith.addi
-// CHECK-NOT: polynomial.intt
-// CHECK-NOT: polynomial.iintt
-func.func @test_canonicalize_fold_add_through_intt(
-    %tensor0 : !tensor_ty,
-    %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
-  return %out : !tensor_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
-// CHECK: polynomial.mul_scalar
-// CHECK: polynomial.add
-// CHECK-NOT: polynomial.ntt
-// CHECK-NOT: polynomial.intt
-func.func @test_canonicalize_fold_sub_through_ntt(
-    %poly0 : !ntt_poly_ty,
-    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %a_plus_b = arith.subi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
-  return %out : !ntt_poly_ty
-}
-
-// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
-// CHECK: arith.subi
-// CHECK-NOT: polynomial.intt
-// CHECK-NOT: polynomial.iintt
-func.func @test_canonicalize_fold_sub_through_intt(
-    %tensor0 : !tensor_ty,
-    %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
-  return %out : !tensor_ty
-}
-
-
-// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
-// CHECK: arith.addi
-func.func @test_canonicalize_do_not_fold_different_roots(
-    %poly0 : !ntt_poly_ty,
-    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
-  %a_plus_b = arith.addi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
-  return %out : !ntt_poly_ty
-}
-



More information about the Mlir-commits mailing list