[Mlir-commits] [mlir] c96a85a - [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (#114137)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 5 23:46:08 PST 2024
Author: Kunwar Grover
Date: 2024-11-06T07:46:05Z
New Revision: c96a85abfde822f2eda9076eb40078389b21f23e
URL: https://github.com/llvm/llvm-project/commit/c96a85abfde822f2eda9076eb40078389b21f23e
DIFF: https://github.com/llvm/llvm-project/commit/c96a85abfde822f2eda9076eb40078389b21f23e.diff
LOG: [mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (#114137)
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 6184225cb6285d..656b1cb3e99a1d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
/// Returns the integer value from the first valid input element, assuming Value
/// inputs are defined by a constant index ops and Attribute inputs are integer
/// attributes.
-static uint64_t getFirstIntValue(ValueRange values) {
- return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
-}
-static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
- return cast<IntegerAttr>(attr[0]).getInt();
-}
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
-static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
- auto attr = foldResults[0].dyn_cast<Attribute>();
- if (attr)
- return getFirstIntValue(attr);
-
- return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
-}
/// Returns the number of bits for the given scalar/vector type.
static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (extractOp.hasDynamicPosition())
- return failure();
-
Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
@@ -169,9 +154,15 @@ struct VectorExtractOpConvert final
return success();
}
- int32_t id = getFirstIntValue(extractOp.getMixedPosition());
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, adaptor.getVector(), id);
+ if (std::optional<int64_t> id =
+ getConstantIntValue(extractOp.getMixedPosition()[0]))
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
+ else
+ rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+ extractOp, dstType, adaptor.getVector(),
+ adaptor.getDynamicPosition()[0]);
return success();
}
};
@@ -249,9 +240,14 @@ struct VectorInsertOpConvert final
return success();
}
- int32_t id = getFirstIntValue(insertOp.getMixedPosition());
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id);
+ if (std::optional<int64_t> id =
+ getConstantIntValue(insertOp.getMixedPosition()[0]))
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ else
+ rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+ insertOp, insertOp.getDest(), adaptor.getSource(),
+ adaptor.getDynamicPosition()[0]);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 25ec5d0159bd5d..8796f153c4911b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,37 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
// -----
+// CHECK-LABEL: @extract_size1_vector_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
+// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: return %[[R]]
+func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 {
+ %0 = vector.extract %arg0[%id] : f32 from vector<1xf32>
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_dynamic
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
+ %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+ return %0: f32
+}
+
+// CHECK-LABEL: @extract_dynamic_cst
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
+ %idx = arith.constant 1 : index
+ %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+ return %0: f32
+}
+
+// -----
+
// CHECK-LABEL: @insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +247,39 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
// -----
+// CHECK-LABEL: @insert_size1_vector_dynamic
+// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
+// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
+// CHECK: return %[[R]]
+func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> {
+ %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+ %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic_cst
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+ %idx = arith.constant 2 : index
+ %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_element
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
More information about the Mlir-commits
mailing list