[Mlir-commits] [mlir] 6fe77b1 - [mlir][Linalg] Fail comprehensive bufferization if a memref is returned.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Sep 15 08:15:14 PDT 2021


Author: Nicolas Vasilache
Date: 2021-09-15T15:11:17Z
New Revision: 6fe77b1051cc6e01dc1ac2d47f51802ab938e076

URL: https://github.com/llvm/llvm-project/commit/6fe77b1051cc6e01dc1ac2d47f51802ab938e076
DIFF: https://github.com/llvm/llvm-project/commit/6fe77b1051cc6e01dc1ac2d47f51802ab938e076.diff

LOG: [mlir][Linalg] Fail comprehensive bufferization if a memref is returned.

Summary:

Reviewers:

Subscribers:

Differential revision: https://reviews.llvm.org/D109824

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index cf4e5f8d218c7..ecde91ff51205 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -37,7 +37,10 @@ def LinalgComprehensiveModuleBufferize :
   let options = [
     Option<"testAnalysisOnly", "test-analysis-only", "bool",
             /*default=*/"false",
-           "Only runs inplaceability analysis (for testing purposes only)">
+           "Only runs inplaceability analysis (for testing purposes only)">,
+    Option<"allowReturnMemref", "allow-return-memref", "bool",
+            /*default=*/"false",
+           "Allows the return of memrefs (for testing purposes only)">
   ];
   let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 1ca179ce8ec3c..d62d0057ad6b8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -2914,6 +2914,14 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
       signalPassFailure();
       return;
     }
+    if (!allowReturnMemref &&
+        llvm::any_of(funcOp.getType().getResults(), [](Type t) {
+          return t.isa<MemRefType, UnrankedMemRefType>();
+        })) {
+      funcOp->emitError("memref return type is unsupported");
+      signalPassFailure();
+      return;
+    }
   }
 
   // Perform a post-processing pass of layout modification at function boundary

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index cdf35c035ef14..60533c677911a 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -130,3 +130,13 @@ func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32>
   %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>)
   return %r: tensor<4xf32>
 }
+
+// -----
+
+// expected-error @+1 {{memref return type is unsupported}}
+func @mini_test_case1() -> tensor<10x20xf32> {
+  %f0 = constant 0.0 : f32
+  %t = linalg.init_tensor [10, 20] : tensor<10x20xf32>
+  %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32>
+  return %r : tensor<10x20xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index dfa2f927ffb5c..88a209a0be660 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-memref -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @transfer_read(%{{.*}}: memref<?xf32, #map>) -> vector<4xf32> {
 func @transfer_read(%A : tensor<?xf32>) -> (vector<4xf32>) {


        


More information about the Mlir-commits mailing list