[Mlir-commits] [llvm] [mlir] Poly canonicalization (PR #91410)
Jeremy Kun
llvmlistbot at llvm.org
Tue May 14 11:28:42 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/91410
>From 7097c13a5e8e6db32df66af0b37c892b735c7185 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Tue, 7 May 2024 15:07:38 -0700
Subject: [PATCH 1/3] add basic polynomial canonicalization patterns
---
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 3 ++
mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt | 4 ++
.../IR/PolynomialCanonicalization.td | 37 ++++++++++++++
.../Dialect/Polynomial/IR/PolynomialOps.cpp | 25 ++++++++++
.../Dialect/Polynomial/canonicalization.mlir | 49 +++++++++++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 18 +++++++
6 files changed, 136 insertions(+)
create mode 100644 mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
create mode 100644 mlir/test/Dialect/Polynomial/canonicalization.mlir
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ae8484501a50d..537be4832e8f8 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -245,6 +245,7 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
```
}];
+ let hasCanonicalizer = 1;
}
def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
@@ -480,6 +481,7 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
let arguments = (ins Polynomial_PolynomialType:$input);
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
@@ -498,6 +500,7 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
index d6e703b8b3591..6dcdcb257674f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS PolynomialCanonicalization.td)
+mlir_tablegen(PolynomialCanonicalization.inc -gen-rewriters)
+add_public_tablegen_target(MLIRPolynomialCanonicalizationIncGen)
+
add_mlir_dialect_library(MLIRPolynomialDialect
Polynomial.cpp
PolynomialAttributes.cpp
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
new file mode 100644
index 0000000000000..1292ececa2309
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -0,0 +1,37 @@
+//===- PolynomialCanonicalization.td - Polynomial patterns -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef POLYNOMIAL_CANONICALIZATION
+#define POLYNOMIAL_CANONICALIZATION
+
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+// TODO: get the proper scalar type from the operand polynomial ring attribute
+def SubAsAdd : Pat<
+ (Polynomial_SubOp $f, $g),
+ (Polynomial_AddOp $f,
+ (Polynomial_MulScalarOp $g,
+ (Arith_ConstantOp
+ ConstantAttr<I32Attr, "-1">)))>;
+
+def INTTAfterNTT : Pat<
+ (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+ (replaceWithValue $poly),
+ []
+>;
+
+def NTTAfterINTT : Pat<
+ (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+ (replaceWithValue $tensor),
+ []
+>;
+
+#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 12010de348237..329d5fec9b7c6 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -7,12 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APInt.h"
@@ -183,3 +185,26 @@ LogicalResult INTTOp::verify() {
auto ring = getOutput().getType().getRing();
return verifyNTTOp(this->getOperation(), ring, tensorType);
}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd canonicalization patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+#include "PolynomialCanonicalization.inc"
+} // namespace
+
+void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ populateWithGenerated(results);
+}
+
+void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ populateWithGenerated(results);
+}
+
+void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ populateWithGenerated(results);
+}
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
new file mode 100644
index 0000000000000..54759ff00c966
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -canonicalize %s | FileCheck %s
+#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!tensor_ty = tensor<8xi32, #ntt_ring>
+
+// CHECK-LABEL: @test_canonicalize_intt_after_ntt
+// CHECK: (%[[P:.*]]: [[T:.*]]) -> [[T]]
+func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty {
+ // CHECK-NOT: polynomial.ntt
+ // CHECK-NOT: polynomial.intt
+ // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
+ %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+ %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+ %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
+ // CHECK: return %[[RESULT]] : [[T]]
+ return %p2 : !ntt_poly_ty
+}
+
+// CHECK-LABEL: @test_canonicalize_ntt_after_intt
+// CHECK: (%[[X:.*]]: [[T:.*]]) -> [[T]]
+func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
+ // CHECK-NOT: polynomial.intt
+ // CHECK-NOT: polynomial.ntt
+ // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
+ %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
+ %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+ %t2 = arith.addi %t1, %t1 : !tensor_ty
+ // CHECK: return %[[RESULT]] : [[T]]
+ return %t2 : !tensor_ty
+}
+
+#cycl_2048 = #polynomial.int_polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
+#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
+#one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
+
+// CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
+func.func @test_canonicalize_sub_power_of_two_cmod() -> !polynomial.polynomial<#ring> {
+ %poly0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring>
+ %poly1 = polynomial.constant {value=#one_minus_x_squared} : !polynomial.polynomial<#ring>
+ %0 = polynomial.sub %poly0, %poly1 : !polynomial.polynomial<#ring>
+ // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
+ // CHECK: %[[p1:.+]] = polynomial.constant
+ // CHECK: %[[p2:.+]] = polynomial.constant
+ // CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
+ // CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
+ return %0 : !polynomial.polynomial<#ring>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 751cd94d5ff10..8860bec64bf14 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6727,6 +6727,7 @@ cc_library(
":IR",
":InferTypeOpInterface",
":PolynomialAttributesIncGen",
+ ":PolynomialCanonicalizationIncGen",
":PolynomialIncGen",
":Support",
"//llvm:Support",
@@ -6817,6 +6818,23 @@ gentbl_cc_library(
deps = [":PolynomialTdFiles"],
)
+gentbl_cc_library(
+ name = "PolynomialCanonicalizationIncGen",
+ strip_include_prefix = "include/mlir/Dialect/Polynomial/IR",
+ tbl_outs = [
+ (
+ ["-gen-rewriters"],
+ "include/mlir/Dialect/Polynomial/IR/PolynomialCanonicalization.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td",
+ deps = [
+ ":ArithOpsTdFiles",
+ ":PolynomialTdFiles",
+ ],
+)
+
td_library(
name = "SPIRVOpsTdFiles",
srcs = glob(["include/mlir/Dialect/SPIRV/IR/*.td"]),
>From 02e31c5c8a370e66a6dd7e283f8faac39eb1f307 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Wed, 8 May 2024 10:04:20 -0700
Subject: [PATCH 2/3] use struct for type
---
mlir/test/Dialect/Polynomial/canonicalization.mlir | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 54759ff00c966..2b4cf6aa8997f 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -canonicalize %s | FileCheck %s
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
-!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
!tensor_ty = tensor<8xi32, #ntt_ring>
// CHECK-LABEL: @test_canonicalize_intt_after_ntt
@@ -34,16 +34,17 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
#one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
+!sub_ty = !polynomial.polynomial<ring=#ring>
// CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
-func.func @test_canonicalize_sub_power_of_two_cmod() -> !polynomial.polynomial<#ring> {
- %poly0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring>
- %poly1 = polynomial.constant {value=#one_minus_x_squared} : !polynomial.polynomial<#ring>
- %0 = polynomial.sub %poly0, %poly1 : !polynomial.polynomial<#ring>
+func.func @test_canonicalize_sub_power_of_two_cmod() -> !sub_ty {
+ %poly0 = polynomial.constant {value=#one_plus_x_squared} : !sub_ty
+ %poly1 = polynomial.constant {value=#one_minus_x_squared} : !sub_ty
+ %0 = polynomial.sub %poly0, %poly1 : !sub_ty
// CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
// CHECK: %[[p1:.+]] = polynomial.constant
// CHECK: %[[p2:.+]] = polynomial.constant
// CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
// CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
- return %0 : !polynomial.polynomial<#ring>
+ return %0 : !sub_ty
}
>From 3cac88caca9e30a82e6980126e05cb1cb5fabcdf Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 13 May 2024 10:18:03 -0700
Subject: [PATCH 3/3] finish TODO and simplify test
---
.../Polynomial/IR/PolynomialCanonicalization.td | 11 ++++++++---
.../test/Dialect/Polynomial/canonicalization.mlir | 15 +++++----------
2 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 1292ececa2309..9d09799c1763a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -14,13 +14,18 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"
-// TODO: get the proper scalar type from the operand polynomial ring attribute
+// Get a -1 integer attribute of the same type as the polynomial SSA value's
+// ring coefficient type.
+def getMinusOne
+ : NativeCodeCall<
+ "$_builder.getIntegerAttr("
+ "cast<PolynomialType>($0.getType()).getRing().getCoefficientType(), -1)">;
+
def SubAsAdd : Pat<
(Polynomial_SubOp $f, $g),
(Polynomial_AddOp $f,
(Polynomial_MulScalarOp $g,
- (Arith_ConstantOp
- ConstantAttr<I32Attr, "-1">)))>;
+ (Arith_ConstantOp (getMinusOne $g))))>;
def INTTAfterNTT : Pat<
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 2b4cf6aa8997f..dbfbf2d93f111 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -32,19 +32,14 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
#cycl_2048 = #polynomial.int_polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
-#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
-#one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
!sub_ty = !polynomial.polynomial<ring=#ring>
-// CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
-func.func @test_canonicalize_sub_power_of_two_cmod() -> !sub_ty {
- %poly0 = polynomial.constant {value=#one_plus_x_squared} : !sub_ty
- %poly1 = polynomial.constant {value=#one_minus_x_squared} : !sub_ty
+// CHECK-LABEL: test_canonicalize_sub
+// CHECK-SAME: (%[[p0:.*]]: [[T:.*]], %[[p1:.*]]: [[T]]) -> [[T]] {
+func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty {
%0 = polynomial.sub %poly0, %poly1 : !sub_ty
// CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
- // CHECK: %[[p1:.+]] = polynomial.constant
- // CHECK: %[[p2:.+]] = polynomial.constant
- // CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
- // CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
+ // CHECK: %[[p1neg:.+]] = polynomial.mul_scalar %[[p1]], %[[minus_one]]
+ // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
return %0 : !sub_ty
}
More information about the Mlir-commits
mailing list