[Mlir-commits] [mlir] [mlir][memref]-Add verification for MemRef::ViewOp bounds (PR #177778)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 24 07:45:49 PST 2026


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --diff_from_common_commit
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 17445a508..ae4474415 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3740,26 +3740,31 @@ void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "view");
 }
 
-static LogicalResult hasIdentityLayoutAndZeroOffset(MemRefType memrefType,
-                                                     StringRef descr,
-                                                     llvm::function_ref<InFlightDiagnostic()> emitError) {
+static LogicalResult hasIdentityLayoutAndZeroOffset(
+    MemRefType memrefType, StringRef descr,
+    llvm::function_ref<InFlightDiagnostic()> emitError) {
   if (!memrefType.getLayout().isIdentity())
     return emitError() << "unsupported map for " << descr << " " << memrefType;
 
   [[maybe_unused]] SmallVector<int64_t> strides;
   int64_t offset;
   if (failed(memrefType.getStridesAndOffset(strides, offset)))
-    return emitError() << "failed to get strides and offset for " << descr << " " << memrefType;
+    return emitError() << "failed to get strides and offset for " << descr
+                       << " " << memrefType;
   if (offset != 0)
-    return emitError() << "unsupported non-zero offset for " << descr << " " << memrefType;
+    return emitError() << "unsupported non-zero offset for " << descr << " "
+                       << memrefType;
   return success();
 }
 
 // Verifies that a view operation's result, plus the byte shift, fits within
 // the source memref bounds. The check is only performed when both the base
-// and view memrefs have static shapes and the view element type is byte-aligned.
-static LogicalResult checkStaticViewBounds(MemRefType baseType,
-  MemRefType viewType, Value shiftInBytes, llvm::function_ref<InFlightDiagnostic()> emitError){
+// and view memrefs have static shapes and the view element type is
+// byte-aligned.
+static LogicalResult
+checkStaticViewBounds(MemRefType baseType, MemRefType viewType,
+                      Value shiftInBytes,
+                      llvm::function_ref<InFlightDiagnostic()> emitError) {
   // Skip if either the base or view has dynamic shape.
   if (!baseType.hasStaticShape() || !viewType.hasStaticShape())
     return success();
@@ -3769,19 +3774,25 @@ static LogicalResult checkStaticViewBounds(MemRefType baseType,
     return success();
 
   // Skip non byte-aligned view element types.
-  int64_t viewElementBitWidth = viewType.getElementType().getIntOrFloatBitWidth();
+  int64_t viewElementBitWidth =
+      viewType.getElementType().getIntOrFloatBitWidth();
   if (viewElementBitWidth % 8 != 0)
     return success();
 
   int64_t baseTotalElementsInBytes = baseType.getNumElements();
   int64_t viewTotalElements = viewType.getNumElements();
-  int64_t viewTotalElementsInBytes = viewTotalElements * (viewElementBitWidth / 8);
+  int64_t viewTotalElementsInBytes =
+      viewTotalElements * (viewElementBitWidth / 8);
   // Shift in bytes may be a non static value, still we will
   // check the sizes bounds.
-  int64_t shiftInBytesInt = getConstantIntValue(getAsOpFoldResult(shiftInBytes)).value_or(0);
+  int64_t shiftInBytesInt =
+      getConstantIntValue(getAsOpFoldResult(shiftInBytes)).value_or(0);
 
   if (viewTotalElementsInBytes + shiftInBytesInt > baseTotalElementsInBytes)
-    return emitError()<< "view total elements in bytes with shift is greater than base total elements in bytes for base memref type " << baseType << " and view memref type " << viewType;
+    return emitError()
+           << "view total elements in bytes with shift is greater than base "
+              "total elements in bytes for base memref type "
+           << baseType << " and view memref type " << viewType;
   return success();
 }
 
@@ -3790,7 +3801,7 @@ LogicalResult ViewOp::verify() {
   auto viewType = getType();
 
   if (failed(hasIdentityLayoutAndZeroOffset(baseType, "base memref type",
-                                            [&](){ return emitError(); })))
+                                            [&]() { return emitError(); })))
     return failure();
 
   if (failed(hasIdentityLayoutAndZeroOffset(viewType, "result memref type",
@@ -3807,7 +3818,8 @@ LogicalResult ViewOp::verify() {
   if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
     return failure();
 
-  if (failed(checkStaticViewBounds(baseType, viewType, getByteShift(), [&](){ return emitError(); })))
+  if (failed(checkStaticViewBounds(baseType, viewType, getByteShift(),
+                                   [&]() { return emitError(); })))
     return failure();
   return success();
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/177778


More information about the Mlir-commits mailing list