[Mlir-commits] [mlir] [mlir][vector]advance support extract insert under dynamic case. (PR #121631)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 4 00:41:17 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

Advance support for `vector.extract` and `vector.insertOp` under `dynamic Ops`.
You  can see  the tests for specific changes, the duplicate code should be written as a function, but I don't know where to write it without calling it good. Feel free to give me suggestions, thank you.

---
Full diff: https://github.com/llvm/llvm-project/pull/121631.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+45-3) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+90) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9657f583c375bb..4af03126fa1edd 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,6 +1096,26 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
+    for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+      if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+        auto defOp = position.getDefiningOp();
+        while (defOp) {
+          if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+            Attribute value =
+                defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+            positionVec[idx] = OpFoldResult{
+                rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+            break;
+          } else if (auto unrealizedCastOp =
+                         llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+            defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+          } else {
+            break;
+          }
+        }
+      }
+    }
+
     // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
     // 1-d vectors. This nesting is modeled using arrays. We do this conversion
     // from a N-d vector extract to a nested aggregate vector extract in two
@@ -1231,6 +1251,25 @@ class VectorInsertOpConversion
 
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
+    for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+      if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+        auto defOp = position.getDefiningOp();
+        while (defOp) {
+          if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+            Attribute value =
+                defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+            positionVec[idx] = OpFoldResult{
+                rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+            break;
+          } else if (auto unrealizedCastOp =
+                         llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+            defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+          } else {
+            break;
+          }
+        }
+      }
+    }
 
     // Overwrite entire vector with value. Should be handled by folder, but
     // just to be safe.
@@ -1242,8 +1281,9 @@ class VectorInsertOpConversion
 
     // One-shot insertion of a vector into an array (only requires insertvalue).
     if (isa<VectorType>(sourceType)) {
-      if (insertOp.hasDynamicPosition())
+      if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
         return failure();
+      }
 
       Value inserted = rewriter.create<LLVM::InsertValueOp>(
           loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
@@ -1255,8 +1295,9 @@ class VectorInsertOpConversion
     Value extracted = adaptor.getDest();
     auto oneDVectorType = destVectorType;
     if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
+      if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
         return failure();
+      }
 
       oneDVectorType = reducedVectorTypeBack(destVectorType);
       extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -1270,8 +1311,9 @@ class VectorInsertOpConversion
 
     // Potential insertion of resulting 1-D vector into array.
     if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
+      if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
         return failure();
+      }
 
       inserted = rewriter.create<LLVM::InsertValueOp>(
           loc, adaptor.getDest(), inserted,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f95e943250bd44..d16d78556da106 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4094,3 +4094,93 @@ func.func @step_scalable() -> vector<[4]xindex> {
   %0 = vector.step : vector<[4]xindex>
   return %0 : vector<[4]xindex>
 }
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+  %v = arith.constant dense<0> : vector<32x1xi32>
+  %c_0 = arith.constant 0 : index
+  %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+  return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.extractelement %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: return %[[VAL_5]] : i32
+
+// -----
+
+// CHECK-LABEL: @extract_llvm_constnt()
+
+module {
+  func.func @extract_llvm_constnt() -> i32 {
+    %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+    %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+    %2 = llvm.mlir.constant(0 : index) : i64
+    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+    %4 = vector.extract %1[%3, %3] : i32 from vector<32x1xi32>
+    return %4 : i32
+  }
+}
+
+// CHECK:      %[[VAL_0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK:           %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:           %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+// CHECK:           return %[[VAL_4]] : i32
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+  %v = arith.constant dense<0> : vector<32x1xi32>
+  %c_0 = arith.constant 0 : index
+  %c_1 = arith.constant 1 : i32
+  %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+  return %v_1 : vector<32x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_6:.*]] = llvm.insertelement %[[VAL_3]], %[[VAL_4]]{{\[}}%[[VAL_5]] : i64] : vector<1xi32>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK: return %[[VAL_8]] : vector<32x1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_llvm_constnt()
+
+module {
+  func.func @insert_llvm_constnt() -> vector<32x1xi32> {
+    %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+    %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+    %2 = llvm.mlir.constant(0 : index) : i64
+    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+    %4 = llvm.mlir.constant(1 : i32) : i32
+    %5 = vector.insert %4, %1 [%3, %3] : i32 into vector<32x1xi32>
+    return %5 : vector<32x1xi32>
+  }
+}
+
+// CHECK:           %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK:           %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:           %[[VAL_5:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK:           %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK:           %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK:           return %[[VAL_7]] : vector<32x1xi32>
+// CHECK:         }

``````````

</details>


https://github.com/llvm/llvm-project/pull/121631


More information about the Mlir-commits mailing list