[Mlir-commits] [mlir] [mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors (PR #117731)
Kunwar Grover
llvmlistbot at llvm.org
Fri Nov 29 07:33:40 PST 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/117731
>From dc1a9a5091bac79cfc6ffa45bf8def7987007b30 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 15:55:54 +0000
Subject: [PATCH 1/4] [mlir][Vector] Fix vector.extract lowering to llvm for
0-d vectors
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 69 ++++++++++---------
.../VectorToLLVM/vector-to-llvm.mlir | 53 ++++++++++++--
2 files changed, 84 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..3f47b20cdb577b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Extract entire vector. Should be handled by folder, but just to be safe.
- ArrayRef<OpFoldResult> position(positionVec);
- if (position.empty()) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- // One-shot extraction of vector from array (only requires extractvalue).
- // Except for extracting 1-element vectors.
- if (isa<VectorType>(resultType) &&
- position.size() !=
- static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
- if (extractOp.hasDynamicPosition())
- return failure();
-
- Value extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getVector(), getAsIntegers(position));
- rewriter.replaceOp(extractOp, extracted);
- return success();
- }
+ // Determine if we need to extract a scalar as the result. We extract
+ // a scalar if the extract is full rank i.e. the number of indices is equal
+ // to source vector rank.
+ bool isScalarExtract =
+ positionVec.size() == extractOp.getSourceVectorType().getRank();
+ // Determine if we need to extract a slice out of the original vector. We
+ // always need to extract a slice if the input rank >= 2.
+ bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
- // Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
- if (position.size() > 1) {
- if (extractOp.hasDynamicPosition())
+ if (isSlicingExtract) {
+ ArrayRef<OpFoldResult> position(positionVec);
+ if (isScalarExtract) {
+ // If we are extracting a scalar from the returned slice, we need to
+ // extract a N-1 D slice.
+ position = position.drop_back();
+ }
+ // llvm.extractvalue does not support dynamic dimensions.
+ if (!llvm::all_of(position,
+ [](OpFoldResult x) { return isa<Attribute>(x); })) {
return failure();
+ }
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, extracted, getAsIntegers(position));
+ }
- SmallVector<int64_t> nMinusOnePosition =
- getAsIntegers(position.drop_back());
- extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
- nMinusOnePosition);
+ if (isScalarExtract) {
+ Value position;
+ if (positionVec.empty()) {
+ // A scalar extract with no position is a 0-D vector extract. The LLVM
+ // type converter converts 0-D vectors to 1-D vectors, so we need to add
+ // a constant position.
+ auto idxType = rewriter.getIndexType();
+ position = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
+ } else {
+ position = getAsLLVMValue(rewriter, loc, positionVec.back());
+ }
+ extracted =
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
}
- Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
- // Remaining extraction of element from 1-D LLVM vector.
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
- lastPosition);
+ rewriter.replaceOp(extractOp, extracted);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1c42538cf85912..58757bfd7951c6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1290,26 +1290,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
// -----
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
return %0 : f32
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
-// CHECK: vector.extract
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
+// CHECK: llvm.extractvalue
+// CHECK: llvm.extractelement
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
return %0 : f32
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
+// CHECK: llvm.extractvalue
+// CHECK: llvm.extractelement
+
+// -----
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+ %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
+ return %0 : f32
+}
+
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
// CHECK: vector.extract
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+ %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
+ return %0 : f32
+}
+
+// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// shouldn't crash.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
+// CHECK: vector.extract
+
+// -----
+
+func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
+ %0 = vector.extract %arg0[]: index from vector<index>
+ return %0 : index
+}
+// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
+// CHECK-SAME: %[[A:.*]]: vector<index>)
+// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
+// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+// CHECK: return %[[T3]] : index
+
// -----
func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {
>From 824be432089b8bae4e91e9f80bf303b14a5b75ab Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 16:04:13 +0000
Subject: [PATCH 2/4] more docs
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f47b20cdb577b..6f83c5e1bd761d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,14 +1096,20 @@ class VectorExtractOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Determine if we need to extract a scalar as the result. We extract
- // a scalar if the extract is full rank i.e. the number of indices is equal
- // to source vector rank.
- bool isScalarExtract =
- positionVec.size() == extractOp.getSourceVectorType().getRank();
+ // The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
+ // The stacking is modeled using arrays. We do this conversion from a
+ // N-d vector extract to stacked 1-d vector extract in two steps:
+ // - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
+ // - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
+
// Determine if we need to extract a slice out of the original vector. We
// always need to extract a slice if the input rank >= 2.
bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
+ // Determine if we need to extract a scalar as the result. We extract
+ // a scalar if the extract is full rank i.e. the number of indices is equal
+ // to source vector rank.
+ bool isScalarExtract = static_cast<int64_t>(positionVec.size()) ==
+ extractOp.getSourceVectorType().getRank();
Value extracted = adaptor.getVector();
if (isSlicingExtract) {
>From f4a3fb41d438722740140ef1acbe07a9b69e4138 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 29 Nov 2024 14:58:16 +0000
Subject: [PATCH 3/4] Address comments
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 70 ++++++++++---------
.../VectorToLLVM/vector-to-llvm.mlir | 6 +-
2 files changed, 39 insertions(+), 37 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6f83c5e1bd761d..6b3f7ca54cf75c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,53 +1096,55 @@ class VectorExtractOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
- // The stacking is modeled using arrays. We do this conversion from a
- // N-d vector extract to stacked 1-d vector extract in two steps:
- // - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
- // - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
-
- // Determine if we need to extract a slice out of the original vector. We
- // always need to extract a slice if the input rank >= 2.
- bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
+ // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
+ // 1-d vectors. This nesting is modeled using arrays. We do this conversion
+ // from a N-d vector extract to a nested aggregate vector extract in two
+ // steps:
+ // - Extract a member from the nested aggregate. The result can be
+ // a lower rank nested aggregate or a vector (1-D). This is done using
+ // `llvm.extractvalue`.
+ // - Extract a scalar out of the vector if needed. This is done using
+ // `llvm.extractelement`.
+
+ // Determine if we need to extract a member out of the aggregate. We
+ // always need to extract a member if the input rank >= 2.
+ bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
// Determine if we need to extract a scalar as the result. We extract
- // a scalar if the extract is full rank i.e. the number of indices is equal
- // to source vector rank.
- bool isScalarExtract = static_cast<int64_t>(positionVec.size()) ==
- extractOp.getSourceVectorType().getRank();
+ // a scalar if the extract is full rank, i.e., the number of indices is
+ // equal to source vector rank.
+ bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
+ extractOp.getSourceVectorType().getRank();
+
+ // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
+ // need to add a position for this change.
+ if (extractOp.getSourceVectorType().getRank() == 0) {
+ auto idxType = rewriter.getIndexType();
+ Value position = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
+ positionVec.push_back(position);
+ }
Value extracted = adaptor.getVector();
- if (isSlicingExtract) {
+ if (extractsScalar) {
ArrayRef<OpFoldResult> position(positionVec);
- if (isScalarExtract) {
- // If we are extracting a scalar from the returned slice, we need to
- // extract a N-1 D slice.
+ if (extractsAggregate) {
+ // If we are extracting a scalar from the extracted member, we drop
+ // the last index, which will be used to extract the scalar out of the
+ // vector.
position = position.drop_back();
}
// llvm.extractvalue does not support dynamic dimensions.
- if (!llvm::all_of(position,
- [](OpFoldResult x) { return isa<Attribute>(x); })) {
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
}
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted, getAsIntegers(position));
}
- if (isScalarExtract) {
- Value position;
- if (positionVec.empty()) {
- // A scalar extract with no position is a 0-D vector extract. The LLVM
- // type converter converts 0-D vectors to 1-D vectors, so we need to add
- // a constant position.
- auto idxType = rewriter.getIndexType();
- position = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- } else {
- position = getAsLLVMValue(rewriter, loc, positionVec.back());
- }
- extracted =
- rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
+ if (extractsScalar) {
+ extracted = rewriter.create<LLVM::ExtractElementOp>(
+ loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
}
rewriter.replaceOp(extractOp, extracted);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 58757bfd7951c6..e3727d6412849a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1295,7 +1295,7 @@ func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf
return %0 : f32
}
-// Multi-dim vectors are supported if the inner most dimension is dynamic.
+// Multi-dim vectors are supported if the innermost index is dynamic.
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
// CHECK: llvm.extractvalue
@@ -1306,7 +1306,7 @@ func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vect
return %0 : f32
}
-// Multi-dim vectors are supported if the inner most dimension is dynamic.
+// Multi-dim vectors are supported if the innermost index is dynamic.
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
// CHECK: llvm.extractvalue
@@ -1329,7 +1329,7 @@ func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vect
return %0 : f32
}
-// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// Multi-dim vectors with outer indices as dynamic are not supported, but it
// shouldn't crash.
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
>From 9ba9d44bf981e00a9a6cfc7608548863ca9311b6 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 29 Nov 2024 15:33:06 +0000
Subject: [PATCH 4/4] Fix bugs
---
.../Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 11 ++++-------
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 2 +-
2 files changed, 5 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6b3f7ca54cf75c..a9a07c323c7358 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1118,17 +1118,14 @@ class VectorExtractOpConversion
// Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
// need to add a position for this change.
if (extractOp.getSourceVectorType().getRank() == 0) {
- auto idxType = rewriter.getIndexType();
- Value position = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- positionVec.push_back(position);
+ Type idxType = typeConverter->convertType(rewriter.getIndexType());
+ positionVec.push_back(rewriter.getZeroAttr(idxType));
}
Value extracted = adaptor.getVector();
- if (extractsScalar) {
+ if (extractsAggregate) {
ArrayRef<OpFoldResult> position(positionVec);
- if (extractsAggregate) {
+ if (extractsScalar) {
// If we are extracting a scalar from the extracted member, we drop
// the last index, which will be used to extract the scalar out of the
// vector.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e3727d6412849a..197f0cb2a568e6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1344,7 +1344,7 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
// CHECK-SAME: %[[A:.*]]: vector<index>)
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
-// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
// CHECK: return %[[T3]] : index
More information about the Mlir-commits
mailing list