[Mlir-commits] [mlir] [MLIR][Tensor] Fix out-of-bounds FoldEmptyTensorWithDimOp crash #111270 (PR #112196)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 16 21:39:31 PDT 2024


https://github.com/brod4910 updated https://github.com/llvm/llvm-project/pull/112196

>From 524a82187158ce63e4a1d355fe026e2d8ba7e8f0 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 14 Oct 2024 07:08:51 -0600
Subject: [PATCH 1/3] Fix out-of-bounds FoldEmptyTensorWithDimOp crash #111270

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4d6c5965c4fcc3..7af869c2183892 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -979,7 +979,9 @@ struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
     auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
     if (!emptyTensorOp || !maybeConstantIndex)
       return failure();
-    if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
+    auto emptyTensorType = emptyTensorOp.getType();
+    if (*maybeConstantIndex >= emptyTensorType.getRank() ||
+        !emptyTensorType.isDynamicDim(*maybeConstantIndex))
       return failure();
     rewriter.replaceOp(dimOp,
                        emptyTensorOp.getDynamicSize(*maybeConstantIndex));

>From e1af9585578ff2386040a887c2ce7d9759ab4a37 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Tue, 15 Oct 2024 20:54:13 -0600
Subject: [PATCH 2/3] add tensor.dim verifier implementation and missing
 negative case

---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  1 +
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 19 ++++++++++++++++++-
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..4973621e588aa1 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -246,6 +246,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7af869c2183892..323055236327be 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -28,6 +28,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include <algorithm>
 #include <optional>
@@ -714,6 +715,21 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
   build(builder, result, source, indexValue);
 }
 
+LogicalResult DimOp::verify() {
+  auto maybeConstantIndex = getConstantIndex();
+  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
+
+  if (tensorType && maybeConstantIndex &&
+      (*maybeConstantIndex < 0 ||
+       *maybeConstantIndex >= tensorType.getRank())) {
+    return emitOpError("out-of-range access, attempted to access index ")
+           << *maybeConstantIndex << " but valid range is [0, "
+           << tensorType.getRank() - 1 << "].";
+  }
+
+  return success();
+}
+
 std::optional<int64_t> DimOp::getConstantIndex() {
   return getConstantIntValue(getIndex());
 }
@@ -980,7 +996,8 @@ struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
     if (!emptyTensorOp || !maybeConstantIndex)
       return failure();
     auto emptyTensorType = emptyTensorOp.getType();
-    if (*maybeConstantIndex >= emptyTensorType.getRank() ||
+    if (*maybeConstantIndex < 0 ||
+        *maybeConstantIndex >= emptyTensorType.getRank() ||
         !emptyTensorType.isDynamicDim(*maybeConstantIndex))
       return failure();
     rewriter.replaceOp(dimOp,

>From 44e507b9cdb3c5cfcb34b3a0b9cc0c20d80e32f4 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Wed, 16 Oct 2024 22:39:17 -0600
Subject: [PATCH 3/3] remove tensor.dim verifier

---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td |  1 -
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp         | 15 ---------------
 2 files changed, 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 4973621e588aa1..3170115883e2be 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -246,7 +246,6 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 323055236327be..e545eeb3efed56 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -715,21 +715,6 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
   build(builder, result, source, indexValue);
 }
 
-LogicalResult DimOp::verify() {
-  auto maybeConstantIndex = getConstantIndex();
-  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
-
-  if (tensorType && maybeConstantIndex &&
-      (*maybeConstantIndex < 0 ||
-       *maybeConstantIndex >= tensorType.getRank())) {
-    return emitOpError("out-of-range access, attempted to access index ")
-           << *maybeConstantIndex << " but valid range is [0, "
-           << tensorType.getRank() - 1 << "].";
-  }
-
-  return success();
-}
-
 std::optional<int64_t> DimOp::getConstantIndex() {
   return getConstantIntValue(getIndex());
 }



More information about the Mlir-commits mailing list