[Mlir-commits] [mlir] a89021b - [mlir][spirv] Enable dot operation for bfloat16 (#145409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 11 07:16:04 PDT 2025


Author: Darren Wihandi
Date: 2025-07-11T10:16:00-04:00
New Revision: a89021bc83705172b4a4cdac0a95ff50f4f868b1

URL: https://github.com/llvm/llvm-project/commit/a89021bc83705172b4a4cdac0a95ff50f4f868b1
DIFF: https://github.com/llvm/llvm-project/commit/a89021bc83705172b4a4cdac0a95ff50f4f868b1.diff

LOG: [mlir][spirv] Enable dot operation for bfloat16 (#145409)

Allows dot operations to use vectors of bfloat16 type.

Added: 
    mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
    mlir/test/Dialect/SPIRV/IR/availability.mlir
    mlir/test/Target/SPIRV/arithmetic-ops.mlir

Removed: 
    mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 46a705eefc262..65771b602e0d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -462,16 +462,19 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
   }];
 
   let arguments = (ins
-    SPIRV_VectorOf<SPIRV_Float>:$vector1,
-    SPIRV_VectorOf<SPIRV_Float>:$vector2
+    SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
+    SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
   );
 
   let results = (outs
-    SPIRV_Float:$result
+    SPIRV_AnyFloat:$result
   );
 
   let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
 
+  // Require dynamic availability specification based on operand/result type.
+  bit autogenAvailability = 0;
+
   let hasVerifier = 0;
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 1a8f30dd39871..b9aa7b7491abf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   CastOps.cpp
   ControlFlowOps.cpp
   CooperativeMatrixOps.cpp
+  DotProductOps.cpp
   GroupOps.cpp
   ImageOps.cpp
-  IntegerDotProductOps.cpp
   MemoryOps.cpp
   MeshOps.cpp
   SPIRVAttributes.cpp

diff  --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
similarity index 83%
rename from mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
rename to mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index f5676f36a0f5f..01ef1bdc42515 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -1,4 +1,4 @@
-//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops  ----===//
+//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops  -------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Defines the Integer Dot Product operations in the SPIR-V dialect.
+// Defines the Dot Product operations in the SPIR-V dialect.
 //
 //===----------------------------------------------------------------------===//
 
@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
+//===----------------------------------------------------------------------===//
+// Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static std::optional<spirv::Version> getDotProductMinVersion() {
+  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static std::optional<spirv::Version> getDotProductMaxVersion() {
+  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
+  if (isa<BFloat16Type>(getType())) {
+    static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
+    return {extension};
+  }
+
+  return {};
+}
+
+SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
+  if (isa<BFloat16Type>(getType())) {
+    static const auto capability = spirv::Capability::BFloat16DotProductKHR;
+    return {capability};
+  }
+
+  return {};
+}
+
+std::optional<spirv::Version> DotOp::getMinVersion() {
+  return getDotProductMinVersion();
+}
+
+std::optional<spirv::Version> DotOp::getMaxVersion() {
+  return getDotProductMaxVersion();
+}
+
 //===----------------------------------------------------------------------===//
 // Integer Dot Product ops
 //===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   return success();
 }
 
-static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
-  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
-  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
-}
-
 static SmallVector<ArrayRef<spirv::Extension>, 1>
 getIntegerDotProductExtensions() {
   // Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
     return getIntegerDotProductCapabilities<OpName>(*this);                    \
   }                                                                            \
   std::optional<spirv::Version> OpName::getMinVersion() {                      \
-    return getIntegerDotProductMinVersion();                                   \
+    return getDotProductMinVersion();                                          \
   }                                                                            \
   std::optional<spirv::Version> OpName::getMaxVersion() {                      \
-    return getIntegerDotProductMaxVersion();                                   \
+    return getDotProductMaxVersion();                                          \
   }
 
 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 99ab0e1dc4eef..27fd74e12d36e 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -967,6 +967,22 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
 
 // -----
 
+module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
+
+// CHECK-LABEL: func @reduction_bf16_addf_mulf
+//  CHECK-SAME:  (%[[ARG0:.+]]: vector<4xbf16>, %[[ARG1:.+]]: vector<4xbf16>)
+//  CHECK:       %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xbf16> -> bf16
+//  CHECK:       return %[[DOT]] : bf16
+func.func @reduction_bf16_addf_mulf(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+  %mul = arith.mulf %arg0, %arg1 : vector<4xbf16>
+  %red = vector.reduction <add>, %mul : vector<4xbf16> into bf16
+  return %red : bf16
+}
+
+} // end module
+
+// -----
+
 // CHECK-LABEL: @shape_cast_same_type
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>)
 //       CHECK:   return %[[ARG0]]

diff  --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index d58c27598f2b8..3adafc15c79f6 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @dot_bf16
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+  // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+  return %0 : bf16
+}
+
+// -----
+
 // expected-note @+1 {{prior use here}}
 func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
   // expected-error @+1 {{use of value '%arg1' expects 
diff erent type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
 // -----
 
 func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
-  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
   %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
   return %0 : i32
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 64ba8e3fc249e..9c8665b1e4bbe 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   return %r: i64
 }
 
+//===----------------------------------------------------------------------===//
+// Dot Product op with bfloat16
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dot_vector_4xbf16_bf16
+func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_bfloat16] ]
+  // CHECK: capabilities: [ [BFloat16DotProductKHR] ]
+  %r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
+  return %r: bf16
+}
+
 //===----------------------------------------------------------------------===//
 // Primitive ops
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854fd..b80e17f979daa 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
     spirv.Return
   }
+  spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
+    // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+    %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+    spirv.Return
+  }
 }


        


More information about the Mlir-commits mailing list