[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Thu May 1 10:30:48 PDT 2025
================
@@ -5366,6 +5414,169 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
}
};
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::reduction, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+ return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+ bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
----------------
shahidact wrote:
Done.
https://github.com/llvm/llvm-project/pull/130944
More information about the Mlir-commits
mailing list