[Mlir-commits] [mlir] [mlir][tensor] fix out-of-bound index in tensor.dim (PR #85901)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 20 00:56:36 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Jianbang Yang (dynamicheart)

<details>
<summary>Changes</summary>

fold tensor.dim with out-of-bound index to ub.poison

Fixes: https://github.com/llvm/llvm-project/issues/70183

---
Full diff: https://github.com/llvm/llvm-project/pull/85901.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+6-3) 
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+1-2) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+22) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 2b226c7a1207cf..a656c812a59feb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -333,6 +333,9 @@ struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
     auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
     if (!allocTensorOp || !maybeConstantIndex)
       return failure();
+    if (*maybeConstantIndex < 0 ||
+        *maybeConstantIndex >= allocTensorOp.getType().getRank())
+      return failure();
     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
       return failure();
     rewriter.replaceOp(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index dc8843aa4e1e13..2183839d8f4290 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -45,6 +46,8 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
   if (complex::ConstantOp::isBuildableWith(value, type))
     return builder.create<complex::ConstantOp>(loc, type,
                                                llvm::cast<ArrayAttr>(value));
+  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+    return builder.create<ub::PoisonOp>(loc, type, poison);
   return nullptr;
 }
 
@@ -738,11 +741,11 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   if (!tensorType)
     return {};
 
-  // Out of bound indices produce undefined behavior but are still valid IR.
-  // Don't choke on them.
+  // Fold dim to posion if the index is out of bound. Poison represents
+  // undefined behavior.
   int64_t indexVal = index.getInt();
   if (indexVal < 0 || indexVal >= tensorType.getRank())
-    return {};
+    return ub::PoisonAttr::get(getContext());
 
   // Fold if the shape extent along the given index is known.
   if (!tensorType.isDynamicDim(index.getInt())) {
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..72ff1ca08c1faa 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -13,10 +13,9 @@ func.func @dim_out_of_bounds(%m : memref<7x8xf32>) -> index {
 // -----
 
 // CHECK-LABEL: func @dim_out_of_bounds_2(
-//  CHECK-NEXT:   arith.constant
+//  CHECK-NEXT:   ub.poison
 //  CHECK-NEXT:   arith.constant
 //  CHECK-NEXT:   bufferization.alloc_tensor
-//  CHECK-NEXT:   tensor.dim
 //  CHECK-NEXT:   return
 func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
   %idx = arith.constant 7 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e5374f031be553..2914c6873f5e47 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2367,3 +2367,25 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
     %dim = tensor.dim %reshape, %0 : tensor<*xf32>
     return %dim : index
   }
+
+// -----
+
+// Test case: tensor.dim(%tensor, %out_of_bound_index) is folded into ub.poison
+// CHECK-LABEL: func @dim_out_of_bounds(
+//       CHECK: ub.poison
+//  CHECK-NEXT: bufferization.alloc_tensor
+//  CHECK-NEXT: memref.alloc
+//  CHECK-NEXT: memref.cast
+//  CHECK-NEXT: affine.vector_load
+//  CHECK-NEXT: return
+func.func @dim_out_of_bounds() -> vector<7xi32> {
+    %c1 = arith.constant 1 : index
+    %idx28 = index.constant 28
+    %c29 = arith.constant 29 : index
+    %3 = bufferization.alloc_tensor(%c29) : tensor<?xi16>
+    %dim = tensor.dim %3, %idx28 : tensor<?xi16>
+    %alloc_21 = memref.alloc(%c29) : memref<?x26x2xi32>
+    %16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
+    return %16 : vector<7xi32>
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/85901


More information about the Mlir-commits mailing list