[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Sep 24 07:46:16 PDT 2024
================
@@ -3382,3 +3453,229 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+ utils::IteratorType::parallel,
+ utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Implements the block region builder for the MatmulOp.This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "MatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
+ Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+/// Populates the output parameter \p indexingMaps with the typical matmul
+/// indexing maps.
+void MatmulOp::getDefaultIndexingMaps(
+ SmallVectorImpl<AffineMap> &indexingMaps) {
+ MLIRContext *context = this->getContext();
+ AffineExpr d0, d1, d2;
+ bindDims(context, d0, d1, d2);
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+}
+
+/// Returns true if the input operand identified by \p opIndex need
+/// broadcasting.
+bool MatmulOp::hasBroadcastSemantic(unsigned opIndex) {
+ if (opIndex > 1)
+ return false;
+ SmallVector<AffineMap, 3> defaultMaps;
+ SmallVector<AffineMap, 3> explicitMaps;
+ getDefaultIndexingMaps(defaultMaps);
+ explicitMaps = getIndexingMapsArray();
+ if (explicitMaps[opIndex].getNumResults() <
+ defaultMaps[opIndex].getNumResults())
+ return true;
+ return false;
+}
+
+/// Infers the dimension sizes and populate into \p dimSizeMap.
+/// Input parameter \p allTypes is a map of shapes for each operand.
+void MatmulOp::inferDimensionSizes(
+ llvm::DenseMap<unsigned, ShapedType> &allTypes,
+ llvm::DenseMap<unsigned, unsigned> &dimSizeMap) {
+ assert(!allTypes.empty() && "Expected non empty types");
+ assert(allTypes[0].getRank() > 0 && allTypes[1].getRank() > 0 &&
+ "Input rank must be positive");
+ assert(allTypes[2].getRank() == 2 && "Output rank must be 2");
+
+ dimSizeMap[0] = allTypes[2].getDimSize(0);
+ dimSizeMap[1] = allTypes[2].getDimSize(1);
+
+ // Get dimension size for 'd2' from input types which needs broadcast.
+ unsigned outputRank = allTypes[2].getRank();
+ for (unsigned i = 0; i < outputRank; i++) {
+ if (allTypes[i].getRank() < outputRank) {
+ dimSizeMap[2] = allTypes[i].getDimSize(0);
+ return;
+ }
+ }
+}
+
+/// Infers the broadcasted dimension and populates \p broadcastDims which is
+/// a map of dimensions to a pair of Boolean and AffineDimExpr position,
+/// indicating broadcast and the corresponding AffineDimExpr position.
+/// It uses input parameters \p explicitMap parsed from the op and \p defaultMap
+/// corresponding to an input operand.
+void MatmulOp::inferBroadcastDimensions(
+ AffineMap explicitMap, AffineMap defaultMap,
+ DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastDims) {
+ assert(!explicitMap.isEmpty() && "Expected non empty map");
+ assert(!defaultMap.isEmpty() && "Expected non empty map");
+
+ llvm::SetVector<unsigned> typicalDims, providedDims, broadcastedDims;
+ // Build set of dimensions using default matmul indexing map
+ for (auto expr : defaultMap.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ typicalDims.insert(dimExpr.getPosition());
+ }
+ }
+
+ // Build set of dimensions from explicitly provided indexing map
+ for (auto expr : explicitMap.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ providedDims.insert(dimExpr.getPosition());
+ }
+ }
+
+ // Compute set difference to get broadcast dimensions
+ broadcastedDims = llvm::set_difference(typicalDims, providedDims);
+
+ // Update broadcastDims map
+ for (unsigned i = 0; i < typicalDims.size(); i++) {
+ broadcastDims[i] = {false, 0};
+ if (!providedDims.count(typicalDims[i])) {
+ broadcastDims[i] = {true, typicalDims[i]};
+ }
+ }
+}
+
+/// Construct and returns a broadcasted type for the input operand identified by
+/// input parameter \p opIndex.
+ShapedType MatmulOp::constructBroadcastedType(unsigned opIndex) {
+ assert(opIndex < 2 && "Operand index out of range");
+ DenseMap<unsigned, ShapedType> allTypes;
+ allTypes[0] = cast<ShapedType>(this->getOperand(0).getType());
+ allTypes[1] = cast<ShapedType>(this->getOperand(1).getType());
+ allTypes[2] = cast<ShapedType>(this->getOperand(2).getType());
+
+ ShapedType outputType = allTypes[2];
+ ShapedType inputType = allTypes[opIndex];
+ SmallVector<AffineMap, 3> defaultIndexingMaps;
+ DenseMap<unsigned, std::pair<bool, unsigned>> broadcastDims;
+ SmallVector<AffineMap, 3> indexingMaps = this->getIndexingMapsArray();
+
+ getDefaultIndexingMaps(defaultIndexingMaps);
+ inferBroadcastDimensions(indexingMaps[opIndex], defaultIndexingMaps[opIndex],
+ broadcastDims);
+
+ AffineMap defaultInputMap = defaultIndexingMaps[opIndex];
+ // Initialize new shape for input operand requiring broadcast.
+ unsigned numDims = outputType.getRank();
+ SmallVector<int64_t, 4> newShape(numDims, ShapedType::kDynamic);
+
+ DenseMap<unsigned, unsigned> dimSizeMap;
+ inferDimensionSizes(allTypes, dimSizeMap);
+
+ // Fill in the known dimensions using defaultIndexingMap
+ for (unsigned i = 0; i < defaultInputMap.getNumResults(); ++i) {
+ if (!broadcastDims[i].first) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(defaultInputMap.getResult(i)))
+ newShape[i] = dimSizeMap[dimExpr.getPosition()];
+ }
+ }
+
+ // Fill in the broadcast dimension.
+ for (unsigned i = 0; i < broadcastDims.size(); i++) {
+ if (broadcastDims[i].first) {
+ newShape[i] = dimSizeMap[broadcastDims[i].second];
+ }
+ }
+
+ // Create the new ShapedType
+ if (auto tensorType = dyn_cast<RankedTensorType>(inputType)) {
+ return RankedTensorType::get(newShape, tensorType.getElementType());
+ } else if (auto memrefType = dyn_cast<MemRefType>(inputType)) {
+ return MemRefType::get(newShape, memrefType.getElementType(),
+ MemRefLayoutAttrInterface(),
+ memrefType.getMemorySpace());
+ } else {
+ llvm::errs() << "Error: Unsupported ShapedType\n";
+ return ShapedType();
+ }
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+ MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+
+ SmallVector<Attribute, 3> indexingMaps;
+ getDefaultIndexingMap(getContext(), indexingMaps);
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability MatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
----------------
nicolasvasilache wrote:
nit: nl
https://github.com/llvm/llvm-project/pull/104783
More information about the Mlir-commits
mailing list