[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 12 04:08:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Md Asghar Ahmad Shahid (shahidact)
<details>
<summary>Changes</summary>
…_reduce_matmul.
This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops.
The broadcast and transpose semantic are as follows:
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.
Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<3x7xf32>)
```
RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 https://github.com/llvm/llvm-project/pull/115319
https://github.com/llvm/llvm-project/pull/122275
---
Patch is 45.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130944.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-70)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+131)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+234-28)
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+30)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+202)
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+165)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_reduce_matmul
- cpp_class_name: BatchReduceMatmulOp
- doc: |-
- Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
- iterator_types:
- - reduction
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..5191a658bbf26 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+
+ let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+ arguments if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast and Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<
+ AffineMapArrayAttr,
+ "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+ >:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("cast", cast);
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps() { return true; };
+ /// Returns true if the user defined indexing maps are not equal to default maps.
+ bool hasUserDefinedMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..d46fbf988d762 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -218,6 +218,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3464,19 +3481,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap,
AffineMap defaultIndexingMap, bool isLHS) {
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
// Check the result dims are valid.
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Unexpected result dim expression (outside the set of default "
"result dims).";
// Check for valid number of result dims of input maps.
if (opIndexingMap.getNumResults() > 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "no. of result dim expressions exceeds 3.";
auto hasValidBatchDim = [](AffineMap map) {
@@ -3486,60 +3508,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
// Check if the requested broadcast is valid.
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
- if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
- return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+ return batchVariantMatmulOp->emitOpError()
+ << "Invalid broadcast requested.";
} else if (!hasValidBatchDim(opIndexingMap)) {
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid batch dimension expression.";
}
return success();
}
/// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap) {
- if (opIndexingMap.getNumResults() != 3)
- return batchMatmulOp->emitOpError()
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
+ if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 3) {
+
+ return batchVariantMatmulOp->emitOpError()
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
<< ").";
+ }
+ if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 2) {
+ return batchVariantMatmulOp->emitOpError()
+ << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+ << ").";
+ }
- auto areValidOutputResultDim = [](AffineMap outputMap) {
- return outputMap.getResult(0).isFunctionOfDim(0) &&
- outputMap.getResult(1).isFunctionOfDim(1) &&
- outputMap.getResult(2).isFunctionOfDim(2);
+ auto areValidOutputResultDim = [&](AffineMap outputMap) {
+ return isa<BatchMatmulOp>(batchVariantMatmulOp)
+ ? outputMap.getResult(0).isFunctionOfDim(0) &&
+ outputMap.getResult(1).isFunctionOfDim(1) &&
+ outputMap.getResult(2).isFunctionOfDim(2)
+ : outputMap.getResult(0).isFunctionOfDim(1) &&
+ outputMap.getResult(1).isFunctionOfDim(2);
};
- if (!areValidOutputResultDim(opIndexingMap))
- return batchMatmulOp->emitOpError()
+ if (!areValidOutputResultDim(opIndexingMap)) {
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid output map result dimension.";
+ }
return success();
}
/// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
- unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+ unsigned opIndex) {
SmallVector<AffineMap, 3> opIndexingMaps =
- batchMatmulOp.getIndexingMapsArray();
+ batchVariantMatmulOp.getIndexingMapsArray();
SmallVector<AffineMap, 3> defaultIndexingMaps =
- batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+ batchVariantMatmulOp.getDefaultIndexingMaps(
+ batchVariantMatmulOp->getContext());
if (opIndexingMaps.size() != 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Indexing_map attribute must have 3 affine maps.";
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
- if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+ if (opIndex == 2 &&
+ failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
return failure();
- if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
- opIndex == 0)))
+ if (opIndex != 2 &&
+ failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+ defaultIndexingMap, opIndex == 0)))
return failure();
return success();
@@ -4035,7 +4080,7 @@ LogicalResult BatchMatmulOp::verify() {
return success();
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
- if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
return failure();
}
return success();
@@ -5340,6 +5385,167 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
}
};
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::reduction, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+ return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+ bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ auto toType = block.getArgument(2).getType();
+ Value castValA =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+ Value castValB =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+ Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value addVal =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ yields.push_back(addVal);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/130944
More information about the Mlir-commits
mailing list