[Mlir-commits] [mlir] 9ba25ec - [mlir][Bufferize] NFC - Introduce areCastCompatible assertions to catch misformed CastOp early

Nicolas Vasilache llvmlistbot at llvm.org
Sun Jan 9 11:13:18 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-09T14:13:08-05:00
New Revision: 9ba25ec92d88639561797674296b81fb3b67eed5

URL: https://github.com/llvm/llvm-project/commit/9ba25ec92d88639561797674296b81fb3b67eed5
DIFF: https://github.com/llvm/llvm-project/commit/9ba25ec92d88639561797674296b81fb3b67eed5.diff

LOG: [mlir][Bufferize] NFC - Introduce areCastCompatible assertions to catch misformed CastOp early

Differential Revision: https://reviews.llvm.org/D116893

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index d2d726312d6ad..e64d5ae3dda61 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -549,6 +549,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
     return failure();
   Value casted = allocated.getValue();
   if (memRefType && memRefType != allocMemRefType) {
+    assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
+                                             memRefType) &&
+           "createAlloc: cast incompatible");
     casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 17719244c5c3f..fd3632fb56d07 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -77,9 +77,13 @@ struct ToMemrefOpInterface
 
       // Insert cast in case to_memref(to_tensor(x))'s type is 
diff erent from
       // x's type.
-      if (toTensorOp.memref().getType() != toMemrefOp.getType())
+      if (toTensorOp.memref().getType() != toMemrefOp.getType()) {
+        assert(memref::CastOp::areCastCompatible(buffer.getType(),
+                                                 toMemrefOp.getType()) &&
+               "ToMemrefOp::bufferize : cast incompatible");
         buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
                                                  toMemrefOp.getType());
+      }
       replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
       return success();
     }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 8c6b32d733174..8138ab2952eac 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -386,7 +386,10 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
     // Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
     for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
       if (auto toMemrefOp =
-          dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
+              dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
+        assert(memref::CastOp::areCastCompatible(
+                   memref.getType(), toMemrefOp.memref().getType()) &&
+               "bufferizeFuncOpBoundary: cast incompatible");
         auto castOp = b.create<memref::CastOp>(
             funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
         toMemrefOp.memref().replaceAllUsesWith(castOp);
@@ -525,6 +528,8 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
         bbArg.setType(desiredMemrefType);
         OpBuilder b(bbArg.getContext());
         b.setInsertionPointToStart(bbArg.getOwner());
+        assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
+               "layoutPostProcessing: cast incompatible");
         // Cast back to the original memrefType and let it canonicalize.
         Value cast =
             b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
@@ -537,6 +542,10 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
       // such cases.
       auto castArg = [&](Operation *caller) {
         OpBuilder b(caller);
+        assert(
+            memref::CastOp::areCastCompatible(
+                caller->getOperand(argNumber).getType(), desiredMemrefType) &&
+            "layoutPostProcessing.2: cast incompatible");
         Value newOperand = b.create<memref::CastOp>(
             funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
         operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
@@ -703,6 +712,9 @@ struct CallOpInterface
       // that will either canonicalize away or fail compilation until we can do
       // something better.
       if (buffer.getType() != memRefType) {
+        assert(
+            memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
+            "CallOp::bufferize: cast incompatible");
         Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
                                                            memRefType, buffer);
         buffer = castBuffer;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 7c9114b284b28..f0f20b433937e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -77,6 +77,9 @@ struct CastOpInterface
     }
 
     // Replace the op with a memref.cast.
+    assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
+                                             resultMemRefType) &&
+           "CallOp::bufferize: cast incompatible");
     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
                                                  *resultBuffer);
 


        


More information about the Mlir-commits mailing list