[Mlir-commits] [mlir] [mlir][memref]-Add verification for MemRef::ViewOp bounds (PR #177778)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 24 07:44:42 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Amir Bishara (amirBish)

<details>
<summary>Changes</summary>

- Added static bounds checking for ViewOp, Verifies that the view result size (in bytes) plus the byte shift does not exceed the source buffer size when both memrefs have static shapes.

- Added new invalid lit tests and fixed existing test which has failed on the new verification.

---
Full diff: https://github.com/llvm/llvm-project/pull/177778.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+53-6) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir (+2-2) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+4-4) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+29) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1ca0cea0f6f2f..17445a50890bd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3740,17 +3740,62 @@ void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "view");
 }
 
+static LogicalResult hasIdentityLayoutAndZeroOffset(MemRefType memrefType,
+                                                     StringRef descr,
+                                                     llvm::function_ref<InFlightDiagnostic()> emitError) {
+  if (!memrefType.getLayout().isIdentity())
+    return emitError() << "unsupported map for " << descr << " " << memrefType;
+
+  [[maybe_unused]] SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(memrefType.getStridesAndOffset(strides, offset)))
+    return emitError() << "failed to get strides and offset for " << descr << " " << memrefType;
+  if (offset != 0)
+    return emitError() << "unsupported non-zero offset for " << descr << " " << memrefType;
+  return success();
+}
+
+// Verifies that a view operation's result, plus the byte shift, fits within
+// the source memref bounds. The check is only performed when both the base
+// and view memrefs have static shapes and the view element type is byte-aligned.
+static LogicalResult checkStaticViewBounds(MemRefType baseType,
+  MemRefType viewType, Value shiftInBytes, llvm::function_ref<InFlightDiagnostic()> emitError){
+  // Skip if either the base or view has dynamic shape.
+  if (!baseType.hasStaticShape() || !viewType.hasStaticShape())
+    return success();
+
+  // Skip if the view element type is not int or float.
+  if (!viewType.getElementType().isIntOrFloat())
+    return success();
+
+  // Skip non byte-aligned view element types.
+  int64_t viewElementBitWidth = viewType.getElementType().getIntOrFloatBitWidth();
+  if (viewElementBitWidth % 8 != 0)
+    return success();
+
+  int64_t baseTotalElementsInBytes = baseType.getNumElements();
+  int64_t viewTotalElements = viewType.getNumElements();
+  int64_t viewTotalElementsInBytes = viewTotalElements * (viewElementBitWidth / 8);
+  // Shift in bytes may be a non static value, still we will
+  // check the sizes bounds.
+  int64_t shiftInBytesInt = getConstantIntValue(getAsOpFoldResult(shiftInBytes)).value_or(0);
+
+  if (viewTotalElementsInBytes + shiftInBytesInt > baseTotalElementsInBytes)
+    return emitError()<< "view total elements in bytes with shift is greater than base total elements in bytes for base memref type " << baseType << " and view memref type " << viewType;
+  return success();
+}
+
 LogicalResult ViewOp::verify() {
   auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
   auto viewType = getType();
 
-  // The base memref should have identity layout map (or none).
-  if (!baseType.getLayout().isIdentity())
-    return emitError("unsupported map for base memref type ") << baseType;
+  if (failed(hasIdentityLayoutAndZeroOffset(baseType, "base memref type",
+                                            [&](){ return emitError(); })))
+    return failure();
 
-  // The result memref should have identity layout map (or none).
-  if (!viewType.getLayout().isIdentity())
-    return emitError("unsupported map for result memref type ") << viewType;
+  if (failed(hasIdentityLayoutAndZeroOffset(viewType, "result memref type",
+                                            [&]() { return emitError(); })))
+    return failure();
 
   // The base memref and the view memref should be in the same memory space.
   if (baseType.getMemorySpace() != viewType.getMemorySpace())
@@ -3762,6 +3807,8 @@ LogicalResult ViewOp::verify() {
   if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
     return failure();
 
+  if (failed(checkStaticViewBounds(baseType, viewType, getByteShift(), [&](){ return emitError(); })))
+    return failure();
   return success();
 }
 
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index fa683175693be..a94172431b474 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -35,9 +35,9 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   }
 
   //CHECK-LABEL: load_store_matrix_plain_2d_input
-  gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<1024xi8, 3>) -> f32 {
+  gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
     %c0 = arith.constant 0 : index
-    %view = memref.view %arg0[%c0][]: memref<1024xi8, 3> to memref<64x32xf32, 3>
+    %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
 
     %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 17afd9a15b60d..85125f051b26e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1574,13 +1574,13 @@ func.func @fold_view_same_source_result_types(%0: memref<128xi8>) -> memref<128x
 // -----
 
 // CHECK-LABEL: func @non_fold_view_non_zero_offset
-//  CHECK-SAME:   (%[[ARG:.*]]: memref<128xi8>)
-func.func @non_fold_view_non_zero_offset(%0: memref<128xi8>) -> memref<128xi8> {
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<129xi8>)
+func.func @non_fold_view_non_zero_offset(%0: memref<129xi8>) -> memref<128xi8> {
   %c1 = arith.constant 1 : index
   // CHECK: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C1]]][] : memref<128xi8> to memref<128xi8>
+  // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C1]]][] : memref<129xi8> to memref<128xi8>
   // CHECK: return %[[RES]]
-  %res = memref.view %0[%c1][] : memref<128xi8> to memref<128xi8>
+  %res = memref.view %0[%c1][] : memref<129xi8> to memref<128xi8>
   return %res : memref<128xi8>
 }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 46e010fc878fe..8a661fc5626fa 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -632,6 +632,35 @@ func.func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 // -----
 
+func.func @invalid_view_out_of_bounds() {
+  %0 = memref.alloc() : memref<64xi8>
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{view total elements in bytes with shift is greater than base total elements in bytes}}
+  %1 = memref.view %0[%c0][] : memref<64xi8> to memref<32xf32>
+  return
+}
+
+// -----
+
+func.func @invalid_view_out_of_bounds_with_dynamic_shift_in_bytes(%shift: index) {
+  %0 = memref.alloc() : memref<64xi8>
+  // expected-error at +1 {{view total elements in bytes with shift is greater than base total elements in bytes}}
+  %1 = memref.view %0[%shift][] : memref<64xi8> to memref<32xf32>
+  return
+}
+
+// -----
+
+func.func @invalid_view_out_of_bounds_with_shift() {
+  %0 = memref.alloc() : memref<128xi8>
+  %c8 = arith.constant 8 : index
+  // expected-error at +1 {{view total elements in bytes with shift is greater than base total elements in bytes}}
+  %1 = memref.view %0[%c8][] : memref<128xi8> to memref<32xf32>
+  return
+}
+
+// -----
+
 func.func @invalid_subview(%input: memref<4x1024xf32>) -> memref<2x256xf32, strided<[1024, 1], offset: 2304>> {
   // expected-error at +1 {{expected offsets to be non-negative, but got -1}}
   %0 = memref.subview %input[-1, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/177778


More information about the Mlir-commits mailing list