[Mlir-commits] [mlir] [mlir][tensor] Add runtime verification for `cast`/`dim`/`extract`/`insert`/`extract_slice` (PR #141332)

lorenzo chelini llvmlistbot at llvm.org
Mon May 26 10:04:54 PDT 2025


================
@@ -0,0 +1,208 @@
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace tensor {
+namespace {
+/// Generate a runtime check for lb <= value < ub.
+Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
+                            Value lb, Value ub) {
+  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::sge, value, lb);
+  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::slt, value, ub);
+  Value inBounds =
+      builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
+  return inBounds;
+}
+
+struct CastOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
+                                                         CastOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto castOp = cast<CastOp>(op);
+    auto srcType = cast<TensorType>(castOp.getSource().getType());
+
+    // Nothing to check if the result is an unranked tensor.
+    auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
----------------
chelini wrote:

why we don't check for UnrankedTensorType directly?

https://github.com/llvm/llvm-project/pull/141332


More information about the Mlir-commits mailing list