[Mlir-commits] [mlir] [MLIR][NVVM] Update dot.accumulate NVVM Ops (PR #140518)

Srinivasa Ravi llvmlistbot at llvm.org
Thu May 22 21:49:16 PDT 2025


================
@@ -851,18 +851,56 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  %0 = nvvm.dot.accumulate.4way %a <u>, %b <u>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.4way %a <s>, %b <u>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  %2 = nvvm.dot.accumulate.4way %a <u>, %b <s>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.4way %a <s>, %b <s>, %c: vector<4xi8>, vector<4xi8>
+  llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_dot_accumulate_2way
+llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %0 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %1 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c {hi = true}: vector<2xi16>, vector<4xi8>
----------------
Wolfram70 wrote:

Renamed the attribute and made it need to be explicitly mentioned in the latest revision, thanks!

https://github.com/llvm/llvm-project/pull/140518


More information about the Mlir-commits mailing list