[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 3 11:23:15 PDT 2024
================
@@ -3382,3 +3464,150 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+ utils::IteratorType::parallel,
+ utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Implements the block region builder for the MatmulOp. This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "MatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
+ Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+/// Returns a list of AffineMap with the typical matmul indexing charactristic.
+SmallVector<AffineMap, 3> MatmulOp::getDefaultIndexingMaps() {
+ MLIRContext *context = this->getContext();
+ return inferDefaultIndexingMaps(context);
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+bool MatmulOp::isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+ if (explictMap.getNumResults() < defaultMap.getNumResults())
+ return true;
+ return false;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool MatmulOp::isValidBroadcastMap(AffineMap bcastMap) {
+ for (unsigned dim = 0; dim < bcastMap.getNumResults(); dim++) {
+ AffineExpr exp = bcastMap.getResult(dim);
+ // Invalid map if dim expr 'd2' not found.
+ if (!exp.isFunctionOfDim(bcastMap.getNumDims() - 1)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/// Returns true if the \p explictMap is transposed with respect to the
+/// \p defaultMap. \p 'isLHS = true' flag indicates thet the check is being
+/// performed on LHS operand, otherwise RHS operand.
+bool MatmulOp::isTransposed(AffineMap explictMap, AffineMap defaultMap,
----------------
MaheshRavishankar wrote:
THis method is ill-defined as well. What if I send an indexing map related to the output. Also might be better to just say `isLhsTransposed` or `isRhsTransposed` and reduce the confusion about how these methods are to be used.
https://github.com/llvm/llvm-project/pull/104783
More information about the Mlir-commits
mailing list