[Mlir-commits] [mlir] a89021b - [mlir][spirv] Enable dot operation for bfloat16 (#145409)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 11 07:16:04 PDT 2025
Author: Darren Wihandi
Date: 2025-07-11T10:16:00-04:00
New Revision: a89021bc83705172b4a4cdac0a95ff50f4f868b1
URL: https://github.com/llvm/llvm-project/commit/a89021bc83705172b4a4cdac0a95ff50f4f868b1
DIFF: https://github.com/llvm/llvm-project/commit/a89021bc83705172b4a4cdac0a95ff50f4f868b1.diff
LOG: [mlir][spirv] Enable dot operation for bfloat16 (#145409)
Allows dot operations to use vectors of bfloat16 type.
Added:
mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
mlir/test/Dialect/SPIRV/IR/availability.mlir
mlir/test/Target/SPIRV/arithmetic-ops.mlir
Removed:
mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 46a705eefc262..65771b602e0d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -462,16 +462,19 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];
let arguments = (ins
- SPIRV_VectorOf<SPIRV_Float>:$vector1,
- SPIRV_VectorOf<SPIRV_Float>:$vector2
+ SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
+ SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
);
let results = (outs
- SPIRV_Float:$result
+ SPIRV_AnyFloat:$result
);
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
+ // Require dynamic availability specification based on operand/result type.
+ bit autogenAvailability = 0;
+
let hasVerifier = 0;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 1a8f30dd39871..b9aa7b7491abf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
CastOps.cpp
ControlFlowOps.cpp
CooperativeMatrixOps.cpp
+ DotProductOps.cpp
GroupOps.cpp
ImageOps.cpp
- IntegerDotProductOps.cpp
MemoryOps.cpp
MeshOps.cpp
SPIRVAttributes.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
similarity index 83%
rename from mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
rename to mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index f5676f36a0f5f..01ef1bdc42515 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -1,4 +1,4 @@
-//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
+//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// Defines the Integer Dot Product operations in the SPIR-V dialect.
+// Defines the Dot Product operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
+//===----------------------------------------------------------------------===//
+// Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static std::optional<spirv::Version> getDotProductMinVersion() {
+ return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static std::optional<spirv::Version> getDotProductMaxVersion() {
+ return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
+ if (isa<BFloat16Type>(getType())) {
+ static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
+ return {extension};
+ }
+
+ return {};
+}
+
+SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
+ if (isa<BFloat16Type>(getType())) {
+ static const auto capability = spirv::Capability::BFloat16DotProductKHR;
+ return {capability};
+ }
+
+ return {};
+}
+
+std::optional<spirv::Version> DotOp::getMinVersion() {
+ return getDotProductMinVersion();
+}
+
+std::optional<spirv::Version> DotOp::getMaxVersion() {
+ return getDotProductMaxVersion();
+}
+
//===----------------------------------------------------------------------===//
// Integer Dot Product ops
//===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
return success();
}
-static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
- return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
- return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
-}
-
static SmallVector<ArrayRef<spirv::Extension>, 1>
getIntegerDotProductExtensions() {
// Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
return getIntegerDotProductCapabilities<OpName>(*this); \
} \
std::optional<spirv::Version> OpName::getMinVersion() { \
- return getIntegerDotProductMinVersion(); \
+ return getDotProductMinVersion(); \
} \
std::optional<spirv::Version> OpName::getMaxVersion() { \
- return getIntegerDotProductMaxVersion(); \
+ return getDotProductMaxVersion(); \
}
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 99ab0e1dc4eef..27fd74e12d36e 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -967,6 +967,22 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
// -----
+module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
+
+// CHECK-LABEL: func @reduction_bf16_addf_mulf
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>, %[[ARG1:.+]]: vector<4xbf16>)
+// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xbf16> -> bf16
+// CHECK: return %[[DOT]] : bf16
+func.func @reduction_bf16_addf_mulf(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+ %mul = arith.mulf %arg0, %arg1 : vector<4xbf16>
+ %red = vector.reduction <add>, %mul : vector<4xbf16> into bf16
+ return %red : bf16
+}
+
+} // end module
+
+// -----
+
// CHECK-LABEL: @shape_cast_same_type
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>)
// CHECK: return %[[ARG0]]
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index d58c27598f2b8..3adafc15c79f6 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: @dot_bf16
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+ // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ return %0 : bf16
+}
+
+// -----
+
// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects
diff erent type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
- // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 64ba8e3fc249e..9c8665b1e4bbe 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
return %r: i64
}
+//===----------------------------------------------------------------------===//
+// Dot Product op with bfloat16
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dot_vector_4xbf16_bf16
+func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_bfloat16] ]
+ // CHECK: capabilities: [ [BFloat16DotProductKHR] ]
+ %r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
+ return %r: bf16
+}
+
//===----------------------------------------------------------------------===//
// Primitive ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854fd..b80e17f979daa 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
spirv.Return
}
+ spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
+ // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ spirv.Return
+ }
}
More information about the Mlir-commits
mailing list