[Mlir-commits] [mlir] [polynomial] distribute add/sub through ntt to reduce ntts (PR #93132)

Jeremy Kun llvmlistbot at llvm.org
Wed May 22 21:17:04 PDT 2024


https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/93132

Addresses https://github.com/google/heir/issues/542#issuecomment-2126175775

>From 8ae8d047efac316cc5edf9dc93428071fc80fa74 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 22 May 2024 21:16:12 -0700
Subject: [PATCH] [polynomial] distribute add/sub through ntt to reduce ntts

---
 .../IR/PolynomialCanonicalization.td          | 41 ++++++++++++-
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   |  4 +-
 .../Dialect/Polynomial/canonicalization.mlir  | 57 +++++++++++++++++++
 3 files changed, 99 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..e37bcf76a20f2 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -9,11 +9,14 @@
 #ifndef POLYNOMIAL_CANONICALIZATION
 #define POLYNOMIAL_CANONICALIZATION
 
-include "mlir/Dialect/Polynomial/IR/Polynomial.td"
 include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -39,4 +42,40 @@ def NTTAfterINTT : Pat<
   []
 >;
 
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+  (Arith_AddIOp
+    (Polynomial_NTTOp $p1),
+    (Polynomial_NTTOp $p2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
+  []
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+  (Polynomial_AddOp
+    (Polynomial_INTTOp $t1),
+    (Polynomial_INTTOp $t2)),
+  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
+  []
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+  (Arith_SubIOp
+    (Polynomial_NTTOp $p1),
+    (Polynomial_NTTOp $p2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
+  []
+>;
+def INTTOfSub : Pat<
+  (Polynomial_SubOp
+    (Polynomial_INTTOp $t1),
+    (Polynomial_INTTOp $t2)),
+  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
+  []
+>;
+
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..98263732da8a9 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -201,10 +201,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<NTTAfterINTT>(context);
+  results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
 }
 
 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<INTTAfterNTT>(context);
+  results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
 }
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..489d9ec2720d6 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
   // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
   return %0 : !sub_ty
 }
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.addi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.subi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}



More information about the Mlir-commits mailing list