[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Thu Feb 6 05:52:09 PST 2025
================
@@ -3611,5 +3712,166 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+ return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
----------------
shahidact wrote:
Done.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list