[Mlir-commits] [mlir] 56c0698 - [mlir][spirv] Fail vector.bitcast conversion with different bitwidth

Lei Zhang llvmlistbot at llvm.org
Thu Dec 29 15:44:14 PST 2022


Author: Lei Zhang
Date: 2022-12-29T15:43:55-08:00
New Revision: 56c069887b955bfa9f87f5fe8e87d219afbf4063

URL: https://github.com/llvm/llvm-project/commit/56c069887b955bfa9f87f5fe8e87d219afbf4063
DIFF: https://github.com/llvm/llvm-project/commit/56c069887b955bfa9f87f5fe8e87d219afbf4063.diff

LOG: [mlir][spirv] Fail vector.bitcast conversion with different bitwidth

Depending on the target environment, we may need to emulate certain
types, which can cause issue with bitcast.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D140437

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 11608db409f71..da505884554b2 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <numeric>
 
 using namespace mlir;
@@ -33,6 +34,13 @@ static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
 
+/// Returns the number of bits for the given scalar/vector type.
+static int getNumBits(Type type) {
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return vectorType.cast<ShapedType>().getSizeInBits();
+  return type.getIntOrFloatBitWidth();
+}
+
 namespace {
 
 struct VectorBitcastConvert final
@@ -46,12 +54,24 @@ struct VectorBitcastConvert final
     if (!dstType)
       return failure();
 
-    if (dstType == adaptor.getSource().getType())
+    if (dstType == adaptor.getSource().getType()) {
       rewriter.replaceOp(bitcastOp, adaptor.getSource());
-    else
-      rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
-                                                    adaptor.getSource());
+      return success();
+    }
+
+    // Check that the source and destination type have the same bitwidth.
+    // Depending on the target environment, we may need to emulate certain
+    // types, which can cause issue with bitcast.
+    Type srcType = adaptor.getSource().getType();
+    if (getNumBits(dstType) != getNumBits(srcType)) {
+      return rewriter.notifyMatchFailure(
+          bitcastOp,
+          llvm::formatv("
diff erent source ({0}) and target ({1}) bitwidth",
+                        srcType, dstType));
+    }
 
+    rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
+                                                  adaptor.getSource());
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4bb1835e1d928..26a2ab6d62436 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -16,6 +16,23 @@ func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16
 
 // -----
 
+// Check that without the proper capability we fail the pattern application
+// to avoid generating invalid ops.
+
+module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>> } {
+
+// CHECK-LABEL: @bitcast
+func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) {
+  // CHECK-COUNT-2: vector.bitcast
+  %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
+  %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
+  return %0, %1: vector<4xf16>, vector<1xf32>
+}
+
+} // end module
+
+// -----
+
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>> } {
 
 // CHECK-LABEL: @cl_fma


        


More information about the Mlir-commits mailing list