[Mlir-commits] [mlir] [mlir][VectorOps] Don't fold extract chains that include dynamic indices (PR #68333)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 5 09:33:29 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This is not yet supported and previously led to a confusing crash where an extract op with a kDynamic marker, but no dynamic positions was created. The verifier has also been updated to check for this, and hint at where the problem is likely to be.
---
Full diff: https://github.com/llvm/llvm-project/pull/68333.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+11)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 027ef3605aeba46..f84a574c4634fc3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1244,6 +1244,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
+ // Note: This check must come before getMixedPosition() to prevent a crash.
+ auto dynamicMarkersCount =
+ llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
+ if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
+ return emitOpError(
+ "mismatch between dynamic and static positions (kDynamic marker but no "
+ "corresponding dynamic position) -- this can only happen due to an "
+ "incorrect/fold rewrite");
auto position = getMixedPosition();
if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
return emitOpError(
@@ -1285,6 +1293,9 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
globalPosition.append(extrPos.rbegin(), extrPos.rend());
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (currentOp.hasDynamicPosition())
+ return failure();
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 05615b96ae6d69f..924886c50030967 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1693,6 +1693,18 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %
// -----
+// CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts
+// CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index)
+// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 {
+ %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// -----
+
// CHECK-LABEL: extract_extract_strided2
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
// CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/68333
More information about the Mlir-commits
mailing list