[Mlir-commits] [mlir] 9be8219 - [mlir][Linalg] Add an interface to decompose complex ops

Quentin Colombet llvmlistbot at llvm.org
Tue Jul 18 10:09:06 PDT 2023


Author: Quentin Colombet
Date: 2023-07-18T19:06:36+02:00
New Revision: 9be8219f60e1bbddaebdb271a55ecfb867078899

URL: https://github.com/llvm/llvm-project/commit/9be8219f60e1bbddaebdb271a55ecfb867078899
DIFF: https://github.com/llvm/llvm-project/commit/9be8219f60e1bbddaebdb271a55ecfb867078899.diff

LOG: [mlir][Linalg] Add an interface to decompose complex ops

This patch adds an interface, named AggregatedOpInterface, that decomposes
complex operations into simpler ones.

For now, make the interface specific to Linalg because although the concept
is general, the way to materialize it needs some maturing.

Use that interface with the softmax operator.

Differential Revision: https://reviews.llvm.org/D154363

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-decompose.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 1aba722773511a..edc393b92c4c35 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -897,4 +897,34 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
   let verifyWithRegions = 1;
 }
 
+def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
+  let description = [{
+    Interface for decomposing aggregated operations into a sequence of simpler
+    ops.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to decompose the operation into simpler operations.
+
+          On success, this method returns one `Value` per result in the
+          original operation.
+          The order of the returned values must match the order of the
+          original values.
+          In other words, the returned vector can be used directly with
+          `RewriterBase::replaceOp(this, returnedValues)`.
+        }],
+        /*retType=*/"FailureOr<SmallVector<Value>>",
+        /*methodName=*/"decomposeOperation",
+        /*args=*/(ins
+            "OpBuilder &":$b),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return {};
+        }]
+      >
+  ];
+}
+
 #endif // LINALG_IR_LINALGINTERFACES

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index eb68890c8487dc..fdb043000ae6c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -14,6 +14,7 @@
 #define LINALG_OPS
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -93,6 +94,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1510f63abdbcd..5d1f2d8f283012 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1199,6 +1199,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// DecomposeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    TODO
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+  let assemblyFormat =
+      "$target attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
 //===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a40618fd4ca6d7..464a30a85f7105 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2323,6 +2323,176 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+// Helper functions for softmax decomposition.
+// @{
+
+// Helper function to produce the iterator types (reduction or parallel) and
+// affine maps for the iterators used in the decomposition of softmax.
+// This method creates:
+// If allParallel == true:
+// - iterator type: {parallel, ..., parallel}
+// - affine maps:
+// -- identity with inputRank dimensions.
+// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
+//    where N == inputRank.
+//
+// If allParallel == false:
+// - iterator type at dim(i) == parallel for i != \p dim and
+//   dim(dim) == reduction.
+// - affine map:
+// -- identity with inputRank dimensions.
+// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
+//    where N == inputRank.
+static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
+computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
+                                    int64_t dim, bool allParallel = false) {
+  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+                                                 utils::IteratorType::parallel);
+  if (!allParallel)
+    iteratorTypes[dim] = utils::IteratorType::reduction;
+  MLIRContext *ctxt = builder.getContext();
+  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
+  SmallVector<AffineExpr, 2> affineExprs;
+  for (int i = 0; i < inputRank; i++) {
+    if (i != dim)
+      affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
+  }
+  auto reductionMap =
+      AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
+  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
+  return std::make_tuple(iteratorTypes, indexingMaps);
+}
+
+// Helper function to produce a linalg.generic that computes a reduction on
+// dimension \p dim with the operation type \p T.
+template <typename T>
+static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
+                    int64_t dim) {
+  auto inputType = cast<ShapedType>(input.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] =
+      computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
+  assert(indexingMaps.size() == 2 &&
+         "We should have two maps: 1 for the input, 1 for the output");
+  assert(indexingMaps[0].isIdentity() && "input map should be identity");
+
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, output.getType(), input, output, indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<T>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+/// Produce a linalg generic that computes the second step of the softmax
+/// decomposition: res = exp(input - max), where \p max is the max of \p input
+/// on dimension \p dim.
+static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
+                              Value max, Value output, int64_t dim) {
+  auto inputType = cast<ShapedType>(input.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
+      builder, inputRank, dim, /*allParallel=*/true);
+  assert(indexingMaps.size() == 2 && "We should have one map for each input");
+  assert(indexingMaps[0].isIdentity() && "input map should be identity");
+  // Add the affine map for the output argument.
+  indexingMaps.push_back(indexingMaps[0]);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
+      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value 
diff  = b.create<arith::SubFOp>(loc, args[0], args[1]);
+        Value result = b.create<math::ExpOp>(loc, 
diff );
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+/// Produce a linalg generic that computes the final step of the softmax
+/// decomposition.
+/// \returns  linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
+///   yield  n / d
+/// }
+static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
+                        Value denominator, Value output, int64_t dim) {
+  auto inputType = cast<ShapedType>(numerator.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
+      builder, inputRank, dim, /*allParallel=*/true);
+  assert(indexingMaps.size() == 2 &&
+         "We should have one map for each input (2)");
+  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
+  // Add the affine map for the output tensor.
+  indexingMaps.push_back(indexingMaps[0]);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, numerator.getType(), ValueRange{numerator, denominator}, output,
+      indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+// @} End helper functions for softmax decomposition.
+
+/// Given an N-dimensional tensor x, this method converts
+/// softmax(x) to the following sequence of operations:
+///
+/// 1. Compute the max of x along dimension d. This results
+///    in a N-1 dimensional tensor m.
+///    m = max(x, dim = d)
+///
+/// 2. Subtract a broadcasted m from x and exponentiate. This results in
+///    a N dimensional tensor z.
+///    z = exp(x - m)
+///
+/// 3. Compute the sum of z along dimension d. This results in
+///    a N-1 dimensional tensor l.
+///    l = sum(z, dim = d)
+///
+/// 4. Divide z and l. This gives the N-dimensional softmax.
+///    softmax = z / l
+///
+FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(*this);
+  Location loc = getLoc();
+  Value input = getInput();
+  ShapedType inputType = getInputOperandType();
+  Type elementType = inputType.getElementType();
+  int64_t reductionDim = getDimension();
+  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
+  Value outputNd = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  dims.erase(dims.begin() + reductionDim);
+  // Step 1: Compute max along dim.
+  Value output = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  Value neutralForMaxF =
+      arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
+  Value neutralForMaxFInit =
+      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, output).result();
+  Value max =
+      reduce<arith::MaxFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+
+  // Step 2: Subtract max from input and exponentiate.
+  Value numerator =
+      buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim);
+
+  // Step 3: Compute sum along dim.
+  Value zero =
+      arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
+  Value zeroInit = b.create<linalg::FillOp>(loc, Value{zero}, output).result();
+  Value denominator =
+      reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+
+  // Step 4: Compute softmax.
+  Value result =
+      buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim);
+  return SmallVector<Value>{result};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a51050b742bfe0..f6e0f27548a20c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -335,6 +335,38 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
   return emitDefaultSilenceableFailure(target);
 }
 
+//===----------------------------------------------------------------------===//
+// DecomposeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+// Decompose the target operation if it implements the AggregatedOpInterface.
+// Push the decomposed operations (the ones that replaces the values produced by
+// \p target) in the `results`.
+DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+  if (!decomposableOp) {
+    failed(rewriter.notifyMatchFailure(target,
+                                       "payload is not a decomposable op"));
+    return emitDefaultSilenceableFailure(target);
+  }
+
+  FailureOr<SmallVector<Value>> maybeNewResults =
+      decomposableOp.decomposeOperation(rewriter);
+  if (failed(maybeNewResults))
+    return emitDefaultSilenceableFailure(target);
+
+  rewriter.replaceOp(decomposableOp, *maybeNewResults);
+  for (Value val : *maybeNewResults) {
+    Operation *definition = val.getDefiningOp();
+    if (definition)
+      results.push_back(definition);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // EliminateLinalgOpAnchoredEmptyTensorsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 052992fa8711b8..30c2ab80546ce7 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
 
+// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
@@ -199,8 +202,54 @@ func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
   return %0 : tensor<?x?x1x?xf32>
 }
 
+func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL:      func.func @softmax(
+//CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-DAG:        %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFF800000 : f32
+// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8:.+]] = arith.maxf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
+// CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
+// CHECK:          linalg.yield %[[D9]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        return %[[D7]] : tensor<2x16x32xf32>
+// CHECK:      }
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op
+
+  %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
 }


        


More information about the Mlir-commits mailing list