[Mlir-commits] [mlir] 0c5e4a2 - [mlir] Prevent segfault in Tensor canonicalization

Tres Popp llvmlistbot at llvm.org
Fri Jan 29 01:58:13 PST 2021


Author: Tres Popp
Date: 2021-01-29T10:57:58+01:00
New Revision: 0c5e4a25ee232afd0ab21294dfe9ce290957aab6

URL: https://github.com/llvm/llvm-project/commit/0c5e4a25ee232afd0ab21294dfe9ce290957aab6
DIFF: https://github.com/llvm/llvm-project/commit/0c5e4a25ee232afd0ab21294dfe9ce290957aab6.diff

LOG: [mlir] Prevent segfault in Tensor canonicalization

This segfault could occur from out of bounds accesses when simplifying
tensor.extract with a constant index and a tensor created by
tensor.from_elements.

This IR is not necesarilly invalid as it might conditionally be
never executed.

Differential Revision: https://reviews.llvm.org/D95535

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 92115d51476e..9a42224b1158 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -248,6 +248,11 @@ struct ExtractElementFromTensorFromElements
     APInt index;
     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
       return failure();
+    // Prevent out of bounds accesses. This can happen in invalid code that will
+    // never execute.
+    if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
+        index.getSExtValue() < 0)
+      return failure();
     rewriter.replaceOp(extract,
                        tensorFromElements.getOperand(index.getZExtValue()));
     return success();

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ae145934ef4d..f975a51280e5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -122,6 +122,51 @@ func @extract_from_tensor.from_elements(%element : index) -> index {
 
 // -----
 
+// Ensure the optimization doesn't segfault from bad constants
+// CHECK-LABEL: func @extract_negative_from_tensor.from_elements
+func @extract_negative_from_tensor.from_elements(%element : index) -> index {
+  // CHECK-SAME: ([[ARG:%.*]]: index)
+  %c-1 = constant -1 : index
+  %tensor = tensor.from_elements %element : tensor<1xindex>
+  %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex>
+  // CHECK: tensor.from_elements
+  // CHECK: %[[RESULT:.*]] = tensor.extract
+  // CHECK: return %[[RESULT]]
+  return %extracted_element : index
+}
+
+// -----
+
+// Ensure the optimization doesn't segfault from bad constants
+// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
+func @extract_oob_from_tensor.from_elements(%element : index) -> index {
+  // CHECK-SAME: ([[ARG:%.*]]: index)
+  %c1 = constant 1 : index
+  %tensor = tensor.from_elements %element : tensor<1xindex>
+  %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex>
+  // CHECK: tensor.from_elements
+  // CHECK: %[[RESULT:.*]] = tensor.extract
+  // CHECK: return %[[RESULT]]
+  return %extracted_element : index
+}
+
+// -----
+
+// Ensure the optimization doesn't segfault from bad constants
+// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
+func @extract_oob_from_tensor.from_elements(%element : index) -> index {
+  // CHECK-SAME: ([[ARG:%.*]]: index)
+  %c2 = constant 2 : index
+  %tensor = tensor.from_elements %element : tensor<1xindex>
+  %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex>
+  // CHECK: tensor.from_elements
+  // CHECK: %[[RESULT:.*]] = tensor.extract
+  // CHECK: return %[[RESULT]]
+  return %extracted_element : index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_from_tensor.generate
 // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
 func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {


        


More information about the Mlir-commits mailing list