[Mlir-commits] [mlir] [mlir][spirv] Add a generic `convert-to-spirv` pass (PR #95942)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jun 19 09:43:14 PDT 2024


================
@@ -0,0 +1,276 @@
+// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// arithmetic ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @int32_scalar
+func.func @int32_scalar(%lhs: i32, %rhs: i32) {
+  // CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
+  %0 = arith.addi %lhs, %rhs: i32
+  // CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
+  %1 = arith.subi %lhs, %rhs: i32
+  // CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
+  %2 = arith.muli %lhs, %rhs: i32
+  // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
+  %3 = arith.divsi %lhs, %rhs: i32
+  // CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
+  %4 = arith.divui %lhs, %rhs: i32
+  // CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
+  %5 = arith.remui %lhs, %rhs: i32
+  return
+}
+
+// CHECK-LABEL: @int32_scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
+  // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
+  // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
+  // CHECK:  %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
+  // CHECK:  %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
+  // CHECK:  %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
+  // CHECK:      %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+  %0 = arith.remsi %lhs, %rhs: i32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std bit ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_scalar
+func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spirv.BitwiseAnd
+  %0 = arith.andi %arg0, %arg1 : i32
+  // CHECK: spirv.BitwiseOr
+  %1 = arith.ori %arg0, %arg1 : i32
+  // CHECK: spirv.BitwiseXor
+  %2 = arith.xori %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @bitwise_vector
+func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+  // CHECK: spirv.BitwiseAnd
+  %0 = arith.andi %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.BitwiseOr
+  %1 = arith.ori %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.BitwiseXor
+  %2 = arith.xori %arg0, %arg1 : vector<4xi32>
+  return
+}
+
+// CHECK-LABEL: @logical_scalar
+func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
+  // CHECK: spirv.LogicalAnd
+  %0 = arith.andi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalOr
+  %1 = arith.ori %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalNotEqual
+  %2 = arith.xori %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @logical_vector
+func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+  // CHECK: spirv.LogicalAnd
+  %0 = arith.andi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalOr
+  %1 = arith.ori %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalNotEqual
+  %2 = arith.xori %arg0, %arg1 : vector<4xi1>
+  return
+}
+
+// CHECK-LABEL: @shift_scalar
+func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spirv.ShiftLeftLogical
+  %0 = arith.shli %arg0, %arg1 : i32
+  // CHECK: spirv.ShiftRightArithmetic
+  %1 = arith.shrsi %arg0, %arg1 : i32
+  // CHECK: spirv.ShiftRightLogical
+  %2 = arith.shrui %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @shift_vector
+func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+  // CHECK: spirv.ShiftLeftLogical
+  %0 = arith.shli %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.ShiftRightArithmetic
+  %1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.ShiftRightLogical
+  %2 = arith.shrui %arg0, %arg1 : vector<4xi32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpf
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpf
+func.func @cmpf(%arg0 : f32, %arg1 : f32) {
+  // CHECK: spirv.FOrdEqual
+  %1 = arith.cmpf oeq, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdGreaterThan
+  %2 = arith.cmpf ogt, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdGreaterThanEqual
+  %3 = arith.cmpf oge, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdLessThan
+  %4 = arith.cmpf olt, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdLessThanEqual
+  %5 = arith.cmpf ole, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdNotEqual
+  %6 = arith.cmpf one, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordEqual
+  %7 = arith.cmpf ueq, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordGreaterThan
+  %8 = arith.cmpf ugt, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordGreaterThanEqual
+  %9 = arith.cmpf uge, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordLessThan
+  %10 = arith.cmpf ult, %arg0, %arg1 : f32
+  // CHECK: FUnordLessThanEqual
+  %11 = arith.cmpf ule, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordNotEqual
+  %12 = arith.cmpf une, %arg0, %arg1 : f32
+  return
+}
+
+// CHECK-LABEL: @vec1cmpf
+func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
+  // CHECK: spirv.FOrdGreaterThan
+  %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
+  // CHECK: spirv.FUnordLessThan
+  %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpi
+func.func @cmpi(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spirv.IEqual
+  %0 = arith.cmpi eq, %arg0, %arg1 : i32
+  // CHECK: spirv.INotEqual
+  %1 = arith.cmpi ne, %arg0, %arg1 : i32
+  // CHECK: spirv.SLessThan
+  %2 = arith.cmpi slt, %arg0, %arg1 : i32
+  // CHECK: spirv.SLessThanEqual
+  %3 = arith.cmpi sle, %arg0, %arg1 : i32
+  // CHECK: spirv.SGreaterThan
+  %4 = arith.cmpi sgt, %arg0, %arg1 : i32
+  // CHECK: spirv.SGreaterThanEqual
+  %5 = arith.cmpi sge, %arg0, %arg1 : i32
+  // CHECK: spirv.ULessThan
+  %6 = arith.cmpi ult, %arg0, %arg1 : i32
+  // CHECK: spirv.ULessThanEqual
+  %7 = arith.cmpi ule, %arg0, %arg1 : i32
+  // CHECK: spirv.UGreaterThan
+  %8 = arith.cmpi ugt, %arg0, %arg1 : i32
+  // CHECK: spirv.UGreaterThanEqual
+  %9 = arith.cmpi uge, %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @indexcmpi
+func.func @indexcmpi(%arg0 : index, %arg1 : index) {
+  // CHECK: spirv.IEqual
+  %0 = arith.cmpi eq, %arg0, %arg1 : index
+  // CHECK: spirv.INotEqual
+  %1 = arith.cmpi ne, %arg0, %arg1 : index
+  // CHECK: spirv.SLessThan
+  %2 = arith.cmpi slt, %arg0, %arg1 : index
+  // CHECK: spirv.SLessThanEqual
+  %3 = arith.cmpi sle, %arg0, %arg1 : index
+  // CHECK: spirv.SGreaterThan
+  %4 = arith.cmpi sgt, %arg0, %arg1 : index
+  // CHECK: spirv.SGreaterThanEqual
+  %5 = arith.cmpi sge, %arg0, %arg1 : index
+  // CHECK: spirv.ULessThan
+  %6 = arith.cmpi ult, %arg0, %arg1 : index
+  // CHECK: spirv.ULessThanEqual
+  %7 = arith.cmpi ule, %arg0, %arg1 : index
+  // CHECK: spirv.UGreaterThan
+  %8 = arith.cmpi ugt, %arg0, %arg1 : index
+  // CHECK: spirv.UGreaterThanEqual
+  %9 = arith.cmpi uge, %arg0, %arg1 : index
----------------
kuhar wrote:

same here

https://github.com/llvm/llvm-project/pull/95942


More information about the Mlir-commits mailing list