[Mlir-commits] [mlir] [mlir][memref]-Add verification for MemRef::ViewOp bounds (PR #177778)
Amir Bishara
llvmlistbot at llvm.org
Sat Jan 24 07:46:50 PST 2026
https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/177778
>From cee697528aa7cd27be1df2e5e9356d93b836278a Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Thu, 18 Dec 2025 20:54:25 +0200
Subject: [PATCH] [mlir][memref]-Add verification for MemRef::ViewOp bounds
- Added static bounds checking for ViewOp, Verifies that the
view result size (in bytes) plus the byte shift does not exceed
the source buffer size when both memrefs have static shapes.
- Added new invalid lit tests and fixed existing test which has failed on
the new verification.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 71 +++++++++++++++++--
.../XeGPUToXeVM/loadstore_matrix.mlir | 4 +-
mlir/test/Dialect/MemRef/canonicalize.mlir | 8 +--
mlir/test/Dialect/MemRef/invalid.mlir | 29 ++++++++
4 files changed, 100 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1ca0cea0f6f2f..ae44744153152 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3740,17 +3740,73 @@ void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "view");
}
+static LogicalResult hasIdentityLayoutAndZeroOffset(
+ MemRefType memrefType, StringRef descr,
+ llvm::function_ref<InFlightDiagnostic()> emitError) {
+ if (!memrefType.getLayout().isIdentity())
+ return emitError() << "unsupported map for " << descr << " " << memrefType;
+
+ [[maybe_unused]] SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(memrefType.getStridesAndOffset(strides, offset)))
+ return emitError() << "failed to get strides and offset for " << descr
+ << " " << memrefType;
+ if (offset != 0)
+ return emitError() << "unsupported non-zero offset for " << descr << " "
+ << memrefType;
+ return success();
+}
+
+// Verifies that a view operation's result, plus the byte shift, fits within
+// the source memref bounds. The check is only performed when both the base
+// and view memrefs have static shapes and the view element type is
+// byte-aligned.
+static LogicalResult
+checkStaticViewBounds(MemRefType baseType, MemRefType viewType,
+ Value shiftInBytes,
+ llvm::function_ref<InFlightDiagnostic()> emitError) {
+ // Skip if either the base or view has dynamic shape.
+ if (!baseType.hasStaticShape() || !viewType.hasStaticShape())
+ return success();
+
+ // Skip if the view element type is not int or float.
+ if (!viewType.getElementType().isIntOrFloat())
+ return success();
+
+ // Skip non byte-aligned view element types.
+ int64_t viewElementBitWidth =
+ viewType.getElementType().getIntOrFloatBitWidth();
+ if (viewElementBitWidth % 8 != 0)
+ return success();
+
+ int64_t baseTotalElementsInBytes = baseType.getNumElements();
+ int64_t viewTotalElements = viewType.getNumElements();
+ int64_t viewTotalElementsInBytes =
+ viewTotalElements * (viewElementBitWidth / 8);
+ // Shift in bytes may be a non static value, still we will
+ // check the sizes bounds.
+ int64_t shiftInBytesInt =
+ getConstantIntValue(getAsOpFoldResult(shiftInBytes)).value_or(0);
+
+ if (viewTotalElementsInBytes + shiftInBytesInt > baseTotalElementsInBytes)
+ return emitError()
+ << "view total elements in bytes with shift is greater than base "
+ "total elements in bytes for base memref type "
+ << baseType << " and view memref type " << viewType;
+ return success();
+}
+
LogicalResult ViewOp::verify() {
auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
auto viewType = getType();
- // The base memref should have identity layout map (or none).
- if (!baseType.getLayout().isIdentity())
- return emitError("unsupported map for base memref type ") << baseType;
+ if (failed(hasIdentityLayoutAndZeroOffset(baseType, "base memref type",
+ [&]() { return emitError(); })))
+ return failure();
- // The result memref should have identity layout map (or none).
- if (!viewType.getLayout().isIdentity())
- return emitError("unsupported map for result memref type ") << viewType;
+ if (failed(hasIdentityLayoutAndZeroOffset(viewType, "result memref type",
+ [&]() { return emitError(); })))
+ return failure();
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != viewType.getMemorySpace())
@@ -3762,6 +3818,9 @@ LogicalResult ViewOp::verify() {
if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
return failure();
+ if (failed(checkStaticViewBounds(baseType, viewType, getByteShift(),
+ [&]() { return emitError(); })))
+ return failure();
return success();
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index fa683175693be..a94172431b474 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -35,9 +35,9 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
}
//CHECK-LABEL: load_store_matrix_plain_2d_input
- gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<1024xi8, 3>) -> f32 {
+ gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
%c0 = arith.constant 0 : index
- %view = memref.view %arg0[%c0][]: memref<1024xi8, 3> to memref<64x32xf32, 3>
+ %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 17afd9a15b60d..85125f051b26e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1574,13 +1574,13 @@ func.func @fold_view_same_source_result_types(%0: memref<128xi8>) -> memref<128x
// -----
// CHECK-LABEL: func @non_fold_view_non_zero_offset
-// CHECK-SAME: (%[[ARG:.*]]: memref<128xi8>)
-func.func @non_fold_view_non_zero_offset(%0: memref<128xi8>) -> memref<128xi8> {
+// CHECK-SAME: (%[[ARG:.*]]: memref<129xi8>)
+func.func @non_fold_view_non_zero_offset(%0: memref<129xi8>) -> memref<128xi8> {
%c1 = arith.constant 1 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C1]]][] : memref<128xi8> to memref<128xi8>
+ // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C1]]][] : memref<129xi8> to memref<128xi8>
// CHECK: return %[[RES]]
- %res = memref.view %0[%c1][] : memref<128xi8> to memref<128xi8>
+ %res = memref.view %0[%c1][] : memref<129xi8> to memref<128xi8>
return %res : memref<128xi8>
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 46e010fc878fe..8a661fc5626fa 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -632,6 +632,35 @@ func.func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
+func.func @invalid_view_out_of_bounds() {
+ %0 = memref.alloc() : memref<64xi8>
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{view total elements in bytes with shift is greater than base total elements in bytes}}
+ %1 = memref.view %0[%c0][] : memref<64xi8> to memref<32xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_view_out_of_bounds_with_dynamic_shift_in_bytes(%shift: index) {
+ %0 = memref.alloc() : memref<64xi8>
+ // expected-error at +1 {{view total elements in bytes with shift is greater than base total elements in bytes}}
+ %1 = memref.view %0[%shift][] : memref<64xi8> to memref<32xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_view_out_of_bounds_with_shift() {
+ %0 = memref.alloc() : memref<128xi8>
+ %c8 = arith.constant 8 : index
+ // expected-error at +1 {{view total elements in bytes with shift is greater than base total elements in bytes}}
+ %1 = memref.view %0[%c8][] : memref<128xi8> to memref<32xf32>
+ return
+}
+
+// -----
+
func.func @invalid_subview(%input: memref<4x1024xf32>) -> memref<2x256xf32, strided<[1024, 1], offset: 2304>> {
// expected-error at +1 {{expected offsets to be non-negative, but got -1}}
%0 = memref.subview %input[-1, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
More information about the Mlir-commits
mailing list