[Mlir-commits] [mlir] 1a28f26 - [polynomial] distribute add/sub through ntt to reduce ntts (#93132)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 23 10:41:53 PDT 2024
Author: Jeremy Kun
Date: 2024-05-23T10:41:49-07:00
New Revision: 1a28f26b16a3eaefb26acaa410712f337f1cda2c
URL: https://github.com/llvm/llvm-project/commit/1a28f26b16a3eaefb26acaa410712f337f1cda2c
DIFF: https://github.com/llvm/llvm-project/commit/1a28f26b16a3eaefb26acaa410712f337f1cda2c.diff
LOG: [polynomial] distribute add/sub through ntt to reduce ntts (#93132)
Addresses
https://github.com/google/heir/issues/542#issuecomment-2126175775
Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>
Added:
Modified:
mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
mlir/test/Dialect/Polynomial/canonicalization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..e37bcf76a20f2 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -9,11 +9,14 @@
#ifndef POLYNOMIAL_CANONICALIZATION
#define POLYNOMIAL_CANONICALIZATION
-include "mlir/Dialect/Polynomial/IR/Polynomial.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
@@ -39,4 +42,40 @@ def NTTAfterINTT : Pat<
[]
>;
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+ (Arith_AddIOp
+ (Polynomial_NTTOp $p1),
+ (Polynomial_NTTOp $p2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
+ []
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+ (Polynomial_AddOp
+ (Polynomial_INTTOp $t1),
+ (Polynomial_INTTOp $t2)),
+ (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
+ []
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+ (Arith_SubIOp
+ (Polynomial_NTTOp $p1),
+ (Polynomial_NTTOp $p2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
+ []
+>;
+def INTTOfSub : Pat<
+ (Polynomial_SubOp
+ (Polynomial_INTTOp $t1),
+ (Polynomial_INTTOp $t2)),
+ (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
+ []
+>;
+
#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index d0a25fd9288b9..3d302797ce513 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -283,10 +283,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<NTTAfterINTT>(context);
+ results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
}
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<INTTAfterNTT>(context);
+ results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
}
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..489d9ec2720d6 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
return %0 : !sub_ty
}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.addi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.subi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
More information about the Mlir-commits
mailing list