[Mlir-commits] [mlir] 332f0b3 - Affine expr simplification for add of const multiple of same expression

Uday Bondhugula llvmlistbot at llvm.org
Mon Mar 16 19:52:49 PDT 2020


Author: Uday Bondhugula
Date: 2020-03-17T08:22:17+05:30
New Revision: 332f0b3cd4848a6c8aeaa663e0cd536b10aefc98

URL: https://github.com/llvm/llvm-project/commit/332f0b3cd4848a6c8aeaa663e0cd536b10aefc98
DIFF: https://github.com/llvm/llvm-project/commit/332f0b3cd4848a6c8aeaa663e0cd536b10aefc98.diff

LOG: Affine expr simplification for add of const multiple of same expression

- Detect "c_1 * expr + c_2 * expr" as (c_1 + c_2) * expr
- subsumes things like 'expr - expr' and "expr * -1 + expr" as 0.
- change AffineConstantExpr ctor to allow default null init

Signed-off-by: Uday Bondhugula <uday at polymagelabs.com>

Differential Revision: https://reviews.llvm.org/D76233

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineExpr.h
    mlir/lib/IR/AffineExpr.cpp
    mlir/test/Dialect/AffineOps/canonicalize.mlir
    mlir/test/IR/affine-map.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 21318114d4ca..5d3e86bc9ba1 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -187,7 +187,7 @@ class AffineSymbolExpr : public AffineExpr {
 class AffineConstantExpr : public AffineExpr {
 public:
   using ImplType = detail::AffineConstantExprStorage;
-  /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr);
+  /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
   int64_t getValue() const;
 };
 

diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 921538b4edc3..295b5155c29b 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -314,6 +314,39 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
   }
 
+  // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
+  // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
+  // respective multiplicands.
+  Optional<int64_t> rLhsConst, rRhsConst;
+  AffineExpr firstExpr, secondExpr;
+  AffineConstantExpr rLhsConstExpr;
+  auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
+      (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
+    rLhsConst = rLhsConstExpr.getValue();
+    firstExpr = lBinOpExpr.getLHS();
+  } else {
+    rLhsConst = 1;
+    firstExpr = lhs;
+  }
+
+  auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
+  AffineConstantExpr rRhsConstExpr;
+  if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
+      (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
+    rRhsConst = rRhsConstExpr.getValue();
+    secondExpr = rBinOpExpr.getLHS();
+  } else {
+    rRhsConst = 1;
+    secondExpr = rhs;
+  }
+
+  if (rLhsConst && rRhsConst && firstExpr == secondExpr)
+    return getAffineBinaryOpExpr(
+        AffineExprKind::Mul, firstExpr,
+        getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
+                              lhs.getContext()));
+
   // When doing successive additions, bring constant to the right: turn (d0 + 2)
   // + d1 into (d0 + d1) + 2.
   if (lBin && lBin.getKind() == AffineExprKind::Add) {
@@ -327,7 +360,6 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
   // general a more compact and readable form.
 
   // Process '(expr floordiv c) * (-c)'.
-  AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
   if (!rBinOpExpr)
     return nullptr;
 

diff  --git a/mlir/test/Dialect/AffineOps/canonicalize.mlir b/mlir/test/Dialect/AffineOps/canonicalize.mlir
index 26220925bfb3..ede438eb0bd4 100644
--- a/mlir/test/Dialect/AffineOps/canonicalize.mlir
+++ b/mlir/test/Dialect/AffineOps/canonicalize.mlir
@@ -448,7 +448,7 @@ func @canonicalize_affine_if(%M : index, %N : index) {
 // -----
 
 // CHECK-DAG: [[LBMAP:#map[0-9]+]] = affine_map<()[s0] -> (0, s0)>
-// CHECK-DAG: [[UBMAP:#map[0-9]+]] = affine_map<()[s0] -> (1024, s0 + s0)>
+// CHECK-DAG: [[UBMAP:#map[0-9]+]] = affine_map<()[s0] -> (1024, s0 * 2)>
 
 // CHECK-LABEL: func @canonicalize_bounds
 // CHECK-SAME: [[M:%.*]]: index,

diff  --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index 9ce4d8c0bfc7..d4c80049b742 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -33,7 +33,7 @@
 // The following reduction should be unique'd out too but such expression
 // simplification is not performed for IR parsing, but only through analyses
 // and transforms.
-// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, d1 + d1 + d1 + d1 + 2)>
+// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, d1 * 4 + 2)>
 #map3l = affine_map<(i, j) -> ((j - i) + 2*(i - j + 1) + j - 1 + 0, j + j + 1 + j + j + 1)>
 
 // CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 + 2, d1)>
@@ -183,6 +183,12 @@
 // CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0, d0 * 2 + d1 * 4 + 2, 1, 2, (d0 * 4) mod 8)>
 #map56 = affine_map<(d0, d1) -> ((4*d0 + 2) floordiv 4, (4*d0 + 8*d1 + 5) floordiv 2, (2*d0 + 4*d1 + 3) mod 2, (3*d0 - 4) mod 3, (4*d0 + 8*d1) mod 8)>
 
+// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d1, d0, 0)>
+#map57 = affine_map<(d0, d1) -> (d0 - d0 + d1, -d0 + d0 + d0, (1 + d0 + d1 floordiv 4) - (d0 + d1 floordiv 4 + 1))>
+
+// CHECK: #map{{[0-9]+}} = affine_map<(d0, d1) -> (d0 * 3, (d0 + d1) * 2, d0 mod 2)>
+#map58 = affine_map<(d0, d1) -> (4*d0 - 2*d0 + d0, (d0 + d1) + (d0 + d1), 2 * (d0 mod 2) - d0 mod 2)>
+
 // Single identity maps are removed.
 // CHECK: func @f0(memref<2x4xi8, 1>)
 func @f0(memref<2x4xi8, #map0, 1>)
@@ -361,3 +367,9 @@ func @f54(memref<10xi32, #map54>)
 
 // CHECK: func @f56(memref<1x1xi8, #map{{[0-9]+}}>)
 func @f56(memref<1x1xi8, #map56>)
+
+// CHECK: "f57"() {map = #map{{[0-9]+}}} : () -> ()
+"f57"() {map = #map57} : () -> ()
+
+// CHECK: "f58"() {map = #map{{[0-9]+}}} : () -> ()
+"f58"() {map = #map58} : () -> ()


        


More information about the Mlir-commits mailing list