[Mlir-commits] [mlir] 0c5e4a2 - [mlir] Prevent segfault in Tensor canonicalization
Tres Popp
llvmlistbot at llvm.org
Fri Jan 29 01:58:13 PST 2021
Author: Tres Popp
Date: 2021-01-29T10:57:58+01:00
New Revision: 0c5e4a25ee232afd0ab21294dfe9ce290957aab6
URL: https://github.com/llvm/llvm-project/commit/0c5e4a25ee232afd0ab21294dfe9ce290957aab6
DIFF: https://github.com/llvm/llvm-project/commit/0c5e4a25ee232afd0ab21294dfe9ce290957aab6.diff
LOG: [mlir] Prevent segfault in Tensor canonicalization
This segfault could occur from out of bounds accesses when simplifying
tensor.extract with a constant index and a tensor created by
tensor.from_elements.
This IR is not necesarilly invalid as it might conditionally be
never executed.
Differential Revision: https://reviews.llvm.org/D95535
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 92115d51476e..9a42224b1158 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -248,6 +248,11 @@ struct ExtractElementFromTensorFromElements
APInt index;
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
return failure();
+ // Prevent out of bounds accesses. This can happen in invalid code that will
+ // never execute.
+ if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
+ index.getSExtValue() < 0)
+ return failure();
rewriter.replaceOp(extract,
tensorFromElements.getOperand(index.getZExtValue()));
return success();
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ae145934ef4d..f975a51280e5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -122,6 +122,51 @@ func @extract_from_tensor.from_elements(%element : index) -> index {
// -----
+// Ensure the optimization doesn't segfault from bad constants
+// CHECK-LABEL: func @extract_negative_from_tensor.from_elements
+func @extract_negative_from_tensor.from_elements(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %c-1 = constant -1 : index
+ %tensor = tensor.from_elements %element : tensor<1xindex>
+ %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex>
+ // CHECK: tensor.from_elements
+ // CHECK: %[[RESULT:.*]] = tensor.extract
+ // CHECK: return %[[RESULT]]
+ return %extracted_element : index
+}
+
+// -----
+
+// Ensure the optimization doesn't segfault from bad constants
+// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
+func @extract_oob_from_tensor.from_elements(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %c1 = constant 1 : index
+ %tensor = tensor.from_elements %element : tensor<1xindex>
+ %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex>
+ // CHECK: tensor.from_elements
+ // CHECK: %[[RESULT:.*]] = tensor.extract
+ // CHECK: return %[[RESULT]]
+ return %extracted_element : index
+}
+
+// -----
+
+// Ensure the optimization doesn't segfault from bad constants
+// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
+func @extract_oob_from_tensor.from_elements(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %c2 = constant 2 : index
+ %tensor = tensor.from_elements %element : tensor<1xindex>
+ %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex>
+ // CHECK: tensor.from_elements
+ // CHECK: %[[RESULT:.*]] = tensor.extract
+ // CHECK: return %[[RESULT]]
+ return %extracted_element : index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.generate
// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
More information about the Mlir-commits
mailing list