[Mlir-commits] [mlir] [mlir][VectorOps] Don't fold extract chains that include dynamic indices (PR #68333)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Oct 5 09:32:16 PDT 2023
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/68333
This is not yet supported and previously led to a confusing crash where an extract op with a kDynamic marker, but no dynamic positions was created. The verifier has also been updated to check for this, and hint at where the problem is likely to be.
>From 4c6cd9c3f9c49633ec69651c1f777c8e0e24a913 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Oct 2023 16:23:36 +0000
Subject: [PATCH] [mlir][VectorOps] Don't fold extract chains that include
dynamic indices
This is not yet supported and previously led to a confusing crash where
an extract op with a kDynamic marker, but no dynamic positions was
created. The verifier has also been updated to check for this, and hint
at where the problem is likely to be.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 +++++++++++
mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++++++++
2 files changed, 23 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 027ef3605aeba46..f84a574c4634fc3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1244,6 +1244,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
+ // Note: This check must come before getMixedPosition() to prevent a crash.
+ auto dynamicMarkersCount =
+ llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
+ if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
+ return emitOpError(
+ "mismatch between dynamic and static positions (kDynamic marker but no "
+ "corresponding dynamic position) -- this can only happen due to an "
+ "incorrect/fold rewrite");
auto position = getMixedPosition();
if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
return emitOpError(
@@ -1285,6 +1293,9 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
globalPosition.append(extrPos.rbegin(), extrPos.rend());
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (currentOp.hasDynamicPosition())
+ return failure();
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 05615b96ae6d69f..924886c50030967 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1693,6 +1693,18 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %
// -----
+// CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts
+// CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index)
+// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 {
+ %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// -----
+
// CHECK-LABEL: extract_extract_strided2
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
// CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>
More information about the Mlir-commits
mailing list