[llvm] 0e8717f - [Matrix] Add shape verification.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Sat May 13 01:42:26 PDT 2023
Author: Florian Hahn
Date: 2023-05-13T09:41:27+01:00
New Revision: 0e8717f71198365fd4e13a01d62a44e22ea6526c
URL: https://github.com/llvm/llvm-project/commit/0e8717f71198365fd4e13a01d62a44e22ea6526c
DIFF: https://github.com/llvm/llvm-project/commit/0e8717f71198365fd4e13a01d62a44e22ea6526c.diff
LOG: [Matrix] Add shape verification.
At the moment, lower-matrix-intrinsics accepts mis-matches between
shapes for operations. See shape-verification.ll for an example where
@llvm.matrix.column.major.load specifies 6x1 and then the use
(@llvm.matrix.multiply) specifies the operand to have 1x6.
This patch adds verification for shapes to check if shapes match.
Reviewed By: thegameg
Differential Revision: https://reviews.llvm.org/D147438
Added:
llvm/test/Transforms/LowerMatrixIntrinsics/shape-verification.ll
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index f9a149f3616aa..594556a0b13df 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -72,6 +72,11 @@ static cl::opt<bool> AllowContractEnabled(
cl::desc("Allow the use of FMAs if available and profitable. This may "
"result in
diff erent results, due to less rounding error."));
+static cl::opt<bool>
+ VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
+ cl::desc("Enable/disable matrix shape verification."),
+ cl::init(false));
+
enum class MatrixLayoutTy { ColumnMajor, RowMajor };
static cl::opt<MatrixLayoutTy> MatrixLayout(
@@ -535,6 +540,15 @@ class LowerMatrixIntrinsics {
auto SIter = ShapeMap.find(V);
if (SIter != ShapeMap.end()) {
+ if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
+ SIter->second.NumColumns != Shape.NumColumns)) {
+ errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
+ << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
+ << Shape.NumColumns << ") for " << *V << "\n";
+ report_fatal_error(
+ "Matrix shape verification failed, compilation aborted!");
+ }
+
LLVM_DEBUG(dbgs() << " not overriding existing shape: "
<< SIter->second.NumRows << " "
<< SIter->second.NumColumns << " for " << *V << "\n");
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/shape-verification.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/shape-verification.ll
new file mode 100644
index 0000000000000..999fb62e59033
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/shape-verification.ll
@@ -0,0 +1,16 @@
+; RUN: not --crash opt -passes='lower-matrix-intrinsics' -verify-matrix-shapes=true -S %s 2>&1 | FileCheck --check-prefix=VERIFY %s
+; RUN: opt -passes='lower-matrix-intrinsics' -verify-matrix-shapes=false -S %s 2>&1 | FileCheck --check-prefix=NOVERIFY %s
+
+; VERIFY: Conflicting shapes (6x1 vs 1x6)
+; NOVERIFY-NOT: Conflicting shapes
+
+define <1 x float> @intrinsic_column_major_load_dot_product_float_v6(ptr %lhs_address, ptr %rhs_address) {
+entry:
+ %lhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %lhs_address, i64 6, i1 false, i32 6, i32 1)
+ %rhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %rhs_address, i64 1, i1 false, i32 1, i32 6)
+ %result = tail call fast <1 x float> @llvm.matrix.multiply.v1f32.v6f32.v6f32(<6 x float> %lhs, <6 x float> %rhs, i32 1, i32 6, i32 1)
+ ret <1 x float> %result
+}
+
+declare <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4, i64, i1, i32, i32)
+declare <1 x float> @llvm.matrix.multiply.v1f32.v6f32.v6f32(<6 x float>, <6 x float>, i32, i32, i32)
More information about the llvm-commits
mailing list