[Mlir-commits] [mlir] [mlir][memref]-Add verification for MemRef::ViewOp bounds (PR #177778)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 24 07:45:49 PST 2026
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --diff_from_common_commit
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 17445a508..ae4474415 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3740,26 +3740,31 @@ void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "view");
}
-static LogicalResult hasIdentityLayoutAndZeroOffset(MemRefType memrefType,
- StringRef descr,
- llvm::function_ref<InFlightDiagnostic()> emitError) {
+static LogicalResult hasIdentityLayoutAndZeroOffset(
+ MemRefType memrefType, StringRef descr,
+ llvm::function_ref<InFlightDiagnostic()> emitError) {
if (!memrefType.getLayout().isIdentity())
return emitError() << "unsupported map for " << descr << " " << memrefType;
[[maybe_unused]] SmallVector<int64_t> strides;
int64_t offset;
if (failed(memrefType.getStridesAndOffset(strides, offset)))
- return emitError() << "failed to get strides and offset for " << descr << " " << memrefType;
+ return emitError() << "failed to get strides and offset for " << descr
+ << " " << memrefType;
if (offset != 0)
- return emitError() << "unsupported non-zero offset for " << descr << " " << memrefType;
+ return emitError() << "unsupported non-zero offset for " << descr << " "
+ << memrefType;
return success();
}
// Verifies that a view operation's result, plus the byte shift, fits within
// the source memref bounds. The check is only performed when both the base
-// and view memrefs have static shapes and the view element type is byte-aligned.
-static LogicalResult checkStaticViewBounds(MemRefType baseType,
- MemRefType viewType, Value shiftInBytes, llvm::function_ref<InFlightDiagnostic()> emitError){
+// and view memrefs have static shapes and the view element type is
+// byte-aligned.
+static LogicalResult
+checkStaticViewBounds(MemRefType baseType, MemRefType viewType,
+ Value shiftInBytes,
+ llvm::function_ref<InFlightDiagnostic()> emitError) {
// Skip if either the base or view has dynamic shape.
if (!baseType.hasStaticShape() || !viewType.hasStaticShape())
return success();
@@ -3769,19 +3774,25 @@ static LogicalResult checkStaticViewBounds(MemRefType baseType,
return success();
// Skip non byte-aligned view element types.
- int64_t viewElementBitWidth = viewType.getElementType().getIntOrFloatBitWidth();
+ int64_t viewElementBitWidth =
+ viewType.getElementType().getIntOrFloatBitWidth();
if (viewElementBitWidth % 8 != 0)
return success();
int64_t baseTotalElementsInBytes = baseType.getNumElements();
int64_t viewTotalElements = viewType.getNumElements();
- int64_t viewTotalElementsInBytes = viewTotalElements * (viewElementBitWidth / 8);
+ int64_t viewTotalElementsInBytes =
+ viewTotalElements * (viewElementBitWidth / 8);
// Shift in bytes may be a non static value, still we will
// check the sizes bounds.
- int64_t shiftInBytesInt = getConstantIntValue(getAsOpFoldResult(shiftInBytes)).value_or(0);
+ int64_t shiftInBytesInt =
+ getConstantIntValue(getAsOpFoldResult(shiftInBytes)).value_or(0);
if (viewTotalElementsInBytes + shiftInBytesInt > baseTotalElementsInBytes)
- return emitError()<< "view total elements in bytes with shift is greater than base total elements in bytes for base memref type " << baseType << " and view memref type " << viewType;
+ return emitError()
+ << "view total elements in bytes with shift is greater than base "
+ "total elements in bytes for base memref type "
+ << baseType << " and view memref type " << viewType;
return success();
}
@@ -3790,7 +3801,7 @@ LogicalResult ViewOp::verify() {
auto viewType = getType();
if (failed(hasIdentityLayoutAndZeroOffset(baseType, "base memref type",
- [&](){ return emitError(); })))
+ [&]() { return emitError(); })))
return failure();
if (failed(hasIdentityLayoutAndZeroOffset(viewType, "result memref type",
@@ -3807,7 +3818,8 @@ LogicalResult ViewOp::verify() {
if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
return failure();
- if (failed(checkStaticViewBounds(baseType, viewType, getByteShift(), [&](){ return emitError(); })))
+ if (failed(checkStaticViewBounds(baseType, viewType, getByteShift(),
+ [&]() { return emitError(); })))
return failure();
return success();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/177778
More information about the Mlir-commits
mailing list