[Mlir-commits] [mlir] [mlir] Add ContractionOpInterface utility functions for vector matrix multiplication (PR #68945)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 16 13:04:03 PDT 2023
https://github.com/NatashaKnk updated https://github.com/llvm/llvm-project/pull/68945
>From e431cc466e5609e19757481604dad47558219166 Mon Sep 17 00:00:00 2001
From: Natasha Kononenko <natashaknk at google.com>
Date: Fri, 13 Oct 2023 00:48:17 +0000
Subject: [PATCH 1/3] Add ContractionOpInterface utility functions for vector
matrix multiplication
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 33 +++++
.../mlir/Dialect/Utils/StructuredOpsUtils.h | 18 +++
mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp | 73 ++++++++++
.../Dialect/Utils/StructuredOpsUtilsTest.cpp | 130 ++++++++++++++++++
4 files changed, 254 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 9ca029b489ad144..e8e9b273cbcf234 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -86,6 +86,39 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
/*methodBody=*/[{
return mlir::isRowMajorBatchMatmul($_op.getIndexingMaps());
}]>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns whether the given op has indexing maps that correspond to a
+ vector-matrix multiplication.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isVecMat",
+ /*args=*/(ins),
+ /*methodBody=*/[{
+ return mlir::isVecMat($_op.getIndexingMaps());
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns whether the given op has indexing maps that correspond to a
+ matrix-vector multiplication.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isMatVec",
+ /*args=*/(ins),
+ /*methodBody=*/[{
+ return mlir::isMatVec($_op.getIndexingMaps());
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns whether the given op has indexing maps that correspond to a
+ batched matrix-vector multiplication.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isBatchMatVec",
+ /*args=*/(ins),
+ /*methodBody=*/[{
+ return mlir::isBatchMatVec($_op.getIndexingMaps());
+ }]>,
];
}
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index dab24bd93032692..7bdf260e83ce0be 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -49,6 +49,24 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
/// the reduction.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
+/// Tests whether the given maps describe a vector matrix multiplication. The
+/// test is permutation-invariant. Note that this only checks the affine maps
+/// from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isVecMat(ArrayAttr indexingMaps);
+
+/// Tests whether the given maps describe a matrix vector multiplication. The
+/// test is permutation-invariant. Note that this only checks the affine maps
+/// from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isMatVec(ArrayAttr indexingMaps);
+
+/// Tests whether the given maps describe a batch matrix vector multiplication.
+/// The test is permutation-invariant. Note that this only checks the affine
+/// maps from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isBatchMatVec(ArrayAttr indexingMaps);
+
/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
utils::IteratorType iteratorTypeName,
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index a2977901f4751d4..e56be5a062ec6b7 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -96,6 +96,79 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
return indexingMaps == maps;
}
+bool mlir::isVecMat(ArrayAttr indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return false;
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+ if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
+ map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
+ map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
+ return false;
+ }
+
+ // Extract dimensions for K * KxN -> N
+ AffineExpr k = map0.getResult(0);
+ AffineExpr n = map2.getResult(0);
+ auto *context = indexingMaps.getContext();
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
+ auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+ return indexingMaps == maps;
+}
+
+bool mlir::isMatVec(ArrayAttr indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return false;
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+ if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
+ map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
+ map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
+ return false;
+ }
+
+ // Extract dimensions for N*K * K -> N
+ AffineExpr k = map1.getResult(0);
+ AffineExpr n = map2.getResult(0);
+ auto *context = indexingMaps.getContext();
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
+ auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+ return indexingMaps == maps;
+}
+
+bool mlir::isBatchMatVec(ArrayAttr indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return false;
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+ if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
+ map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
+ map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
+ return false;
+ }
+
+ // Extract dimensions for B*N*K * B*K -> B*N
+ AffineExpr b = map0.getResult(0);
+ AffineExpr k = map1.getResult(1);
+ AffineExpr n = map2.getResult(1);
+ auto *context = indexingMaps.getContext();
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
+ auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+ return indexingMaps == maps;
+}
+
Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
ValueRange newOperands) {
IRMapping bvm;
diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
index 583dbd463b91159..c0b4fc285232c46 100644
--- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
@@ -240,4 +240,134 @@ TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
}
+TEST(isVecMat, Simple) {
+ MLIRContext context;
+
+ AffineExpr k, n;
+ bindDims(&context, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isVecMat));
+}
+
+TEST(isVecMat, BindingSwapped) {
+ MLIRContext context;
+
+ AffineExpr k, n;
+ bindDims(&context, k, n); // bind in different order
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isVecMat));
+}
+
+TEST(isVecMat, WrongDimOrderMatrix) {
+ MLIRContext context;
+
+ AffineExpr k, n;
+ bindDims(&context, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Not(Truly(isVecMat)));
+}
+
+TEST(isMatVec, Simple) {
+ MLIRContext context;
+
+ AffineExpr k, n;
+ bindDims(&context, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isMatVec));
+}
+
+TEST(isMatVec, BindingSwapped) {
+ MLIRContext context;
+
+ AffineExpr k, n;
+ bindDims(&context, k, n); // bind in different order
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isMatVec));
+}
+
+TEST(isMatVec, WrongDimOrderMatrix) {
+ MLIRContext context;
+
+ AffineExpr k, n;
+ bindDims(&context, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Not(Truly(isMatVec)));
+}
+
+TEST(isBatchMatVec, Simple) {
+ MLIRContext context;
+
+ AffineExpr batch, k, n;
+ bindDims(&context, batch, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isBatchMatVec));
+}
+
+TEST(isBatchMatVec, BindingSwapped) {
+ MLIRContext context;
+
+ AffineExpr batch, k, n;
+ bindDims(&context, batch, k, n); // bind in different order
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Truly(isBatchMatVec));
+}
+
+TEST(isBatchMatVec, Matmul) {
+ MLIRContext context;
+
+ AffineExpr m, n, k;
+ bindDims(&context, m, n, k);
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Not(Truly(isBatchMatVec)));
+}
+
+TEST(isBatchMatVec, WrongDimOrderMatrix) {
+ MLIRContext context;
+
+ AffineExpr batch, k, n;
+ bindDims(&context, batch, k, n);
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+ auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+ EXPECT_THAT(maps, Not(Truly(isBatchMatVec)));
+}
+
} // namespace
>From f2c6334db795021fd8a6f7d38c53842e9a2a1e4d Mon Sep 17 00:00:00 2001
From: Natasha Kononenko <natashaknk at google.com>
Date: Fri, 13 Oct 2023 01:01:08 +0000
Subject: [PATCH 2/3] Fix capitaliztation
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 12 +++---
.../mlir/Dialect/Utils/StructuredOpsUtils.h | 6 +--
mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp | 6 +--
.../Dialect/Utils/StructuredOpsUtilsTest.cpp | 40 +++++++++----------
4 files changed, 32 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index e8e9b273cbcf234..44e82f452b3cef1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -92,10 +92,10 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
vector-matrix multiplication.
}],
/*retTy=*/"bool",
- /*methodName=*/"isVecMat",
+ /*methodName=*/"isVecmat",
/*args=*/(ins),
/*methodBody=*/[{
- return mlir::isVecMat($_op.getIndexingMaps());
+ return mlir::isVecmat($_op.getIndexingMaps());
}]>,
InterfaceMethod<
/*desc=*/[{
@@ -103,10 +103,10 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
matrix-vector multiplication.
}],
/*retTy=*/"bool",
- /*methodName=*/"isMatVec",
+ /*methodName=*/"isMatvec",
/*args=*/(ins),
/*methodBody=*/[{
- return mlir::isMatVec($_op.getIndexingMaps());
+ return mlir::isMatvec($_op.getIndexingMaps());
}]>,
InterfaceMethod<
/*desc=*/[{
@@ -114,10 +114,10 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
batched matrix-vector multiplication.
}],
/*retTy=*/"bool",
- /*methodName=*/"isBatchMatVec",
+ /*methodName=*/"isBatchMatvec",
/*args=*/(ins),
/*methodBody=*/[{
- return mlir::isBatchMatVec($_op.getIndexingMaps());
+ return mlir::isBatchMatvec($_op.getIndexingMaps());
}]>,
];
}
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 7bdf260e83ce0be..225b9f287d340db 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -53,19 +53,19 @@ bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
/// test is permutation-invariant. Note that this only checks the affine maps
/// from an operation, so does not perform any checks on the math being
/// performed within the reduction.
-bool isVecMat(ArrayAttr indexingMaps);
+bool isVecmat(ArrayAttr indexingMaps);
/// Tests whether the given maps describe a matrix vector multiplication. The
/// test is permutation-invariant. Note that this only checks the affine maps
/// from an operation, so does not perform any checks on the math being
/// performed within the reduction.
-bool isMatVec(ArrayAttr indexingMaps);
+bool isMatvec(ArrayAttr indexingMaps);
/// Tests whether the given maps describe a batch matrix vector multiplication.
/// The test is permutation-invariant. Note that this only checks the affine
/// maps from an operation, so does not perform any checks on the math being
/// performed within the reduction.
-bool isBatchMatVec(ArrayAttr indexingMaps);
+bool isBatchMatvec(ArrayAttr indexingMaps);
/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index e56be5a062ec6b7..50579011a4f4fc1 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -96,7 +96,7 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
return indexingMaps == maps;
}
-bool mlir::isVecMat(ArrayAttr indexingMaps) {
+bool mlir::isVecmat(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
@@ -120,7 +120,7 @@ bool mlir::isVecMat(ArrayAttr indexingMaps) {
return indexingMaps == maps;
}
-bool mlir::isMatVec(ArrayAttr indexingMaps) {
+bool mlir::isMatvec(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
@@ -144,7 +144,7 @@ bool mlir::isMatVec(ArrayAttr indexingMaps) {
return indexingMaps == maps;
}
-bool mlir::isBatchMatVec(ArrayAttr indexingMaps) {
+bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
index c0b4fc285232c46..458e33da0697b4a 100644
--- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
@@ -240,7 +240,7 @@ TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
}
-TEST(isVecMat, Simple) {
+TEST(isVecmat, Simple) {
MLIRContext context;
AffineExpr k, n;
@@ -250,10 +250,10 @@ TEST(isVecMat, Simple) {
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Truly(isVecMat));
+ EXPECT_THAT(maps, Truly(isVecmat));
}
-TEST(isVecMat, BindingSwapped) {
+TEST(isVecmat, BindingSwapped) {
MLIRContext context;
AffineExpr k, n;
@@ -263,10 +263,10 @@ TEST(isVecMat, BindingSwapped) {
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Truly(isVecMat));
+ EXPECT_THAT(maps, Truly(isVecmat));
}
-TEST(isVecMat, WrongDimOrderMatrix) {
+TEST(isVecmat, WrongDimOrderMatrix) {
MLIRContext context;
AffineExpr k, n;
@@ -276,10 +276,10 @@ TEST(isVecMat, WrongDimOrderMatrix) {
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Not(Truly(isVecMat)));
+ EXPECT_THAT(maps, Not(Truly(isVecmat)));
}
-TEST(isMatVec, Simple) {
+TEST(isMatvec, Simple) {
MLIRContext context;
AffineExpr k, n;
@@ -289,10 +289,10 @@ TEST(isMatVec, Simple) {
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Truly(isMatVec));
+ EXPECT_THAT(maps, Truly(isMatvec));
}
-TEST(isMatVec, BindingSwapped) {
+TEST(isMatvec, BindingSwapped) {
MLIRContext context;
AffineExpr k, n;
@@ -302,10 +302,10 @@ TEST(isMatVec, BindingSwapped) {
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Truly(isMatVec));
+ EXPECT_THAT(maps, Truly(isMatvec));
}
-TEST(isMatVec, WrongDimOrderMatrix) {
+TEST(isMatvec, WrongDimOrderMatrix) {
MLIRContext context;
AffineExpr k, n;
@@ -315,10 +315,10 @@ TEST(isMatVec, WrongDimOrderMatrix) {
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Not(Truly(isMatVec)));
+ EXPECT_THAT(maps, Not(Truly(isMatvec)));
}
-TEST(isBatchMatVec, Simple) {
+TEST(isBatchMatvec, Simple) {
MLIRContext context;
AffineExpr batch, k, n;
@@ -328,10 +328,10 @@ TEST(isBatchMatVec, Simple) {
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Truly(isBatchMatVec));
+ EXPECT_THAT(maps, Truly(isBatchMatvec));
}
-TEST(isBatchMatVec, BindingSwapped) {
+TEST(isBatchMatvec, BindingSwapped) {
MLIRContext context;
AffineExpr batch, k, n;
@@ -341,10 +341,10 @@ TEST(isBatchMatVec, BindingSwapped) {
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Truly(isBatchMatVec));
+ EXPECT_THAT(maps, Truly(isBatchMatvec));
}
-TEST(isBatchMatVec, Matmul) {
+TEST(isBatchMatvec, Matmul) {
MLIRContext context;
AffineExpr m, n, k;
@@ -354,10 +354,10 @@ TEST(isBatchMatVec, Matmul) {
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Not(Truly(isBatchMatVec)));
+ EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
}
-TEST(isBatchMatVec, WrongDimOrderMatrix) {
+TEST(isBatchMatvec, WrongDimOrderMatrix) {
MLIRContext context;
AffineExpr batch, k, n;
@@ -367,7 +367,7 @@ TEST(isBatchMatVec, WrongDimOrderMatrix) {
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
- EXPECT_THAT(maps, Not(Truly(isBatchMatVec)));
+ EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
}
} // namespace
>From 58cda1dd44c100db217b5fdd2b7749ee13c0e808 Mon Sep 17 00:00:00 2001
From: Natasha Kononenko <natashaknk at google.com>
Date: Mon, 16 Oct 2023 19:57:20 +0000
Subject: [PATCH 3/3] Fix tests and add type clarity in util function
---
mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp | 36 +++++++++----------
.../Dialect/Utils/StructuredOpsUtilsTest.cpp | 6 ++--
2 files changed, 21 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index 50579011a4f4fc1..641ddf3f91cb2d9 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
- auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
- auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+ AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
- auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
- auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+ AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
- auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
- auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+ AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
@@ -99,9 +99,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
bool mlir::isVecmat(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
- auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
- auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+ AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
@@ -123,9 +123,9 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) {
bool mlir::isMatvec(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
- auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
- auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+ AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
@@ -147,9 +147,9 @@ bool mlir::isMatvec(ArrayAttr indexingMaps) {
bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
- auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
- auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+ AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
index 458e33da0697b4a..3f576bacebf6aad 100644
--- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
@@ -257,7 +257,7 @@ TEST(isVecmat, BindingSwapped) {
MLIRContext context;
AffineExpr k, n;
- bindDims(&context, k, n); // bind in different order
+ bindDims(&context, n, k); // bind in different order
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
@@ -296,7 +296,7 @@ TEST(isMatvec, BindingSwapped) {
MLIRContext context;
AffineExpr k, n;
- bindDims(&context, k, n); // bind in different order
+ bindDims(&context, n, k); // bind in different order
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
@@ -335,7 +335,7 @@ TEST(isBatchMatvec, BindingSwapped) {
MLIRContext context;
AffineExpr batch, k, n;
- bindDims(&context, batch, k, n); // bind in different order
+ bindDims(&context, batch, n, k); // bind in different order
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
More information about the Mlir-commits
mailing list