[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)

Nicolas Vasilache llvmlistbot at llvm.org
Tue Sep 24 07:46:16 PDT 2024


================
@@ -3382,3 +3453,229 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+                                          utils::IteratorType::parallel,
+                                          utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Implements the block region builder for the MatmulOp.This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                             ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+/// Populates the output parameter \p indexingMaps with the typical matmul
+/// indexing maps.
+void MatmulOp::getDefaultIndexingMaps(
+    SmallVectorImpl<AffineMap> &indexingMaps) {
+  MLIRContext *context = this->getContext();
+  AffineExpr d0, d1, d2;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+}
+
+/// Returns true if the input operand identified by \p opIndex need
+/// broadcasting.
+bool MatmulOp::hasBroadcastSemantic(unsigned opIndex) {
+  if (opIndex > 1)
+    return false;
+  SmallVector<AffineMap, 3> defaultMaps;
+  SmallVector<AffineMap, 3> explicitMaps;
+  getDefaultIndexingMaps(defaultMaps);
+  explicitMaps = getIndexingMapsArray();
+  if (explicitMaps[opIndex].getNumResults() <
+      defaultMaps[opIndex].getNumResults())
+    return true;
+  return false;
+}
+
+/// Infers the dimension sizes and populate into \p dimSizeMap.
+/// Input parameter \p allTypes is a map of shapes for each operand.
+void MatmulOp::inferDimensionSizes(
+    llvm::DenseMap<unsigned, ShapedType> &allTypes,
+    llvm::DenseMap<unsigned, unsigned> &dimSizeMap) {
+  assert(!allTypes.empty() && "Expected non empty types");
+  assert(allTypes[0].getRank() > 0 && allTypes[1].getRank() > 0 &&
+         "Input rank must be positive");
+  assert(allTypes[2].getRank() == 2 && "Output rank must be 2");
+
+  dimSizeMap[0] = allTypes[2].getDimSize(0);
+  dimSizeMap[1] = allTypes[2].getDimSize(1);
+
+  // Get dimension size for 'd2' from input types which needs broadcast.
+  unsigned outputRank = allTypes[2].getRank();
+  for (unsigned i = 0; i < outputRank; i++) {
+    if (allTypes[i].getRank() < outputRank) {
+      dimSizeMap[2] = allTypes[i].getDimSize(0);
+      return;
+    }
+  }
+}
+
+/// Infers the broadcasted dimension and populates \p broadcastDims which is
+/// a map of dimensions to a pair of Boolean and AffineDimExpr position,
+/// indicating broadcast and the corresponding AffineDimExpr position.
+/// It uses input parameters \p explicitMap parsed from the op and \p defaultMap
+/// corresponding to an input operand.
+void MatmulOp::inferBroadcastDimensions(
+    AffineMap explicitMap, AffineMap defaultMap,
+    DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastDims) {
+  assert(!explicitMap.isEmpty() && "Expected non empty map");
+  assert(!defaultMap.isEmpty() && "Expected non empty map");
+
+  llvm::SetVector<unsigned> typicalDims, providedDims, broadcastedDims;
+  // Build set of dimensions using default matmul indexing map
+  for (auto expr : defaultMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      typicalDims.insert(dimExpr.getPosition());
+    }
+  }
+
+  // Build set of dimensions from explicitly provided indexing map
+  for (auto expr : explicitMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      providedDims.insert(dimExpr.getPosition());
+    }
+  }
+
+  // Compute set difference to get broadcast dimensions
+  broadcastedDims = llvm::set_difference(typicalDims, providedDims);
+
+  // Update broadcastDims map
+  for (unsigned i = 0; i < typicalDims.size(); i++) {
+    broadcastDims[i] = {false, 0};
+    if (!providedDims.count(typicalDims[i])) {
+      broadcastDims[i] = {true, typicalDims[i]};
+    }
+  }
+}
+
+/// Construct and returns a broadcasted type for the input operand identified by
+/// input parameter \p opIndex.
+ShapedType MatmulOp::constructBroadcastedType(unsigned opIndex) {
+  assert(opIndex < 2 && "Operand index out of range");
+  DenseMap<unsigned, ShapedType> allTypes;
+  allTypes[0] = cast<ShapedType>(this->getOperand(0).getType());
+  allTypes[1] = cast<ShapedType>(this->getOperand(1).getType());
+  allTypes[2] = cast<ShapedType>(this->getOperand(2).getType());
+
+  ShapedType outputType = allTypes[2];
+  ShapedType inputType = allTypes[opIndex];
+  SmallVector<AffineMap, 3> defaultIndexingMaps;
+  DenseMap<unsigned, std::pair<bool, unsigned>> broadcastDims;
+  SmallVector<AffineMap, 3> indexingMaps = this->getIndexingMapsArray();
+
+  getDefaultIndexingMaps(defaultIndexingMaps);
+  inferBroadcastDimensions(indexingMaps[opIndex], defaultIndexingMaps[opIndex],
+                           broadcastDims);
+
+  AffineMap defaultInputMap = defaultIndexingMaps[opIndex];
+  // Initialize new shape for input operand requiring broadcast.
+  unsigned numDims = outputType.getRank();
+  SmallVector<int64_t, 4> newShape(numDims, ShapedType::kDynamic);
+
+  DenseMap<unsigned, unsigned> dimSizeMap;
+  inferDimensionSizes(allTypes, dimSizeMap);
+
+  // Fill in the known dimensions using defaultIndexingMap
+  for (unsigned i = 0; i < defaultInputMap.getNumResults(); ++i) {
+    if (!broadcastDims[i].first) {
+      if (auto dimExpr = dyn_cast<AffineDimExpr>(defaultInputMap.getResult(i)))
+        newShape[i] = dimSizeMap[dimExpr.getPosition()];
+    }
+  }
+
+  // Fill in the broadcast dimension.
+  for (unsigned i = 0; i < broadcastDims.size(); i++) {
+    if (broadcastDims[i].first) {
+      newShape[i] = dimSizeMap[broadcastDims[i].second];
+    }
+  }
+
+  // Create the new ShapedType
+  if (auto tensorType = dyn_cast<RankedTensorType>(inputType)) {
+    return RankedTensorType::get(newShape, tensorType.getElementType());
+  } else if (auto memrefType = dyn_cast<MemRefType>(inputType)) {
+    return MemRefType::get(newShape, memrefType.getElementType(),
+                           MemRefLayoutAttrInterface(),
+                           memrefType.getMemorySpace());
+  } else {
+    llvm::errs() << "Error: Unsupported ShapedType\n";
+    return ShapedType();
+  }
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+                                MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+
+  SmallVector<Attribute, 3> indexingMaps;
+  getDefaultIndexingMap(getContext(), indexingMaps);
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability MatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
----------------
nicolasvasilache wrote:

nit: nl

https://github.com/llvm/llvm-project/pull/104783


More information about the Mlir-commits mailing list