[Mlir-commits] [mlir] [mlir][spirv] Add a generic `convert-to-spirv` pass (PR #95942)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jun 19 09:43:12 PDT 2024
================
@@ -0,0 +1,276 @@
+// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// arithmetic ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @int32_scalar
+func.func @int32_scalar(%lhs: i32, %rhs: i32) {
+ // CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
+ %0 = arith.addi %lhs, %rhs: i32
+ // CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
+ %1 = arith.subi %lhs, %rhs: i32
+ // CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
+ %2 = arith.muli %lhs, %rhs: i32
+ // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
+ %3 = arith.divsi %lhs, %rhs: i32
+ // CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
+ %4 = arith.divui %lhs, %rhs: i32
+ // CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
+ %5 = arith.remui %lhs, %rhs: i32
+ return
+}
+
+// CHECK-LABEL: @int32_scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
+ // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
+ // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
+ // CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
+ // CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
+ // CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
+ // CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+ %0 = arith.remsi %lhs, %rhs: i32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// std bit ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_scalar
+func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.BitwiseAnd
+ %0 = arith.andi %arg0, %arg1 : i32
+ // CHECK: spirv.BitwiseOr
+ %1 = arith.ori %arg0, %arg1 : i32
+ // CHECK: spirv.BitwiseXor
+ %2 = arith.xori %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @bitwise_vector
+func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+ // CHECK: spirv.BitwiseAnd
+ %0 = arith.andi %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.BitwiseOr
+ %1 = arith.ori %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.BitwiseXor
+ %2 = arith.xori %arg0, %arg1 : vector<4xi32>
+ return
+}
+
+// CHECK-LABEL: @logical_scalar
+func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spirv.LogicalAnd
+ %0 = arith.andi %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalOr
+ %1 = arith.ori %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalNotEqual
+ %2 = arith.xori %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @logical_vector
+func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+ // CHECK: spirv.LogicalAnd
+ %0 = arith.andi %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalOr
+ %1 = arith.ori %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %2 = arith.xori %arg0, %arg1 : vector<4xi1>
+ return
+}
+
+// CHECK-LABEL: @shift_scalar
+func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spirv.ShiftLeftLogical
+ %0 = arith.shli %arg0, %arg1 : i32
+ // CHECK: spirv.ShiftRightArithmetic
+ %1 = arith.shrsi %arg0, %arg1 : i32
+ // CHECK: spirv.ShiftRightLogical
+ %2 = arith.shrui %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @shift_vector
+func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+ // CHECK: spirv.ShiftLeftLogical
+ %0 = arith.shli %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.ShiftRightArithmetic
+ %1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
+ // CHECK: spirv.ShiftRightLogical
+ %2 = arith.shrui %arg0, %arg1 : vector<4xi32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// arith.cmpf
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmpf
+func.func @cmpf(%arg0 : f32, %arg1 : f32) {
+ // CHECK: spirv.FOrdEqual
+ %1 = arith.cmpf oeq, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdGreaterThan
+ %2 = arith.cmpf ogt, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdGreaterThanEqual
+ %3 = arith.cmpf oge, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdLessThan
+ %4 = arith.cmpf olt, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdLessThanEqual
+ %5 = arith.cmpf ole, %arg0, %arg1 : f32
+ // CHECK: spirv.FOrdNotEqual
+ %6 = arith.cmpf one, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordEqual
+ %7 = arith.cmpf ueq, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordGreaterThan
+ %8 = arith.cmpf ugt, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordGreaterThanEqual
+ %9 = arith.cmpf uge, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordLessThan
+ %10 = arith.cmpf ult, %arg0, %arg1 : f32
+ // CHECK: FUnordLessThanEqual
+ %11 = arith.cmpf ule, %arg0, %arg1 : f32
+ // CHECK: spirv.FUnordNotEqual
+ %12 = arith.cmpf une, %arg0, %arg1 : f32
----------------
kuhar wrote:
I don't think we need to test all of these here -- one should be enough
https://github.com/llvm/llvm-project/pull/95942
More information about the Mlir-commits
mailing list