[Mlir-commits] [mlir] [MLIR][NVVM] Update dot.accumulate NVVM Ops (PR #140518)
Srinivasa Ravi
llvmlistbot at llvm.org
Thu May 22 21:49:16 PDT 2025
================
@@ -851,18 +851,56 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %0 = nvvm.dot.accumulate.4way %a <u>, %b <u>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.4way %a <s>, %b <u>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %2 = nvvm.dot.accumulate.4way %a <u>, %b <s>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.4way %a <s>, %b <s>, %c: vector<4xi8>, vector<4xi8>
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_dot_accumulate_2way
+llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+ %0 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+ %1 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c {hi = true}: vector<2xi16>, vector<4xi8>
----------------
Wolfram70 wrote:
Renamed the attribute and made it need to be explicitly mentioned in the latest revision, thanks!
https://github.com/llvm/llvm-project/pull/140518
More information about the Mlir-commits
mailing list