[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)

Nicolas Vasilache llvmlistbot at llvm.org
Tue Sep 24 07:44:27 PDT 2024


================
@@ -3382,3 +3453,229 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+                                          utils::IteratorType::parallel,
+                                          utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Implements the block region builder for the MatmulOp.This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                             ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+/// Populates the output parameter \p indexingMaps with the typical matmul
+/// indexing maps.
+void MatmulOp::getDefaultIndexingMaps(
+    SmallVectorImpl<AffineMap> &indexingMaps) {
+  MLIRContext *context = this->getContext();
+  AffineExpr d0, d1, d2;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+}
+
+/// Returns true if the input operand identified by \p opIndex need
+/// broadcasting.
+bool MatmulOp::hasBroadcastSemantic(unsigned opIndex) {
+  if (opIndex > 1)
+    return false;
+  SmallVector<AffineMap, 3> defaultMaps;
+  SmallVector<AffineMap, 3> explicitMaps;
+  getDefaultIndexingMaps(defaultMaps);
----------------
nicolasvasilache wrote:

Can we use a single API form and not fight against RVO?
```
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
```



https://github.com/llvm/llvm-project/pull/104783


More information about the Mlir-commits mailing list