[Mlir-commits] [mlir] [MLIR][Presburger] Add LLL basis reduction (PR #75565)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 16 11:45:48 PST 2023


https://github.com/Abhinav271828 updated https://github.com/llvm/llvm-project/pull/75565

>From 4ddb05a281bacc8c066f7d1ee76063a8a6fd0d26 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Fri, 15 Dec 2023 12:31:09 +0530
Subject: [PATCH 01/14] Add abs method

---
 mlir/include/mlir/Analysis/Presburger/Fraction.h | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index afcbed84c66bc3..5c440425e429e1 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -101,6 +101,10 @@ inline bool operator>=(const Fraction &x, const Fraction &y) {
   return compare(x, y) >= 0;
 }
 
+inline Fraction abs(const Fraction &f) {
+  return Fraction(abs(f.num), f.den);
+}
+
 inline Fraction reduce(const Fraction &f) {
   if (f == Fraction(0))
     return Fraction(0, 1);

>From 8b18de26e1dfccfa2e031bb4da498a485f71452d Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Fri, 15 Dec 2023 12:31:18 +0530
Subject: [PATCH 02/14] Add lll

---
 .../include/mlir/Analysis/Presburger/Matrix.h |  5 +++
 mlir/lib/Analysis/Presburger/Matrix.cpp       | 35 +++++++++++++++-
 .../Analysis/Presburger/MatrixTest.cpp        | 40 ++++++++++++++++++-
 3 files changed, 78 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 89fad85c0c3374..c9d5c5be180c56 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -270,6 +270,11 @@ class FracMatrix : public Matrix<Fraction> {
   // of the rows of matrix (cubic time).
   // The rows of the matrix must be linearly independent.
   FracMatrix gramSchmidt() const;
+
+  // Run LLL basis reduction on the matrix, modifying it in-place.
+  // The parameter is delta.
+  void LLL(Fraction delta);
+
 };
 
 } // namespace presburger
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 1fcc6d072b44b7..c2fed286d37bad 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -576,4 +576,37 @@ FracMatrix FracMatrix::gramSchmidt() const {
     }
   }
   return orth;
-}
\ No newline at end of file
+}
+
+void FracMatrix::LLL(Fraction delta)
+{
+    MPInt nearest;
+    Fraction mu;
+
+    FracMatrix bStar = gramSchmidt();
+
+    unsigned k = 1;
+    while (k < getNumRows())
+    {
+        for (unsigned j = k-1; j < k; j--)
+        {
+            mu = dotProduct(getRow(k), bStar.getRow(j)) / dotProduct(bStar.getRow(j), bStar.getRow(j));
+            if (abs(mu) > Fraction(1, 2))
+            {
+                nearest = floor(mu + Fraction(1, 2));
+                addToRow(k, getRow(j), -Fraction(nearest, 1));
+                bStar = gramSchmidt();
+            }
+        }
+        mu = dotProduct(getRow(k), bStar.getRow(k-1)) / dotProduct(bStar.getRow(k-1), bStar.getRow(k-1));
+        if (dotProduct(bStar.getRow(k), bStar.getRow(k)) > (delta - mu*mu) * dotProduct(bStar.getRow(k-1), bStar.getRow(k-1)))
+            k += 1;
+        else
+        {
+            swapRows(k, k-1);
+            bStar = gramSchmidt();
+            k = k > 1 ? k-1 : 1;
+        }
+    }
+    return;
+}
diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 508d4fa369c14c..9c5cfb643119e2 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -377,4 +377,42 @@ TEST(MatrixTest, gramSchmidt) {
   gs = mat.gramSchmidt();
 
   EXPECT_EQ_FRAC_MATRIX(gs, FracMatrix::identity(10));
-}
\ No newline at end of file
+}
+
+TEST(MatrixTest, LLL) {
+    FracMatrix mat = makeFracMatrix(3, 3, {{Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)},
+                                                 {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)},
+                                                 {Fraction(3, 1), Fraction(5, 1), Fraction(6, 1)}});
+    mat.LLL(Fraction(3, 4));
+    
+    FracMatrix LLL = makeFracMatrix(3, 3, {{Fraction(0, 1), Fraction(1, 1), Fraction(0, 1)},
+                                                 {Fraction(1, 1), Fraction(0, 1), Fraction(1, 1)},
+                                                 {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)}});
+
+    for (unsigned row = 0; row < 3; row++)
+      for (unsigned col = 0; col < 3; col++)
+        EXPECT_EQ(mat(row, col), LLL(row, col));
+
+
+    mat = makeFracMatrix(2, 2, {{Fraction(12, 1), Fraction(2, 1)}, {Fraction(13, 1), Fraction(4, 1)}});
+    LLL = makeFracMatrix(2, 2, {{Fraction(1, 1),  Fraction(2, 1)}, {Fraction(9, 1),  Fraction(-4, 1)}});
+
+    mat.LLL(Fraction(3, 4));
+
+    for (unsigned row = 0; row < 2; row++)
+      for (unsigned col = 0; col < 2; col++)
+        EXPECT_EQ(mat(row, col), LLL(row, col));
+
+    mat = makeFracMatrix(3, 3, {{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
+                                {Fraction(0, 1), Fraction(1, 3), -Fraction(5, 3)},
+                                {Fraction(0, 1), Fraction(0, 1), Fraction(1, 1)}});
+    LLL = makeFracMatrix(3, 3, {{Fraction(0, 1), Fraction(1, 3), Fraction(1, 3)},
+                                {Fraction(0, 1), Fraction(1, 3), -Fraction(2, 3)},
+                                {Fraction(1, 1), Fraction(0, 1), Fraction(0, 1)}});
+
+    mat.LLL(Fraction(3, 4));
+
+    for (unsigned row = 0; row < 3; row++)
+      for (unsigned col = 0; col < 3; col++)
+        EXPECT_EQ(mat(row, col), LLL(row, col));
+}

>From e1760faf1aea972223664b050448dfdfad440372 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Fri, 15 Dec 2023 12:32:24 +0530
Subject: [PATCH 03/14] Formatting

---
 .../mlir/Analysis/Presburger/Fraction.h       |  4 +-
 .../include/mlir/Analysis/Presburger/Matrix.h |  1 -
 mlir/lib/Analysis/Presburger/Matrix.cpp       | 59 ++++++++-------
 .../Analysis/Presburger/MatrixTest.cpp        | 71 +++++++++++--------
 4 files changed, 70 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index 5c440425e429e1..e95056ae5fc961 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -101,9 +101,7 @@ inline bool operator>=(const Fraction &x, const Fraction &y) {
   return compare(x, y) >= 0;
 }
 
-inline Fraction abs(const Fraction &f) {
-  return Fraction(abs(f.num), f.den);
-}
+inline Fraction abs(const Fraction &f) { return Fraction(abs(f.num), f.den); }
 
 inline Fraction reduce(const Fraction &f) {
   if (f == Fraction(0))
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index c9d5c5be180c56..fca3164bda6278 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -274,7 +274,6 @@ class FracMatrix : public Matrix<Fraction> {
   // Run LLL basis reduction on the matrix, modifying it in-place.
   // The parameter is delta.
   void LLL(Fraction delta);
-
 };
 
 } // namespace presburger
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index c2fed286d37bad..e07bcc6de8ab5c 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -578,35 +578,34 @@ FracMatrix FracMatrix::gramSchmidt() const {
   return orth;
 }
 
-void FracMatrix::LLL(Fraction delta)
-{
-    MPInt nearest;
-    Fraction mu;
-
-    FracMatrix bStar = gramSchmidt();
-
-    unsigned k = 1;
-    while (k < getNumRows())
-    {
-        for (unsigned j = k-1; j < k; j--)
-        {
-            mu = dotProduct(getRow(k), bStar.getRow(j)) / dotProduct(bStar.getRow(j), bStar.getRow(j));
-            if (abs(mu) > Fraction(1, 2))
-            {
-                nearest = floor(mu + Fraction(1, 2));
-                addToRow(k, getRow(j), -Fraction(nearest, 1));
-                bStar = gramSchmidt();
-            }
-        }
-        mu = dotProduct(getRow(k), bStar.getRow(k-1)) / dotProduct(bStar.getRow(k-1), bStar.getRow(k-1));
-        if (dotProduct(bStar.getRow(k), bStar.getRow(k)) > (delta - mu*mu) * dotProduct(bStar.getRow(k-1), bStar.getRow(k-1)))
-            k += 1;
-        else
-        {
-            swapRows(k, k-1);
-            bStar = gramSchmidt();
-            k = k > 1 ? k-1 : 1;
-        }
+void FracMatrix::LLL(Fraction delta) {
+  MPInt nearest;
+  Fraction mu;
+
+  FracMatrix bStar = gramSchmidt();
+
+  unsigned k = 1;
+  while (k < getNumRows()) {
+    for (unsigned j = k - 1; j < k; j--) {
+      mu = dotProduct(getRow(k), bStar.getRow(j)) /
+           dotProduct(bStar.getRow(j), bStar.getRow(j));
+      if (abs(mu) > Fraction(1, 2)) {
+        nearest = floor(mu + Fraction(1, 2));
+        addToRow(k, getRow(j), -Fraction(nearest, 1));
+        bStar = gramSchmidt();
+      }
     }
-    return;
+    mu = dotProduct(getRow(k), bStar.getRow(k - 1)) /
+         dotProduct(bStar.getRow(k - 1), bStar.getRow(k - 1));
+    if (dotProduct(bStar.getRow(k), bStar.getRow(k)) >
+        (delta - mu * mu) *
+            dotProduct(bStar.getRow(k - 1), bStar.getRow(k - 1)))
+      k += 1;
+    else {
+      swapRows(k, k - 1);
+      bStar = gramSchmidt();
+      k = k > 1 ? k - 1 : 1;
+    }
+  }
+  return;
 }
diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 9c5cfb643119e2..4d7d0531e5ee84 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -380,39 +380,48 @@ TEST(MatrixTest, gramSchmidt) {
 }
 
 TEST(MatrixTest, LLL) {
-    FracMatrix mat = makeFracMatrix(3, 3, {{Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)},
-                                                 {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)},
-                                                 {Fraction(3, 1), Fraction(5, 1), Fraction(6, 1)}});
-    mat.LLL(Fraction(3, 4));
-    
-    FracMatrix LLL = makeFracMatrix(3, 3, {{Fraction(0, 1), Fraction(1, 1), Fraction(0, 1)},
-                                                 {Fraction(1, 1), Fraction(0, 1), Fraction(1, 1)},
-                                                 {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)}});
-
-    for (unsigned row = 0; row < 3; row++)
-      for (unsigned col = 0; col < 3; col++)
-        EXPECT_EQ(mat(row, col), LLL(row, col));
-
-
-    mat = makeFracMatrix(2, 2, {{Fraction(12, 1), Fraction(2, 1)}, {Fraction(13, 1), Fraction(4, 1)}});
-    LLL = makeFracMatrix(2, 2, {{Fraction(1, 1),  Fraction(2, 1)}, {Fraction(9, 1),  Fraction(-4, 1)}});
-
-    mat.LLL(Fraction(3, 4));
+  FracMatrix mat =
+      makeFracMatrix(3, 3,
+                     {{Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)},
+                      {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)},
+                      {Fraction(3, 1), Fraction(5, 1), Fraction(6, 1)}});
+  mat.LLL(Fraction(3, 4));
+
+  FracMatrix LLL =
+      makeFracMatrix(3, 3,
+                     {{Fraction(0, 1), Fraction(1, 1), Fraction(0, 1)},
+                      {Fraction(1, 1), Fraction(0, 1), Fraction(1, 1)},
+                      {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)}});
+
+  for (unsigned row = 0; row < 3; row++)
+    for (unsigned col = 0; col < 3; col++)
+      EXPECT_EQ(mat(row, col), LLL(row, col));
 
-    for (unsigned row = 0; row < 2; row++)
-      for (unsigned col = 0; col < 2; col++)
-        EXPECT_EQ(mat(row, col), LLL(row, col));
+  mat = makeFracMatrix(
+      2, 2,
+      {{Fraction(12, 1), Fraction(2, 1)}, {Fraction(13, 1), Fraction(4, 1)}});
+  LLL = makeFracMatrix(
+      2, 2,
+      {{Fraction(1, 1), Fraction(2, 1)}, {Fraction(9, 1), Fraction(-4, 1)}});
 
-    mat = makeFracMatrix(3, 3, {{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
-                                {Fraction(0, 1), Fraction(1, 3), -Fraction(5, 3)},
-                                {Fraction(0, 1), Fraction(0, 1), Fraction(1, 1)}});
-    LLL = makeFracMatrix(3, 3, {{Fraction(0, 1), Fraction(1, 3), Fraction(1, 3)},
-                                {Fraction(0, 1), Fraction(1, 3), -Fraction(2, 3)},
-                                {Fraction(1, 1), Fraction(0, 1), Fraction(0, 1)}});
+  mat.LLL(Fraction(3, 4));
 
-    mat.LLL(Fraction(3, 4));
+  for (unsigned row = 0; row < 2; row++)
+    for (unsigned col = 0; col < 2; col++)
+      EXPECT_EQ(mat(row, col), LLL(row, col));
 
-    for (unsigned row = 0; row < 3; row++)
-      for (unsigned col = 0; col < 3; col++)
-        EXPECT_EQ(mat(row, col), LLL(row, col));
+  mat = makeFracMatrix(3, 3,
+                       {{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
+                        {Fraction(0, 1), Fraction(1, 3), -Fraction(5, 3)},
+                        {Fraction(0, 1), Fraction(0, 1), Fraction(1, 1)}});
+  LLL = makeFracMatrix(3, 3,
+                       {{Fraction(0, 1), Fraction(1, 3), Fraction(1, 3)},
+                        {Fraction(0, 1), Fraction(1, 3), -Fraction(2, 3)},
+                        {Fraction(1, 1), Fraction(0, 1), Fraction(0, 1)}});
+
+  mat.LLL(Fraction(3, 4));
+
+  for (unsigned row = 0; row < 3; row++)
+    for (unsigned col = 0; col < 3; col++)
+      EXPECT_EQ(mat(row, col), LLL(row, col));
 }

>From 5556d5ea06b05650305ce4a2b272d7c736296384 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Fri, 15 Dec 2023 12:35:33 +0530
Subject: [PATCH 04/14] Comment

---
 mlir/lib/Analysis/Presburger/Matrix.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index e07bcc6de8ab5c..ae04c9f9149a5a 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -582,6 +582,9 @@ void FracMatrix::LLL(Fraction delta) {
   MPInt nearest;
   Fraction mu;
 
+  // `bStar` holds the Gram-Schmidt orthogonalisation
+  // of the matrix at all times. It is recomputed every
+  // time the matrix is modified during the algorithm.
   FracMatrix bStar = gramSchmidt();
 
   unsigned k = 1;

>From d2171a516a00e5343a00e576aacbabd17d87ea29 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 16:04:26 +0530
Subject: [PATCH 05/14] Add assert in abs

---
 mlir/include/mlir/Analysis/Presburger/Fraction.h | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index e95056ae5fc961..4769784bdc2b44 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -101,7 +101,10 @@ inline bool operator>=(const Fraction &x, const Fraction &y) {
   return compare(x, y) >= 0;
 }
 
-inline Fraction abs(const Fraction &f) { return Fraction(abs(f.num), f.den); }
+inline Fraction abs(const Fraction &f) {
+  assert(f.den > 0 && "denominator of fraction must be positive!");
+  return Fraction(abs(f.num), f.den);
+}
 
 inline Fraction reduce(const Fraction &f) {
   if (f == Fraction(0))

>From d73ada8fe8297dab203cb64570958e29af7384fc Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 16:05:26 +0530
Subject: [PATCH 06/14] Comment for delta

---
 mlir/include/mlir/Analysis/Presburger/Matrix.h | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index fca3164bda6278..347e2e0489786f 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -272,7 +272,9 @@ class FracMatrix : public Matrix<Fraction> {
   FracMatrix gramSchmidt() const;
 
   // Run LLL basis reduction on the matrix, modifying it in-place.
-  // The parameter is delta.
+  // The parameter is what [the original
+  // paper](https://www.cs.cmu.edu/~avrim/451f11/lectures/lect1129_LLL.pdf)
+  // calls `y`, usually 3/4.
   void LLL(Fraction delta);
 };
 

>From c210e28fe9d85efb3f89bb4730385bda05776d89 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 16:12:00 +0530
Subject: [PATCH 07/14] Use EXPECT_EQ_FRAC_MATRIX for tests

---
 mlir/unittests/Analysis/Presburger/MatrixTest.cpp | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 4d7d0531e5ee84..33bebbb30baac2 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -393,9 +393,7 @@ TEST(MatrixTest, LLL) {
                       {Fraction(1, 1), Fraction(0, 1), Fraction(1, 1)},
                       {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)}});
 
-  for (unsigned row = 0; row < 3; row++)
-    for (unsigned col = 0; col < 3; col++)
-      EXPECT_EQ(mat(row, col), LLL(row, col));
+  EXPECT_EQ_FRAC_MATRIX(mat, LLL);
 
   mat = makeFracMatrix(
       2, 2,
@@ -406,9 +404,7 @@ TEST(MatrixTest, LLL) {
 
   mat.LLL(Fraction(3, 4));
 
-  for (unsigned row = 0; row < 2; row++)
-    for (unsigned col = 0; col < 2; col++)
-      EXPECT_EQ(mat(row, col), LLL(row, col));
+  EXPECT_EQ_FRAC_MATRIX(mat, LLL);
 
   mat = makeFracMatrix(3, 3,
                        {{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
@@ -421,7 +417,5 @@ TEST(MatrixTest, LLL) {
 
   mat.LLL(Fraction(3, 4));
 
-  for (unsigned row = 0; row < 3; row++)
-    for (unsigned col = 0; col < 3; col++)
-      EXPECT_EQ(mat(row, col), LLL(row, col));
+  EXPECT_EQ_FRAC_MATRIX(mat, LLL);
 }

>From 7475643af69f3660a2c601eb825c4c79060752d8 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 16:13:38 +0530
Subject: [PATCH 08/14] Rename bStar to gsOrth

---
 mlir/lib/Analysis/Presburger/Matrix.cpp | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index ae04c9f9149a5a..d720462a1ec0e0 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -582,31 +582,31 @@ void FracMatrix::LLL(Fraction delta) {
   MPInt nearest;
   Fraction mu;
 
-  // `bStar` holds the Gram-Schmidt orthogonalisation
+  // `gsOrth` holds the Gram-Schmidt orthogonalisation
   // of the matrix at all times. It is recomputed every
   // time the matrix is modified during the algorithm.
-  FracMatrix bStar = gramSchmidt();
+  FracMatrix gsOrth = gramSchmidt();
 
   unsigned k = 1;
   while (k < getNumRows()) {
     for (unsigned j = k - 1; j < k; j--) {
-      mu = dotProduct(getRow(k), bStar.getRow(j)) /
-           dotProduct(bStar.getRow(j), bStar.getRow(j));
+      mu = dotProduct(getRow(k), gsOrth.getRow(j)) /
+           dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
       if (abs(mu) > Fraction(1, 2)) {
         nearest = floor(mu + Fraction(1, 2));
         addToRow(k, getRow(j), -Fraction(nearest, 1));
-        bStar = gramSchmidt();
+        gsOrth = gramSchmidt();
       }
     }
-    mu = dotProduct(getRow(k), bStar.getRow(k - 1)) /
-         dotProduct(bStar.getRow(k - 1), bStar.getRow(k - 1));
-    if (dotProduct(bStar.getRow(k), bStar.getRow(k)) >
+    mu = dotProduct(getRow(k), gsOrth.getRow(k - 1)) /
+         dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1));
+    if (dotProduct(gsOrth.getRow(k), gsOrth.getRow(k)) >
         (delta - mu * mu) *
-            dotProduct(bStar.getRow(k - 1), bStar.getRow(k - 1)))
+            dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1)))
       k += 1;
     else {
       swapRows(k, k - 1);
-      bStar = gramSchmidt();
+      gsOrth = gramSchmidt();
       k = k > 1 ? k - 1 : 1;
     }
   }

>From aa8453dd73b4a67c8d4afcac70801c9816eb34ee Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 18:25:52 +0530
Subject: [PATCH 09/14] Add documentation for LLL

---
 mlir/lib/Analysis/Presburger/Matrix.cpp | 38 +++++++++++++++++++++++--
 1 file changed, 36 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index d720462a1ec0e0..e3babb862b7596 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -578,6 +578,31 @@ FracMatrix FracMatrix::gramSchmidt() const {
   return orth;
 }
 
+// Convert the matrix, interpreted (row-wise) as a basis
+// to an LLL-reduced basis.
+//
+// This is an implementation of the algorithm described in
+// "Factoring polynomials with rational coefficients" by
+// A. K. Lenstra, H. W. Lenstra Jr., L. Lovasz.
+//
+// Let {b_1,  ..., b_n}  be the current basis and
+//     {b_1*, ..., b_n*} be the Gram-Schmidt orthogonalised
+//                          basis (unnormalized).
+// Define the Gram-Schmidt coefficients μ_ij as
+// (b_i • b_j*) / (b_j* • b_j*), where (•) represents the inner product.
+//
+// We iterate starting from the second row to the last row.
+//
+// For the kth row, we first check μ_kj for all rows j < k.
+// If it is more than 1/2, we subtract b_j (scaled by μ_kj)
+// from b_k.
+//
+// Now, we update k.
+// If b_k and b_{k-1} satisfy the Lovasz condition, we are done
+// and we increment k.
+// Otherwise, we swap b_k and b_{k-1} and decrement k.
+//
+// We repeat this until k = n and return.
 void FracMatrix::LLL(Fraction delta) {
   MPInt nearest;
   Fraction mu;
@@ -585,28 +610,37 @@ void FracMatrix::LLL(Fraction delta) {
   // `gsOrth` holds the Gram-Schmidt orthogonalisation
   // of the matrix at all times. It is recomputed every
   // time the matrix is modified during the algorithm.
+  // This is naive and can be optimised.
   FracMatrix gsOrth = gramSchmidt();
 
+  // We start from the second row.
   unsigned k = 1;
+
   while (k < getNumRows()) {
     for (unsigned j = k - 1; j < k; j--) {
+      // Compute the Gram-Schmidt coefficient μ_jk.
       mu = dotProduct(getRow(k), gsOrth.getRow(j)) /
            dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
       if (abs(mu) > Fraction(1, 2)) {
         nearest = floor(mu + Fraction(1, 2));
+        // Subtract b_j scaled by μ_jk from b_k.
         addToRow(k, getRow(j), -Fraction(nearest, 1));
-        gsOrth = gramSchmidt();
+        gsOrth = gramSchmidt(); // recomputation
       }
     }
     mu = dotProduct(getRow(k), gsOrth.getRow(k - 1)) /
          dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1));
+    // Check the Lovasz condition for b_k and b_{k-1}.
     if (dotProduct(gsOrth.getRow(k), gsOrth.getRow(k)) >
         (delta - mu * mu) *
             dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1)))
+      // If it is satisfied, proceed to the next k.
       k += 1;
     else {
+      // If it is unsatisfied, decrement k (without
+      // going beyond the second row).
       swapRows(k, k - 1);
-      gsOrth = gramSchmidt();
+      gsOrth = gramSchmidt(); // recomputation
       k = k > 1 ? k - 1 : 1;
     }
   }

>From 61d9685012df6a708ff054d7aeaeb9db1832ca14 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 19:53:55 +0530
Subject: [PATCH 10/14] Add round method to Fraction

---
 mlir/include/mlir/Analysis/Presburger/Fraction.h | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index 4769784bdc2b44..773b9926f02397 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -129,6 +129,9 @@ inline Fraction operator-(const Fraction &x, const Fraction &y) {
   return reduce(Fraction(x.num * y.den - x.den * y.num, x.den * y.den));
 }
 
+// Find the integer nearest to a given fraction.
+inline MPInt round(const Fraction &f) { return floor(f + Fraction(1, 2)); }
+
 inline Fraction &operator+=(Fraction &x, const Fraction &y) {
   x = x + y;
   return x;

>From c0ac5e6088bad450d040370929bd9f4451641eea Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 19:54:54 +0530
Subject: [PATCH 11/14] Update LLL documentation

---
 mlir/lib/Analysis/Presburger/Matrix.cpp | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index e3babb862b7596..1ccb4969345b2d 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -594,12 +594,13 @@ FracMatrix FracMatrix::gramSchmidt() const {
 // We iterate starting from the second row to the last row.
 //
 // For the kth row, we first check μ_kj for all rows j < k.
-// If it is more than 1/2, we subtract b_j (scaled by μ_kj)
+// We subtract b_j (scaled by the integer nearest to μ_kj)
 // from b_k.
 //
 // Now, we update k.
-// If b_k and b_{k-1} satisfy the Lovasz condition, we are done
-// and we increment k.
+// If b_k and b_{k-1} satisfy the Lovasz condition
+//    |b_k|^2 ≥ (δ - μ_k{k-1}^2) |b_{k-1}|^2,
+// we are done and we increment k.
 // Otherwise, we swap b_k and b_{k-1} and decrement k.
 //
 // We repeat this until k = n and return.
@@ -623,9 +624,9 @@ void FracMatrix::LLL(Fraction delta) {
            dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
       if (abs(mu) > Fraction(1, 2)) {
         nearest = floor(mu + Fraction(1, 2));
-        // Subtract b_j scaled by μ_jk from b_k.
+        // Subtract b_j scaled by the integer nearest to μ_jk from b_k.
         addToRow(k, getRow(j), -Fraction(nearest, 1));
-        gsOrth = gramSchmidt(); // recomputation
+        gsOrth = gramSchmidt(); // Update orthogonalization.
       }
     }
     mu = dotProduct(getRow(k), gsOrth.getRow(k - 1)) /
@@ -637,10 +638,10 @@ void FracMatrix::LLL(Fraction delta) {
       // If it is satisfied, proceed to the next k.
       k += 1;
     else {
-      // If it is unsatisfied, decrement k (without
+      // If it is not satisfied, decrement k (without
       // going beyond the second row).
       swapRows(k, k - 1);
-      gsOrth = gramSchmidt(); // recomputation
+      gsOrth = gramSchmidt(); // Update orthogonalization.
       k = k > 1 ? k - 1 : 1;
     }
   }

>From 3983928432b73a4ab7bd70e9c66a0ce39f258026 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 20:14:58 +0530
Subject: [PATCH 12/14] Check conditions instead of hardcoded answers

---
 .../Analysis/Presburger/MatrixTest.cpp        | 72 ++++++++++++++-----
 1 file changed, 54 insertions(+), 18 deletions(-)

diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 33bebbb30baac2..29857962f120b7 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -387,35 +387,71 @@ TEST(MatrixTest, LLL) {
                       {Fraction(3, 1), Fraction(5, 1), Fraction(6, 1)}});
   mat.LLL(Fraction(3, 4));
 
-  FracMatrix LLL =
-      makeFracMatrix(3, 3,
-                     {{Fraction(0, 1), Fraction(1, 1), Fraction(0, 1)},
-                      {Fraction(1, 1), Fraction(0, 1), Fraction(1, 1)},
-                      {Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)}});
+  FracMatrix gsOrth = mat.gramSchmidt();
+
+  // Size-reduced check.
+  for (unsigned i = 0; i < 3; i++)
+    for (unsigned j = 0; j < i; j++) {
+      Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(j)) /
+                    dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
+      ASSERT_TRUE(abs(mu) <= Fraction(1, 2));
+    }
 
-  EXPECT_EQ_FRAC_MATRIX(mat, LLL);
+  // Lovasz condition check.
+  for (unsigned i = 1; i < 3; i++) {
+    Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(i - 1)) /
+                  dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1));
+    ASSERT_TRUE(dotProduct(mat.getRow(i), mat.getRow(i)) >
+                (Fraction(3, 4) - mu * mu) *
+                    dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1)));
+  }
 
   mat = makeFracMatrix(
       2, 2,
       {{Fraction(12, 1), Fraction(2, 1)}, {Fraction(13, 1), Fraction(4, 1)}});
-  LLL = makeFracMatrix(
-      2, 2,
-      {{Fraction(1, 1), Fraction(2, 1)}, {Fraction(9, 1), Fraction(-4, 1)}});
-
   mat.LLL(Fraction(3, 4));
 
-  EXPECT_EQ_FRAC_MATRIX(mat, LLL);
+  gsOrth = mat.gramSchmidt();
+
+  // Size-reduced check.
+  for (unsigned i = 0; i < 2; i++)
+    for (unsigned j = 0; j < i; j++) {
+      Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(j)) /
+                    dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
+      ASSERT_TRUE(abs(mu) <= Fraction(1, 2));
+    }
+
+  // Lovasz condition check.
+  for (unsigned i = 1; i < 2; i++) {
+    Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(i - 1)) /
+                  dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1));
+    ASSERT_TRUE(dotProduct(mat.getRow(i), mat.getRow(i)) >
+                (Fraction(3, 4) - mu * mu) *
+                    dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1)));
+  }
 
   mat = makeFracMatrix(3, 3,
                        {{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
                         {Fraction(0, 1), Fraction(1, 3), -Fraction(5, 3)},
                         {Fraction(0, 1), Fraction(0, 1), Fraction(1, 1)}});
-  LLL = makeFracMatrix(3, 3,
-                       {{Fraction(0, 1), Fraction(1, 3), Fraction(1, 3)},
-                        {Fraction(0, 1), Fraction(1, 3), -Fraction(2, 3)},
-                        {Fraction(1, 1), Fraction(0, 1), Fraction(0, 1)}});
-
   mat.LLL(Fraction(3, 4));
 
-  EXPECT_EQ_FRAC_MATRIX(mat, LLL);
-}
+  gsOrth = mat.gramSchmidt();
+
+  // Size-reduced check.
+  for (unsigned i = 0; i < 3; i++)
+    for (unsigned j = 0; j < i; j++) {
+      Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(j)) /
+                    dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
+      ASSERT_TRUE(abs(mu) <= Fraction(1, 2));
+    }
+
+  // Lovasz condition check.
+  for (unsigned i = 1; i < 3; i++) {
+    Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(i - 1)) /
+                  dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1));
+    ASSERT_TRUE(dotProduct(mat.getRow(i), mat.getRow(i)) >
+                (Fraction(3, 4) - mu * mu) *
+                    dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1)));
+  }
+}
\ No newline at end of file

>From baf2ec0e7f16805c701348690cf00cab99b4b365 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sat, 16 Dec 2023 20:16:13 +0530
Subject: [PATCH 13/14] Remove superfluous check

---
 mlir/lib/Analysis/Presburger/Matrix.cpp | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 1ccb4969345b2d..3561007ea4840d 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -622,12 +622,10 @@ void FracMatrix::LLL(Fraction delta) {
       // Compute the Gram-Schmidt coefficient μ_jk.
       mu = dotProduct(getRow(k), gsOrth.getRow(j)) /
            dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
-      if (abs(mu) > Fraction(1, 2)) {
-        nearest = floor(mu + Fraction(1, 2));
-        // Subtract b_j scaled by the integer nearest to μ_jk from b_k.
-        addToRow(k, getRow(j), -Fraction(nearest, 1));
-        gsOrth = gramSchmidt(); // Update orthogonalization.
-      }
+      nearest = round(mu);
+      // Subtract b_j scaled by the integer nearest to μ_jk from b_k.
+      addToRow(k, getRow(j), -Fraction(nearest, 1));
+      gsOrth = gramSchmidt(); // Update orthogonalization.
     }
     mu = dotProduct(getRow(k), gsOrth.getRow(k - 1)) /
          dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1));

>From f2168db7a8e1b675194d3df6e6f44cfd6029201f Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Sun, 17 Dec 2023 01:15:31 +0530
Subject: [PATCH 14/14] Reformat Fraction::round()

---
 .../mlir/Analysis/Presburger/Fraction.h       |  7 +++++-
 .../Analysis/Presburger/MatrixTest.cpp        | 22 +++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index 773b9926f02397..23bbcba4f74797 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -130,7 +130,12 @@ inline Fraction operator-(const Fraction &x, const Fraction &y) {
 }
 
 // Find the integer nearest to a given fraction.
-inline MPInt round(const Fraction &f) { return floor(f + Fraction(1, 2)); }
+inline MPInt round(const Fraction &f) {
+  MPInt rem = f.num % f.den;
+  if (rem < Fraction(f.den, 2))
+    return (f.num - rem) / f.den;
+  return (f.num + f.den - rem) / f.den;
+}
 
 inline Fraction &operator+=(Fraction &x, const Fraction &y) {
   x = x + y;
diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 29857962f120b7..fb6a1a3779cdd1 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -380,6 +380,28 @@ TEST(MatrixTest, gramSchmidt) {
 }
 
 TEST(MatrixTest, LLL) {
+
+  //void checkReducedBasis(FracMatrix mat, Fraction delta) {
+  //FracMatrix gsOrth = mat.gramSchmidt();
+
+  //// Size-reduced check.
+  //for (unsigned i = 0; i < 3; i++)
+  //  for (unsigned j = 0; j < i; j++) {
+  //    Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(j)) /
+  //                  dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
+  //    ASSERT_TRUE(abs(mu) <= Fraction(1, 2));
+  //  }
+
+  //// Lovasz condition check.
+  //for (unsigned i = 1; i < 3; i++) {
+  //  Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(i - 1)) /
+  //                dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1));
+  //  ASSERT_TRUE(dotProduct(mat.getRow(i), mat.getRow(i)) >
+  //              (delta - mu * mu) *
+  //                  dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1)));
+  //}
+  //}
+
   FracMatrix mat =
       makeFracMatrix(3, 3,
                      {{Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)},



More information about the Mlir-commits mailing list