[Mlir-commits] [mlir] 72e8b28 - [mlir][spirv] Allow vectors of index types in elementwise conversions
Quinn Dawkins
llvmlistbot at llvm.org
Fri Mar 17 10:37:08 PDT 2023
Author: Quinn Dawkins
Date: 2023-03-17T13:33:56-04:00
New Revision: 72e8b286f03c7f6bacbec10ca9883f77d482284c
URL: https://github.com/llvm/llvm-project/commit/72e8b286f03c7f6bacbec10ca9883f77d482284c
DIFF: https://github.com/llvm/llvm-project/commit/72e8b286f03c7f6bacbec10ca9883f77d482284c.diff
LOG: [mlir][spirv] Allow vectors of index types in elementwise conversions
Currently the conversion of elementwise ops only checks for scalar index
types when checking for bitwidth emulation.
Differential Revision: https://reviews.llvm.org/D146307
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Conversion/SPIRVCommon/Pattern.h
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index be46a06ba53ab..b6ed24490bb0e 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -845,7 +845,8 @@ CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
- srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \
+ !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
+ !hasSameBitwidth(srcType, dstType)) { \
return op.emitError( \
"bitwidth emulation is not implemented yet on unsigned op"); \
} \
diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
index 4da3e197ca3ed..7425f4b5311ce 100644
--- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h
+++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h
@@ -10,6 +10,7 @@
#define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
@@ -34,9 +35,11 @@ struct ElementwiseOpPattern : public OpConversionPattern<Op> {
}
if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
- !op.getType().isIndex() && dstType != op.getType()) {
- return op.emitError(
- "bitwidth emulation is not implemented yet on unsigned op");
+ !getElementTypeOrSelf(op.getType()).isIndex() &&
+ dstType != op.getType()) {
+ op.dump();
+ return op.emitError("bitwidth emulation is not implemented yet on "
+ "unsigned op pattern version");
}
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
adaptor.getOperands());
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 4c32a8648bd1c..d70df982c366a 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -179,6 +179,13 @@ func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
return
}
+// CHECK-LABEL: @index_vector
+func.func @index_vector(%arg0: vector<4xindex>) {
+ // CHECK: spirv.UMod %{{.*}}, %{{.*}}: vector<4xi32>
+ %0 = arith.remui %arg0, %arg0: vector<4xindex>
+ return
+}
+
// CHECK-LABEL: @vector_srem
// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
func.func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
@@ -417,6 +424,31 @@ func.func @cmpi(%arg0 : i32, %arg1 : i32) {
return
}
+// CHECK-LABEL: @indexcmpi
+func.func @indexcmpi(%arg0 : index, %arg1 : index) {
+ // CHECK: spirv.IEqual
+ %0 = arith.cmpi eq, %arg0, %arg1 : index
+ // CHECK: spirv.INotEqual
+ %1 = arith.cmpi ne, %arg0, %arg1 : index
+ // CHECK: spirv.SLessThan
+ %2 = arith.cmpi slt, %arg0, %arg1 : index
+ // CHECK: spirv.SLessThanEqual
+ %3 = arith.cmpi sle, %arg0, %arg1 : index
+ // CHECK: spirv.SGreaterThan
+ %4 = arith.cmpi sgt, %arg0, %arg1 : index
+ // CHECK: spirv.SGreaterThanEqual
+ %5 = arith.cmpi sge, %arg0, %arg1 : index
+ // CHECK: spirv.ULessThan
+ %6 = arith.cmpi ult, %arg0, %arg1 : index
+ // CHECK: spirv.ULessThanEqual
+ %7 = arith.cmpi ule, %arg0, %arg1 : index
+ // CHECK: spirv.UGreaterThan
+ %8 = arith.cmpi ugt, %arg0, %arg1 : index
+ // CHECK: spirv.UGreaterThanEqual
+ %9 = arith.cmpi uge, %arg0, %arg1 : index
+ return
+}
+
// CHECK-LABEL: @vec1cmpi
func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
// CHECK: spirv.ULessThan
More information about the Mlir-commits
mailing list