# [Mlir-commits] [mlir] 121aab8 - [MLIR][Affine] Simplify nested modulo operations when able

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Sep 17 12:06:05 PDT 2021

```Author: Krzysztof Drewniak
Date: 2021-09-17T19:06:00Z
New Revision: 121aab84d16f659cea539becff2cc2fef82ec152

URL: https://github.com/llvm/llvm-project/commit/121aab84d16f659cea539becff2cc2fef82ec152
DIFF: https://github.com/llvm/llvm-project/commit/121aab84d16f659cea539becff2cc2fef82ec152.diff

LOG: [MLIR][Affine] Simplify nested modulo operations when able

It is the case that, for all positive a and b such that b divides a
(e mod (a * b)) mod b = e mod b. For example, ((d0 mod 35) mod 5) can
be simplified to (d0 mod 5), but ((d0 mod 35) mod 4) cannot be simplified
further (x = 36 is a counterexample).

This change enables more complex simplifications. For example,
((d0 * 72 + d1) mod 144) mod 9 can now simplify to (d0 * 72 + d1) mod 9
and thus to d1 mod 9. Expressions with chained modulus operators are
reasonably common in tensor applications, and this change _should_
improve code generation for such expressions.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D109930

Modified:
mlir/lib/IR/AffineExpr.cpp
mlir/test/IR/affine-map.mlir
mlir/test/Transforms/loop-fusion-2.mlir
mlir/test/Transforms/loop-fusion.mlir

Removed:

################################################################################
diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0364b3b2b96ed..2e13d6ae62c47 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -829,6 +829,15 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
return lBin.getLHS() % rhsConst.getValue();
}

+  // Simplify (e % a) % b to e % b when b evenly divides a
+  if (lBin && lBin.getKind() == AffineExprKind::Mod) {
+    auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
+    if (intermediate && intermediate.getValue() >= 1 &&
+        mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
+      return lBin.getLHS() % rhsConst.getValue();
+    }
+  }
+
return nullptr;
}

diff  --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index 3e3e2c3fe6f6f..414741dab38fe 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -189,6 +189,9 @@
// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 * 3, (d0 + d1) * 2, d0 mod 2)>
#map58 = affine_map<(d0, d1) -> (4*d0 - 2*d0 + d0, (d0 + d1) + (d0 + d1), 2 * (d0 mod 2) - d0 mod 2)>

+// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 mod 5, (d1 mod 35) mod 4)>
+#map59 = affine_map<(d0, d1) -> ((d0 mod 35) mod 5, (d1 mod 35) mod 4)>
+
// Single identity maps are removed.
// CHECK: @f0(memref<2x4xi8, 1>)
func private @f0(memref<2x4xi8, #map0, 1>)
@@ -373,3 +376,6 @@ func private @f56(memref<1x1xi8, #map56>)

// CHECK: "f58"() {map = #map{{[0-9]+}}} : () -> ()
"f58"() {map = #map58} : () -> ()
+
+// CHECK: "f59"() {map = #map{{[0-9]+}}} : () -> ()
+"f59"() {map = #map59} : () -> ()

diff  --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir
index c214e44296df8..ccd701b3dc986 100644
--- a/mlir/test/Transforms/loop-fusion-2.mlir
+++ b/mlir/test/Transforms/loop-fusion-2.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL

-// Part I of fusion tests in  mlir/test/Transforms/loop-fusion.mlir.
+// Part I of fusion tests in  mlir/test/Transforms/loop-fusion.mlir.
// Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir
// Part IV of fusion tests in mlir/test/Transforms/loop-fusion-4.mlir

@@ -576,9 +576,9 @@ func @fuse_across_varying_dims_complex(%arg0: f32) {
}
// MAXIMAL-DAG: [[\$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 72 + d1) floordiv 2304)>
// MAXIMAL-DAG: [[\$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 72 + d1) mod 2304) floordiv 1152)>
-// MAXIMAL-DAG: [[\$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) floordiv 9) floordiv 8)>
-// MAXIMAL-DAG: [[\$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) floordiv 3)>
-// MAXIMAL-DAG: [[\$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) mod 3)>
+// MAXIMAL-DAG: [[\$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 72 + d1) mod 1152) floordiv 9) floordiv 8)>
+// MAXIMAL-DAG: [[\$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> ((d1 mod 9) floordiv 3)>
+// MAXIMAL-DAG: [[\$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (d1 mod 3)>
// MAXIMAL-DAG: [[\$MAP7:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
// MAXIMAL-DAG: [[\$MAP8:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 16 - d1 + 15)>
// MAXIMAL-LABEL: func @fuse_across_varying_dims_complex

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 3086c682a8c89..5bef80ef07ba6 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s

-// Part II of fusion tests in  mlir/test/Transforms/loop-fusion=2.mlir.
+// Part II of fusion tests in  mlir/test/Transforms/loop-fusion=2.mlir.
// Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir
// Part IV of fusion tests in mlir/test/Transforms/loop-fusion-4.mlir

@@ -737,15 +737,15 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
//
// CHECK-DAG: [[\$MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) floordiv 288)>
// CHECK-DAG: [[\$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 288) floordiv 144)>
-// CHECK-DAG: [[\$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> ((((d0 * 9 + d1) mod 288) mod 144) floordiv 48)>
-// CHECK-DAG: [[\$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) floordiv 16)>
-// CHECK-DAG: [[\$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) mod 16)>
+// CHECK-DAG: [[\$MAP2:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 144) floordiv 48)>
+// CHECK-DAG: [[\$MAP3:#map[0-9]+]] = affine_map<(d0, d1) -> (((d0 * 9 + d1) mod 48) floordiv 16)>
+// CHECK-DAG: [[\$MAP4:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 * 9 + d1) mod 16)>
// CHECK-DAG: [[\$MAP11:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 9 + d1)>
// CHECK-DAG: [[\$MAP12:#map[0-9]+]] = affine_map<(d0) -> (d0 floordiv 288)>
// CHECK-DAG: [[\$MAP13:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 288) floordiv 144)>
-// CHECK-DAG: [[\$MAP14:#map[0-9]+]] = affine_map<(d0) -> (((d0 mod 288) mod 144) floordiv 48)>
-// CHECK-DAG: [[\$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) floordiv 16)>
-// CHECK-DAG: [[\$MAP16:#map[0-9]+]] = affine_map<(d0) -> ((((d0 mod 288) mod 144) mod 48) mod 16)>
+// CHECK-DAG: [[\$MAP14:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 144) floordiv 48)>
+// CHECK-DAG: [[\$MAP15:#map[0-9]+]] = affine_map<(d0) -> ((d0 mod 48) floordiv 16)>
+// CHECK-DAG: [[\$MAP16:#map[0-9]+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: [[\$MAP17:#map[0-9]+]] = affine_map<(d0) -> (0)>

//
@@ -761,7 +761,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
// CHECK-NEXT:      affine.apply [[\$MAP3]](%{{.*}}, %{{.*}})
// CHECK-NEXT:      affine.apply [[\$MAP4]](%{{.*}}, %{{.*}})
// CHECK-NEXT:      "foo"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (index, index, index, index, index, index) -> i32
-// CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32>
+// CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT:      affine.apply [[\$MAP11]](%{{.*}}, %{{.*}})
// CHECK-NEXT:      affine.apply [[\$MAP12]](%{{.*}})
// CHECK-NEXT:      affine.apply [[\$MAP13]](%{{.*}})
@@ -769,7 +769,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
// CHECK-NEXT:      affine.apply [[\$MAP15]](%{{.*}})
// CHECK-NEXT:      affine.apply [[\$MAP16]](%{{.*}})
// CHECK-NEXT:      affine.apply [[\$MAP17]](%{{.*}})
-// CHECK-NEXT:      affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, (((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) floordiv 48, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) floordiv 16, ((((%{{.*}} * 9 + %{{.*}}) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32>
+// CHECK-NEXT:      affine.load %{{.*}}[0, ((%{{.*}} * 9 + %{{.*}}) mod 288) floordiv 144, ((%{{.*}} * 9 + %{{.*}}) mod 144) floordiv 48, ((%{{.*}} * 9 + %{{.*}}) mod 48) floordiv 16, (%{{.*}} * 9 + %{{.*}}) mod 16, 0] : memref<1x2x3x3x16x1xi32>
// CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xi32>
// CHECK-NEXT:      affine.load %{{.*}}[0, 0] : memref<1x1xi32>
// CHECK-NEXT:      muli %{{.*}}, %{{.*}} : i32

```