[Mlir-commits] [mlir] 72e8b28 - [mlir][spirv] Allow vectors of index types in elementwise conversions

Quinn Dawkins llvmlistbot at llvm.org
Fri Mar 17 10:37:08 PDT 2023


Author: Quinn Dawkins
Date: 2023-03-17T13:33:56-04:00
New Revision: 72e8b286f03c7f6bacbec10ca9883f77d482284c

URL: https://github.com/llvm/llvm-project/commit/72e8b286f03c7f6bacbec10ca9883f77d482284c
DIFF: https://github.com/llvm/llvm-project/commit/72e8b286f03c7f6bacbec10ca9883f77d482284c.diff

LOG: [mlir][spirv] Allow vectors of index types in elementwise conversions

Currently the conversion of elementwise ops only checks for scalar index
types when checking for bitwidth emulation.

Differential Revision: https://reviews.llvm.org/D146307

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/lib/Conversion/SPIRVCommon/Pattern.h
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index be46a06ba53ab..b6ed24490bb0e 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -845,7 +845,8 @@ CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
-        srcType != dstType && !hasSameBitwidth(srcType, dstType)) {            \
+        !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType &&      \
+        !hasSameBitwidth(srcType, dstType)) {                                  \
       return op.emitError(                                                     \
           "bitwidth emulation is not implemented yet on unsigned op");         \
     }                                                                          \

diff  --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
index 4da3e197ca3ed..7425f4b5311ce 100644
--- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h
+++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
@@ -10,6 +10,7 @@
 #define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -34,9 +35,11 @@ struct ElementwiseOpPattern : public OpConversionPattern<Op> {
     }
 
     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
-        !op.getType().isIndex() && dstType != op.getType()) {
-      return op.emitError(
-          "bitwidth emulation is not implemented yet on unsigned op");
+        !getElementTypeOrSelf(op.getType()).isIndex() &&
+        dstType != op.getType()) {
+      op.dump();
+      return op.emitError("bitwidth emulation is not implemented yet on "
+                          "unsigned op pattern version");
     }
     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
                                                   adaptor.getOperands());

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 4c32a8648bd1c..d70df982c366a 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -179,6 +179,13 @@ func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
   return
 }
 
+// CHECK-LABEL: @index_vector
+func.func @index_vector(%arg0: vector<4xindex>) {
+  // CHECK: spirv.UMod %{{.*}}, %{{.*}}: vector<4xi32>
+  %0 = arith.remui %arg0, %arg0: vector<4xindex>
+  return
+}
+
 // CHECK-LABEL: @vector_srem
 // CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
 func.func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
@@ -417,6 +424,31 @@ func.func @cmpi(%arg0 : i32, %arg1 : i32) {
   return
 }
 
+// CHECK-LABEL: @indexcmpi
+func.func @indexcmpi(%arg0 : index, %arg1 : index) {
+  // CHECK: spirv.IEqual
+  %0 = arith.cmpi eq, %arg0, %arg1 : index
+  // CHECK: spirv.INotEqual
+  %1 = arith.cmpi ne, %arg0, %arg1 : index
+  // CHECK: spirv.SLessThan
+  %2 = arith.cmpi slt, %arg0, %arg1 : index
+  // CHECK: spirv.SLessThanEqual
+  %3 = arith.cmpi sle, %arg0, %arg1 : index
+  // CHECK: spirv.SGreaterThan
+  %4 = arith.cmpi sgt, %arg0, %arg1 : index
+  // CHECK: spirv.SGreaterThanEqual
+  %5 = arith.cmpi sge, %arg0, %arg1 : index
+  // CHECK: spirv.ULessThan
+  %6 = arith.cmpi ult, %arg0, %arg1 : index
+  // CHECK: spirv.ULessThanEqual
+  %7 = arith.cmpi ule, %arg0, %arg1 : index
+  // CHECK: spirv.UGreaterThan
+  %8 = arith.cmpi ugt, %arg0, %arg1 : index
+  // CHECK: spirv.UGreaterThanEqual
+  %9 = arith.cmpi uge, %arg0, %arg1 : index
+  return
+}
+
 // CHECK-LABEL: @vec1cmpi
 func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
   // CHECK: spirv.ULessThan


        


More information about the Mlir-commits mailing list