[Mlir-commits] [mlir] 64bd5bb - [mlir] Avoid tensor canonicalizer crash on negative dimensions

Alex Zinenko llvmlistbot at llvm.org
Wed May 31 05:14:37 PDT 2023


Author: rikhuijzer
Date: 2023-05-31T14:14:30+02:00
New Revision: 64bd5bbb9bbb72de5f59755c74dae4b4881d93d5

URL: https://github.com/llvm/llvm-project/commit/64bd5bbb9bbb72de5f59755c74dae4b4881d93d5
DIFF: https://github.com/llvm/llvm-project/commit/64bd5bbb9bbb72de5f59755c74dae4b4881d93d5.diff

LOG: [mlir] Avoid tensor canonicalizer crash on negative dimensions

Fixes #59703.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D151611

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1adb9c7f262fe..283f1be6aa793 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1111,14 +1111,43 @@ LogicalResult GenerateOp::reifyResultShapes(
   return success();
 }
 
+/// Extract operands and shape from a tensor with dynamic extents.
+static void operandsAndShape(TensorType resultType,
+                             Operation::operand_range dynamicExtents,
+                             SmallVectorImpl<Value> &newOperands,
+                             SmallVectorImpl<int64_t> &newShape) {
+  auto operandsIt = dynamicExtents.begin();
+  for (int64_t dim : resultType.getShape()) {
+    if (!ShapedType::isDynamic(dim)) {
+      newShape.push_back(dim);
+      continue;
+    }
+    APInt index;
+    if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
+      newShape.push_back(ShapedType::kDynamic);
+      newOperands.push_back(*operandsIt++);
+      continue;
+    }
+    newShape.push_back(index.getSExtValue());
+    operandsIt++;
+  }
+}
+
 LogicalResult GenerateOp::verify() {
   // Ensure that the tensor type has as many dynamic dimensions as are
   // specified by the operands.
-  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
-  if (getNumOperands() != resultTy.getNumDynamicDims())
+  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
+  if (getNumOperands() != resultType.getNumDynamicDims())
     return emitError("must have as many index operands as dynamic extents "
                      "in the result type");
-
+  // Ensure operands are non-negative.
+  SmallVector<Value> newOperands;
+  SmallVector<int64_t> newShape;
+  operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape);
+  for (int64_t newdim : newShape) {
+    if (newdim < 0 && !ShapedType::isDynamic(newdim))
+      return emitError("tensor dimensions must be non-negative");
+  }
   return success();
 }
 
@@ -1176,24 +1205,11 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     if (resultType.hasStaticShape())
       return failure();
 
-    SmallVector<Value, 4> newOperands;
-    SmallVector<int64_t, 4> newShape;
-    auto operandsIt = tensorFromElements.getDynamicExtents().begin();
-
-    for (int64_t dim : resultType.getShape()) {
-      if (!ShapedType::isDynamic(dim)) {
-        newShape.push_back(dim);
-        continue;
-      }
-      APInt index;
-      if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
-        newShape.push_back(ShapedType::kDynamic);
-        newOperands.push_back(*operandsIt++);
-        continue;
-      }
-      newShape.push_back(index.getSExtValue());
-      operandsIt++;
-    }
+    Operation::operand_range dynamicExtents =
+        tensorFromElements.getDynamicExtents();
+    SmallVector<Value> newOperands;
+    SmallVector<int64_t> newShape;
+    operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 61f03f19de33b..389e7e675c0ee 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -112,6 +112,20 @@ func.func @tensor.generate(%m : index, %n : index)
   } : tensor<?x3x?xf32>
   return %tnsr : tensor<?x3x?xf32>
 }
+
+// -----
+
+func.func @generate_negative_size() -> tensor<?x8xi32> {
+  %cst = arith.constant 0 : i32
+  %size = index.constant -128
+  // expected-error at +1 {{tensor dimensions must be non-negative}}
+  %tensor = tensor.generate %size {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : i32
+  } : tensor<?x8xi32>
+  return %tensor : tensor<?x8xi32>
+}
+
 // -----
 
 func.func @tensor.reshape_element_type_mismatch(


        


More information about the Mlir-commits mailing list