[Mlir-commits] [mlir] [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (PR #114137)

Kunwar Grover llvmlistbot at llvm.org
Tue Nov 5 01:05:40 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/114137

>From 9270c2b05e0b95c6b8ab78eefcf95468cea8bff2 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 29 Oct 2024 22:20:07 +0000
Subject: [PATCH 1/2] [mlir][VectorToSPIRV] Add conversion for vector.extract
 with dynamic indices

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 48 ++++++++++---------
 .../VectorToSPIRV/vector-to-spirv.mlir        | 42 ++++++++++++++++
 2 files changed, 68 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 6184225cb6285d..ee8dccf025a0c6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
 /// Returns the integer value from the first valid input element, assuming Value
 /// inputs are defined by a constant index ops and Attribute inputs are integer
 /// attributes.
-static uint64_t getFirstIntValue(ValueRange values) {
-  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
-}
-static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
-  return cast<IntegerAttr>(attr[0]).getInt();
-}
 static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
-static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
-  auto attr = foldResults[0].dyn_cast<Attribute>();
-  if (attr)
-    return getFirstIntValue(attr);
-
-  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
-}
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
       return failure();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(extractOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-        extractOp, adaptor.getVector(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(extractOp.getMixedPosition()[0]);
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, dstType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr(id.value()));
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+          extractOp, dstType, adaptor.getVector(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(insertOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(insertOp.getMixedPosition()[0]);
+
+    //    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+    //        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    //    return success();
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+          insertOp, insertOp.getDest(), adaptor.getSource(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 25ec5d0159bd5d..62210108aa73cf 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,26 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_dynamic
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
+//       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// CHECK-LABEL: @extract_dynamic_cst
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
+  %idx = arith.constant 1 : index
+  %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
 //       CHECK:   spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +236,28 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 
 // -----
 
+// CHECK-LABEL: @insert_dynamic
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+  %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic_cst
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+  %idx = arith.constant 2 : index
+  %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32

>From edb8da02fcd46d9398efcb517cddbff77375985f Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 5 Nov 2024 09:01:28 +0000
Subject: [PATCH 2/2] Address comments

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  4 ----
 .../VectorToSPIRV/vector-to-spirv.mlir        | 22 +++++++++++++++++++
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index ee8dccf025a0c6..b6b5a1cf939e49 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -245,10 +245,6 @@ struct VectorInsertOpConvert final
     std::optional<int64_t> id =
         getConstantIntValue(insertOp.getMixedPosition()[0]);
 
-    //    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-    //        insertOp, adaptor.getSource(), adaptor.getDest(), id);
-    //    return success();
-
     if (id.has_value())
       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
           insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 62210108aa73cf..8796f153c4911b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,17 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_size1_vector_dynamic
+//  CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
+//       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:   return %[[R]]
+func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<1xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @extract_dynamic
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
 //       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
@@ -236,6 +247,17 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 
 // -----
 
+// CHECK-LABEL: @insert_size1_vector_dynamic
+//  CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
+//       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
+//       CHECK:   return %[[R]]
+func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> {
+  %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @insert_dynamic
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
 //       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32



More information about the Mlir-commits mailing list