[Mlir-commits] [mlir] [mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors (PR #117731)

Kunwar Grover llvmlistbot at llvm.org
Tue Nov 26 08:04:30 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/117731

>From 76c9b90366797748a473fef437648a44523ff147 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 15:55:54 +0000
Subject: [PATCH 1/2] [mlir][Vector] Fix vector.extract lowering to llvm for
 0-d vectors

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 69 ++++++++++---------
 .../VectorToLLVM/vector-to-llvm.mlir          | 53 ++++++++++++--
 2 files changed, 84 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..3f47b20cdb577b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Extract entire vector. Should be handled by folder, but just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(extractOp, adaptor.getVector());
-      return success();
-    }
-
-    // One-shot extraction of vector from array (only requires extractvalue).
-    // Except for extracting 1-element vectors.
-    if (isa<VectorType>(resultType) &&
-        position.size() !=
-            static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
-      if (extractOp.hasDynamicPosition())
-        return failure();
-
-      Value extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, adaptor.getVector(), getAsIntegers(position));
-      rewriter.replaceOp(extractOp, extracted);
-      return success();
-    }
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank i.e. the number of indices is equal
+    // to source vector rank.
+    bool isScalarExtract =
+        positionVec.size() == extractOp.getSourceVectorType().getRank();
+    // Determine if we need to extract a slice out of the original vector. We
+    // always need to extract a slice if the input rank >= 2.
+    bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
 
-    // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getVector();
-    if (position.size() > 1) {
-      if (extractOp.hasDynamicPosition())
+    if (isSlicingExtract) {
+      ArrayRef<OpFoldResult> position(positionVec);
+      if (isScalarExtract) {
+        // If we are extracting a scalar from the returned slice, we need to
+        // extract a N-1 D slice.
+        position = position.drop_back();
+      }
+      // llvm.extractvalue does not support dynamic dimensions.
+      if (!llvm::all_of(position,
+                        [](OpFoldResult x) { return isa<Attribute>(x); })) {
         return failure();
+      }
+      extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, extracted, getAsIntegers(position));
+    }
 
-      SmallVector<int64_t> nMinusOnePosition =
-          getAsIntegers(position.drop_back());
-      extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
-                                                        nMinusOnePosition);
+    if (isScalarExtract) {
+      Value position;
+      if (positionVec.empty()) {
+        // A scalar extract with no position is a 0-D vector extract. The LLVM
+        // type converter converts 0-D vectors to 1-D vectors, so we need to add
+        // a constant position.
+        auto idxType = rewriter.getIndexType();
+        position = rewriter.create<LLVM::ConstantOp>(
+            loc, typeConverter->convertType(idxType),
+            rewriter.getIntegerAttr(idxType, 0));
+      } else {
+        position = getAsLLVMValue(rewriter, loc, positionVec.back());
+      }
+      extracted =
+          rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
     }
 
-    Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
-    // Remaining extraction of element from 1-D LLVM vector.
-    rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
-                                                        lastPosition);
+    rewriter.replaceOp(extractOp, extracted);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index da0222bc942376..cd687becb82b82 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1258,26 +1258,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
 
 // -----
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
-//       CHECK:   vector.extract
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
+
+// -----
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
 //       CHECK:   vector.extract
 
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// shouldn't crash.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
+//       CHECK:   vector.extract
+
+// -----
+
+func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
+  %0 = vector.extract %arg0[]: index from vector<index>
+  return %0 : index
+}
+// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
+//  CHECK-SAME:   %[[A:.*]]: vector<index>)
+//       CHECK:   %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
+//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK:   %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
+//       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+//       CHECK:   return %[[T3]] : index
+
 // -----
 
 func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {

>From c513cb17f31dfe434e8ff333af0306f0d5a3aa31 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 16:04:13 +0000
Subject: [PATCH 2/2] more docs

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp         | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f47b20cdb577b..6f83c5e1bd761d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,14 +1096,20 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Determine if we need to extract a scalar as the result. We extract
-    // a scalar if the extract is full rank i.e. the number of indices is equal
-    // to source vector rank.
-    bool isScalarExtract =
-        positionVec.size() == extractOp.getSourceVectorType().getRank();
+    // The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
+    // The stacking is modeled using arrays. We do this conversion from a
+    // N-d vector extract to stacked 1-d vector extract in two steps:
+    //  - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
+    //  - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
+
     // Determine if we need to extract a slice out of the original vector. We
     // always need to extract a slice if the input rank >= 2.
     bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank i.e. the number of indices is equal
+    // to source vector rank.
+    bool isScalarExtract = static_cast<int64_t>(positionVec.size()) ==
+                           extractOp.getSourceVectorType().getRank();
 
     Value extracted = adaptor.getVector();
     if (isSlicingExtract) {



More information about the Mlir-commits mailing list