[Mlir-commits] [mlir] 3f914d8 - [mlir][bufferize] Better error handling: Fail if ToMemrefOps are found

Matthias Springer llvmlistbot at llvm.org
Thu Aug 18 02:38:10 PDT 2022


Author: Matthias Springer
Date: 2022-08-18T11:37:57+02:00
New Revision: 3f914d84c3849c60068f8182b02c7ad06ab21e72

URL: https://github.com/llvm/llvm-project/commit/3f914d84c3849c60068f8182b02c7ad06ab21e72
DIFF: https://github.com/llvm/llvm-project/commit/3f914d84c3849c60068f8182b02c7ad06ab21e72.diff

LOG: [mlir][bufferize] Better error handling: Fail if ToMemrefOps are found

bufferization.to_memref ops are not supported in One-Shot Analysis. They often trigger a failed assertion that can be confusing. Instead, scan for to_memref ops before running the analysis and immediately abort with a proper error message.

Differential Revision: https://reviews.llvm.org/D132027

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 72e9600e44f47..375330c65351e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -832,27 +832,37 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
                           AnalysisState &state,
                           const BufferizationAliasInfo &aliasInfo) {
   const BufferizationOptions &options = state.getOptions();
-  Operation *inconsistentOp = nullptr;
-  WalkResult walkResult = op->walk([&](Operation *op) {
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
-      for (OpOperand &opOperand : op->getOpOperands())
-        if (opOperand.get().getType().isa<TensorType>()) {
-          if (wouldCreateReadAfterWriteInterference(
-                  opOperand, domInfo, state, aliasInfo,
-                  /*checkConsistencyOnly=*/true)) {
-            // This error can happen if certain "mustBufferizeInPlace" interface
-            // methods are implemented incorrectly, such that the IR already has
-            // a RaW conflict before making any bufferization decisions.
-            inconsistentOp = op;
-            return WalkResult::interrupt();
-          }
+
+  WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
+    // Skip ops that are not in the filter.
+    if (!options.isOpAllowed(op.getOperation()))
+      return WalkResult::advance();
+
+    // Input IR may not contain any ToMemrefOps. These are not supported because
+    // the analysis cannot follow the data flow through memrefs.
+    if (isa<ToMemrefOp>(op.getOperation())) {
+      op->emitError("to_memref ops not supported during One-Shot Analysis");
+      return WalkResult::interrupt();
+    }
+
+    for (OpOperand &opOperand : op->getOpOperands()) {
+      if (opOperand.get().getType().isa<TensorType>()) {
+        if (wouldCreateReadAfterWriteInterference(
+                opOperand, domInfo, state, aliasInfo,
+                /*checkConsistencyOnly=*/true)) {
+          // This error can happen if certain "mustBufferizeInPlace" interface
+          // methods are implemented incorrectly, such that the IR already has
+          // a RaW conflict before making any bufferization decisions.
+          op->emitError("input IR has RaW conflict");
+          return WalkResult::interrupt();
         }
+      }
+    }
+
     return WalkResult::advance();
   });
 
-  if (walkResult.wasInterrupted())
-    return inconsistentOp->emitError("input IR has RaW conflict");
-  return success();
+  return success(!walkResult.wasInterrupted());
 }
 
 /// Annotate the IR with the result of the analysis. For testing/debugging only.

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 6794c9ac0067a..4a637a64e2ad0 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1074,28 +1074,6 @@ func.func @to_tensor_op_not_writable(%m: memref<?xf32>, %v:  vector<5xf32>,
 
 // -----
 
-// CHECK-LABEL: func @to_memref_op_is_reading
-func.func @to_memref_op_is_reading(%t1: tensor<?xf32> {bufferization.writable = true},
-                                   %idx1: index, %idx2: index, %idx3: index,
-                                   %v1: vector<5xf32>)
-    -> (vector<5xf32>, vector<5xf32>) {
-  // Write + read to/from tensor.
-  //      CHECK: vector.transfer_write
-  // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]
-  %1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor<?xf32>
-  %cst = arith.constant 0.0 : f32
-  %r1 = vector.transfer_read %1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
-
-  // Write + read to/from same memref.
-  %0 = bufferization.to_memref %t1 : memref<?xf32>
-  vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref<?xf32>
-  %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
-
-  return %r1, %r2 : vector<5xf32>, vector<5xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @inner_func
 func.func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
   //      CHECK: return

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 140f67b7c3024..27d1f52d56e1e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -249,7 +249,7 @@ func.func @to_memref_op_is_writing(
   // read further down. This will likely have to change with partial
   // bufferization.
 
-  // expected-error @+1 {{input IR has RaW conflict}}
+  // expected-error @+1 {{to_memref ops not supported during One-Shot Analysis}}
   %0 = bufferization.to_memref %t1 : memref<?xf32>
 
   // Read from both.
@@ -289,7 +289,7 @@ func.func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
 // -----
 
 func.func @destination_passing_style_dominance_test_1(%cst : f32, %idx : index,
-                                                 %idx2 : index) -> f32 {
+                                                      %idx2 : index) -> f32 {
   %0 = scf.execute_region -> tensor<?xf32> {
     %1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
     // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
@@ -303,7 +303,7 @@ func.func @destination_passing_style_dominance_test_1(%cst : f32, %idx : index,
 // -----
 
 func.func @destination_passing_style_dominance_test_2(%cst : f32, %idx : index,
-                                                 %idx2 : index) -> f32 {
+                                                      %idx2 : index) -> f32 {
   %1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
 
   %0 = scf.execute_region -> tensor<?xf32> {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir
index 37f6f749d4dff..d518702ef8fdc 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir
@@ -166,8 +166,7 @@ module {
     %du = arith.constant -1.0 : f64
 
     %c = sparse_tensor.convert %A : tensor<9x4xf64, #MAT_C_C> to tensor<9x4xf64>
-    %m = bufferization.to_memref %c : memref<9x4xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<9x4xf64>, vector<9x4xf64>
+    %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64>
     vector.print %v : vector<9x4xf64>
 
     %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C> to memref<?xf64>
@@ -182,8 +181,7 @@ module {
     %du = arith.constant -1.0 : f64
 
     %c = sparse_tensor.convert %A : tensor<9x4xf64, #MAT_C_C_P> to tensor<9x4xf64>
-    %m = bufferization.to_memref %c : memref<9x4xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<9x4xf64>, vector<9x4xf64>
+    %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64>
     vector.print %v : vector<9x4xf64>
 
     %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C_P> to memref<?xf64>
@@ -197,8 +195,7 @@ module {
     %c0 = arith.constant 0 : index
     %du = arith.constant -1.0 : f64
 
-    %m = bufferization.to_memref %A : memref<9x4xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<9x4xf64>, vector<9x4xf64>
+    %v = vector.transfer_read %A[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64>
     vector.print %v : vector<9x4xf64>
 
     return
@@ -209,8 +206,7 @@ module {
     %du = arith.constant -1.0 : f64
 
     %c = sparse_tensor.convert %A : tensor<4x9xf64, #MAT_C_C> to tensor<4x9xf64>
-    %m = bufferization.to_memref %c : memref<4x9xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x9xf64>, vector<4x9xf64>
+    %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64>
     vector.print %v : vector<4x9xf64>
 
     %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C> to memref<?xf64>
@@ -225,8 +221,7 @@ module {
     %du = arith.constant -1.0 : f64
 
     %c = sparse_tensor.convert %A : tensor<?x?xf64, #MAT_C_C> to tensor<?x?xf64>
-    %m = bufferization.to_memref %c : memref<?x?xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<?x?xf64>, vector<4x9xf64>
+    %v = vector.transfer_read %c[%c0, %c0], %du: tensor<?x?xf64>, vector<4x9xf64>
     vector.print %v : vector<4x9xf64>
 
     %1 = sparse_tensor.values %A : tensor<?x?xf64, #MAT_C_C> to memref<?xf64>
@@ -241,8 +236,7 @@ module {
     %du = arith.constant -1.0 : f64
 
     %c = sparse_tensor.convert %A : tensor<4x9xf64, #MAT_C_C_P> to tensor<4x9xf64>
-    %m = bufferization.to_memref %c : memref<4x9xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x9xf64>, vector<4x9xf64>
+    %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64>
     vector.print %v : vector<4x9xf64>
 
     %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C_P> to memref<?xf64>
@@ -256,8 +250,7 @@ module {
     %c0 = arith.constant 0 : index
     %du = arith.constant -1.0 : f64
 
-    %m = bufferization.to_memref %A : memref<4x9xf64>
-    %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x9xf64>, vector<4x9xf64>
+    %v = vector.transfer_read %A[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64>
     vector.print %v : vector<4x9xf64>
 
     return


        


More information about the Mlir-commits mailing list