[Mlir-commits] [mlir] b5192cb - [mlir][spirv] Fix result type for arith.cmpi/cmpf conversion
Lei Zhang
llvmlistbot at llvm.org
Mon Jun 13 10:16:09 PDT 2022
Author: Lei Zhang
Date: 2022-06-13T13:15:23-04:00
New Revision: b5192cbe506c13c13fc6a3cefda3a8ecef72bd40
URL: https://github.com/llvm/llvm-project/commit/b5192cbe506c13c13fc6a3cefda3a8ecef72bd40
DIFF: https://github.com/llvm/llvm-project/commit/b5192cbe506c13c13fc6a3cefda3a8ecef72bd40.diff
LOG: [mlir][spirv] Fix result type for arith.cmpi/cmpf conversion
We cannot directly use the original result type; instead we need
to deduce it from the converted operand type. This addresses
invalid ops generated from converting single element vectors.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D127574
Added:
Modified:
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index c4b6382a42ba..31d9023e4346 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -275,6 +275,22 @@ static bool isBoolScalarOrVector(Type type) {
return false;
}
+/// Returns true if scalar/vector type `a` and `b` have the same number of
+/// bitwidth.
+static bool hasSameBitwidth(Type a, Type b) {
+ auto getNumBitwidth = [](Type type) {
+ unsigned bw = 0;
+ if (type.isIntOrFloat())
+ bw = type.getIntOrFloatBitWidth();
+ else if (auto vecType = type.dyn_cast<VectorType>())
+ bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
+ return bw;
+ };
+ unsigned aBW = getNumBitwidth(a);
+ unsigned bBW = getNumBitwidth(b);
+ return aBW != 0 && bBW != 0 && aBW == bBW;
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp with composite type
//===----------------------------------------------------------------------===//
@@ -655,10 +671,11 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
- case cmpPredicate: \
- rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \
- adaptor.getLhs(), adaptor.getRhs()); \
- return success();
+ case cmpPredicate: { \
+ rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
+ adaptor.getRhs()); \
+ return success(); \
+ }
DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
@@ -676,20 +693,23 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
LogicalResult
CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- Type operandType = op.getLhs().getType();
- if (isBoolScalarOrVector(operandType))
+ Type srcType = op.getLhs().getType();
+ if (isBoolScalarOrVector(srcType))
+ return failure();
+ Type dstType = getTypeConverter()->convertType(srcType);
+ if (!dstType)
return failure();
switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
- operandType != this->getTypeConverter()->convertType(operandType)) { \
+ srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \
return op.emitError( \
"bitwidth emulation is not implemented yet on unsigned op"); \
} \
- rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \
- adaptor.getLhs(), adaptor.getRhs()); \
+ rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
+ adaptor.getRhs()); \
return success();
DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
@@ -718,8 +738,8 @@ CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
- rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \
- adaptor.getLhs(), adaptor.getRhs()); \
+ rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
+ adaptor.getRhs()); \
return success();
// Ordered.
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 7d17359030d4..892519751287 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -392,6 +392,15 @@ func.func @cmpi(%arg0 : i32, %arg1 : i32) {
return
}
+// CHECK-LABEL: @vec1cmpi
+func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
+ // CHECK: spv.ULessThan
+ %0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
+ // CHECK: spv.SGreaterThan
+ %1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
+ return
+}
+
// CHECK-LABEL: @boolcmpi
func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
// CHECK: spv.LogicalEqual
@@ -401,6 +410,15 @@ func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
return
}
+// CHECK-LABEL: @vec1boolcmpi
+func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+ // CHECK: spv.LogicalEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
+ // CHECK: spv.LogicalNotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
+ return
+}
+
// CHECK-LABEL: @vecboolcmpi
func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spv.LogicalEqual
@@ -1237,6 +1255,15 @@ func.func @cmpf(%arg0 : f32, %arg1 : f32) {
return
}
+// CHECK-LABEL: @vec1cmpf
+func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
+ // CHECK: spv.FOrdGreaterThan
+ %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
+ // CHECK: spv.FUnordLessThan
+ %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
+ return
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list