[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Jan 9 08:17:17 PST 2025
================
@@ -3611,5 +3668,165 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+ return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+ return ::parseNamedStructuredOp(parser, result,
+ BatchMatmulOp::getNumRegionArgs(),
+ BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+ // Verification of pure batch_matmul is handled by
+ // verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
----------------
adam-smnk wrote:
Why not validate output maps as well?
Right now sth like this is possible:
```mlir
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1)>
]
ins(%arg0, %arg1 : memref<5x5x5xf32>, memref<5x5x5xf32>)
outs(%arg2: memref<5xf32>)
```
Same goes for `linalg.matmul` with `indexing_maps`.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list