[Mlir-commits] [mlir] b5192cb - [mlir][spirv] Fix result type for arith.cmpi/cmpf conversion

Lei Zhang llvmlistbot at llvm.org
Mon Jun 13 10:16:09 PDT 2022


Author: Lei Zhang
Date: 2022-06-13T13:15:23-04:00
New Revision: b5192cbe506c13c13fc6a3cefda3a8ecef72bd40

URL: https://github.com/llvm/llvm-project/commit/b5192cbe506c13c13fc6a3cefda3a8ecef72bd40
DIFF: https://github.com/llvm/llvm-project/commit/b5192cbe506c13c13fc6a3cefda3a8ecef72bd40.diff

LOG: [mlir][spirv] Fix result type for arith.cmpi/cmpf conversion

We cannot directly use the original result type; instead we need
to deduce it from the converted operand type. This addresses
invalid ops generated from converting single element vectors.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D127574

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index c4b6382a42ba..31d9023e4346 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -275,6 +275,22 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
+/// Returns true if scalar/vector type `a` and `b` have the same number of
+/// bitwidth.
+static bool hasSameBitwidth(Type a, Type b) {
+  auto getNumBitwidth = [](Type type) {
+    unsigned bw = 0;
+    if (type.isIntOrFloat())
+      bw = type.getIntOrFloatBitWidth();
+    else if (auto vecType = type.dyn_cast<VectorType>())
+      bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
+    return bw;
+  };
+  unsigned aBW = getNumBitwidth(a);
+  unsigned bBW = getNumBitwidth(b);
+  return aBW != 0 && bBW != 0 && aBW == bBW;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp with composite type
 //===----------------------------------------------------------------------===//
@@ -655,10 +671,11 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
 
   switch (op.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
-  case cmpPredicate:                                                           \
-    rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
-                                         adaptor.getLhs(), adaptor.getRhs());  \
-    return success();
+  case cmpPredicate: {                                                         \
+    rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
+                                         adaptor.getRhs());                    \
+    return success();                                                          \
+  }
 
     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
@@ -676,20 +693,23 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
 LogicalResult
 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
-  Type operandType = op.getLhs().getType();
-  if (isBoolScalarOrVector(operandType))
+  Type srcType = op.getLhs().getType();
+  if (isBoolScalarOrVector(srcType))
+    return failure();
+  Type dstType = getTypeConverter()->convertType(srcType);
+  if (!dstType)
     return failure();
 
   switch (op.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
-        operandType != this->getTypeConverter()->convertType(operandType)) {   \
+        srcType != dstType && !hasSameBitwidth(srcType, dstType)) {            \
       return op.emitError(                                                     \
           "bitwidth emulation is not implemented yet on unsigned op");         \
     }                                                                          \
-    rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
-                                         adaptor.getLhs(), adaptor.getRhs());  \
+    rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
+                                         adaptor.getRhs());                    \
     return success();
 
     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
@@ -718,8 +738,8 @@ CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
   switch (op.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
-    rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
-                                         adaptor.getLhs(), adaptor.getRhs());  \
+    rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
+                                         adaptor.getRhs());                    \
     return success();
 
     // Ordered.

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 7d17359030d4..892519751287 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -392,6 +392,15 @@ func.func @cmpi(%arg0 : i32, %arg1 : i32) {
   return
 }
 
+// CHECK-LABEL: @vec1cmpi
+func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
+  // CHECK: spv.ULessThan
+  %0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
+  // CHECK: spv.SGreaterThan
+  %1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
+  return
+}
+
 // CHECK-LABEL: @boolcmpi
 func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
   // CHECK: spv.LogicalEqual
@@ -401,6 +410,15 @@ func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
   return
 }
 
+// CHECK-LABEL: @vec1boolcmpi
+func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+  // CHECK: spv.LogicalEqual
+  %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
+  // CHECK: spv.LogicalNotEqual
+  %1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
+  return
+}
+
 // CHECK-LABEL: @vecboolcmpi
 func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
   // CHECK: spv.LogicalEqual
@@ -1237,6 +1255,15 @@ func.func @cmpf(%arg0 : f32, %arg1 : f32) {
   return
 }
 
+// CHECK-LABEL: @vec1cmpf
+func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
+  // CHECK: spv.FOrdGreaterThan
+  %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
+  // CHECK: spv.FUnordLessThan
+  %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
+  return
+}
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list