[Mlir-commits] [mlir] [mlir][vector]advance support extract insert under dynamic case. (PR #121631)
lonely eagle
llvmlistbot at llvm.org
Sat Jan 4 00:40:44 PST 2025
https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/121631
Advance support for `vector.extract` and `vector.insertOp` under `dynamic Ops`.
You can see the tests for specific changes, the duplicate code should be written as a function, but I don't know where to write it without calling it good. Feel free to give me suggestions, thank you.
>From fa488d57dba9917f80a431d65e7e24a9b2b4bde5 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 3 Jan 2025 21:20:54 +0800
Subject: [PATCH 1/2] support extract extract.
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 23 +++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 40 +++++++++++++++++++
2 files changed, 63 insertions(+)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9657f583c375bb..4a96c689c0e403 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,6 +1096,29 @@ class VectorExtractOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
+ for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+ if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+ auto defOp = position.getDefiningOp();
+ while (true) {
+ if (!defOp) {
+ break;
+ }
+ if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+ Attribute value =
+ defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+ positionVec[idx] = OpFoldResult{
+ rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+ break;
+ } else if (auto unrealizedCastOp =
+ llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+ defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+ } else {
+ break;
+ }
+ }
+ }
+ }
+
// The Vector -> LLVM lowering models N-D vectors as nested aggregates of
// 1-d vectors. This nesting is modeled using arrays. We do this conversion
// from a N-d vector extract to a nested aggregate vector extract in two
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f95e943250bd44..51a55643bafa87 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4094,3 +4094,43 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+ return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.extractelement %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: return %[[VAL_5]] : i32
+
+// -----
+
+// CHECK-LABEL: @extract_llvm_constnt()
+
+module {
+ func.func @extract_llvm_constnt() -> i32 {
+ %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+ %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+ %2 = llvm.mlir.constant(0 : index) : i64
+ %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+ %4 = vector.extract %1[%3, %3] : i32 from vector<32x1xi32>
+ return %4 : i32
+ }
+}
+
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+// CHECK: return %[[VAL_4]] : i32
>From 567412bc28754ba90d5363b8c2038b9c5589c499 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 4 Jan 2025 16:35:27 +0800
Subject: [PATCH 2/2] support insert.
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 33 +++++++++---
.../VectorToLLVM/vector-to-llvm.mlir | 50 +++++++++++++++++++
2 files changed, 76 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4a96c689c0e403..4af03126fa1edd 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1099,10 +1099,7 @@ class VectorExtractOpConversion
for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
auto defOp = position.getDefiningOp();
- while (true) {
- if (!defOp) {
- break;
- }
+ while (defOp) {
if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
Attribute value =
defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
@@ -1254,6 +1251,25 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
+ for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+ if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+ auto defOp = position.getDefiningOp();
+ while (defOp) {
+ if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+ Attribute value =
+ defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+ positionVec[idx] = OpFoldResult{
+ rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+ break;
+ } else if (auto unrealizedCastOp =
+ llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+ defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+ } else {
+ break;
+ }
+ }
+ }
+ }
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
@@ -1265,8 +1281,9 @@ class VectorInsertOpConversion
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
@@ -1278,8 +1295,9 @@ class VectorInsertOpConversion
Value extracted = adaptor.getDest();
auto oneDVectorType = destVectorType;
if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -1293,8 +1311,9 @@ class VectorInsertOpConversion
// Potential insertion of resulting 1-D vector into array.
if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), inserted,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 51a55643bafa87..d16d78556da106 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4134,3 +4134,53 @@ module {
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
// CHECK: return %[[VAL_4]] : i32
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : i32
+ %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+ return %v_1 : vector<32x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_6:.*]] = llvm.insertelement %[[VAL_3]], %[[VAL_4]]{{\[}}%[[VAL_5]] : i64] : vector<1xi32>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK: return %[[VAL_8]] : vector<32x1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_llvm_constnt()
+
+module {
+ func.func @insert_llvm_constnt() -> vector<32x1xi32> {
+ %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+ %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+ %2 = llvm.mlir.constant(0 : index) : i64
+ %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+ %4 = llvm.mlir.constant(1 : i32) : i32
+ %5 = vector.insert %4, %1 [%3, %3] : i32 into vector<32x1xi32>
+ return %5 : vector<32x1xi32>
+ }
+}
+
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK: return %[[VAL_7]] : vector<32x1xi32>
+// CHECK: }
More information about the Mlir-commits
mailing list