[Mlir-commits] [mlir] Add `isBatchVecmat` utilities for `linalg.batch_vecmat` (PR #70284)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 25 19:55:09 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: None (bjacob)

<details>
<summary>Changes</summary>

`linalg.batch_vecmat` was just added in https://github.com/llvm/llvm-project/pull/70218, but I forgot then to add the standard `isBatchVecmat` utilities

---
Full diff: https://github.com/llvm/llvm-project/pull/70284.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+11) 
- (modified) mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h (+6) 
- (modified) mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp (+25) 
- (modified) mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp (+52) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 44e82f452b3cef1..69ca888a8acdbe0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -98,6 +98,17 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
         return mlir::isVecmat($_op.getIndexingMaps());
     }]>,
     InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps that correspond to a
+      batched vector-matrix multiplication.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isBatchVecmat",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isBatchVecmat($_op.getIndexingMaps());
+    }]>,
+    InterfaceMethod<
     /*desc=*/[{
       Returns whether the given op has indexing maps that correspond to a
       matrix-vector multiplication.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 225b9f287d340db..134c5569fbb2f3e 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -55,6 +55,12 @@ bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
 /// performed within the reduction.
 bool isVecmat(ArrayAttr indexingMaps);
 
+/// Tests whether the given maps describe a batch vector matrix multiplication.
+/// The test is permutation-invariant. Note that this only checks the affine
+/// maps from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isBatchVecmat(ArrayAttr indexingMaps);
+
 /// Tests whether the given maps describe a matrix vector multiplication. The
 /// test is permutation-invariant. Note that this only checks the affine maps
 /// from an operation, so does not perform any checks on the math being
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index 641ddf3f91cb2d9..383ef1cea53fd30 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -120,6 +120,31 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) {
   return indexingMaps == maps;
 }
 
+bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
+  if (indexingMaps.size() != 3)
+    return false;
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+  if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
+      map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
+      map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
+    return false;
+  }
+
+  // Extract dimensions for B*K * B*K*N -> B*N
+  AffineExpr b = map0.getResult(0);
+  AffineExpr k = map0.getResult(1);
+  AffineExpr n = map2.getResult(1);
+  auto *context = indexingMaps.getContext();
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+  return indexingMaps == maps;
+}
+
 bool mlir::isMatvec(ArrayAttr indexingMaps) {
   if (indexingMaps.size() != 3)
     return false;
diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
index 3f576bacebf6aad..d257fc5d6e041d1 100644
--- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
@@ -370,4 +370,56 @@ TEST(isBatchMatvec, WrongDimOrderMatrix) {
   EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
 }
 
+TEST(isBatchVecmat, Simple) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isBatchVecmat));
+}
+
+TEST(isBatchVecmat, BindingSwapped) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, n, k); // bind in different order
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isBatchVecmat));
+}
+
+TEST(isBatchVecmat, Matmul) {
+  MLIRContext context;
+
+  AffineExpr m, n, k;
+  bindDims(&context, m, n, k);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
+}
+
+TEST(isBatchVecmat, WrongDimOrderMatrix) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
+}
+
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/70284


More information about the Mlir-commits mailing list