[Mlir-commits] [mlir] [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (PR #114137)

Kunwar Grover llvmlistbot at llvm.org
Tue Oct 29 15:22:07 PDT 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/114137

None

>From 9270c2b05e0b95c6b8ab78eefcf95468cea8bff2 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 29 Oct 2024 22:20:07 +0000
Subject: [PATCH] [mlir][VectorToSPIRV] Add conversion for vector.extract with
 dynamic indices

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 48 ++++++++++---------
 .../VectorToSPIRV/vector-to-spirv.mlir        | 42 ++++++++++++++++
 2 files changed, 68 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 6184225cb6285d..ee8dccf025a0c6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
 /// Returns the integer value from the first valid input element, assuming Value
 /// inputs are defined by a constant index ops and Attribute inputs are integer
 /// attributes.
-static uint64_t getFirstIntValue(ValueRange values) {
-  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
-}
-static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
-  return cast<IntegerAttr>(attr[0]).getInt();
-}
 static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
-static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
-  auto attr = foldResults[0].dyn_cast<Attribute>();
-  if (attr)
-    return getFirstIntValue(attr);
-
-  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
-}
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
       return failure();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(extractOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-        extractOp, adaptor.getVector(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(extractOp.getMixedPosition()[0]);
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, dstType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr(id.value()));
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+          extractOp, dstType, adaptor.getVector(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(insertOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(insertOp.getMixedPosition()[0]);
+
+    //    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+    //        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    //    return success();
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+          insertOp, insertOp.getDest(), adaptor.getSource(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 25ec5d0159bd5d..62210108aa73cf 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,26 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_dynamic
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
+//       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// CHECK-LABEL: @extract_dynamic_cst
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
+  %idx = arith.constant 1 : index
+  %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
 //       CHECK:   spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +236,28 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 
 // -----
 
+// CHECK-LABEL: @insert_dynamic
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+  %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic_cst
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+  %idx = arith.constant 2 : index
+  %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32



More information about the Mlir-commits mailing list