[Mlir-commits] [mlir] [mlir][MemRef] Add runtime bounds checking (PR #75817)

Ryan Holt llvmlistbot at llvm.org
Thu Dec 21 10:22:41 PST 2023


================
@@ -133,6 +144,179 @@ struct CastOpInterface
   }
 };
 
+template <typename LoadStoreOp>
+struct LoadStoreOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<
+          LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto loadStoreOp = cast<LoadStoreOp>(op);
+
+    // Verify that the indices on the load/store are in-bounds of the memref's
+    // index space
+
+    auto memref = loadStoreOp.getMemref();
+    auto rank = memref.getType().getRank();
+    if (rank == 0) {
+      return;
+    }
+    auto indices = loadStoreOp.getIndices();
+
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    Value assertCond;
+    for (auto i : llvm::seq<int64_t>(0, rank)) {
+      auto index = indices[i];
+
+      auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+
+      auto geLow = builder.createOrFold<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, index, zero);
+      auto ltHigh = builder.createOrFold<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::slt, index, dimOp);
+      auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
+
+      assertCond =
+          i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
+                : andOp;
+    }
+    builder.create<cf::AssertOp>(
+        loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
+  }
+};
+
+// Compute the linear index for the provided strided layout and indices.
+Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
+                         ArrayRef<OpFoldResult> strides,
+                         ArrayRef<OpFoldResult> indices) {
+  auto [expr, values] = computeLinearIndex(offset, strides, indices);
+  auto index =
+      affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+  return getValueOrCreateConstantIndexOp(builder, loc, index);
+}
+
+// Returns two Values representing the bounds of the provided strided layout
+// metadata. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+                                            OpFoldResult offset,
+                                            ArrayRef<OpFoldResult> strides,
+                                            ArrayRef<OpFoldResult> sizes) {
+  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
+  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
+  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
+  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
+  return {lowerBound, upperBound};
+}
+
+// Returns two Values representing the bounds of the address space of the
+// memref. The bounds are returned as a half open interval -- [low, high).
+std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
+                                            TypedValue<BaseMemRefType> memref) {
+  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
+  auto offset = runtimeMetadata.getConstifiedMixedOffset();
+  auto strides = runtimeMetadata.getConstifiedMixedStrides();
+  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
+  return computeLinearBounds(builder, loc, offset, strides, sizes);
+}
+
+struct ReinterpretCastOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<
+          ReinterpretCastOpInterface, ReinterpretCastOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto reinterpretCast = cast<ReinterpretCastOp>(op);
+
+    // Verify that the resulting address space is in-bounds of the base memref's
+    // address space.
+
+    auto baseMemref = reinterpretCast.getSource();
+
+    auto castOffset = reinterpretCast.getMixedOffsets().front();
+    auto castStrides = reinterpretCast.getMixedStrides();
+    auto castSizes = reinterpretCast.getMixedSizes();
+
+    // Compute the bounds of the base memref's address space
+    auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
+
+    // Compute the bounds of the resulting memref's address space
+    auto [low, high] =
+        computeLinearBounds(builder, loc, castOffset, castStrides, castSizes);
+
+    auto geLow = builder.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sge, low, baseLow);
+
+    auto leHigh = builder.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sle, high, baseHigh);
+
+    auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
+
+    builder.create<cf::AssertOp>(
+        loc, assertCond,
+        generateErrorMessage(
+            op,
+            "result of reinterpret_cast is out-of-bounds of the base memref"));
+  }
+};
+
+struct SubViewOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
+                                                         SubViewOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto subView = cast<SubViewOp>(op);
----------------
ryan-holt-1 wrote:

Subview with rank reduction works okay because I'm only checking the linearized bounds of the memrefs. That may be more of an issue if we were to check on a per-dimension basis. I added a test for rank reduction.

https://github.com/llvm/llvm-project/pull/75817


More information about the Mlir-commits mailing list