[Mlir-commits] [llvm] [mlir] Poly canonicalization (PR #91410)

Jeremy Kun llvmlistbot at llvm.org
Thu May 16 10:26:58 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/91410

>From 7097c13a5e8e6db32df66af0b37c892b735c7185 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Tue, 7 May 2024 15:07:38 -0700
Subject: [PATCH 1/4] add basic polynomial canonicalization patterns

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  |  3 ++
 mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt |  4 ++
 .../IR/PolynomialCanonicalization.td          | 37 ++++++++++++++
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 25 ++++++++++
 .../Dialect/Polynomial/canonicalization.mlir  | 49 +++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     | 18 +++++++
 6 files changed, 136 insertions(+)
 create mode 100644 mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
 create mode 100644 mlir/test/Dialect/Polynomial/canonicalization.mlir

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ae8484501a50d..537be4832e8f8 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -245,6 +245,7 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
     %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
@@ -480,6 +481,7 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 
@@ -498,6 +500,7 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
   let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
index d6e703b8b3591..6dcdcb257674f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS PolynomialCanonicalization.td)
+mlir_tablegen(PolynomialCanonicalization.inc -gen-rewriters)
+add_public_tablegen_target(MLIRPolynomialCanonicalizationIncGen)
+
 add_mlir_dialect_library(MLIRPolynomialDialect
   Polynomial.cpp
   PolynomialAttributes.cpp
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
new file mode 100644
index 0000000000000..1292ececa2309
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -0,0 +1,37 @@
+//===- PolynomialCanonicalization.td - Polynomial patterns -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef POLYNOMIAL_CANONICALIZATION
+#define POLYNOMIAL_CANONICALIZATION
+
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+// TODO: get the proper scalar type from the operand polynomial ring attribute
+def SubAsAdd : Pat<
+  (Polynomial_SubOp $f, $g),
+  (Polynomial_AddOp $f,
+    (Polynomial_MulScalarOp $g,
+      (Arith_ConstantOp
+        ConstantAttr<I32Attr, "-1">)))>;
+
+def INTTAfterNTT : Pat<
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (replaceWithValue $poly),
+  []
+>;
+
+def NTTAfterINTT : Pat<
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (replaceWithValue $tensor),
+  []
+>;
+
+#endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 12010de348237..329d5fec9b7c6 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -7,12 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APInt.h"
 
@@ -183,3 +185,26 @@ LogicalResult INTTOp::verify() {
   auto ring = getOutput().getType().getRing();
   return verifyNTTOp(this->getOperation(), ring, tensorType);
 }
+
+//===----------------------------------------------------------------------===//
+// TableGen'd canonicalization patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+#include "PolynomialCanonicalization.inc"
+} // namespace
+
+void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  populateWithGenerated(results);
+}
+
+void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  populateWithGenerated(results);
+}
+
+void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  populateWithGenerated(results);
+}
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
new file mode 100644
index 0000000000000..54759ff00c966
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -canonicalize %s | FileCheck %s
+#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!tensor_ty = tensor<8xi32, #ntt_ring>
+
+// CHECK-LABEL: @test_canonicalize_intt_after_ntt
+// CHECK: (%[[P:.*]]: [[T:.*]]) -> [[T]]
+func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty {
+  // CHECK-NOT: polynomial.ntt
+  // CHECK-NOT: polynomial.intt
+  // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
+  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+  %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
+  // CHECK: return %[[RESULT]] : [[T]]
+  return %p2 : !ntt_poly_ty
+}
+
+// CHECK-LABEL: @test_canonicalize_ntt_after_intt
+// CHECK: (%[[X:.*]]: [[T:.*]]) -> [[T]]
+func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
+  // CHECK-NOT: polynomial.intt
+  // CHECK-NOT: polynomial.ntt
+  // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
+  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %t2 = arith.addi %t1, %t1 : !tensor_ty
+  // CHECK: return %[[RESULT]] : [[T]]
+  return %t2 : !tensor_ty
+}
+
+#cycl_2048 = #polynomial.int_polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
+#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
+#one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
+
+// CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
+func.func @test_canonicalize_sub_power_of_two_cmod() -> !polynomial.polynomial<#ring> {
+  %poly0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring>
+  %poly1 = polynomial.constant {value=#one_minus_x_squared} : !polynomial.polynomial<#ring>
+  %0 = polynomial.sub %poly0, %poly1  : !polynomial.polynomial<#ring>
+  // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
+  // CHECK: %[[p1:.+]] = polynomial.constant
+  // CHECK: %[[p2:.+]] = polynomial.constant
+  // CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
+  // CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
+  return %0 : !polynomial.polynomial<#ring>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 751cd94d5ff10..8860bec64bf14 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6727,6 +6727,7 @@ cc_library(
         ":IR",
         ":InferTypeOpInterface",
         ":PolynomialAttributesIncGen",
+        ":PolynomialCanonicalizationIncGen",
         ":PolynomialIncGen",
         ":Support",
         "//llvm:Support",
@@ -6817,6 +6818,23 @@ gentbl_cc_library(
     deps = [":PolynomialTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "PolynomialCanonicalizationIncGen",
+    strip_include_prefix = "include/mlir/Dialect/Polynomial/IR",
+    tbl_outs = [
+        (
+            ["-gen-rewriters"],
+            "include/mlir/Dialect/Polynomial/IR/PolynomialCanonicalization.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td",
+    deps = [
+        ":ArithOpsTdFiles",
+        ":PolynomialTdFiles",
+    ],
+)
+
 td_library(
     name = "SPIRVOpsTdFiles",
     srcs = glob(["include/mlir/Dialect/SPIRV/IR/*.td"]),

>From 02e31c5c8a370e66a6dd7e283f8faac39eb1f307 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Wed, 8 May 2024 10:04:20 -0700
Subject: [PATCH 2/4] use struct for type

---
 mlir/test/Dialect/Polynomial/canonicalization.mlir | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 54759ff00c966..2b4cf6aa8997f 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
 #ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
-!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
 // CHECK-LABEL: @test_canonicalize_intt_after_ntt
@@ -34,16 +34,17 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
 #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
 #one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
 #one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
+!sub_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
-func.func @test_canonicalize_sub_power_of_two_cmod() -> !polynomial.polynomial<#ring> {
-  %poly0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring>
-  %poly1 = polynomial.constant {value=#one_minus_x_squared} : !polynomial.polynomial<#ring>
-  %0 = polynomial.sub %poly0, %poly1  : !polynomial.polynomial<#ring>
+func.func @test_canonicalize_sub_power_of_two_cmod() -> !sub_ty {
+  %poly0 = polynomial.constant {value=#one_plus_x_squared} : !sub_ty
+  %poly1 = polynomial.constant {value=#one_minus_x_squared} : !sub_ty
+  %0 = polynomial.sub %poly0, %poly1  : !sub_ty
   // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
   // CHECK: %[[p1:.+]] = polynomial.constant
   // CHECK: %[[p2:.+]] = polynomial.constant
   // CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
   // CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
-  return %0 : !polynomial.polynomial<#ring>
+  return %0 : !sub_ty
 }

>From 3cac88caca9e30a82e6980126e05cb1cb5fabcdf Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 13 May 2024 10:18:03 -0700
Subject: [PATCH 3/4] finish TODO and simplify test

---
 .../Polynomial/IR/PolynomialCanonicalization.td   | 11 ++++++++---
 .../test/Dialect/Polynomial/canonicalization.mlir | 15 +++++----------
 2 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 1292ececa2309..9d09799c1763a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -14,13 +14,18 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
-// TODO: get the proper scalar type from the operand polynomial ring attribute
+// Get a -1 integer attribute of the same type as the polynomial SSA value's
+// ring coefficient type.
+def getMinusOne
+  : NativeCodeCall<
+      "$_builder.getIntegerAttr("
+        "cast<PolynomialType>($0.getType()).getRing().getCoefficientType(), -1)">;
+
 def SubAsAdd : Pat<
   (Polynomial_SubOp $f, $g),
   (Polynomial_AddOp $f,
     (Polynomial_MulScalarOp $g,
-      (Arith_ConstantOp
-        ConstantAttr<I32Attr, "-1">)))>;
+      (Arith_ConstantOp (getMinusOne $g))))>;
 
 def INTTAfterNTT : Pat<
   (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 2b4cf6aa8997f..dbfbf2d93f111 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -32,19 +32,14 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
 
 #cycl_2048 = #polynomial.int_polynomial<1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
-#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
-#one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
 !sub_ty = !polynomial.polynomial<ring=#ring>
 
-// CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
-func.func @test_canonicalize_sub_power_of_two_cmod() -> !sub_ty {
-  %poly0 = polynomial.constant {value=#one_plus_x_squared} : !sub_ty
-  %poly1 = polynomial.constant {value=#one_minus_x_squared} : !sub_ty
+// CHECK-LABEL: test_canonicalize_sub
+// CHECK-SAME: (%[[p0:.*]]: [[T:.*]], %[[p1:.*]]: [[T]]) -> [[T]] {
+func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty {
   %0 = polynomial.sub %poly0, %poly1  : !sub_ty
   // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
-  // CHECK: %[[p1:.+]] = polynomial.constant
-  // CHECK: %[[p2:.+]] = polynomial.constant
-  // CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
-  // CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
+  // CHECK: %[[p1neg:.+]] = polynomial.mul_scalar %[[p1]], %[[minus_one]]
+  // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
   return %0 : !sub_ty
 }

>From ab63790e45f7bdcd2b2a8b0a4db05bfe39686679 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 16 May 2024 10:25:04 -0700
Subject: [PATCH 4/4] populate with specific patterns

---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 329d5fec9b7c6..1a2439fe810b5 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -196,15 +196,15 @@ namespace {
 
 void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  populateWithGenerated(results);
+  results.add<SubAsAdd>(context);
 }
 
 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  populateWithGenerated(results);
+  results.add<NTTAfterINTT>(context);
 }
 
 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  populateWithGenerated(results);
+  results.add<INTTAfterNTT>(context);
 }



More information about the Mlir-commits mailing list