[Mlir-commits] [mlir] 56c0698 - [mlir][spirv] Fail vector.bitcast conversion with different bitwidth
Lei Zhang
llvmlistbot at llvm.org
Thu Dec 29 15:44:14 PST 2022
Author: Lei Zhang
Date: 2022-12-29T15:43:55-08:00
New Revision: 56c069887b955bfa9f87f5fe8e87d219afbf4063
URL: https://github.com/llvm/llvm-project/commit/56c069887b955bfa9f87f5fe8e87d219afbf4063
DIFF: https://github.com/llvm/llvm-project/commit/56c069887b955bfa9f87f5fe8e87d219afbf4063.diff
LOG: [mlir][spirv] Fail vector.bitcast conversion with different bitwidth
Depending on the target environment, we may need to emulate certain
types, which can cause issue with bitcast.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D140437
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 11608db409f71..da505884554b2 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -23,6 +23,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/FormatVariadic.h"
#include <numeric>
using namespace mlir;
@@ -33,6 +34,13 @@ static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
+/// Returns the number of bits for the given scalar/vector type.
+static int getNumBits(Type type) {
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return vectorType.cast<ShapedType>().getSizeInBits();
+ return type.getIntOrFloatBitWidth();
+}
+
namespace {
struct VectorBitcastConvert final
@@ -46,12 +54,24 @@ struct VectorBitcastConvert final
if (!dstType)
return failure();
- if (dstType == adaptor.getSource().getType())
+ if (dstType == adaptor.getSource().getType()) {
rewriter.replaceOp(bitcastOp, adaptor.getSource());
- else
- rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
- adaptor.getSource());
+ return success();
+ }
+
+ // Check that the source and destination type have the same bitwidth.
+ // Depending on the target environment, we may need to emulate certain
+ // types, which can cause issue with bitcast.
+ Type srcType = adaptor.getSource().getType();
+ if (getNumBits(dstType) != getNumBits(srcType)) {
+ return rewriter.notifyMatchFailure(
+ bitcastOp,
+ llvm::formatv("
diff erent source ({0}) and target ({1}) bitwidth",
+ srcType, dstType));
+ }
+ rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
+ adaptor.getSource());
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4bb1835e1d928..26a2ab6d62436 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -16,6 +16,23 @@ func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16
// -----
+// Check that without the proper capability we fail the pattern application
+// to avoid generating invalid ops.
+
+module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>> } {
+
+// CHECK-LABEL: @bitcast
+func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) {
+ // CHECK-COUNT-2: vector.bitcast
+ %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
+ %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
+ return %0, %1: vector<4xf16>, vector<1xf32>
+}
+
+} // end module
+
+// -----
+
module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>> } {
// CHECK-LABEL: @cl_fma
More information about the Mlir-commits
mailing list