[Mlir-commits] [mlir] 95630b9 - [mlir][Linalg] Add MapOp to Linalg structured ops.

Adrian Kuegel llvmlistbot at llvm.org
Wed Oct 12 04:56:38 PDT 2022

Author: Adrian Kuegel
Date: 2022-10-12T13:56:21+02:00
New Revision: 95630b93b00302d3a9b0ea3219235259e5de2109

URL: https://github.com/llvm/llvm-project/commit/95630b93b00302d3a9b0ea3219235259e5de2109
DIFF: https://github.com/llvm/llvm-project/commit/95630b93b00302d3a9b0ea3219235259e5de2109.diff

LOG: [mlir][Linalg] Add MapOp to Linalg structured ops.

This will allow to model elementwise ops with this special op instead of using
Also allow MapOp and ReduceOp to have no result if the output type is not a tensor.
This is needed for buffer semantics.

Differential Revision: https://reviews.llvm.org/D135754




diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c37d97834e838..6bcf509844cb8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -225,12 +225,77 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
-// Reduce op.
+// Map op.
 def TensorOrMemref :
   AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
+def MapOp : LinalgStructuredBase_Op<"map", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
+    SingleBlockImplicitTerminator<"YieldOp">]> {
+  let summary = "Elementwise operations";
+  let description = [{
+    Models elementwise operations on tensors in terms of arithmetic operations
+    on the corresponding elements.
+    Example:
+    ```
+      %add = linalg.map
+          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+          outs(%init: tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem: f32
+            linalg.yield %0: f32
+          }
+    ```
+  }];
+  let arguments = (ins
+    // Input args
+    Variadic<TensorOrMemref>:$inputs,
+    // Output arg
+    TensorOrMemref:$init
+  );
+  let results = (outs Variadic<AnyTensor>:$result);
+  let regions = (region SizedRegion<1>:$mapper);
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Implement functions necessary for LinalgStructuredInterface.
+    ArrayAttr getIteratorTypes();
+    ArrayAttr getIndexingMaps();
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+    // Implement functions necessary for DestinationStyleOpInterface.
+    unsigned getNumInputs() {
+      return this->getOperation()->getNumOperands() - getNumOutputs();
+    };
+    unsigned getNumOutputs() { return 1; };
+    mlir::ValueRange getOutputs() { return getOperands().take_back(1); }
+    linalg::OpOperandVector getOpOperandsMatchingBBargs() {
+      return getInputOperands();
+    }
+    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                              mlir::ArrayRef<mlir::NamedAttribute>)>
+    getRegionBuilder() {
+      return nullptr;
+    }
+  }];
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+// Reduce op.
 def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
@@ -264,7 +329,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
-  let results = (outs Variadic<TensorOrMemref>);
+  let results = (outs Variadic<AnyTensor>);
   let regions = (region SizedRegion<1>:$combiner);
   let extraClassDeclaration = structuredOpsBaseDecls # [{

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbec1ea80d5d..61b938601e103 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1288,6 +1288,135 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
   return foldMemRefCast(*this);
+// MapOp
+static ParseResult parseDstStyleOp(
+    OpAsmParser &parser, OperationState &result,
+    function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
+        nullptr) {
+  // Parse `ins` and `outs`.
+  SmallVector<Type, 4> inputTypes, outputTypes;
+  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
+                                   /*addOperandSegmentSizes=*/false))
+    return failure();
+  // Add result types.
+  for (Type outputType : outputTypes) {
+    if (outputType.isa<RankedTensorType>())
+      result.addTypes(outputType);
+  }
+  // Parse required attributes.
+  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
+    return failure();
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  return success();
+void MapOp::getAsmBlockArgumentNames(Region &region,
+                                     OpAsmSetValueNameFn setNameFn) {
+  for (Value v : getRegionInputArgs())
+    setNameFn(v, "in");
+void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  if (!getResults().empty())
+    setNameFn(getResults().front(), "mapped");
+ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
+  if (parseDstStyleOp(parser, result))
+    return failure();
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+                               /*allowType=*/true, /*allowAttrs=*/true)) {
+    return failure();
+  }
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+  return success();
+void MapOp::print(OpAsmPrinter &p) {
+  printCommonStructuredOpParts(p, getInputs(), getOutputs());
+  p.printOptionalAttrDict((*this)->getAttrs());
+  p << "(";
+  llvm::interleaveComma(getMapper().getArguments(), p,
+                        [&](auto arg) { p.printRegionArgument(arg); });
+  p << ") ";
+  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+LogicalResult MapOp::verify() {
+  auto *bodyBlock = getBody();
+  auto blockArgs = bodyBlock->getArguments();
+  // Checks if the number of `inputs` match the arity of the `mapper` region.
+  if (getInputs().size() != blockArgs.size())
+    return emitOpError() << "expects number of operands to match the arity of "
+                            "mapper, but got: "
+                         << getInputs().size() << " and " << blockArgs.size();
+  // The parameters of mapper should all match the element type // of inputs.
+  for (const auto &[bbArgType, inputArg] :
+       llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
+    auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
+    if (bbArgType != inputElemType) {
+      return emitOpError() << "expected element type of input " << inputElemType
+                           << " to match bbArg type " << bbArgType;
+    }
+  }
+  // The shape of each input must match the shape of the output.
+  auto outputShape =
+      getOutputs().front().getType().cast<ShapedType>().getShape();
+  for (Type inputArgType : TypeRange{getInputs()}) {
+    auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
+    if (inputElemShape != outputShape) {
+      return emitOpError() << "expected shape of input (" << inputElemShape
+                           << ") to match shape of output (" << outputShape
+                           << ")";
+    }
+  }
+  return success();
+ArrayAttr MapOp::getIteratorTypes() {
+  int64_t rank = getInit().getType().getRank();
+  return Builder(getContext())
+      .getStrArrayAttr(
+          SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+ArrayAttr MapOp::getIndexingMaps() {
+  Builder builder(getContext());
+  int64_t rank = getInit().getType().getRank();
+  int64_t numIndexingMaps = getOperands().size();
+  return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
+      numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
+void MapOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  SmallVector<Value> inputBuffers = getInputBufferOperands();
+  SmallVector<Value> outputBuffers = getOutputBufferOperands();
+  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
+                        outputBuffers);
 // ReduceOp
@@ -1302,7 +1431,8 @@ void ReduceOp::getAsmBlockArgumentNames(Region &region,
 void ReduceOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
-  setNameFn(getResults().front(), "reduced");
+  if (!getResults().empty())
+    setNameFn(getResults().front(), "reduced");
 ArrayAttr ReduceOp::getIteratorTypes() {
@@ -1336,33 +1466,6 @@ void ReduceOp::getEffects(
-static ParseResult parseDstStyleOp(
-    OpAsmParser &parser, OperationState &result,
-    function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
-        nullptr) {
-  // Parse `ins` and `outs`.
-  SmallVector<Type, 4> inputTypes, outputTypes;
-  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
-                                   /*addOperandSegmentSizes=*/false))
-    return failure();
-  // Add result types.
-  for (Type outputType : outputTypes) {
-    if (!outputType.isa<RankedTensorType>())
-      return failure();
-    result.addTypes(outputType);
-  }
-  // Parse required attributes.
-  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
-    return failure();
-  // Parse optional attributes.
-  if (parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-  return success();
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
                                           NamedAttrList &attributes,
                                           StringRef attributeName) {

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9ae761e639e41..00352c43bb07e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -391,6 +391,70 @@ func.func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
 // -----
+func.func @map_binary_wrong_yield_operands(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+   %add = linalg.map
+          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem: f32
+            // expected-error @+1{{'linalg.yield' op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
+            linalg.yield %0, %0: f32, f32
+          }
+  func.return %add : tensor<64xf32>
+// -----
+func.func @map_input_mapper_arity_mismatch(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+  // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+      outs(%init:tensor<64xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return %add : tensor<64xf32>
+// -----
+func.func @map_input_mapper_type_mismatch(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+    // expected-error at +1{{'linalg.map' op expected element type of input 'f32' to match bbArg type 'f64'}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+      outs(%init:tensor<64xf32>)
+      (%lhs_elem: f64, %rhs_elem: f64) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f64
+        linalg.yield %0: f64
+      }
+  func.return %add : tensor<64xf32>
+// -----
+func.func @map_input_output_shape_mismatch(
+    %lhs: tensor<64x64xf32>, %rhs: tensor<64x64xf32>, %init: tensor<32xf32>)
+    -> tensor<32xf32> {
+    // expected-error at +1{{'linalg.map' op expected shape of input (64, 64) to match shape of output (32)}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
+      outs(%init:tensor<32xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return %add : tensor<32xf32>
+// -----
 func.func @reduce_input_vs_init_dimension_mismatch(
     %input: tensor<16x32x64xf32>,
     %init: tensor<16x64xf32>)  -> tensor<16x64xf32> {

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 3fb6c3d74e8eb..02471b17a2fbe 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -354,8 +354,70 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
 // -----
+func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
+                      %init: tensor<64xf32>) -> tensor<64xf32> {
+   %add = linalg.map
+          ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem: f32
+            linalg.yield %0: f32
+          }
+  func.return %add : tensor<64xf32>
+// CHECK-LABEL: func @map_binary
+//       CHECK:     linalg.map
+// -----
+func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
+                      %init: memref<64xf32>) {
+   linalg.map
+      ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
+      outs(%init:memref<64xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return
+// CHECK-LABEL: func @map_binary_memref
+//       CHECK:     linalg.map
+// -----
+func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> {
+   %abs = linalg.map
+          ins(%input:tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%input_elem: f32) {
+            %0 = math.absf %input_elem: f32
+            linalg.yield %0: f32
+          }
+  func.return %abs : tensor<64xf32>
+// CHECK-LABEL: func @map_unary
+//       CHECK:     linalg.map
+// -----
+func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
+   linalg.map
+      ins(%input:memref<64xf32>)
+      outs(%init:memref<64xf32>)
+      (%input_elem: f32) {
+        %0 = math.absf %input_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return
+// CHECK-LABEL: func @map_unary_memref
+//       CHECK:     linalg.map
+// -----
 func.func @reduce(%input: tensor<16x32x64xf32>,
-                     %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+                  %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
   %reduce = linalg.reduce
@@ -371,6 +433,23 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
 // -----
+func.func @reduce_memref(%input: memref<16x32x64xf32>,
+                         %init: memref<16x64xf32>) {
+  linalg.reduce
+      ins(%input:memref<16x32x64xf32>)
+      outs(%init:memref<16x64xf32>)
+      dimensions = [1]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return
+// CHECK-LABEL: func @reduce_memref
+//       CHECK:     linalg.reduce
+// -----
 func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
     %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xi64>,
     %init2: tensor<16x64xi64>)  -> (tensor<16x64xf32>, tensor<16x64xi64>) {
@@ -387,3 +466,22 @@ func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
 // CHECK-LABEL: func @variadic_reduce
 //       CHECK:     linalg.reduce
+// -----
+func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
+    %init1: memref<16x64xf32>, %input2: memref<16x32x64xi64>,
+    %init2: memref<16x64xi64>) {
+  linalg.reduce
+      ins(%input1, %input2 : memref<16x32x64xf32>, memref<16x32x64xi64>)
+      outs(%init1, %init2 : memref<16x64xf32>, memref<16x64xi64>)
+      dimensions = [1]
+      (%in1: f32, %in2: i64, %out1: f32, %out2: i64) {
+        %0 = arith.addf %in1, %out1: f32
+        %1 = arith.addi %in2, %out2: i64
+        linalg.yield %0, %1: f32, i64
+      }
+  func.return
+// CHECK-LABEL: func @variadic_reduce_memref
+//       CHECK:     linalg.reduce


More information about the Mlir-commits mailing list