[Mlir-commits] [mlir] fc253e6 - Fixed bug in buffer deallocation pass using unranked memref types.

Julian Gross llvmlistbot at llvm.org
Mon May 10 01:51:06 PDT 2021


Author: Julian Gross
Date: 2021-05-10T10:50:29+02:00
New Revision: fc253e69f9b988e8b2d4c940946146696b2acf5a

URL: https://github.com/llvm/llvm-project/commit/fc253e69f9b988e8b2d4c940946146696b2acf5a
DIFF: https://github.com/llvm/llvm-project/commit/fc253e69f9b988e8b2d4c940946146696b2acf5a.diff

LOG: Fixed bug in buffer deallocation pass using unranked memref types.

In the buffer deallocation pass, unranked memref types are not properly supported.
After investigating this issue, it turns out that the Clone and Dealloc operation
does not support unranked memref types in the current implementation.
This patch adds the missing feature and enables the transformation of any memref
type.

This patch solves this bug: https://bugs.llvm.org/show_bug.cgi?id=48385

Differential Revision: https://reviews.llvm.org/D101760

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Conversion/StandardToSPIRV/alloc.mlir
    mlir/test/Dialect/MemRef/ops.mlir
    mlir/test/Transforms/buffer-deallocation.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 7b341b1940cf9..74afcd09d1a9b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -342,8 +342,8 @@ def CloneOp : MemRef_Op<"clone", [
     undefined behavior.
   }];
 
-  let arguments = (ins Arg<AnyMemRef, "", []>:$input);
-  let results = (outs Arg<AnyMemRef, "", []>:$output);
+  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", []>:$input);
+  let results = (outs Arg<AnyRankedOrUnrankedMemRef, "", []>:$output);
 
   let extraClassDeclaration = [{
     Value getSource() { return input();}
@@ -353,6 +353,7 @@ def CloneOp : MemRef_Op<"clone", [
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
 
   let hasFolder = 1;
+  let verifier = ?;
   let hasCanonicalizer = 1;
 }
 
@@ -376,9 +377,10 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
     ```
   }];
 
-  let arguments = (ins Arg<AnyMemRef, "", [MemFree]>:$memref);
+  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", [MemFree]>:$memref);
 
   let hasFolder = 1;
+  let verifier = ?;
   let assemblyFormat = "$memref attr-dict `:` type($memref)";
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 57c1b1581a232..9a502780134a6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -474,8 +474,6 @@ OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
 // CloneOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(CloneOp op) { return success(); }
-
 void CloneOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -544,12 +542,6 @@ OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
 // DeallocOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(DeallocOp op) {
-  if (!op.memref().getType().isa<MemRefType>())
-    return op.emitOpError("operand must be a memref");
-  return success();
-}
-
 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
                               SmallVectorImpl<OpFoldResult> &results) {
   /// dealloc(memrefcast) -> dealloc

diff  --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
index 2d9dcf472fcf0..2d8e84ac25138 100644
--- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
@@ -139,7 +139,7 @@ module attributes {
 {
   func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) {
     // expected-error @+2 {{unhandled deallocation type}}
-    // expected-error @+1 {{'memref.dealloc' op operand #0 must be memref of any type values}}
+    // expected-error @+1 {{'memref.dealloc' op operand #0 must be unranked.memref of any type values or memref of any type values}}
     memref.dealloc %arg0 : memref<4x?xf32, 3>
     return
   }
@@ -154,7 +154,7 @@ module attributes {
 {
   func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) {
     // expected-error @+2 {{unhandled deallocation type}}
-    // expected-error @+1 {{op operand #0 must be memref of any type values}}
+    // expected-error @+1 {{op operand #0 must be unranked.memref of any type values or memref of any type values}}
     memref.dealloc %arg0 : memref<4x5xf32>
     return
   }

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 9b6a05d7dc32d..1b5728486a367 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -60,3 +60,19 @@ func @read_global_memref() {
   %1 = memref.tensor_load %0 : memref<2xf32>
   return
 }
+
+// CHECK-LABEL: func @memref_clone
+func @memref_clone() {
+  %0 = memref.alloc() : memref<2xf32>
+  %1 = memref.cast %0 : memref<2xf32> to memref<*xf32>
+  %2 = memref.clone %1 : memref<*xf32> to memref<*xf32>
+  return
+}
+
+// CHECK-LABEL: func @memref_dealloc
+func @memref_dealloc() {
+  %0 = memref.alloc() : memref<2xf32>
+  %1 = memref.cast %0 : memref<2xf32> to memref<*xf32>
+  memref.dealloc %1 : memref<*xf32>
+  return
+}

diff  --git a/mlir/test/Transforms/buffer-deallocation.mlir b/mlir/test/Transforms/buffer-deallocation.mlir
index 35f7bbf79c8f5..77945113e1647 100644
--- a/mlir/test/Transforms/buffer-deallocation.mlir
+++ b/mlir/test/Transforms/buffer-deallocation.mlir
@@ -90,6 +90,43 @@ func @condBranchDynamicType(
 
 // -----
 
+// Test case: See above.
+
+// CHECK-LABEL: func @condBranchUnrankedType
+func @condBranchUnrankedType(
+  %arg0: i1,
+  %arg1: memref<*xf32>,
+  %arg2: memref<*xf32>,
+  %arg3: index) {
+  cond_br %arg0, ^bb1, ^bb2(%arg3: index)
+^bb1:
+  br ^bb3(%arg1 : memref<*xf32>)
+^bb2(%0: index):
+  %1 = memref.alloc(%0) : memref<?xf32>
+  %2 = memref.cast %1 : memref<?xf32> to memref<*xf32>
+  test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>)
+  br ^bb3(%2 : memref<*xf32>)
+^bb3(%3: memref<*xf32>):
+  test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>)
+  return
+}
+
+// CHECK-NEXT: cond_br
+//      CHECK: %[[ALLOC0:.*]] = memref.clone
+// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
+//      CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
+// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
+//      CHECK: test.buffer_based
+// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone
+// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
+// CHECK-NEXT: br ^bb3
+// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
+//      CHECK: test.copy(%[[ALLOC3]],
+// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
+// CHECK-NEXT: return
+
+// -----
+
 // Test Case:
 //      bb0
 //     /    \

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b0c2fe45ed681..795a5af35babe 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1879,8 +1879,8 @@ def CopyOp : TEST_Op<"copy", [CopyOpInterface]> {
   let description = [{
     Represents a copy operation.
   }];
-  let arguments = (ins Res<AnyMemRef, "", [MemRead]>:$source,
-                   Res<AnyMemRef, "", [MemWrite]>:$target);
+  let arguments = (ins Res<AnyRankedOrUnrankedMemRef, "", [MemRead]>:$source,
+                   Res<AnyRankedOrUnrankedMemRef, "", [MemWrite]>:$target);
   let assemblyFormat = [{
     `(` $source `,` $target `)` `:` `(` type($source) `,` type($target) `)`
      attr-dict
@@ -1915,7 +1915,8 @@ class BufferBasedOpBase<string mnemonic, list<OpTrait> traits>
   let description = [{
     A buffer based operation, that uses memRefs as input and output.
   }];
-  let arguments = (ins AnyMemRef:$input, AnyMemRef:$output);
+  let arguments = (ins AnyRankedOrUnrankedMemRef:$input,
+                       AnyRankedOrUnrankedMemRef:$output);
 }
 
 def BufferBasedOp : BufferBasedOpBase<"buffer_based", []>{


        


More information about the Mlir-commits mailing list