[Mlir-commits] [mlir] [mlir][spirv] Enable dot operation for bfloat16 (PR #145409)
Darren Wihandi
llvmlistbot at llvm.org
Tue Jul 8 19:46:34 PDT 2025
https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/145409
>From 9a4e40ae626ea6ae26e6937901ecefbc004ea6b2 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 23 Jun 2025 03:55:05 -0400
Subject: [PATCH 1/4] [mlir][spirv] Enable dot operation for bfloat16
---
.../mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 6 +++---
mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 11 ++++++++++-
mlir/test/Target/SPIRV/arithmetic-ops.mlir | 5 +++++
3 files changed, 18 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 309079e549846..33af979a45bc5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];
let arguments = (ins
- SPIRV_VectorOf<SPIRV_Float>:$vector1,
- SPIRV_VectorOf<SPIRV_Float>:$vector2
+ SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
+ SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
);
let results = (outs
- SPIRV_Float:$result
+ SPIRV_AnyFloat:$result
);
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index d58c27598f2b8..3adafc15c79f6 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: @dot_bf16
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+ // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ return %0 : bf16
+}
+
+// -----
+
// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
- // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854fd..84d301c608d7d 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
spirv.Return
}
+ spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
+ // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> f16
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ spirv.Return
+ }
}
>From 71755ac3c413b67b7b6c6eb52f9f62b2bf01482a Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Wed, 25 Jun 2025 15:36:47 -0600
Subject: [PATCH 2/4] Fix test typo
---
mlir/test/Target/SPIRV/arithmetic-ops.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index 84d301c608d7d..b80e17f979daa 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -87,7 +87,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.Return
}
spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
- // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> f16
+ // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
spirv.Return
}
>From e51870f94d3a0aeac4bf643e11f3dc6a8c626911 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sat, 5 Jul 2025 00:25:21 -0400
Subject: [PATCH 3/4] Add availability generation and test
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 3 ++
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt | 2 +-
...gerDotProductOps.cpp => DotProductOps.cpp} | 54 ++++++++++++++-----
mlir/test/Dialect/SPIRV/IR/availability.mlir | 14 +++++
4 files changed, 60 insertions(+), 13 deletions(-)
rename mlir/lib/Dialect/SPIRV/IR/{IntegerDotProductOps.cpp => DotProductOps.cpp} (83%)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 33af979a45bc5..2260ee85493c7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -455,6 +455,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
+ // Require dynamic availability specification based on operand/result type.
+ bit autogenAvailability = 0;
+
let hasVerifier = 0;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 1a8f30dd39871..b9aa7b7491abf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
CastOps.cpp
ControlFlowOps.cpp
CooperativeMatrixOps.cpp
+ DotProductOps.cpp
GroupOps.cpp
ImageOps.cpp
- IntegerDotProductOps.cpp
MemoryOps.cpp
MeshOps.cpp
SPIRVAttributes.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
similarity index 83%
rename from mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
rename to mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index f5676f36a0f5f..2f8da5f58a793 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -1,4 +1,4 @@
-//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
+//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// Defines the Integer Dot Product operations in the SPIR-V dialect.
+// Defines the Dot Product operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
+//===----------------------------------------------------------------------===//
+// Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static std::optional<spirv::Version> getDotProductMinVersion() {
+ return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static std::optional<spirv::Version> getDotProductMaxVersion() {
+ return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
+ if (getResult().getType().isBF16()) {
+ static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
+ return {extension};
+ }
+
+ return {};
+}
+
+SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
+ if (getResult().getType().isBF16()) {
+ static const auto capability = spirv::Capability::BFloat16DotProductKHR;
+ return {capability};
+ }
+
+ return {};
+}
+
+std::optional<spirv::Version> DotOp::getMinVersion() {
+ return getDotProductMinVersion();
+}
+
+std::optional<spirv::Version> DotOp::getMaxVersion() {
+ return getDotProductMaxVersion();
+}
+
//===----------------------------------------------------------------------===//
// Integer Dot Product ops
//===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
return success();
}
-static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
- return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
- return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
-}
-
static SmallVector<ArrayRef<spirv::Extension>, 1>
getIntegerDotProductExtensions() {
// Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
return getIntegerDotProductCapabilities<OpName>(*this); \
} \
std::optional<spirv::Version> OpName::getMinVersion() { \
- return getIntegerDotProductMinVersion(); \
+ return getDotProductMinVersion(); \
} \
std::optional<spirv::Version> OpName::getMaxVersion() { \
- return getIntegerDotProductMaxVersion(); \
+ return getDotProductMaxVersion(); \
}
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 64ba8e3fc249e..9c8665b1e4bbe 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
return %r: i64
}
+//===----------------------------------------------------------------------===//
+// Dot Product op with bfloat16
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dot_vector_4xbf16_bf16
+func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_bfloat16] ]
+ // CHECK: capabilities: [ [BFloat16DotProductKHR] ]
+ %r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
+ return %r: bf16
+}
+
//===----------------------------------------------------------------------===//
// Primitive ops
//===----------------------------------------------------------------------===//
>From 41b0195691e2cef0b0fb92a682cdbd19bc0645e8 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Tue, 8 Jul 2025 22:45:10 -0400
Subject: [PATCH 4/4] Address review comments
---
mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index 2f8da5f58a793..01ef1bdc42515 100644
--- a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -1,4 +1,4 @@
-//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===//
+//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -34,7 +34,7 @@ static std::optional<spirv::Version> getDotProductMaxVersion() {
}
SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
- if (getResult().getType().isBF16()) {
+ if (isa<BFloat16Type>(getType())) {
static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
return {extension};
}
@@ -43,7 +43,7 @@ SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
}
SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
- if (getResult().getType().isBF16()) {
+ if (isa<BFloat16Type>(getType())) {
static const auto capability = spirv::Capability::BFloat16DotProductKHR;
return {capability};
}
More information about the Mlir-commits
mailing list