[Mlir-commits] [mlir] [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (PR #114137)
Kunwar Grover
llvmlistbot at llvm.org
Tue Oct 29 15:22:07 PDT 2024
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/114137
None
>From 9270c2b05e0b95c6b8ab78eefcf95468cea8bff2 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 29 Oct 2024 22:20:07 +0000
Subject: [PATCH] [mlir][VectorToSPIRV] Add conversion for vector.extract with
dynamic indices
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 48 ++++++++++---------
.../VectorToSPIRV/vector-to-spirv.mlir | 42 ++++++++++++++++
2 files changed, 68 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 6184225cb6285d..ee8dccf025a0c6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
/// Returns the integer value from the first valid input element, assuming Value
/// inputs are defined by a constant index ops and Attribute inputs are integer
/// attributes.
-static uint64_t getFirstIntValue(ValueRange values) {
- return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
-}
-static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
- return cast<IntegerAttr>(attr[0]).getInt();
-}
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
-static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
- auto attr = foldResults[0].dyn_cast<Attribute>();
- if (attr)
- return getFirstIntValue(attr);
-
- return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
-}
/// Returns the number of bits for the given scalar/vector type.
static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (extractOp.hasDynamicPosition())
- return failure();
-
Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
return success();
}
- int32_t id = getFirstIntValue(extractOp.getMixedPosition());
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, adaptor.getVector(), id);
+ std::optional<int64_t> id =
+ getConstantIntValue(extractOp.getMixedPosition()[0]);
+
+ if (id.has_value())
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
+ else
+ rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+ extractOp, dstType, adaptor.getVector(),
+ adaptor.getDynamicPosition()[0]);
return success();
}
};
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
return success();
}
- int32_t id = getFirstIntValue(insertOp.getMixedPosition());
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id);
+ std::optional<int64_t> id =
+ getConstantIntValue(insertOp.getMixedPosition()[0]);
+
+ // rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ // insertOp, adaptor.getSource(), adaptor.getDest(), id);
+ // return success();
+
+ if (id.has_value())
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ else
+ rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+ insertOp, insertOp.getDest(), adaptor.getSource(),
+ adaptor.getDynamicPosition()[0]);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 25ec5d0159bd5d..62210108aa73cf 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,26 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
// -----
+// CHECK-LABEL: @extract_dynamic
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
+ %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+ return %0: f32
+}
+
+// CHECK-LABEL: @extract_dynamic_cst
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
+ %idx = arith.constant 1 : index
+ %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+ return %0: f32
+}
+
+// -----
+
// CHECK-LABEL: @insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +236,28 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
// -----
+// CHECK-LABEL: @insert_dynamic
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+ %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic_cst
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+ %idx = arith.constant 2 : index
+ %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_element
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
More information about the Mlir-commits
mailing list