[Mlir-commits] [mlir] 91de20c - [mlir][spirv] Use UnrealizedConversionCast in ArithmeticToSPIRV

Lei Zhang llvmlistbot at llvm.org
Mon Jun 13 10:14:04 PDT 2022


Author: Lei Zhang
Date: 2022-06-13T13:13:57-04:00
New Revision: 91de20c36d585eed3abc82f4f15907c6dcd2067c

URL: https://github.com/llvm/llvm-project/commit/91de20c36d585eed3abc82f4f15907c6dcd2067c
DIFF: https://github.com/llvm/llvm-project/commit/91de20c36d585eed3abc82f4f15907c6dcd2067c.diff

LOG: [mlir][spirv] Use UnrealizedConversionCast in ArithmeticToSPIRV

This avoids pulling in function converion patterns, which is not
part of what we want to test in ArithmeticToSPIRV. It also allows
using ConvertArithmeticToSPIRVPass as a standalone step.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D127573

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 0170edac742da..c4b6382a42bac 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -878,10 +878,19 @@ struct ConvertArithmeticToSPIRVPass
     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
+    // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+    // in patterns for other dialects.
+    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
+                                ValueRange inputs, Location loc) {
+      auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+      return Optional<Value>(cast.getResult(0));
+    };
+    typeConverter.addSourceMaterialization(addUnrealizedCast);
+    typeConverter.addTargetMaterialization(addUnrealizedCast);
+    target->addLegalOp<UnrealizedConversionCastOp>();
+
     RewritePatternSet patterns(&getContext());
     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
-    populateFuncToSPIRVPatterns(typeConverter, patterns);
-    populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
     if (failed(applyPartialConversion(module, *target, std::move(patterns))))
       signalPassFailure();

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index f5793142ed252..7d17359030d46 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -58,8 +58,10 @@ func.func @index_scalar(%lhs: index, %rhs: index) {
 }
 
 // CHECK-LABEL: @index_scalar_srem
-// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+// CHECK-SAME: (%[[A:.+]]: index, %[[B:.+]]: index)
 func.func @index_scalar_srem(%lhs: index, %rhs: index) {
+  // CHECK: %[[LHS:.+]] = builtin.unrealized_conversion_cast %[[A]] : index to i32
+  // CHECK: %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[B]] : index to i32
   // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32
   // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32
   // CHECK:  %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32
@@ -185,10 +187,8 @@ module attributes {
   spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, #spv.resource_limits<>>
 } {
 
-// expected-error @+1 {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'vector<4xi64>', with target type 'vector<4xi32>'}}
 func.func @int_vector4_invalid(%arg0: vector<4xi64>) {
-  // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}}
-  // expected-note @+1 {{see existing live user here}}
+  // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}}
   %0 = arith.divui %arg0, %arg0: vector<4xi64>
   return
 }
@@ -837,8 +837,9 @@ module attributes {
 } {
 
 // CHECK-LABEL: @fpext1
-// CHECK-SAME: %[[ARG:.*]]: f32
+// CHECK-SAME: %[[A:.*]]: f16
 func.func @fpext1(%arg0: f16) -> f64 {
+  // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f16 to f32
   // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64
   %0 = arith.extf %arg0 : f16 to f64
   return %0: f64
@@ -863,8 +864,9 @@ module attributes {
 } {
 
 // CHECK-LABEL: @fptrunc1
-// CHECK-SAME: %[[ARG:.*]]: f32
+// CHECK-SAME: %[[A:.*]]: f64
 func.func @fptrunc1(%arg0 : f64) -> f16 {
+  // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32
   // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16
   %0 = arith.truncf %arg0 : f64 to f16
   return %0: f16
@@ -1110,10 +1112,8 @@ module attributes {
   spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, #spv.resource_limits<>>
 } {
 
-// expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
 func.func @int_vector4_invalid(%arg0: vector<4xi64>) {
-  // expected-error at below {{bitwidth emulation is not implemented yet on unsigned op}}
-  // expected-note at below {{see existing live user here}}
+  // expected-error at +1 {{bitwidth emulation is not implemented yet on unsigned op}}
   %0 = arith.divui %arg0, %arg0: vector<4xi64>
   return
 }
@@ -1733,8 +1733,9 @@ module attributes {
 } {
 
 // CHECK-LABEL: @fpext1
-// CHECK-SAME: %[[ARG:.*]]: f32
+// CHECK-SAME: %[[A:.*]]: f16
 func.func @fpext1(%arg0: f16) -> f64 {
+  // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f16 to f32
   // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64
   %0 = arith.extf %arg0 : f16 to f64
   return %0: f64
@@ -1759,8 +1760,9 @@ module attributes {
 } {
 
 // CHECK-LABEL: @fptrunc1
-// CHECK-SAME: %[[ARG:.*]]: f32
+// CHECK-SAME: %[[A:.*]]: f64
 func.func @fptrunc1(%arg0 : f64) -> f16 {
+  // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32
   // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16
   %0 = arith.truncf %arg0 : f64 to f16
   return %0: f16


        


More information about the Mlir-commits mailing list