[Mlir-commits] [mlir] [mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors (PR #117731)

Kunwar Grover llvmlistbot at llvm.org
Fri Nov 29 07:33:40 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/117731

>From dc1a9a5091bac79cfc6ffa45bf8def7987007b30 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 15:55:54 +0000
Subject: [PATCH 1/4] [mlir][Vector] Fix vector.extract lowering to llvm for
 0-d vectors

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 69 ++++++++++---------
 .../VectorToLLVM/vector-to-llvm.mlir          | 53 ++++++++++++--
 2 files changed, 84 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..3f47b20cdb577b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Extract entire vector. Should be handled by folder, but just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(extractOp, adaptor.getVector());
-      return success();
-    }
-
-    // One-shot extraction of vector from array (only requires extractvalue).
-    // Except for extracting 1-element vectors.
-    if (isa<VectorType>(resultType) &&
-        position.size() !=
-            static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
-      if (extractOp.hasDynamicPosition())
-        return failure();
-
-      Value extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, adaptor.getVector(), getAsIntegers(position));
-      rewriter.replaceOp(extractOp, extracted);
-      return success();
-    }
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank i.e. the number of indices is equal
+    // to source vector rank.
+    bool isScalarExtract =
+        positionVec.size() == extractOp.getSourceVectorType().getRank();
+    // Determine if we need to extract a slice out of the original vector. We
+    // always need to extract a slice if the input rank >= 2.
+    bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
 
-    // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getVector();
-    if (position.size() > 1) {
-      if (extractOp.hasDynamicPosition())
+    if (isSlicingExtract) {
+      ArrayRef<OpFoldResult> position(positionVec);
+      if (isScalarExtract) {
+        // If we are extracting a scalar from the returned slice, we need to
+        // extract a N-1 D slice.
+        position = position.drop_back();
+      }
+      // llvm.extractvalue does not support dynamic dimensions.
+      if (!llvm::all_of(position,
+                        [](OpFoldResult x) { return isa<Attribute>(x); })) {
         return failure();
+      }
+      extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, extracted, getAsIntegers(position));
+    }
 
-      SmallVector<int64_t> nMinusOnePosition =
-          getAsIntegers(position.drop_back());
-      extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
-                                                        nMinusOnePosition);
+    if (isScalarExtract) {
+      Value position;
+      if (positionVec.empty()) {
+        // A scalar extract with no position is a 0-D vector extract. The LLVM
+        // type converter converts 0-D vectors to 1-D vectors, so we need to add
+        // a constant position.
+        auto idxType = rewriter.getIndexType();
+        position = rewriter.create<LLVM::ConstantOp>(
+            loc, typeConverter->convertType(idxType),
+            rewriter.getIntegerAttr(idxType, 0));
+      } else {
+        position = getAsLLVMValue(rewriter, loc, positionVec.back());
+      }
+      extracted =
+          rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
     }
 
-    Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
-    // Remaining extraction of element from 1-D LLVM vector.
-    rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
-                                                        lastPosition);
+    rewriter.replaceOp(extractOp, extracted);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1c42538cf85912..58757bfd7951c6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1290,26 +1290,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
 
 // -----
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
-//       CHECK:   vector.extract
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
+
+// -----
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
 //       CHECK:   vector.extract
 
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// shouldn't crash.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
+//       CHECK:   vector.extract
+
+// -----
+
+func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
+  %0 = vector.extract %arg0[]: index from vector<index>
+  return %0 : index
+}
+// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
+//  CHECK-SAME:   %[[A:.*]]: vector<index>)
+//       CHECK:   %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
+//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK:   %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
+//       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+//       CHECK:   return %[[T3]] : index
+
 // -----
 
 func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {

>From 824be432089b8bae4e91e9f80bf303b14a5b75ab Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 26 Nov 2024 16:04:13 +0000
Subject: [PATCH 2/4] more docs

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp         | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f47b20cdb577b..6f83c5e1bd761d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,14 +1096,20 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Determine if we need to extract a scalar as the result. We extract
-    // a scalar if the extract is full rank i.e. the number of indices is equal
-    // to source vector rank.
-    bool isScalarExtract =
-        positionVec.size() == extractOp.getSourceVectorType().getRank();
+    // The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
+    // The stacking is modeled using arrays. We do this conversion from a
+    // N-d vector extract to stacked 1-d vector extract in two steps:
+    //  - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
+    //  - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
+
     // Determine if we need to extract a slice out of the original vector. We
     // always need to extract a slice if the input rank >= 2.
     bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank i.e. the number of indices is equal
+    // to source vector rank.
+    bool isScalarExtract = static_cast<int64_t>(positionVec.size()) ==
+                           extractOp.getSourceVectorType().getRank();
 
     Value extracted = adaptor.getVector();
     if (isSlicingExtract) {

>From f4a3fb41d438722740140ef1acbe07a9b69e4138 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 29 Nov 2024 14:58:16 +0000
Subject: [PATCH 3/4] Address comments

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 70 ++++++++++---------
 .../VectorToLLVM/vector-to-llvm.mlir          |  6 +-
 2 files changed, 39 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6f83c5e1bd761d..6b3f7ca54cf75c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,53 +1096,55 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
-    // The stacking is modeled using arrays. We do this conversion from a
-    // N-d vector extract to stacked 1-d vector extract in two steps:
-    //  - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
-    //  - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
-
-    // Determine if we need to extract a slice out of the original vector. We
-    // always need to extract a slice if the input rank >= 2.
-    bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
+    // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
+    // 1-d vectors. This nesting is modeled using arrays. We do this conversion
+    // from a N-d vector extract to a nested aggregate vector extract in two
+    // steps:
+    //  - Extract a member from the nested aggregate. The result can be
+    //    a lower rank nested aggregate or a vector (1-D). This is done using
+    //    `llvm.extractvalue`.
+    //  - Extract a scalar out of the vector if needed. This is done using
+    //   `llvm.extractelement`.
+
+    // Determine if we need to extract a member out of the aggregate. We
+    // always need to extract a member if the input rank >= 2.
+    bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
     // Determine if we need to extract a scalar as the result. We extract
-    // a scalar if the extract is full rank i.e. the number of indices is equal
-    // to source vector rank.
-    bool isScalarExtract = static_cast<int64_t>(positionVec.size()) ==
-                           extractOp.getSourceVectorType().getRank();
+    // a scalar if the extract is full rank, i.e., the number of indices is
+    // equal to source vector rank.
+    bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
+                          extractOp.getSourceVectorType().getRank();
+
+    // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
+    // need to add a position for this change.
+    if (extractOp.getSourceVectorType().getRank() == 0) {
+      auto idxType = rewriter.getIndexType();
+      Value position = rewriter.create<LLVM::ConstantOp>(
+          loc, typeConverter->convertType(idxType),
+          rewriter.getIntegerAttr(idxType, 0));
+      positionVec.push_back(position);
+    }
 
     Value extracted = adaptor.getVector();
-    if (isSlicingExtract) {
+    if (extractsScalar) {
       ArrayRef<OpFoldResult> position(positionVec);
-      if (isScalarExtract) {
-        // If we are extracting a scalar from the returned slice, we need to
-        // extract a N-1 D slice.
+      if (extractsAggregate) {
+        // If we are extracting a scalar from the extracted member, we drop
+        // the last index, which will be used to extract the scalar out of the
+        // vector.
         position = position.drop_back();
       }
       // llvm.extractvalue does not support dynamic dimensions.
-      if (!llvm::all_of(position,
-                        [](OpFoldResult x) { return isa<Attribute>(x); })) {
+      if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
         return failure();
       }
       extracted = rewriter.create<LLVM::ExtractValueOp>(
           loc, extracted, getAsIntegers(position));
     }
 
-    if (isScalarExtract) {
-      Value position;
-      if (positionVec.empty()) {
-        // A scalar extract with no position is a 0-D vector extract. The LLVM
-        // type converter converts 0-D vectors to 1-D vectors, so we need to add
-        // a constant position.
-        auto idxType = rewriter.getIndexType();
-        position = rewriter.create<LLVM::ConstantOp>(
-            loc, typeConverter->convertType(idxType),
-            rewriter.getIntegerAttr(idxType, 0));
-      } else {
-        position = getAsLLVMValue(rewriter, loc, positionVec.back());
-      }
-      extracted =
-          rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
+    if (extractsScalar) {
+      extracted = rewriter.create<LLVM::ExtractElementOp>(
+          loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
     }
 
     rewriter.replaceOp(extractOp, extracted);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 58757bfd7951c6..e3727d6412849a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1295,7 +1295,7 @@ func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf
   return %0 : f32
 }
 
-// Multi-dim vectors are supported if the inner most dimension is dynamic.
+// Multi-dim vectors are supported if the innermost index is dynamic.
 
 // CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
 //       CHECK:   llvm.extractvalue
@@ -1306,7 +1306,7 @@ func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vect
   return %0 : f32
 }
 
-// Multi-dim vectors are supported if the inner most dimension is dynamic.
+// Multi-dim vectors are supported if the innermost index is dynamic.
 
 // CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
 //       CHECK:   llvm.extractvalue
@@ -1329,7 +1329,7 @@ func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vect
   return %0 : f32
 }
 
-// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// Multi-dim vectors with outer indices as dynamic are not supported, but it
 // shouldn't crash.
 
 // CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(

>From 9ba9d44bf981e00a9a6cfc7608548863ca9311b6 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 29 Nov 2024 15:33:06 +0000
Subject: [PATCH 4/4] Fix bugs

---
 .../Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp   | 11 ++++-------
 mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir |  2 +-
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6b3f7ca54cf75c..a9a07c323c7358 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1118,17 +1118,14 @@ class VectorExtractOpConversion
     // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
     // need to add a position for this change.
     if (extractOp.getSourceVectorType().getRank() == 0) {
-      auto idxType = rewriter.getIndexType();
-      Value position = rewriter.create<LLVM::ConstantOp>(
-          loc, typeConverter->convertType(idxType),
-          rewriter.getIntegerAttr(idxType, 0));
-      positionVec.push_back(position);
+      Type idxType = typeConverter->convertType(rewriter.getIndexType());
+      positionVec.push_back(rewriter.getZeroAttr(idxType));
     }
 
     Value extracted = adaptor.getVector();
-    if (extractsScalar) {
+    if (extractsAggregate) {
       ArrayRef<OpFoldResult> position(positionVec);
-      if (extractsAggregate) {
+      if (extractsScalar) {
         // If we are extracting a scalar from the extracted member, we drop
         // the last index, which will be used to extract the scalar out of the
         // vector.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e3727d6412849a..197f0cb2a568e6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1344,7 +1344,7 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
 // CHECK-LABEL: @extract_scalar_from_vec_0d_index(
 //  CHECK-SAME:   %[[A:.*]]: vector<index>)
 //       CHECK:   %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
-//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(0 : i64) : i64
 //       CHECK:   %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
 //       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
 //       CHECK:   return %[[T3]] : index



More information about the Mlir-commits mailing list