[Mlir-commits] [mlir] 8b58711 - [mlir][memref] fix overflow in realloc

Peiming Liu llvmlistbot at llvm.org
Thu Sep 22 20:07:33 PDT 2022


Author: Peiming Liu
Date: 2022-09-23T03:07:23Z
New Revision: 8b587113b746f31b63fd6473083df78cef30a72e

URL: https://github.com/llvm/llvm-project/commit/8b587113b746f31b63fd6473083df78cef30a72e
DIFF: https://github.com/llvm/llvm-project/commit/8b587113b746f31b63fd6473083df78cef30a72e.diff

LOG: [mlir][memref] fix overflow in realloc

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134511

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 241f62eb1d59a..8fe631b25bad4 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -191,6 +191,11 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
     // Compute total byte size.
     auto dstByteSize =
         rewriter.create<LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
+    // Since the src and dst memref are guarantee to have the same
+    // element type by the verifier, it is safe here to reuse the
+    // type size computed from dst memref.
+    auto srcByteSize =
+        rewriter.create<LLVM::MulOp>(loc, srcNumElements, sizeInBytes);
     // Allocate a new buffer.
     auto [dstRawPtr, dstAlignedPtr] =
         allocateBuffer(rewriter, loc, dstByteSize, op);
@@ -202,7 +207,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
       return rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
     };
     rewriter.create<LLVM::MemcpyOp>(loc, toVoidPtr(dstAlignedPtr),
-                                    toVoidPtr(srcAlignedPtr), dstByteSize,
+                                    toVoidPtr(srcAlignedPtr), srcByteSize,
                                     isVolatile);
     // Deallocate the old buffer.
     LLVM::LLVMFuncOp freeFunc =

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 3cd8fb3334640..821f298ba4ce5 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -633,22 +633,23 @@ func.func @ranked_unranked() {
 // CHECK-SAME:      %[[arg1:.*]]: index) -> memref<?xf32> {
 func.func @realloc_dynamic(%in: memref<?xf32>, %d: index) -> memref<?xf32>{
 // CHECK:           %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]]
-// CHECK:           %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
+// CHECK:           %[[src_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
 // CHECK:           %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64
-// CHECK:           %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64
+// CHECK:           %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64
 // CHECK:           llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
 // CHECK:           ^bb1:
 // CHECK:           %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
 // CHECK:           %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<f32>
 // CHECK:           %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])
@@ -683,6 +684,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]]
 // CHECK:           %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK:           %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
@@ -698,7 +700,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])
@@ -720,6 +722,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // ALIGNED-ALLOC:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // ALIGNED-ALLOC:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // ALIGNED-ALLOC:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// ALIGNED-ALLOC:           %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]]
 // ALIGNED-ALLOC-DAG:       %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
 // ALIGNED-ALLOC-DAG:       %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64
 // ALIGNED-ALLOC:           %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]]
@@ -732,7 +735,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // ALIGNED-ALLOC:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // ALIGNED-ALLOC-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // ALIGNED-ALLOC-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// ALIGNED-ALLOC:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// ALIGNED-ALLOC:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // ALIGNED-ALLOC:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // ALIGNED-ALLOC:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // ALIGNED-ALLOC:           llvm.call @free(%[[old_buffer_unaligned_void]])

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index cabc84f57847b..023955987fb5c 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -354,13 +354,14 @@ func.func @realloc_static(%in: memref<2xi32>) -> memref<4xi32>{
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<i32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
 // CHECK:           %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<i32>
 // CHECK:           %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<i32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])
@@ -391,6 +392,7 @@ func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
 // CHECK:           %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK:           %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
@@ -406,7 +408,7 @@ func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
index 035c90c8e4be4..ff57bfee527d2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
@@ -1,6 +1,3 @@
-// FIXME: re-enable when sanitizer issue is resolved
-// UNSUPPORTED: asan
-//
 // RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
 // RUN: mlir-cpu-runner \
 // RUN:  -e entry -entry-point-result=void  \


        


More information about the Mlir-commits mailing list