[Mlir-commits] [mlir] [mlir][spirv] Handle non-innerprod float vector add reductions (PR #73476)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 27 08:06:17 PST 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/73476
>From 49041e079eadcff4c8de7cebb490f7492cd72f4d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 23:08:19 -0500
Subject: [PATCH 1/5] [mlir][spirv] Handle non-innerprod float vector add
reductions
Instead of extracting all individial vector components and performing a
scalar summation, use `spirv.Dot` with the original reduction operand and
a vector constant of all ones.
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 32 +++++++++++---
.../VectorToSPIRV/vector-to-spirv.mlir | 44 +++++++++++++++++--
2 files changed, 66 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index ade41b0372c82f1..1db6713d8b85694 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
@@ -755,14 +756,33 @@ struct VectorReductionToFPDotProd final
if (!resultType)
return rewriter.notifyMatchFailure(op, "result is not a float");
- auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
- if (!mul)
- return rewriter.notifyMatchFailure(
- op, "reduction operand is not 'arith.mulf'");
+ auto vectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+ if (!vectorType) {
+ assert(isa<FloatType>(adaptor.getVector().getType()) &&
+ "Expected the vector to be scalarized");
+ rewriter.replaceOp(op, adaptor.getVector());
+ return success();
+ }
Location loc = op.getLoc();
- Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
- mul.getRhs());
+ Value lhs;
+ Value rhs;
+ if (auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>()) {
+ lhs = mul.getLhs();
+ rhs = mul.getRhs();
+ } else {
+ // If the operand is not a mul, use a vector of ones for the dot operand
+ // to just sum up all values.
+ lhs = adaptor.getVector();
+ Attribute oneAttr =
+ rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
+ oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
+ rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
+ }
+ assert(lhs);
+ assert(rhs);
+
+ Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
if (op.getAcc())
res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index d8585d59770bfdc..022bc0114bc523b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -500,11 +500,11 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
// -----
-// CHECK-LABEL: func @reduction_addf
+// CHECK-LABEL: func @reduction_addf_mulf
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
// CHECK: return %[[DOT]] : f32
-func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+func.func @reduction_addf_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
%mul = arith.mulf %arg0, %arg1 : vector<4xf32>
%red = vector.reduction <add>, %mul : vector<4xf32> into f32
return %red : f32
@@ -512,12 +512,12 @@ func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
// -----
-// CHECK-LABEL: func @reduction_addf_acc
+// CHECK-LABEL: func @reduction_addf_acc_mulf
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
// CHECK: return %[[RES]] : f32
-func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
+func.func @reduction_addf_acc_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
%mul = arith.mulf %arg0, %arg1 : vector<4xf32>
%red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
return %red : f32
@@ -525,6 +525,42 @@ func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc:
// -----
+// CHECK-LABEL: func @reduction_addf
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.0{{.+}}> : vector<4xf32>
+// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
+// CHECK: return %[[DOT]] : f32
+func.func @reduction_addf_mulf(%arg0: vector<4xf32>) -> f32 {
+ %red = vector.reduction <add>, %arg0 : vector<4xf32> into f32
+ return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_acc
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
+// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.0{{.*}}> : vector<4xf32>
+// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
+// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
+// CHECK: return %[[RES]] : f32
+func.func @reduction_addf_acc(%arg0: vector<4xf32>, %acc: f32) -> f32 {
+ %red = vector.reduction <add>, %arg0, %acc : vector<4xf32> into f32
+ return %red : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_addf_one_elem
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>)
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf32> to f32
+// CHECK: return %[[RES]] : f32
+func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 {
+ %red = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+ return %red : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_mul
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
>From 07fa057b14ef502c19f5a1d33f92e37fc5f1430c Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 23:14:00 -0500
Subject: [PATCH 2/5] Fix typo
---
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 022bc0114bc523b..e13e4356be6dd04 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -552,7 +552,7 @@ func.func @reduction_addf_acc(%arg0: vector<4xf32>, %acc: f32) -> f32 {
// CHECK-LABEL: func @reduction_addf_one_elem
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>)
-// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf32> to f32
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
// CHECK: return %[[RES]] : f32
func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 {
%red = vector.reduction <add>, %arg0 : vector<1xf32> into f32
>From dce4cdba703d660172acb80ac76d75dff9f2b3fd Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 23:28:24 -0500
Subject: [PATCH 3/5] Handle one element and accumulator
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 6 ++++++
.../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 12 ++++++++++++
2 files changed, 18 insertions(+)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1db6713d8b85694..86e80a49fe0d74b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -760,6 +760,12 @@ struct VectorReductionToFPDotProd final
if (!vectorType) {
assert(isa<FloatType>(adaptor.getVector().getType()) &&
"Expected the vector to be scalarized");
+ if (op.getAcc()) {
+ rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, adaptor.getAcc(),
+ adaptor.getVector());
+ return success();
+ }
+
rewriter.replaceOp(op, adaptor.getVector());
return success();
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index e13e4356be6dd04..c9984091d5acc6a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -561,6 +561,18 @@ func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 {
// -----
+// CHECK-LABEL: func @reduction_addf_one_elem_acc
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ACC:.+]]: f32)
+// CHECK: %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
+// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[RHS]] : f32
+// CHECK: return %[[RES]] : f32
+func.func @reduction_addf_one_elem_acc(%arg0: vector<1xf32>, %acc: f32) -> f32 {
+ %red = vector.reduction <add>, %arg0, %acc : vector<1xf32> into f32
+ return %red : f32
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_mul
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
>From 882aa1704fa43f942ecce6f72e5521cce66d6be8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 10:53:50 -0500
Subject: [PATCH 4/5] Address comments
---
.../Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 18 ++++++++++--------
1 file changed, 10 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 86e80a49fe0d74b..df1b4334c1f2e6f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -756,30 +756,32 @@ struct VectorReductionToFPDotProd final
if (!resultType)
return rewriter.notifyMatchFailure(op, "result is not a float");
- auto vectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
+ Value vec = adaptor.getVector();
+ Value acc = adaptor.getAcc();
+
+ auto vectorType = dyn_cast<VectorType>(vec.getType());
if (!vectorType) {
- assert(isa<FloatType>(adaptor.getVector().getType()) &&
+ assert(isa<FloatType>(vec.getType()) &&
"Expected the vector to be scalarized");
if (op.getAcc()) {
- rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, adaptor.getAcc(),
- adaptor.getVector());
+ rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
return success();
}
- rewriter.replaceOp(op, adaptor.getVector());
+ rewriter.replaceOp(op, vec);
return success();
}
Location loc = op.getLoc();
Value lhs;
Value rhs;
- if (auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>()) {
+ if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
lhs = mul.getLhs();
rhs = mul.getRhs();
} else {
// If the operand is not a mul, use a vector of ones for the dot operand
// to just sum up all values.
- lhs = adaptor.getVector();
+ lhs = vec;
Attribute oneAttr =
rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
@@ -790,7 +792,7 @@ struct VectorReductionToFPDotProd final
Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
if (op.getAcc())
- res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
+ res = rewriter.create<spirv::FAddOp>(loc, acc, res);
rewriter.replaceOp(op, res);
return success();
>From ac69293644c018a592321f6ca15d6a4801054822 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 Nov 2023 11:06:03 -0500
Subject: [PATCH 5/5] Simplify
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index df1b4334c1f2e6f..e48f29a4f170290 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -763,7 +763,7 @@ struct VectorReductionToFPDotProd final
if (!vectorType) {
assert(isa<FloatType>(vec.getType()) &&
"Expected the vector to be scalarized");
- if (op.getAcc()) {
+ if (acc) {
rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
return success();
}
@@ -791,7 +791,7 @@ struct VectorReductionToFPDotProd final
assert(rhs);
Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
- if (op.getAcc())
+ if (acc)
res = rewriter.create<spirv::FAddOp>(loc, acc, res);
rewriter.replaceOp(op, res);
More information about the Mlir-commits
mailing list