[Mlir-commits] [mlir] 4bc2357 - [mlir][MemRef|Tensor] Fix the handling of DimOp
Quentin Colombet
llvmlistbot at llvm.org
Thu Feb 16 02:54:21 PST 2023
Author: Quentin Colombet
Date: 2023-02-16T11:38:19+01:00
New Revision: 4bc2357c3de268b2b50ad0ff9c2c040329b75375
URL: https://github.com/llvm/llvm-project/commit/4bc2357c3de268b2b50ad0ff9c2c040329b75375
DIFF: https://github.com/llvm/llvm-project/commit/4bc2357c3de268b2b50ad0ff9c2c040329b75375.diff
LOG: [mlir][MemRef|Tensor] Fix the handling of DimOp
Although specifying an index that is out of bounds for both `memref.dim`
and `tensor.dim` produces an undefined behavior, this is still valid IR.
In particular, we could expose an out of bound index because of some
optimizations, for instance as demonstrated with
https://github.com/llvm/llvm-project/issues/60295, and this shouldn't
cause the compiler to abort.
This patch removes the overzealous verifier checks and properly handles
out of bound indices (as in it doesn't crash the compiler, but still
produces UB).
This fixes https://github.com/llvm/llvm-project/issues/60295.
Note: That `shape.dim` has a similar problem but we're not supposed to
produce UB in this case. Instead we're supposed to propagate an error in
the resulting value and I don't know how to do that at the moment. Hence I
left this part out of the patch.
Differential Revision: https://reviews.llvm.org/D143999
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
mlir/test/Dialect/Tensor/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index fdc070fdd068e..45b5c9f21b072 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -604,7 +604,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
let hasCanonicalizer = 1;
let hasFolder = 1;
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9e1c8bc3abf81..77d183e7d6ec3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -143,7 +143,6 @@ def Tensor_DimOp : Tensor_Op<"dim", [
let hasCanonicalizer = 1;
let hasFolder = 1;
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 700304d56df86..35091a3be8023 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -495,14 +495,16 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
MemRefType memRefType = operandType.cast<MemRefType>();
if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
int64_t i = *index;
- if (memRefType.isDynamicDim(i)) {
- // extract dynamic size from the memref descriptor.
- MemRefDescriptor descriptor(adaptor.getSource());
- return descriptor.size(rewriter, loc, i);
+ if (i >= 0 && i < memRefType.getRank()) {
+ if (memRefType.isDynamicDim(i)) {
+ // extract dynamic size from the memref descriptor.
+ MemRefDescriptor descriptor(adaptor.getSource());
+ return descriptor.size(rewriter, loc, i);
+ }
+ // Use constant for static size.
+ int64_t dimSize = memRefType.getDimSize(i);
+ return createIndexConstant(rewriter, loc, dimSize);
}
- // Use constant for static size.
- int64_t dimSize = memRefType.getDimSize(i);
- return createIndexConstant(rewriter, loc, dimSize);
}
Value index = adaptor.getIndex();
int64_t rank = memRefType.getRank();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 02f8019996cdd..6814aa5c971b5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -941,25 +941,6 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}
-LogicalResult DimOp::verify() {
- // Assume unknown index to be in range.
- std::optional<int64_t> index = getConstantIndex();
- if (!index)
- return success();
-
- // Check that constant index is not knowingly out of range.
- auto type = getSource().getType();
- if (auto memrefType = type.dyn_cast<MemRefType>()) {
- if (*index >= memrefType.getRank())
- return emitOpError("index is out of range");
- } else if (type.isa<UnrankedMemRefType>()) {
- // Assume index to be in range.
- } else {
- llvm_unreachable("expected operand with memref type");
- }
- return success();
-}
-
/// Return a map with key being elements in `vals` and data being number of
/// occurences of it. Use std::map, since the `vals` here are strides and the
/// dynamic stride value is the same as the tombstone value for
@@ -1067,6 +1048,12 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
if (!memrefType)
return {};
+ // Out of bound indices produce undefined behavior but are still valid IR.
+ // Don't choke on them.
+ int64_t indexVal = index.getInt();
+ if (indexVal < 0 || indexVal >= memrefType.getRank())
+ return {};
+
// Fold if the shape extent along the given index is known.
if (!memrefType.isDynamicDim(index.getInt())) {
Builder builder(getContext());
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 74b3f9338aa75..d5679495a3bce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -399,25 +399,6 @@ Speculation::Speculatability DimOp::getSpeculatability() {
return Speculation::Speculatable;
}
-LogicalResult DimOp::verify() {
- // Assume unknown index to be in range.
- std::optional<int64_t> index = getConstantIndex();
- if (!index)
- return success();
-
- // Check that constant index is not knowingly out of range.
- auto type = getSource().getType();
- if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
- if (*index >= tensorType.getRank())
- return emitOpError("index is out of range");
- } else if (type.isa<UnrankedTensorType>()) {
- // Assume index to be in range.
- } else {
- llvm_unreachable("expected operand with tensor type");
- }
- return success();
-}
-
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
@@ -429,6 +410,12 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
if (!tensorType)
return {};
+ // Out of bound indices produce undefined behavior but are still valid IR.
+ // Don't choke on them.
+ int64_t indexVal = index.getInt();
+ if (indexVal < 0 || indexVal >= tensorType.getRank())
+ return {};
+
// Fold if the shape extent along the given index is known.
if (!tensorType.isDynamicDim(index.getInt())) {
Builder builder(getContext());
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 24877cc299a14..0f3f3e4c2ff2d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -188,6 +188,20 @@ func.func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
// -----
+// CHECK-LABEL: func @static_out_of_bound_memref_dim
+func.func @static_out_of_bound_memref_dim(%static : memref<42x32x15x13x27xf32>) -> index {
+// CHECK: %[[C_MINUS_7:.*]] = arith.constant -7 : index
+// CHECK: %[[C_MINUS_7_I64:.*]] = builtin.unrealized_conversion_cast %[[C_MINUS_7]] : index to i64
+// CHECK: %[[UB_IDX:.*]] = llvm.getelementptr %{{.*}}[0, %[[C_MINUS_7_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr
+// CHECK: %[[UB_DIM_I64:.*]] = llvm.load %[[UB_IDX]] : !llvm.ptr
+// CHECK: %[[UB_DIM:.*]] = builtin.unrealized_conversion_cast %[[UB_DIM_I64]] : i64 to index
+// CHECK: return %[[UB_DIM]] : index
+ %c-7 = arith.constant -7 : index
+ %1 = memref.dim %static, %c-7 : memref<42x32x15x13x27xf32>
+ return %1 : index
+}
+// -----
+
// Check that consistent types are emitted in address arithemic in presence of
// a data layout specification.
module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index d15819f497a71..f74bd9456b666 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -1,13 +1,5 @@
// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
-func.func @dim(%arg : tensor<1x?xf32>) {
- %c2 = arith.constant 2 : index
- tensor.dim %arg, %c2 : tensor<1x?xf32> // expected-error {{'tensor.dim' op index is out of range}}
- return
-}
-
-// -----
-
// Asking the dimension of a 0-D shape doesn't make sense.
func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}}
More information about the Mlir-commits
mailing list