[Mlir-commits] [mlir] 5851671 - [MLIR] Constant fold multiplies in deriveStaticUpperBound.

Tres Popp llvmlistbot at llvm.org
Tue Apr 14 04:04:39 PDT 2020


Author: Tres Popp
Date: 2020-04-14T13:04:16+02:00
New Revision: 58516718fc66de01f1221c7d30b06c297296a264

URL: https://github.com/llvm/llvm-project/commit/58516718fc66de01f1221c7d30b06c297296a264
DIFF: https://github.com/llvm/llvm-project/commit/58516718fc66de01f1221c7d30b06c297296a264.diff

LOG: [MLIR] Constant fold multiplies in deriveStaticUpperBound.

Summary:
This operation occurs during collapseParallelLoops, so we constant fold
them also to allow more situations of determining a loop invariant upper
bound when lowering to the GPU dialect from the Loop dialect.

Differential Revision: https://reviews.llvm.org/D77723

Added: 
    

Modified: 
    mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
    mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
index b08cfb696953..04d5381009cc 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -506,16 +506,36 @@ struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
 /// `upperBound`.
 static Value deriveStaticUpperBound(Value upperBound,
                                     PatternRewriter &rewriter) {
-  if (AffineMinOp minOp =
-          dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
+  if (auto op = dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp())) {
+    return op;
+  }
+
+  if (auto minOp = dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
     for (const AffineExpr &result : minOp.map().getResults()) {
-      if (AffineConstantExpr constExpr =
-              result.dyn_cast<AffineConstantExpr>()) {
+      if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) {
         return rewriter.create<ConstantIndexOp>(minOp.getLoc(),
                                                 constExpr.getValue());
       }
     }
   }
+
+  if (auto multiplyOp = dyn_cast_or_null<MulIOp>(upperBound.getDefiningOp())) {
+    if (auto lhs = dyn_cast_or_null<ConstantIndexOp>(
+            deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter)
+                .getDefiningOp()))
+      if (auto rhs = dyn_cast_or_null<ConstantIndexOp>(
+              deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter)
+                  .getDefiningOp())) {
+        // Assumptions about the upper bound of minimum computations no longer
+        // work if multiplied by a negative value, so abort in this case.
+        if (lhs.getValue() < 0 || rhs.getValue() < 0)
+          return {};
+
+        return rewriter.create<ConstantIndexOp>(
+            multiplyOp.getLoc(), lhs.getValue() * rhs.getValue());
+      }
+  }
+
   return {};
 }
 

diff  --git a/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir b/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir
index ab195936d83a..ca5fb36091d7 100644
--- a/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir
+++ b/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir
@@ -213,9 +213,10 @@ module {
     loop.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c2, %c3) {
       %2 = dim %arg0, 0 : memref<?x?xf32, #map0>
       %3 = affine.min #map1(%arg3)[%2]
+      %squared_min = muli %3, %3 : index
       %4 = dim %arg0, 1 : memref<?x?xf32, #map0>
       %5 = affine.min #map2(%arg4)[%4]
-      %6 = std.subview %arg0[%arg3, %arg4][%3, %5][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
+      %6 = std.subview %arg0[%arg3, %arg4][%squared_min, %5][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
       %7 = dim %arg1, 0 : memref<?x?xf32, #map0>
       %8 = affine.min #map1(%arg3)[%7]
       %9 = dim %arg1, 1 : memref<?x?xf32, #map0>
@@ -226,7 +227,7 @@ module {
       %14 = dim %arg2, 1 : memref<?x?xf32, #map0>
       %15 = affine.min #map2(%arg4)[%14]
       %16 = std.subview %arg2[%arg3, %arg4][%13, %15][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
-      loop.parallel (%arg5, %arg6) = (%c0, %c0) to (%3, %5) step (%c1, %c1) {
+      loop.parallel (%arg5, %arg6) = (%c0, %c0) to (%squared_min, %5) step (%c1, %c1) {
         %17 = load %6[%arg5, %arg6] : memref<?x?xf32, #map3>
         %18 = load %11[%arg5, %arg6] : memref<?x?xf32, #map3>
         %19 = load %16[%arg5, %arg6] : memref<?x?xf32, #map3>
@@ -259,7 +260,7 @@ module {
 // CHECK:           [[VAL_9:%.*]] = constant 1 : index
 // CHECK:           [[VAL_10:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_7]], [[VAL_4]], [[VAL_6]]]
 // CHECK:           [[VAL_11:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_8]], [[VAL_4]], [[VAL_5]]]
-// CHECK:           [[VAL_12:%.*]] = constant 2 : index
+// CHECK:           [[VAL_12:%.*]] = constant 4 : index
 // CHECK:           [[VAL_13:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_12]], [[VAL_4]], [[VAL_3]]]
 // CHECK:           [[VAL_14:%.*]] = constant 3 : index
 // CHECK:           [[VAL_15:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_14]], [[VAL_4]], [[VAL_3]]]
@@ -268,9 +269,10 @@ module {
 // CHECK:             [[VAL_29:%.*]] = affine.apply #[[MAP2]]([[VAL_17]]){{\[}}[[VAL_5]], [[VAL_4]]]
 // CHECK:             [[VAL_30:%.*]] = dim [[VAL_0]], 0 : memref<?x?xf32, #[[MAP0]]>
 // CHECK:             [[VAL_31:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_30]]]
+// CHECK:             [[VAL_31_SQUARED:%.*]] = muli [[VAL_31]], [[VAL_31]] : index
 // CHECK:             [[VAL_32:%.*]] = dim [[VAL_0]], 1 : memref<?x?xf32, #[[MAP0]]>
 // CHECK:             [[VAL_33:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_32]]]
-// CHECK:             [[VAL_34:%.*]] = subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31]], [[VAL_33]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
+// CHECK:             [[VAL_34:%.*]] = subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31_SQUARED]], [[VAL_33]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
 // CHECK:             [[VAL_35:%.*]] = dim [[VAL_1]], 0 : memref<?x?xf32, #[[MAP0]]>
 // CHECK:             [[VAL_36:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_35]]]
 // CHECK:             [[VAL_37:%.*]] = dim [[VAL_1]], 1 : memref<?x?xf32, #[[MAP0]]>
@@ -282,7 +284,7 @@ module {
 // CHECK:             [[VAL_43:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_42]]]
 // CHECK:             [[VAL_44:%.*]] = subview [[VAL_2]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_41]], [[VAL_43]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
 // CHECK:             [[VAL_45:%.*]] = affine.apply #[[MAP2]]([[VAL_22]]){{\[}}[[VAL_3]], [[VAL_4]]]
-// CHECK:             [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31]] : index
+// CHECK:             [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31_SQUARED]] : index
 // CHECK:             loop.if [[VAL_46]] {
 // CHECK:               [[VAL_47:%.*]] = affine.apply #[[MAP2]]([[VAL_23]]){{\[}}[[VAL_3]], [[VAL_4]]]
 // CHECK:               [[VAL_48:%.*]] = cmpi "slt", [[VAL_47]], [[VAL_33]] : index


        


More information about the Mlir-commits mailing list