[Mlir-commits] [mlir] [mlir][vector]advance support extract insert under dynamic case. (PR #121631)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 4 00:41:17 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
Advance support for `vector.extract` and `vector.insertOp` under `dynamic Ops`.
You can see the tests for specific changes, the duplicate code should be written as a function, but I don't know where to write it without calling it good. Feel free to give me suggestions, thank you.
---
Full diff: https://github.com/llvm/llvm-project/pull/121631.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+45-3)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+90)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9657f583c375bb..4af03126fa1edd 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,6 +1096,26 @@ class VectorExtractOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
+ for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+ if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+ auto defOp = position.getDefiningOp();
+ while (defOp) {
+ if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+ Attribute value =
+ defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+ positionVec[idx] = OpFoldResult{
+ rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+ break;
+ } else if (auto unrealizedCastOp =
+ llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+ defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+ } else {
+ break;
+ }
+ }
+ }
+ }
+
// The Vector -> LLVM lowering models N-D vectors as nested aggregates of
// 1-d vectors. This nesting is modeled using arrays. We do this conversion
// from a N-d vector extract to a nested aggregate vector extract in two
@@ -1231,6 +1251,25 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
+ for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+ if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+ auto defOp = position.getDefiningOp();
+ while (defOp) {
+ if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+ Attribute value =
+ defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+ positionVec[idx] = OpFoldResult{
+ rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+ break;
+ } else if (auto unrealizedCastOp =
+ llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+ defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+ } else {
+ break;
+ }
+ }
+ }
+ }
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
@@ -1242,8 +1281,9 @@ class VectorInsertOpConversion
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
@@ -1255,8 +1295,9 @@ class VectorInsertOpConversion
Value extracted = adaptor.getDest();
auto oneDVectorType = destVectorType;
if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -1270,8 +1311,9 @@ class VectorInsertOpConversion
// Potential insertion of resulting 1-D vector into array.
if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), inserted,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f95e943250bd44..d16d78556da106 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4094,3 +4094,93 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+ return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.extractelement %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: return %[[VAL_5]] : i32
+
+// -----
+
+// CHECK-LABEL: @extract_llvm_constnt()
+
+module {
+ func.func @extract_llvm_constnt() -> i32 {
+ %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+ %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+ %2 = llvm.mlir.constant(0 : index) : i64
+ %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+ %4 = vector.extract %1[%3, %3] : i32 from vector<32x1xi32>
+ return %4 : i32
+ }
+}
+
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+// CHECK: return %[[VAL_4]] : i32
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : i32
+ %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+ return %v_1 : vector<32x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_6:.*]] = llvm.insertelement %[[VAL_3]], %[[VAL_4]]{{\[}}%[[VAL_5]] : i64] : vector<1xi32>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK: return %[[VAL_8]] : vector<32x1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_llvm_constnt()
+
+module {
+ func.func @insert_llvm_constnt() -> vector<32x1xi32> {
+ %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+ %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+ %2 = llvm.mlir.constant(0 : index) : i64
+ %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+ %4 = llvm.mlir.constant(1 : i32) : i32
+ %5 = vector.insert %4, %1 [%3, %3] : i32 into vector<32x1xi32>
+ return %5 : vector<32x1xi32>
+ }
+}
+
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK: return %[[VAL_7]] : vector<32x1xi32>
+// CHECK: }
``````````
</details>
https://github.com/llvm/llvm-project/pull/121631
More information about the Mlir-commits
mailing list