[Mlir-commits] [mlir] ff81cc8 - [mlir][spirv] Improve vector extract/insert element conversion
Lei Zhang
llvmlistbot at llvm.org
Wed Nov 30 16:41:06 PST 2022
Author: Lei Zhang
Date: 2022-12-01T00:35:41Z
New Revision: ff81cc824f2867e8b9f8504549fb2592faea4567
URL: https://github.com/llvm/llvm-project/commit/ff81cc824f2867e8b9f8504549fb2592faea4567
DIFF: https://github.com/llvm/llvm-project/commit/ff81cc824f2867e8b9f8504549fb2592faea4567.diff
LOG: [mlir][spirv] Improve vector extract/insert element conversion
* Fix type conversions around positions--we need to use the
converted value from the adaptor.
* Convert constant position cases to composite extract/insert.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D139057
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 38696d062fbe5..11608db409f71 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -191,19 +192,23 @@ struct VectorExtractElementOpConvert final
LogicalResult
matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type vectorType =
- getTypeConverter()->convertType(adaptor.getVector().getType());
- if (!vectorType)
+ Type resultType = getTypeConverter()->convertType(extractOp.getType());
+ if (!resultType)
return failure();
- if (vectorType.isa<spirv::ScalarType>()) {
+ if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, extractOp.getType(), adaptor.getVector(),
- extractOp.getPosition());
+ APInt cstPos;
+ if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, resultType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
+ else
+ rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+ extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
return success();
}
};
@@ -224,9 +229,15 @@ struct VectorInsertElementOpConvert final
return success();
}
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- insertOp.getPosition());
+ APInt cstPos;
+ if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(),
+ cstPos.getSExtValue());
+ else
+ rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+ insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
+ adaptor.getPosition());
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a1f05e9a620c9..4bb1835e1d928 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -187,9 +187,20 @@ func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
// -----
+// CHECK-LABEL: @extract_element_cst
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 {
+ %idx = arith.constant 1 : i32
+ %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32>
+ return %0: f32
+}
+
+// -----
+
// CHECK-LABEL: @extract_element_index
func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
- // CHECK: vector.extractelement
+ // CHECK: spirv.VectorExtractDynamic
%0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
return %0: f32
}
@@ -249,9 +260,20 @@ func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector
// -----
+// CHECK-LABEL: @insert_element_cst
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+ %idx = arith.constant 2 : i32
+ %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @insert_element_index
func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
- // CHECK: vector.insertelement
+ // CHECK: spirv.VectorInsertDynamic
%0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
return %0: vector<4xf32>
}
More information about the Mlir-commits
mailing list