[Mlir-commits] [mlir] 5dc1ed3 - [mlir] Update dstNode after DenseMap insertion in loop fusion pass.
Sergei Grechanik
llvmlistbot at llvm.org
Thu May 6 15:34:21 PDT 2021
Author: Amy Zhuang
Date: 2021-05-06T15:23:59-07:00
New Revision: 5dc1ed3f627ecfad119ada84d3534cc21b80f810
URL: https://github.com/llvm/llvm-project/commit/5dc1ed3f627ecfad119ada84d3534cc21b80f810
DIFF: https://github.com/llvm/llvm-project/commit/5dc1ed3f627ecfad119ada84d3534cc21b80f810.diff
LOG: [mlir] Update dstNode after DenseMap insertion in loop fusion pass.
Reviewed By: vinayaka-polymage
Differential Revision: https://reviews.llvm.org/D101794
Added:
Modified:
mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 50adc972ac9d9..aea8d98712a37 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1645,6 +1645,10 @@ struct GreedyFusion {
// Add edge from 'newMemRef' node to dstNode.
mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
}
+ // One or more entries for 'newMemRef' alloc op are inserted into
+ // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
+ // reallocate, update dstNode.
+ dstNode = mdg->getNode(dstId);
}
// Collect dst loop stats after memref privatization transformation.
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index f3fcae25b0fa2..2cad8613b1e3f 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -3115,3 +3115,186 @@ func @no_fusion_cannot_compute_valid_slice() {
// CHECK-NEXT: affine.load
// CHECK-NEXT: mulf
// CHECK-NEXT: affine.store
+
+// -----
+
+// CHECK-LABEL: func @fuse_large_number_of_loops
+func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x10xf32, 1>, %arg2: memref<20x10xf32, 1>, %arg3: memref<20x10xf32, 1>, %arg4: memref<20x10xf32, 1>, %arg5: memref<f32, 1>, %arg6: memref<f32, 1>, %arg7: memref<f32, 1>, %arg8: memref<f32, 1>, %arg9: memref<20x10xf32, 1>, %arg10: memref<20x10xf32, 1>, %arg11: memref<20x10xf32, 1>, %arg12: memref<20x10xf32, 1>) {
+ %cst = constant 1.000000e+00 : f32
+ %0 = memref.alloc() : memref<f32, 1>
+ affine.store %cst, %0[] : memref<f32, 1>
+ %1 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg6[] : memref<f32, 1>
+ affine.store %21, %1[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %2 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg3[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %2[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %3 = memref.alloc() : memref<f32, 1>
+ %4 = affine.load %arg6[] : memref<f32, 1>
+ %5 = affine.load %0[] : memref<f32, 1>
+ %6 = subf %5, %4 : f32
+ affine.store %6, %3[] : memref<f32, 1>
+ %7 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %3[] : memref<f32, 1>
+ affine.store %21, %7[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %8 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %7[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %8[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %9 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %9[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %9[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %2[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %arg11[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %10 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg2[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %10[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %10[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %11 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %11[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %12 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %11[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg11[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = subf %22, %21 : f32
+ affine.store %23, %12[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %13 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg7[] : memref<f32, 1>
+ affine.store %21, %13[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %14 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg4[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %13[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %14[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %15 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg8[] : memref<f32, 1>
+ affine.store %21, %15[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %16 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %15[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %12[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %16[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %17 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %16[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = math.sqrt %21 : f32
+ affine.store %22, %17[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %18 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg5[] : memref<f32, 1>
+ affine.store %21, %18[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %19 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %18[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %19[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %20 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %17[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %19[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = divf %22, %21 : f32
+ affine.store %23, %20[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %20[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %14[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %arg12[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg12[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg0[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = subf %22, %21 : f32
+ affine.store %23, %arg9[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ return
+}
+// CHECK: affine.for
+// CHECK: affine.for
+// CHECK-NOT: affine.for
More information about the Mlir-commits
mailing list