[Mlir-commits] [mlir] 9f49509 - [mlir] Add ContractionOpInterface utility functions for vector matrix multiplication (#68945)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 18 08:55:56 PDT 2023


Author: NatashaKnk
Date: 2023-10-18T08:55:51-07:00
New Revision: 9f4950983e2ae1e8cebf48ce33b84c23ac3d2dc2

URL: https://github.com/llvm/llvm-project/commit/9f4950983e2ae1e8cebf48ce33b84c23ac3d2dc2
DIFF: https://github.com/llvm/llvm-project/commit/9f4950983e2ae1e8cebf48ce33b84c23ac3d2dc2.diff

LOG: [mlir] Add ContractionOpInterface utility functions for vector matrix multiplication (#68945)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
    mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 9ca029b489ad144..44e82f452b3cef1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -86,6 +86,39 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
     /*methodBody=*/[{
         return mlir::isRowMajorBatchMatmul($_op.getIndexingMaps());
     }]>,
+    InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps that correspond to a
+      vector-matrix multiplication.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isVecmat",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isVecmat($_op.getIndexingMaps());
+    }]>,
+    InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps that correspond to a
+      matrix-vector multiplication.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isMatvec",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isMatvec($_op.getIndexingMaps());
+    }]>,
+    InterfaceMethod<
+    /*desc=*/[{
+      Returns whether the given op has indexing maps that correspond to a
+      batched matrix-vector multiplication.
+    }],
+    /*retTy=*/"bool",
+    /*methodName=*/"isBatchMatvec",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+        return mlir::isBatchMatvec($_op.getIndexingMaps());
+    }]>,
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index dab24bd93032692..225b9f287d340db 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -49,6 +49,24 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
 /// the reduction.
 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
 
+/// Tests whether the given maps describe a vector matrix multiplication. The
+/// test is permutation-invariant. Note that this only checks the affine maps
+/// from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isVecmat(ArrayAttr indexingMaps);
+
+/// Tests whether the given maps describe a matrix vector multiplication. The
+/// test is permutation-invariant. Note that this only checks the affine maps
+/// from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isMatvec(ArrayAttr indexingMaps);
+
+/// Tests whether the given maps describe a batch matrix vector multiplication.
+/// The test is permutation-invariant. Note that this only checks the affine
+/// maps from an operation, so does not perform any checks on the math being
+/// performed within the reduction.
+bool isBatchMatvec(ArrayAttr indexingMaps);
+
 /// Return positions in `iteratorTypes` that match `iteratorTypeName`.
 inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
                                 utils::IteratorType iteratorTypeName,

diff  --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index a2977901f4751d4..641ddf3f91cb2d9 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
   if (indexingMaps.size() != 3)
     return false;
 
-  auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
-  auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
-  auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
 
   if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
   if (indexingMaps.size() != 3)
     return false;
 
-  auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
-  auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
-  auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
 
   if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
       map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
   if (indexingMaps.size() != 3)
     return false;
 
-  auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
-  auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
-  auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
 
   if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
       map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
@@ -96,6 +96,79 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
   return indexingMaps == maps;
 }
 
+bool mlir::isVecmat(ArrayAttr indexingMaps) {
+  if (indexingMaps.size() != 3)
+    return false;
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+  if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
+      map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
+      map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
+    return false;
+  }
+
+  // Extract dimensions for K * KxN -> N
+  AffineExpr k = map0.getResult(0);
+  AffineExpr n = map2.getResult(0);
+  auto *context = indexingMaps.getContext();
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+  return indexingMaps == maps;
+}
+
+bool mlir::isMatvec(ArrayAttr indexingMaps) {
+  if (indexingMaps.size() != 3)
+    return false;
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+  if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
+      map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
+      map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
+    return false;
+  }
+
+  // Extract dimensions for N*K * K -> N
+  AffineExpr k = map1.getResult(0);
+  AffineExpr n = map2.getResult(0);
+  auto *context = indexingMaps.getContext();
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+  return indexingMaps == maps;
+}
+
+bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
+  if (indexingMaps.size() != 3)
+    return false;
+  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
+
+  if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
+      map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
+      map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
+    return false;
+  }
+
+  // Extract dimensions for B*N*K * B*K -> B*N
+  AffineExpr b = map0.getResult(0);
+  AffineExpr k = map1.getResult(1);
+  AffineExpr n = map2.getResult(1);
+  auto *context = indexingMaps.getContext();
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
+  return indexingMaps == maps;
+}
+
 Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
                        ValueRange newOperands) {
   IRMapping bvm;

diff  --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
index 583dbd463b91159..3f576bacebf6aad 100644
--- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
@@ -240,4 +240,134 @@ TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
   EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
 }
 
+TEST(isVecmat, Simple) {
+  MLIRContext context;
+
+  AffineExpr k, n;
+  bindDims(&context, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isVecmat));
+}
+
+TEST(isVecmat, BindingSwapped) {
+  MLIRContext context;
+
+  AffineExpr k, n;
+  bindDims(&context, n, k); // bind in 
diff erent order
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isVecmat));
+}
+
+TEST(isVecmat, WrongDimOrderMatrix) {
+  MLIRContext context;
+
+  AffineExpr k, n;
+  bindDims(&context, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isVecmat)));
+}
+
+TEST(isMatvec, Simple) {
+  MLIRContext context;
+
+  AffineExpr k, n;
+  bindDims(&context, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isMatvec));
+}
+
+TEST(isMatvec, BindingSwapped) {
+  MLIRContext context;
+
+  AffineExpr k, n;
+  bindDims(&context, n, k); // bind in 
diff erent order
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isMatvec));
+}
+
+TEST(isMatvec, WrongDimOrderMatrix) {
+  MLIRContext context;
+
+  AffineExpr k, n;
+  bindDims(&context, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isMatvec)));
+}
+
+TEST(isBatchMatvec, Simple) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isBatchMatvec));
+}
+
+TEST(isBatchMatvec, BindingSwapped) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, n, k); // bind in 
diff erent order
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Truly(isBatchMatvec));
+}
+
+TEST(isBatchMatvec, Matmul) {
+  MLIRContext context;
+
+  AffineExpr m, n, k;
+  bindDims(&context, m, n, k);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
+}
+
+TEST(isBatchMatvec, WrongDimOrderMatrix) {
+  MLIRContext context;
+
+  AffineExpr batch, k, n;
+  bindDims(&context, batch, k, n);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
+  auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
+
+  EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
+}
+
 } // namespace


        


More information about the Mlir-commits mailing list