[Mlir-commits] [mlir] c96a85a - [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (#114137)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 5 23:46:08 PST 2024


Author: Kunwar Grover
Date: 2024-11-06T07:46:05Z
New Revision: c96a85abfde822f2eda9076eb40078389b21f23e

URL: https://github.com/llvm/llvm-project/commit/c96a85abfde822f2eda9076eb40078389b21f23e
DIFF: https://github.com/llvm/llvm-project/commit/c96a85abfde822f2eda9076eb40078389b21f23e.diff

LOG: [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (#114137)

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 6184225cb6285d..656b1cb3e99a1d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
 /// Returns the integer value from the first valid input element, assuming Value
 /// inputs are defined by a constant index ops and Attribute inputs are integer
 /// attributes.
-static uint64_t getFirstIntValue(ValueRange values) {
-  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
-}
-static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
-  return cast<IntegerAttr>(attr[0]).getInt();
-}
 static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
-static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
-  auto attr = foldResults[0].dyn_cast<Attribute>();
-  if (attr)
-    return getFirstIntValue(attr);
-
-  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
-}
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
       return failure();
@@ -169,9 +154,15 @@ struct VectorExtractOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(extractOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-        extractOp, adaptor.getVector(), id);
+    if (std::optional<int64_t> id =
+            getConstantIntValue(extractOp.getMixedPosition()[0]))
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, dstType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr(id.value()));
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+          extractOp, dstType, adaptor.getVector(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
@@ -249,9 +240,14 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(insertOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    if (std::optional<int64_t> id =
+            getConstantIntValue(insertOp.getMixedPosition()[0]))
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+          insertOp, insertOp.getDest(), adaptor.getSource(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 25ec5d0159bd5d..8796f153c4911b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,37 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_size1_vector_dynamic
+//  CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
+//       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:   return %[[R]]
+func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<1xf32>
+  return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_dynamic
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
+//       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// CHECK-LABEL: @extract_dynamic_cst
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
+  %idx = arith.constant 1 : index
+  %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
 //       CHECK:   spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +247,39 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 
 // -----
 
+// CHECK-LABEL: @insert_size1_vector_dynamic
+//  CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
+//       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
+//       CHECK:   return %[[R]]
+func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> {
+  %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+  %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic_cst
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+  %idx = arith.constant 2 : index
+  %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32


        


More information about the Mlir-commits mailing list