[Mlir-commits] [mlir] 1a28f26 - [polynomial] distribute add/sub through ntt to reduce ntts (#93132)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 10:41:53 PDT 2024


Author: Jeremy Kun
Date: 2024-05-23T10:41:49-07:00
New Revision: 1a28f26b16a3eaefb26acaa410712f337f1cda2c

URL: https://github.com/llvm/llvm-project/commit/1a28f26b16a3eaefb26acaa410712f337f1cda2c
DIFF: https://github.com/llvm/llvm-project/commit/1a28f26b16a3eaefb26acaa410712f337f1cda2c.diff

LOG: [polynomial] distribute add/sub through ntt to reduce ntts (#93132)

Addresses
https://github.com/google/heir/issues/542#issuecomment-2126175775

Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
    mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
    mlir/test/Dialect/Polynomial/canonicalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..e37bcf76a20f2 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -9,11 +9,14 @@
 #ifndef POLYNOMIAL_CANONICALIZATION
 #define POLYNOMIAL_CANONICALIZATION
 
-include "mlir/Dialect/Polynomial/IR/Polynomial.td"
 include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -39,4 +42,40 @@ def NTTAfterINTT : Pat<
   []
 >;
 
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+  (Arith_AddIOp
+    (Polynomial_NTTOp $p1),
+    (Polynomial_NTTOp $p2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
+  []
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+  (Polynomial_AddOp
+    (Polynomial_INTTOp $t1),
+    (Polynomial_INTTOp $t2)),
+  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
+  []
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+  (Arith_SubIOp
+    (Polynomial_NTTOp $p1),
+    (Polynomial_NTTOp $p2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
+  []
+>;
+def INTTOfSub : Pat<
+  (Polynomial_SubOp
+    (Polynomial_INTTOp $t1),
+    (Polynomial_INTTOp $t2)),
+  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
+  []
+>;
+
 #endif  // POLYNOMIAL_CANONICALIZATION

diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index d0a25fd9288b9..3d302797ce513 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -283,10 +283,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<NTTAfterINTT>(context);
+  results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
 }
 
 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<INTTAfterNTT>(context);
+  results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
 }

diff  --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..489d9ec2720d6 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
   // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
   return %0 : !sub_ty
 }
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.addi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.subi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}


        


More information about the Mlir-commits mailing list