[Mlir-commits] [mlir] [mlir][spirv] Enable dot operation for bfloat16 (PR #145409)

Darren Wihandi llvmlistbot at llvm.org
Tue Jul 8 19:46:34 PDT 2025


https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/145409

>From 9a4e40ae626ea6ae26e6937901ecefbc004ea6b2 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 23 Jun 2025 03:55:05 -0400
Subject: [PATCH 1/4] [mlir][spirv] Enable dot operation for bfloat16

---
 .../mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td       |  6 +++---
 mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir        | 11 ++++++++++-
 mlir/test/Target/SPIRV/arithmetic-ops.mlir            |  5 +++++
 3 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 309079e549846..33af979a45bc5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
   }];
 
   let arguments = (ins
-    SPIRV_VectorOf<SPIRV_Float>:$vector1,
-    SPIRV_VectorOf<SPIRV_Float>:$vector2
+    SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
+    SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
   );
 
   let results = (outs
-    SPIRV_Float:$result
+    SPIRV_AnyFloat:$result
   );
 
   let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index d58c27598f2b8..3adafc15c79f6 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @dot_bf16
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+  // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+  return %0 : bf16
+}
+
+// -----
+
 // expected-note @+1 {{prior use here}}
 func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
   // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
 // -----
 
 func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
-  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
   %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
   return %0 : i32
 }
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854fd..84d301c608d7d 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
     spirv.Return
   }
+  spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
+    // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> f16
+    %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+    spirv.Return
+  }
 }

>From 71755ac3c413b67b7b6c6eb52f9f62b2bf01482a Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Wed, 25 Jun 2025 15:36:47 -0600
Subject: [PATCH 2/4] Fix test typo

---
 mlir/test/Target/SPIRV/arithmetic-ops.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index 84d301c608d7d..b80e17f979daa 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -87,7 +87,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.Return
   }
   spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
-    // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> f16
+    // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
     %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
     spirv.Return
   }

>From e51870f94d3a0aeac4bf643e11f3dc6a8c626911 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sat, 5 Jul 2025 00:25:21 -0400
Subject: [PATCH 3/4] Add availability generation and test

---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |  3 ++
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt      |  2 +-
 ...gerDotProductOps.cpp => DotProductOps.cpp} | 54 ++++++++++++++-----
 mlir/test/Dialect/SPIRV/IR/availability.mlir  | 14 +++++
 4 files changed, 60 insertions(+), 13 deletions(-)
 rename mlir/lib/Dialect/SPIRV/IR/{IntegerDotProductOps.cpp => DotProductOps.cpp} (83%)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 33af979a45bc5..2260ee85493c7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -455,6 +455,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
 
   let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
 
+  // Require dynamic availability specification based on operand/result type.
+  bit autogenAvailability = 0;
+
   let hasVerifier = 0;
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 1a8f30dd39871..b9aa7b7491abf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   CastOps.cpp
   ControlFlowOps.cpp
   CooperativeMatrixOps.cpp
+  DotProductOps.cpp
   GroupOps.cpp
   ImageOps.cpp
-  IntegerDotProductOps.cpp
   MemoryOps.cpp
   MeshOps.cpp
   SPIRVAttributes.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
similarity index 83%
rename from mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
rename to mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index f5676f36a0f5f..2f8da5f58a793 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -1,4 +1,4 @@
-//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops  ----===//
+//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops  ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Defines the Integer Dot Product operations in the SPIR-V dialect.
+// Defines the Dot Product operations in the SPIR-V dialect.
 //
 //===----------------------------------------------------------------------===//
 
@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
+//===----------------------------------------------------------------------===//
+// Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static std::optional<spirv::Version> getDotProductMinVersion() {
+  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static std::optional<spirv::Version> getDotProductMaxVersion() {
+  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
+  if (getResult().getType().isBF16()) {
+    static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
+    return {extension};
+  }
+
+  return {};
+}
+
+SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
+  if (getResult().getType().isBF16()) {
+    static const auto capability = spirv::Capability::BFloat16DotProductKHR;
+    return {capability};
+  }
+
+  return {};
+}
+
+std::optional<spirv::Version> DotOp::getMinVersion() {
+  return getDotProductMinVersion();
+}
+
+std::optional<spirv::Version> DotOp::getMaxVersion() {
+  return getDotProductMaxVersion();
+}
+
 //===----------------------------------------------------------------------===//
 // Integer Dot Product ops
 //===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   return success();
 }
 
-static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
-  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
-  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
-}
-
 static SmallVector<ArrayRef<spirv::Extension>, 1>
 getIntegerDotProductExtensions() {
   // Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
     return getIntegerDotProductCapabilities<OpName>(*this);                    \
   }                                                                            \
   std::optional<spirv::Version> OpName::getMinVersion() {                      \
-    return getIntegerDotProductMinVersion();                                   \
+    return getDotProductMinVersion();                                          \
   }                                                                            \
   std::optional<spirv::Version> OpName::getMaxVersion() {                      \
-    return getIntegerDotProductMaxVersion();                                   \
+    return getDotProductMaxVersion();                                          \
   }
 
 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 64ba8e3fc249e..9c8665b1e4bbe 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   return %r: i64
 }
 
+//===----------------------------------------------------------------------===//
+// Dot Product op with bfloat16
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dot_vector_4xbf16_bf16
+func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_bfloat16] ]
+  // CHECK: capabilities: [ [BFloat16DotProductKHR] ]
+  %r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
+  return %r: bf16
+}
+
 //===----------------------------------------------------------------------===//
 // Primitive ops
 //===----------------------------------------------------------------------===//

>From 41b0195691e2cef0b0fb92a682cdbd19bc0645e8 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Tue, 8 Jul 2025 22:45:10 -0400
Subject: [PATCH 4/4] Address review comments

---
 mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index 2f8da5f58a793..01ef1bdc42515 100644
--- a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -1,4 +1,4 @@
-//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops  ----===//
+//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops  -------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -34,7 +34,7 @@ static std::optional<spirv::Version> getDotProductMaxVersion() {
 }
 
 SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
-  if (getResult().getType().isBF16()) {
+  if (isa<BFloat16Type>(getType())) {
     static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
     return {extension};
   }
@@ -43,7 +43,7 @@ SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
 }
 
 SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
-  if (getResult().getType().isBF16()) {
+  if (isa<BFloat16Type>(getType())) {
     static const auto capability = spirv::Capability::BFloat16DotProductKHR;
     return {capability};
   }



More information about the Mlir-commits mailing list