[Mlir-commits] [mlir] [polynomial] distribute add/sub through ntt to reduce ntts (PR #93132)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 22 21:17:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jeremy Kun (j2kun)
<details>
<summary>Changes</summary>
Addresses https://github.com/google/heir/issues/542#issuecomment-2126175775
---
Full diff: https://github.com/llvm/llvm-project/pull/93132.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td (+40-1)
- (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (+2-2)
- (modified) mlir/test/Dialect/Polynomial/canonicalization.mlir (+57)
``````````diff
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..e37bcf76a20f2 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -9,11 +9,14 @@
#ifndef POLYNOMIAL_CANONICALIZATION
#define POLYNOMIAL_CANONICALIZATION
-include "mlir/Dialect/Polynomial/IR/Polynomial.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
@@ -39,4 +42,40 @@ def NTTAfterINTT : Pat<
[]
>;
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+ (Arith_AddIOp
+ (Polynomial_NTTOp $p1),
+ (Polynomial_NTTOp $p2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
+ []
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+ (Polynomial_AddOp
+ (Polynomial_INTTOp $t1),
+ (Polynomial_INTTOp $t2)),
+ (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
+ []
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+ (Arith_SubIOp
+ (Polynomial_NTTOp $p1),
+ (Polynomial_NTTOp $p2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
+ []
+>;
+def INTTOfSub : Pat<
+ (Polynomial_SubOp
+ (Polynomial_INTTOp $t1),
+ (Polynomial_INTTOp $t2)),
+ (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
+ []
+>;
+
#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..98263732da8a9 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -201,10 +201,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<NTTAfterINTT>(context);
+ results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
}
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<INTTAfterNTT>(context);
+ results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
}
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..489d9ec2720d6 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
return %0 : !sub_ty
}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.addi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.subi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/93132
More information about the Mlir-commits
mailing list