[Mlir-commits] [mlir] 223be5f - [mlir][spirv] Perform partial conversion in VectorToSPIRVPass
Lei Zhang
llvmlistbot at llvm.org
Thu Dec 16 06:41:02 PST 2021
Author: Lei Zhang
Date: 2021-12-16T09:35:56-05:00
New Revision: 223be5f630c0fa71a69a591a9b0317953093035c
URL: https://github.com/llvm/llvm-project/commit/223be5f630c0fa71a69a591a9b0317953093035c
DIFF: https://github.com/llvm/llvm-project/commit/223be5f630c0fa71a69a591a9b0317953093035c.diff
LOG: [mlir][spirv] Perform partial conversion in VectorToSPIRVPass
This allows the pass to participate in progressive lowering
and it also allows us to write tests better.
Along the way, cleaned up the tests.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D115756
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
mlir/test/Conversion/VectorToSPIRV/simple.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 9ffd1e595cd59..7391defe78589 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -22,13 +22,13 @@
using namespace mlir;
namespace {
-struct LowerVectorToSPIRVPass
- : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
+struct ConvertVectorToSPIRVPass
+ : public ConvertVectorToSPIRVBase<ConvertVectorToSPIRVPass> {
void runOnOperation() override;
};
} // namespace
-void LowerVectorToSPIRVPass::runOnOperation() {
+void ConvertVectorToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
@@ -37,17 +37,26 @@ void LowerVectorToSPIRVPass::runOnOperation() {
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
+
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
+ // patterns for other dialects.
+ auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) {
+ auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ return Optional<Value>(cast.getResult(0));
+ };
+ typeConverter.addSourceMaterialization(addUnrealizedCast);
+ typeConverter.addTargetMaterialization(addUnrealizedCast);
+ target->addLegalOp<UnrealizedConversionCastOp>();
+
RewritePatternSet patterns(context);
populateVectorToSPIRVPatterns(typeConverter, patterns);
- target->addLegalOp<ModuleOp>();
- target->addLegalOp<FuncOp>();
-
- if (failed(applyFullConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(module, *target, std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToSPIRVPass() {
- return std::make_unique<LowerVectorToSPIRVPass>();
+ return std::make_unique<ConvertVectorToSPIRVPass>();
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index a253fc7fbbcbe..8f5cf197713d2 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -2,152 +2,147 @@
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
-// CHECK-LABEL: func @bitcast
+// CHECK-LABEL: @bitcast
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf16>
-// CHECK: %{{.+}} = spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16>
-// CHECK: %{{.+}} = spv.Bitcast %[[ARG1]] : vector<2xf16> to f32
-func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) {
+// CHECK: spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16>
+// CHECK: spv.Bitcast %[[ARG1]] : vector<2xf16> to f32
+func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) {
%0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
%1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
- spv.Return
+ return %0, %1: vector<4xf16>, vector<1xf32>
}
} // end module
// -----
-// CHECK-LABEL: broadcast
+// CHECK-LABEL: @broadcast
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
// CHECK: spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32>
-func @broadcast(%arg0 : f32) {
+func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
%0 = vector.broadcast %arg0 : f32 to vector<4xf32>
%1 = vector.broadcast %arg0 : f32 to vector<2xf32>
- spv.Return
+ return %0, %1: vector<4xf32>, vector<2xf32>
}
// -----
-// CHECK-LABEL: func @extract
+// CHECK-LABEL: @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
-// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func @extract(%arg0 : vector<2xf32>) {
+// CHECK: spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
+// CHECK: spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
+func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
%0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
%1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
- spv.Return
+ return %0, %1: vector<1xf32>, f32
}
// -----
-module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
-
-// CHECK-LABEL: func @extract_scalar
-// CHECK-SAME: %[[ARG0:.+]]: vector<2xf16>
-// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
-// CHECK: %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32
-// CHECK: spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32>
-func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) {
- %0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32>
- %1 = vector.extract %0[0] : vector<1xf32>
- %2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32>
- spv.Return
+// CHECK-LABEL: @extract_size1_vector
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
+// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: return %[[R]]
+func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
+ %0 = vector.extract %arg0[0] : vector<1xf32>
+ return %0: f32
}
-} // end module
-
// -----
-// CHECK-LABEL: extract_insert
-// CHECK-SAME: %[[V:.*]]: vector<4xf32>
-// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
-// CHECK: spv.CompositeInsert %[[S]], %[[V]][0 : i32] : f32 into vector<4xf32>
-func @extract_insert(%arg0 : vector<4xf32>) {
- %0 = vector.extract %arg0[1] : vector<4xf32>
- %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32>
- spv.Return
+// CHECK-LABEL: @insert
+// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
+// CHECK: spv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
+ %1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32>
+ return %1: vector<4xf32>
}
// -----
-// CHECK-LABEL: extract_element
+// CHECK-LABEL: @extract_element
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
-func @extract_element(%arg0 : vector<4xf32>, %id : i32) {
+func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
%0 = vector.extractelement %arg0[%id : i32] : vector<4xf32>
- spv.ReturnValue %0: f32
+ return %0: f32
}
// -----
-func @extract_element_index(%arg0 : vector<4xf32>, %id : index) {
-// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}}
+// CHECK-LABEL: @extract_element_index
+func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
+ // CHECK: vector.extractelement
%0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
- spv.ReturnValue %0: f32
+ return %0: f32
}
// -----
-func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
-// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}}
+// CHECK-LABEL: @extract_element_size5_vector
+func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 {
+ // CHECK: vector.extractelement
%0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
- spv.ReturnValue %0: f32
+ return %0: f32
}
// -----
-// CHECK-LABEL: func @extract_strided_slice
+// CHECK-LABEL: @extract_strided_slice
// CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
-// CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
-// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32>
-func @extract_strided_slice(%arg0: vector<4xf32>) {
+// CHECK: spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
+// CHECK: spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32>
+func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector<1xf32>) {
%0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
- spv.Return
+ return %0, %1 : vector<2xf32>, vector<1xf32>
}
// -----
-// CHECK-LABEL: insert_element
+// CHECK-LABEL: @insert_element
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
-func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) {
+func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> {
%0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32>
- spv.ReturnValue %0: vector<4xf32>
+ return %0: vector<4xf32>
}
// -----
-func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) {
-// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
+// CHECK-LABEL: @insert_element_index
+func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+ // CHECK: vector.insertelement
%0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
- spv.ReturnValue %0: vector<4xf32>
+ return %0: vector<4xf32>
}
// -----
-func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
-// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
+// CHECK-LABEL: @insert_element_size5_vector
+func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> {
+ // CHECK: vector.insertelement
%0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
- spv.Return
+ return %0 : vector<5xf32>
}
// -----
-// CHECK-LABEL: func @insert_strided_slice
+// CHECK-LABEL: @insert_strided_slice
// CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
-// CHECK: %{{.+}} = spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
-func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) {
+// CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
+func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32>
- spv.Return
+ return %0 : vector<4xf32>
}
// -----
-// CHECK-LABEL: func @fma
+// CHECK-LABEL: @fma
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
-func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) {
+func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
%0 = vector.fma %a, %b, %c: vector<4xf32>
- spv.Return
+ return %0 : vector<4xf32>
}
More information about the Mlir-commits
mailing list