[Mlir-commits] [mlir] [mlir][spirv] Add a generic `convert-to-spirv` pass (PR #95942)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jun 19 09:43:12 PDT 2024


================
@@ -0,0 +1,276 @@
+// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// arithmetic ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @int32_scalar
+func.func @int32_scalar(%lhs: i32, %rhs: i32) {
+  // CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
+  %0 = arith.addi %lhs, %rhs: i32
+  // CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
+  %1 = arith.subi %lhs, %rhs: i32
+  // CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
+  %2 = arith.muli %lhs, %rhs: i32
+  // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
+  %3 = arith.divsi %lhs, %rhs: i32
+  // CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
+  %4 = arith.divui %lhs, %rhs: i32
+  // CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
+  %5 = arith.remui %lhs, %rhs: i32
+  return
+}
+
+// CHECK-LABEL: @int32_scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
+  // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
+  // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
+  // CHECK:  %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
+  // CHECK:  %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
+  // CHECK:  %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
+  // CHECK:      %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+  %0 = arith.remsi %lhs, %rhs: i32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std bit ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_scalar
+func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spirv.BitwiseAnd
+  %0 = arith.andi %arg0, %arg1 : i32
+  // CHECK: spirv.BitwiseOr
+  %1 = arith.ori %arg0, %arg1 : i32
+  // CHECK: spirv.BitwiseXor
+  %2 = arith.xori %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @bitwise_vector
+func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+  // CHECK: spirv.BitwiseAnd
+  %0 = arith.andi %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.BitwiseOr
+  %1 = arith.ori %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.BitwiseXor
+  %2 = arith.xori %arg0, %arg1 : vector<4xi32>
+  return
+}
+
+// CHECK-LABEL: @logical_scalar
+func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
+  // CHECK: spirv.LogicalAnd
+  %0 = arith.andi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalOr
+  %1 = arith.ori %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalNotEqual
+  %2 = arith.xori %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @logical_vector
+func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+  // CHECK: spirv.LogicalAnd
+  %0 = arith.andi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalOr
+  %1 = arith.ori %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalNotEqual
+  %2 = arith.xori %arg0, %arg1 : vector<4xi1>
+  return
+}
+
+// CHECK-LABEL: @shift_scalar
+func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spirv.ShiftLeftLogical
+  %0 = arith.shli %arg0, %arg1 : i32
+  // CHECK: spirv.ShiftRightArithmetic
+  %1 = arith.shrsi %arg0, %arg1 : i32
+  // CHECK: spirv.ShiftRightLogical
+  %2 = arith.shrui %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @shift_vector
+func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+  // CHECK: spirv.ShiftLeftLogical
+  %0 = arith.shli %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.ShiftRightArithmetic
+  %1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
+  // CHECK: spirv.ShiftRightLogical
+  %2 = arith.shrui %arg0, %arg1 : vector<4xi32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpf
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpf
+func.func @cmpf(%arg0 : f32, %arg1 : f32) {
+  // CHECK: spirv.FOrdEqual
+  %1 = arith.cmpf oeq, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdGreaterThan
+  %2 = arith.cmpf ogt, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdGreaterThanEqual
+  %3 = arith.cmpf oge, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdLessThan
+  %4 = arith.cmpf olt, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdLessThanEqual
+  %5 = arith.cmpf ole, %arg0, %arg1 : f32
+  // CHECK: spirv.FOrdNotEqual
+  %6 = arith.cmpf one, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordEqual
+  %7 = arith.cmpf ueq, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordGreaterThan
+  %8 = arith.cmpf ugt, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordGreaterThanEqual
+  %9 = arith.cmpf uge, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordLessThan
+  %10 = arith.cmpf ult, %arg0, %arg1 : f32
+  // CHECK: FUnordLessThanEqual
+  %11 = arith.cmpf ule, %arg0, %arg1 : f32
+  // CHECK: spirv.FUnordNotEqual
+  %12 = arith.cmpf une, %arg0, %arg1 : f32
----------------
kuhar wrote:

I don't think we need to test all of these here -- one should be enough

https://github.com/llvm/llvm-project/pull/95942


More information about the Mlir-commits mailing list