[Mlir-commits] [mlir] ebf8d74 - [mlir][linalg][bufferize] Fix bufferize bug where non-tensor ops are not skipped

Matthias Springer llvmlistbot at llvm.org
Wed Nov 17 23:26:45 PST 2021


Author: Matthias Springer
Date: 2021-11-18T16:20:22+09:00
New Revision: ebf8d74e929d908829eda4ad8548ec21e2dbc6ae

URL: https://github.com/llvm/llvm-project/commit/ebf8d74e929d908829eda4ad8548ec21e2dbc6ae
DIFF: https://github.com/llvm/llvm-project/commit/ebf8d74e929d908829eda4ad8548ec21e2dbc6ae.diff

LOG: [mlir][linalg][bufferize] Fix bufferize bug where non-tensor ops are not skipped

`BufferizableOpInterface::bufferize` will only be called on ops that
have tensor operands and/or results.

Differential Revision: https://reviews.llvm.org/D113962

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 96fc066e7553e..fdea306a2cd17 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -1114,16 +1114,19 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
   if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
     return success();
 
+  // Check if op has tensor results or operands.
+  auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
+  bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
+  bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
+  if (!hasTensorResult && !hasTensorOperand)
+    return success();
+
   // Bufferize using `BufferizableOpInterface`.
   if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
     return bufferizableOp.bufferize(b, state);
 
   // Other op with tensors. No bufferization method specified.
-  auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
-  if (any_of(op->getOperandTypes(), isaTensor) ||
-      any_of(op->getResultTypes(), isaTensor))
-    return op->emitError() << "unsupported op with tensors";
-  return success();
+  return op->emitError() << "unsupported op with tensors";
 }
 
 static LogicalResult bufferizeFuncOpInternals(
@@ -2482,10 +2485,9 @@ struct TransferReadOpInterface
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPoint(op);
 
-    if (transferReadOp.getShapedType().isa<MemRefType>())
-      return failure();
-
     // TransferReadOp always reads from the bufferized op.source().
+    assert(transferReadOp.getShapedType().isa<TensorType>() &&
+           "only tensor types expected");
     Value v = state.lookupBuffer(transferReadOp.source());
     transferReadOp.sourceMutable().assign(v);
     return success();
@@ -2530,12 +2532,11 @@ struct TransferWriteOpInterface
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPoint(op);
 
-    if (writeOp.getShapedType().isa<MemRefType>())
-      return failure();
-
     // Create a new transfer_write on buffer that doesn't have a return value.
     // Leave the previous transfer_write to dead code as it still has uses at
     // this point.
+    assert(writeOp.getShapedType().isa<TensorType>() &&
+           "only tensor types expected");
     Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
     if (!resultBuffer)
       return failure();

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 6edc7d1090c36..a3e799bf1faaf 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -167,16 +167,3 @@ func @main() -> tensor<4xi32> {
   }
   return %r: tensor<4xi32>
 }
-
-// -----
-
-func @main() -> i32 {
-  %c0 = arith.constant 0: index
-  // expected-error @+1 {{expected result-less scf.execute_region containing op}}
-  %r = scf.execute_region -> i32 {
-    %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-    %e = tensor.extract %A[%c0]: tensor<4xi32>
-    scf.yield %e: i32
-  }
-  return %r: i32
-}


        


More information about the Mlir-commits mailing list