[Mlir-commits] [mlir] 4078b11 - [MLIR][Affine] Fix fusion crash for non-int/fp memref elt types (#126829)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 12 15:57:51 PST 2025


Author: Uday Bondhugula
Date: 2025-02-13T05:27:48+05:30
New Revision: 4078b11daa5b4902f59fa79c1647a20532b16c55

URL: https://github.com/llvm/llvm-project/commit/4078b11daa5b4902f59fa79c1647a20532b16c55
DIFF: https://github.com/llvm/llvm-project/commit/4078b11daa5b4902f59fa79c1647a20532b16c55.diff

LOG: [MLIR][Affine] Fix fusion crash for non-int/fp memref elt types (#126829)

Fix assumption on memref elt types being int or float during private
memref creation in affine fusion.

Fixes: https://github.com/llvm/llvm-project/issues/121020

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
    mlir/test/Dialect/Affine/loop-fusion-4.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index b38dd8effe669..7763831141c6b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -759,6 +759,9 @@ struct GreedyFusion {
                               const DenseSet<Value> &srcEscapingMemRefs,
                               unsigned producerId, unsigned consumerId,
                               bool removeSrcNode) {
+    // We can't generate private memrefs if their size can't be computed.
+    if (!getMemRefIntOrFloatEltSizeInBytes(cast<MemRefType>(memref.getType())))
+      return false;
     const Node *consumerNode = mdg->getNode(consumerId);
     // If `memref` is an escaping one, do not create a private memref
     // for the below scenarios, since doing so will leave the escaping

diff  --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 2830235431c76..07d2d06f1451d 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
 
@@ -345,3 +346,37 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
   // PRODUCER-CONSUMER-NEXT: }
   return
 }
+
+#map = affine_map<()[s0] -> (s0 + 5)>
+#map1 = affine_map<()[s0] -> (s0 + 17)>
+
+// Test with non-int/float memref types.
+
+// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @memref_index_type
+func.func @memref_index_type() {
+  %0 = llvm.mlir.constant(2 : index) : i64
+  %2 = llvm.mlir.constant(0 : index) : i64
+  %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x18xf32>
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<3xf32>
+  %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<3xindex>
+  affine.for %arg3 = 0 to 3 {
+    %4 = affine.load %alloc_2[%arg3] : memref<3xindex>
+    %5 = builtin.unrealized_conversion_cast %4 : index to i64
+    %6 = llvm.sub %0, %5 : i64
+    %7 = builtin.unrealized_conversion_cast %6 : i64 to index
+    affine.store %7, %alloc_2[%arg3] : memref<3xindex>
+  }
+  affine.for %arg3 = 0 to 3 {
+    %4 = affine.load %alloc_2[%arg3] : memref<3xindex>
+    %5 = affine.apply #map()[%4]
+    %6 = affine.apply #map1()[%3]
+    %7 = memref.load %alloc[%5, %6] : memref<8x18xf32>
+    affine.store %7, %alloc_1[%arg3] : memref<3xf32>
+  }
+  // Expect fusion.
+  // PRODUCER-CONSUMER-MAXIMAL: affine.for
+  // PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for
+  // PRODUCER-CONSUMER-MAXIMAL: return
+  return
+}


        


More information about the Mlir-commits mailing list