[Mlir-commits] [mlir] [mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors (PR #117731)

Kunwar Grover llvmlistbot at llvm.org
Tue Nov 26 07:59:57 PST 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/117731

The current implementation of lowering to llvm for vector.extract incorrectly assumes that if the number of indices is zero, the operation can be folded away. This PR removes this condition and relies on the folder to do it instead.

This PR also unifies the logic for scalar extracts and slice extracts, which as a side effect also enables vector.extract lowering for n-d vector.extract with dynamic inner most dimension. (This was only prevented by a conservative check in the old implementation)

>From 76c9b90366797748a473fef437648a44523ff147 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 15:55:54 +0000
Subject: [PATCH] [mlir][Vector] Fix vector.extract lowering to llvm for 0-d
 vectors

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 69 ++++++++++---------
 .../VectorToLLVM/vector-to-llvm.mlir          | 53 ++++++++++++--
 2 files changed, 84 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..3f47b20cdb577b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Extract entire vector. Should be handled by folder, but just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(extractOp, adaptor.getVector());
-      return success();
-    }
-
-    // One-shot extraction of vector from array (only requires extractvalue).
-    // Except for extracting 1-element vectors.
-    if (isa<VectorType>(resultType) &&
-        position.size() !=
-            static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
-      if (extractOp.hasDynamicPosition())
-        return failure();
-
-      Value extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, adaptor.getVector(), getAsIntegers(position));
-      rewriter.replaceOp(extractOp, extracted);
-      return success();
-    }
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank i.e. the number of indices is equal
+    // to source vector rank.
+    bool isScalarExtract =
+        positionVec.size() == extractOp.getSourceVectorType().getRank();
+    // Determine if we need to extract a slice out of the original vector. We
+    // always need to extract a slice if the input rank >= 2.
+    bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
 
-    // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getVector();
-    if (position.size() > 1) {
-      if (extractOp.hasDynamicPosition())
+    if (isSlicingExtract) {
+      ArrayRef<OpFoldResult> position(positionVec);
+      if (isScalarExtract) {
+        // If we are extracting a scalar from the returned slice, we need to
+        // extract a N-1 D slice.
+        position = position.drop_back();
+      }
+      // llvm.extractvalue does not support dynamic dimensions.
+      if (!llvm::all_of(position,
+                        [](OpFoldResult x) { return isa<Attribute>(x); })) {
         return failure();
+      }
+      extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, extracted, getAsIntegers(position));
+    }
 
-      SmallVector<int64_t> nMinusOnePosition =
-          getAsIntegers(position.drop_back());
-      extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
-                                                        nMinusOnePosition);
+    if (isScalarExtract) {
+      Value position;
+      if (positionVec.empty()) {
+        // A scalar extract with no position is a 0-D vector extract. The LLVM
+        // type converter converts 0-D vectors to 1-D vectors, so we need to add
+        // a constant position.
+        auto idxType = rewriter.getIndexType();
+        position = rewriter.create<LLVM::ConstantOp>(
+            loc, typeConverter->convertType(idxType),
+            rewriter.getIntegerAttr(idxType, 0));
+      } else {
+        position = getAsLLVMValue(rewriter, loc, positionVec.back());
+      }
+      extracted =
+          rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
     }
 
-    Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
-    // Remaining extraction of element from 1-D LLVM vector.
-    rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
-                                                        lastPosition);
+    rewriter.replaceOp(extractOp, extracted);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index da0222bc942376..cd687becb82b82 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1258,26 +1258,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
 
 // -----
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
-//       CHECK:   vector.extract
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
+
+// -----
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
 //       CHECK:   vector.extract
 
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// shouldn't crash.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
+//       CHECK:   vector.extract
+
+// -----
+
+func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
+  %0 = vector.extract %arg0[]: index from vector<index>
+  return %0 : index
+}
+// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
+//  CHECK-SAME:   %[[A:.*]]: vector<index>)
+//       CHECK:   %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
+//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK:   %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
+//       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+//       CHECK:   return %[[T3]] : index
+
 // -----
 
 func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {



More information about the Mlir-commits mailing list