[Mlir-commits] [mlir] [mlir][MemRef] Add runtime bounds checking (PR #75817)
Matthias Springer
llvmlistbot at llvm.org
Tue Dec 19 22:19:44 PST 2023
================
@@ -133,6 +144,179 @@ struct CastOpInterface
}
};
+template <typename LoadStoreOp>
+struct LoadStoreOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<
+ LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto loadStoreOp = cast<LoadStoreOp>(op);
+
+ // Verify that the indices on the load/store are in-bounds of the memref's
+ // index space
+
+ auto memref = loadStoreOp.getMemref();
+ auto rank = memref.getType().getRank();
+ if (rank == 0) {
+ return;
+ }
+ auto indices = loadStoreOp.getIndices();
+
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value assertCond;
+ for (auto i : llvm::seq<int64_t>(0, rank)) {
+ auto index = indices[i];
+
+ auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+
+ auto geLow = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, index, zero);
+ auto ltHigh = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, index, dimOp);
+ auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
+
+ assertCond =
+ i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
+ : andOp;
+ }
+ builder.create<cf::AssertOp>(
+ loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
+ }
+};
+
+// Compute the linear index for the provided strided layout and indices.
+Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices) {
+ auto [expr, values] = computeLinearIndex(offset, strides, indices);
+ auto index =
+ affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+ return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
+// Returns two Values representing the bounds of the provided strided layout
+// metadata. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+ OpFoldResult offset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> sizes) {
+ auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+ auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+ auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+ auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
+ return {lowerBound, upperBound};
+}
+
+// Returns two Values representing the bounds of the address space of the
+// memref. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+ TypedValue<BaseMemRefType> memref) {
+ auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+ auto offset = runtimeMetadata.getConstifiedMixedOffset();
+ auto strides = runtimeMetadata.getConstifiedMixedStrides();
+ auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+ return computeLinearBounds(builder, loc, offset, strides, sizes);
+}
+
+struct ReinterpretCastOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<
+ ReinterpretCastOpInterface, ReinterpretCastOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto reinterpretCast = cast<ReinterpretCastOp>(op);
+
+ // Verify that the resulting address space is in-bounds of the base memref's
+ // address space.
+
+ auto baseMemref = reinterpretCast.getSource();
+
+ auto castOffset = reinterpretCast.getMixedOffsets().front();
+ auto castStrides = reinterpretCast.getMixedStrides();
+ auto castSizes = reinterpretCast.getMixedSizes();
+
+ // Compute the bounds of the base memref's address space
+ auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+
+ // Compute the bounds of the resulting memref's address space
+ auto [low, high] =
+ computeLinearBounds(builder, loc, castOffset, castStrides, castSizes);
+
+ auto geLow = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, low, baseLow);
+
+ auto leHigh = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sle, high, baseHigh);
+
+ auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
+
+ builder.create<cf::AssertOp>(
+ loc, assertCond,
+ generateErrorMessage(
+ op,
+ "result of reinterpret_cast is out-of-bounds of the base memref"));
+ }
+};
+
+struct SubViewOpInterface
----------------
matthias-springer wrote:
I think this implementation does not reject like the following one:
```
memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
: memref<?x?xf32> to memref<?x?xf32>
// source runtime dims: 10x10
```
Because `20x2 < 10x10`, so the subview is in-bounds even though the first dim is out-of-bounds. Is that correct?
If so, can we verify each dim separately? It may also simplify the index computations here.
https://github.com/llvm/llvm-project/pull/75817
More information about the Mlir-commits
mailing list