[Mlir-commits] [mlir] 223be5f - [mlir][spirv] Perform partial conversion in VectorToSPIRVPass

Lei Zhang llvmlistbot at llvm.org
Thu Dec 16 06:41:02 PST 2021


Author: Lei Zhang
Date: 2021-12-16T09:35:56-05:00
New Revision: 223be5f630c0fa71a69a591a9b0317953093035c

URL: https://github.com/llvm/llvm-project/commit/223be5f630c0fa71a69a591a9b0317953093035c
DIFF: https://github.com/llvm/llvm-project/commit/223be5f630c0fa71a69a591a9b0317953093035c.diff

LOG: [mlir][spirv] Perform partial conversion in VectorToSPIRVPass

This allows the pass to participate in progressive lowering
and it also allows us to write tests better.

Along the way, cleaned up the tests.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D115756

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 9ffd1e595cd59..7391defe78589 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -22,13 +22,13 @@
 using namespace mlir;
 
 namespace {
-struct LowerVectorToSPIRVPass
-    : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
+struct ConvertVectorToSPIRVPass
+    : public ConvertVectorToSPIRVBase<ConvertVectorToSPIRVPass> {
   void runOnOperation() override;
 };
 } // namespace
 
-void LowerVectorToSPIRVPass::runOnOperation() {
+void ConvertVectorToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
   ModuleOp module = getOperation();
 
@@ -37,17 +37,26 @@ void LowerVectorToSPIRVPass::runOnOperation() {
       SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
+
+  // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
+  // patterns for other dialects.
+  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                              Location loc) {
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return Optional<Value>(cast.getResult(0));
+  };
+  typeConverter.addSourceMaterialization(addUnrealizedCast);
+  typeConverter.addTargetMaterialization(addUnrealizedCast);
+  target->addLegalOp<UnrealizedConversionCastOp>();
+
   RewritePatternSet patterns(context);
   populateVectorToSPIRVPatterns(typeConverter, patterns);
 
-  target->addLegalOp<ModuleOp>();
-  target->addLegalOp<FuncOp>();
-
-  if (failed(applyFullConversion(module, *target, std::move(patterns))))
+  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
 std::unique_ptr<OperationPass<ModuleOp>>
 mlir::createConvertVectorToSPIRVPass() {
-  return std::make_unique<LowerVectorToSPIRVPass>();
+  return std::make_unique<ConvertVectorToSPIRVPass>();
 }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index a253fc7fbbcbe..8f5cf197713d2 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -2,152 +2,147 @@
 
 module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
 
-// CHECK-LABEL: func @bitcast
+// CHECK-LABEL: @bitcast
 //  CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf16>
-//       CHECK:   %{{.+}} = spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16>
-//       CHECK:   %{{.+}} = spv.Bitcast %[[ARG1]] : vector<2xf16> to f32
-func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) {
+//       CHECK:   spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16>
+//       CHECK:   spv.Bitcast %[[ARG1]] : vector<2xf16> to f32
+func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) {
   %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
   %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
-  spv.Return
+  return %0, %1: vector<4xf16>, vector<1xf32>
 }
 
 } // end module
 
 // -----
 
-// CHECK-LABEL: broadcast
+// CHECK-LABEL: @broadcast
 //  CHECK-SAME: %[[A:.*]]: f32
 //       CHECK:   spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
 //       CHECK:   spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32>
-func @broadcast(%arg0 : f32) {
+func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
   %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
   %1 = vector.broadcast %arg0 : f32 to vector<2xf32>
-  spv.Return
+  return %0, %1: vector<4xf32>, vector<2xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func @extract
+// CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
-//       CHECK:   %{{.+}} = spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
-//       CHECK:   %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
-func @extract(%arg0 : vector<2xf32>) {
+//       CHECK:   spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
+//       CHECK:   spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
+func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
   %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
   %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
-  spv.Return
+  return %0, %1: vector<1xf32>, f32
 }
 
 // -----
 
-module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
-
-// CHECK-LABEL: func @extract_scalar
-//  CHECK-SAME: %[[ARG0:.+]]: vector<2xf16>
-//  CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
-//       CHECK:   %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32
-//       CHECK:   spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32>
-func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) {
-  %0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32>
-  %1 = vector.extract %0[0] : vector<1xf32>
-  %2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32>
-  spv.Return
+// CHECK-LABEL: @extract_size1_vector
+//  CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
+//       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:   return %[[R]]
+func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
+  %0 = vector.extract %arg0[0] : vector<1xf32>
+	return %0: f32
 }
 
-} // end module
-
 // -----
 
-// CHECK-LABEL: extract_insert
-//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
-//       CHECK:   %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
-//       CHECK:   spv.CompositeInsert %[[S]], %[[V]][0 : i32] : f32 into vector<4xf32>
-func @extract_insert(%arg0 : vector<4xf32>) {
-  %0 = vector.extract %arg0[1] : vector<4xf32>
-  %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32>
-  spv.Return
+// CHECK-LABEL: @insert
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
+//       CHECK:   spv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
+  %1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32>
+	return %1: vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: extract_element
+// CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
-func @extract_element(%arg0 : vector<4xf32>, %id : i32) {
+func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
   %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32>
-  spv.ReturnValue %0: f32
+  return %0: f32
 }
 
 // -----
 
-func @extract_element_index(%arg0 : vector<4xf32>, %id : index) {
-// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}}
+// CHECK-LABEL: @extract_element_index
+func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
+	// CHECK: vector.extractelement
   %0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
-  spv.ReturnValue %0: f32
+  return %0: f32
 }
 
 // -----
 
-func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
-// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}}
+// CHECK-LABEL: @extract_element_size5_vector
+func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 {
+	// CHECK: vector.extractelement
   %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
-  spv.ReturnValue %0: f32
+  return %0: f32
 }
 
 // -----
 
-// CHECK-LABEL: func @extract_strided_slice
+// CHECK-LABEL: @extract_strided_slice
 //  CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
-//       CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
-//       CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32>
-func @extract_strided_slice(%arg0: vector<4xf32>) {
+//       CHECK:   spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
+//       CHECK:   spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32>
+func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector<1xf32>) {
   %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   %1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
-  spv.Return
+  return %0, %1 : vector<2xf32>, vector<1xf32>
 }
 
 // -----
 
-// CHECK-LABEL: insert_element
+// CHECK-LABEL: @insert_element
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
-func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) {
+func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> {
   %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32>
-  spv.ReturnValue %0: vector<4xf32>
+  return %0: vector<4xf32>
 }
 
 // -----
 
-func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) {
-// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
+// CHECK-LABEL: @insert_element_index
+func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+	// CHECK: vector.insertelement
   %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
-  spv.ReturnValue %0: vector<4xf32>
+  return %0: vector<4xf32>
 }
 
 // -----
 
-func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
-// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
+// CHECK-LABEL: @insert_element_size5_vector
+func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> {
+	// CHECK: vector.insertelement
   %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
-  spv.Return
+	return %0 : vector<5xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func @insert_strided_slice
+// CHECK-LABEL: @insert_strided_slice
 //  CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
-//       CHECK: %{{.+}} = spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
-func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) {
+//       CHECK: 	spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
+func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32>
-  spv.Return
+	return %0 : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func @fma
+// CHECK-LABEL: @fma
 //  CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
 //       CHECK:   spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
-func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) {
+func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.fma %a, %b, %c: vector<4xf32>
-  spv.Return
+	return %0 : vector<4xf32>
 }


        


More information about the Mlir-commits mailing list