[Mlir-commits] [mlir] [mlir] add tensor_static.extract/insert to take only static indices. (PR #110550)

donald chen llvmlistbot at llvm.org
Wed Oct 2 07:22:41 PDT 2024


================
@@ -39,6 +39,59 @@ using llvm::divideCeilSigned;
 using llvm::divideFloorSigned;
 using llvm::mod;
 
+namespace {
+template <typename ExtractOpTy>
+OpFoldResult foldExtractFromElementsHelper(ExtractOpTy op,
+                                           FromElementsOp fromElementsOp,
+                                           ArrayRef<uint64_t> indices) {
+  // Fold extract(from_elements(...)).
+  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
+  auto rank = tensorType.getRank();
+  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
+         "rank mismatch");
+  int flatIndex = 0;
+  int stride = 1;
+  for (int i = rank - 1; i >= 0; --i) {
+    flatIndex += indices[i] * stride;
+    stride *= tensorType.getDimSize(i);
+  }
+  // Prevent out of bounds accesses. This can happen in invalid code that
+  // will never execute.
+  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
+      flatIndex < 0)
+    return {};
+  return fromElementsOp.getElements()[flatIndex];
+}
+
+LogicalResult verifyStaticIndicesInBound(RankedTensorType type,
+                                         ArrayRef<int64_t> indices) {
----------------
cxy-1993 wrote:

This function should be marked as a static function.

https://github.com/llvm/llvm-project/pull/110550


More information about the Mlir-commits mailing list