[Mlir-commits] [mlir] [mlir] fix affine-loop-fusion crash (PR #76351)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 28 18:43:15 PST 2023


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/76351

>From 05dd64917d730fee21d1426c6927c38c784c0487 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Mon, 25 Dec 2023 11:47:07 +0800
Subject: [PATCH 1/2] [mlir] fix affine-loop-fusion has a crash (#76281)

fixes #76281
---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  |  5 ++-
 mlir/test/Dialect/Affine/loop-fusion.mlir     | 32 +++++++++++++++++++
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 66d921b4889f59..77f283725dc597 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -205,7 +205,10 @@ static bool isEscapingMemref(Value memref, Block *block) {
   // (e.g., call ops, alias creating ops, etc.).
   return llvm::any_of(memref.getUsers(), [&](Operation *user) {
     // Ignore users outside of `block`.
-    if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
+    auto ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
+    if (!ancestorOp)
+      return true;
+    if (ancestorOp->getBlock() != block)
       return false;
     return !isa<AffineMapAccessInterface>(*user);
   });
diff --git a/mlir/test/Dialect/Affine/loop-fusion.mlir b/mlir/test/Dialect/Affine/loop-fusion.mlir
index 8c536e631a86c9..045b1bec272e1e 100644
--- a/mlir/test/Dialect/Affine/loop-fusion.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion.mlir
@@ -1541,5 +1541,37 @@ func.func @should_fuse_and_preserve_dep_on_constant() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: @producer_consumer_with_outmost_user
+func.func @producer_consumer_with_outmost_user(%arg0 : f16) {
+  %c0 = arith.constant 0 : index
+  %src = memref.alloc() : memref<f16, 1>
+  %dst = memref.alloc() : memref<f16>
+  %tag = memref.alloc() : memref<1xi32>
+  affine.for %arg1 = 4 to 6 {
+    affine.for %arg2 = 0 to 1 {
+      %0 = arith.addf %arg0, %arg0 : f16
+      affine.store %0, %src[] : memref<f16, 1>
+    }
+    affine.for %arg3 = 0 to 1 {
+      %0 = affine.load %src[] : memref<f16, 1>
+    }
+  }
+  affine.dma_start %src[], %dst[], %tag[%c0], %c0 : memref<f16, 1>, memref<f16>, memref<1xi32>
+  // CHECK:       %[[CST_INDEX:.*]] = arith.constant 0 : index
+  // CHECK:       %[[DMA_SRC:.*]] = memref.alloc() : memref<f16, 1>
+  // CHECK:       %[[DMA_DST:.*]] = memref.alloc() : memref<f16>
+  // CHECK:       %[[DMA_TAG:.*]] = memref.alloc() : memref<1xi32>
+  // CHECK:       affine.for %arg1 = 4 to 6
+  // CHECK-NEXT:  affine.for %arg2 = 0 to 1
+  // CHECK-NEXT:  %[[RESULT_ADD:.*]] = arith.addf %arg0, %arg0 : f16
+  // CHECK-NEXT:  affine.store %[[RESULT_ADD]], %[[DMA_SRC]][] : memref<f16, 1>
+  // CHECK-NEXT:  affine.load %[[DMA_SRC]][] : memref<f16, 1>
+  // CHECK:       affine.dma_start %[[DMA_SRC]][], %[[DMA_DST]][], %[[DMA_TAG]][%[[CST_INDEX]]], %[[CST_INDEX]] : memref<f16, 1>, memref<f16>, memref<1xi32>
+  // CHECK-NEXT:  return
+  return
+}
+
 // Add further tests in mlir/test/Transforms/loop-fusion-4.mlir
 

>From 0fc71ce74a243d57c3085ce80790243fc5eb7775 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Thu, 28 Dec 2023 14:56:09 +0800
Subject: [PATCH 2/2] refine

---
 mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 77f283725dc597..bb319208f58a85 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -205,7 +205,7 @@ static bool isEscapingMemref(Value memref, Block *block) {
   // (e.g., call ops, alias creating ops, etc.).
   return llvm::any_of(memref.getUsers(), [&](Operation *user) {
     // Ignore users outside of `block`.
-    auto ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
+    Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
     if (!ancestorOp)
       return true;
     if (ancestorOp->getBlock() != block)



More information about the Mlir-commits mailing list