[Mlir-commits] [mlir] 18ad80e - [mlir][spirv] Improve integer cast during type conversion

Lei Zhang llvmlistbot at llvm.org
Wed Jul 12 14:38:18 PDT 2023


Author: Lei Zhang
Date: 2023-07-12T14:38:11-07:00
New Revision: 18ad80ea6f69879b7ada4089de259553917e5166

URL: https://github.com/llvm/llvm-project/commit/18ad80ea6f69879b7ada4089de259553917e5166
DIFF: https://github.com/llvm/llvm-project/commit/18ad80ea6f69879b7ada4089de259553917e5166.diff

LOG: [mlir][spirv] Improve integer cast during type conversion

In SPIR-V, the capabilities for storage and compute are separate.
We have good handling of the storage side in general via MemRef
type conversion and various `memref` dialect ops.

Once the value was loaded properly, if the compute capability is
supported directly, we don't need to emulate like the storage side
with int32. However, we do need to make sure casting ops are
properly inserted to chain the flow to go back to the original
bitwidth.

Right now that is done in the each individual pattern directly,
which put lots of pressure that shouldn't be on the patterns and
causes duplication and trickiness w.r.t. capability check and such.

Instead, we should handle such casting within the SPIR-V conversion
framework using `addSourceMaterialization`, where we can check with
the target environment to make sure the corresponding compute
capability is allowed and then we can materialize and SPIR-V casting
op.

Along the way, we can drop all the duplicated cast materialization
registration in various places.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D155118

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 5d2c1f35432149..f74c7e3490cd80 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -1155,13 +1155,6 @@ struct ConvertArithToSPIRVPass
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
-      auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-      return std::optional<Value>(cast.getResult(0));
-    };
-    typeConverter.addSourceMaterialization(addUnrealizedCast);
-    typeConverter.addTargetMaterialization(addUnrealizedCast);
     target->addLegalOp<UnrealizedConversionCastOp>();
 
     // Fail hard when there are any remaining 'arith' ops.

diff  --git a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp
index d57fa4a7c8d266..519a90e7f306a0 100644
--- a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp
@@ -40,13 +40,6 @@ class ConvertComplexToSPIRVPass
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
-      auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-      return std::optional<Value>(cast.getResult(0));
-    };
-    typeConverter.addSourceMaterialization(addUnrealizedCast);
-    typeConverter.addTargetMaterialization(addUnrealizedCast);
     target->addLegalOp<UnrealizedConversionCastOp>();
 
     RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
index ad455c297e1c4a..fac9c7e3a4f7a6 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
@@ -44,13 +44,6 @@ void ConvertMathToSPIRVPass::runOnOperation() {
 
   // Use UnrealizedConversionCast as the bridge so that we don't need to pull
   // in patterns for other dialects.
-  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return std::optional<Value>(cast.getResult(0));
-  };
-  typeConverter.addSourceMaterialization(addUnrealizedCast);
-  typeConverter.addTargetMaterialization(addUnrealizedCast);
   target->addLegalOp<UnrealizedConversionCastOp>();
 
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index efd541b46d8fe2..28da42966e7337 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -490,14 +490,6 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
                                                           shiftValue);
 
-  if (isBool) {
-    dstType = typeConverter.convertType(loadOp.getType());
-    mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
-    result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
-  } else if (result.getType().getIntOrFloatBitWidth() !=
-             static_cast<unsigned>(dstBits)) {
-    result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
-  }
   rewriter.replaceOp(loadOp, result);
 
   assert(accessChainOp.use_empty());

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
index 9effeabc78d37a..e2ce927cc8fdfd 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
@@ -45,13 +45,6 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
 
   // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
   // patterns for other dialects.
-  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return std::optional<Value>(cast.getResult(0));
-  };
-  typeConverter.addSourceMaterialization(addUnrealizedCast);
-  typeConverter.addTargetMaterialization(addUnrealizedCast);
   target->addLegalOp<UnrealizedConversionCastOp>();
 
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 57646c86b66203..1932de1be603b6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -44,13 +44,6 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
 
   // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
   // patterns for other dialects.
-  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return std::optional<Value>(cast.getResult(0));
-  };
-  typeConverter.addSourceMaterialization(addUnrealizedCast);
-  typeConverter.addTargetMaterialization(addUnrealizedCast);
   target->addLegalOp<UnrealizedConversionCastOp>();
 
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 9fe2f8b35d7a4f..c8d7aef8964201 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -565,6 +565,84 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
+//===----------------------------------------------------------------------===//
+// Type casting materialization
+//===----------------------------------------------------------------------===//
+
+/// Converts the given `inputs` to the original source `type` considering the
+/// `targetEnv`'s capabilities.
+///
+/// This function is meant to be used for source materialization in type
+/// converters. When the type converter needs to materialize a cast op back
+/// to some original source type, we need to check whether the original source
+/// type is supported in the target environment. If so, we can insert legal
+/// SPIR-V cast ops accordingly.
+///
+/// Note that in SPIR-V the capabilities for storage and compute are separate.
+/// This function is meant to handle the **compute** side; so it does not
+/// involve storage classes in its logic. The storage side is expected to be
+/// handled by MemRef conversion logic.
+std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
+                                      OpBuilder &builder, Type type,
+                                      ValueRange inputs, Location loc) {
+  // We can only cast one value in SPIR-V.
+  if (inputs.size() != 1) {
+    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return castOp.getResult(0);
+  }
+  Value input = inputs.front();
+
+  // Only support integer types for now. Floating point types to be implemented.
+  if (!isa<IntegerType>(type)) {
+    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return castOp.getResult(0);
+  }
+  auto inputType = cast<IntegerType>(input.getType());
+
+  auto scalarType = dyn_cast<spirv::ScalarType>(type);
+  if (!scalarType) {
+    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return castOp.getResult(0);
+  }
+
+  // Only support source type with a smaller bitwidth. This would mean we are
+  // truncating to go back so we don't need to worry about the signedness.
+  // For extension, we cannot have enough signal here to decide which op to use.
+  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
+    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return castOp.getResult(0);
+  }
+
+  // Boolean values would need to use 
diff erent ops than normal integer values.
+  if (type.isInteger(1)) {
+    Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
+    return builder.create<spirv::IEqualOp>(loc, input, one);
+  }
+
+  // Check that the source integer type is supported by the environment.
+  SmallVector<ArrayRef<spirv::Extension>, 1> exts;
+  SmallVector<ArrayRef<spirv::Capability>, 2> caps;
+  scalarType.getExtensions(exts);
+  scalarType.getCapabilities(caps);
+  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
+      failed(checkExtensionRequirements(type, targetEnv, exts))) {
+    auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return castOp.getResult(0);
+  }
+
+  // We've already made sure this is truncating previously, so we don't need to
+  // care about signedness here. Still try to use a corresponding op for better
+  // consistency though.
+  if (type.isSignedInteger()) {
+    return builder.create<spirv::SConvertOp>(loc, type, input);
+  }
+  return builder.create<spirv::UConvertOp>(loc, type, input);
+}
+
+//===----------------------------------------------------------------------===//
+// SPIRVTypeConverter
+//===----------------------------------------------------------------------===//
+
 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
                                        const SPIRVConversionOptions &options)
     : targetEnv(targetAttr), options(options) {
@@ -611,6 +689,17 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   addConversion([this](MemRefType memRefType) {
     return convertMemrefType(this->targetEnv, this->options, memRefType);
   });
+
+  // Register some last line of defense casting logic.
+  addSourceMaterialization(
+      [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
+        return castToSourceType(this->targetEnv, builder, type, inputs, loc);
+      });
+  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
+                              Location loc) {
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return std::optional<Value>(cast.getResult(0));
+  });
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 97d1a3add9c186..ef77dc9e75933e 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -297,7 +297,8 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
   //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
-  //     CHECK: spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
+  //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
+  //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
   %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
   return %0 : i8
 }
@@ -321,7 +322,8 @@ func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
   //     CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
   //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[SIXTEEN]] : i32, i32
-  //     CHECK: spirv.ShiftRightArithmetic %[[T3]], %[[SIXTEEN]] : i32, i32
+  //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[SIXTEEN]] : i32, i32
+  //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
   %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>>
   return %0: i16
 }
@@ -448,7 +450,8 @@ func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i:
   // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
   // CHECK: %[[C28:.+]] = spirv.Constant 28 : i32
   // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[AND]], %[[C28]] : i32, i32
-  // CHECK: spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32
+  // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32
+  // CHECK: builtin.unrealized_conversion_cast %[[SR]]
   %0 = memref.load %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>>
   return %0 : i4
 }
@@ -479,3 +482,41 @@ func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %v
 }
 
 } // end module
+
+// -----
+
+// Check that casts are properly inserted if the corresponding **compute**
+// capability is allowed.
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader, Int8, Int16], [
+      SPV_KHR_8bit_storage, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class
+      ]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_i1
+func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
+  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+  //     CHECK: %[[RES:.+]]  = spirv.IEqual %{{.+}}, %[[ONE]] : i32
+  //     CHECK: return %[[RES]]
+  %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
+  return %0 : i1
+}
+
+// CHECK-LABEL: @load_i8
+func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
+  //     CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i8
+  //     CHECK: return %[[RES]]
+  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
+  return %0 : i8
+}
+
+// CHECK-LABEL: @load_i16
+func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
+  //     CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i16
+  //     CHECK: return %[[RES]]
+  %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class<StorageBuffer>>
+  return %0: i16
+}
+
+} // end module


        


More information about the Mlir-commits mailing list