[Mlir-commits] [mlir] 273600c - [mlir][Linalg] Implement the tiling interface for softmax

Quentin Colombet llvmlistbot at llvm.org
Mon Jul 3 01:49:40 PDT 2023


Author: Quentin Colombet
Date: 2023-07-03T10:49:16+02:00
New Revision: 273600ccea914b1178df70dfbe8aa28d4ddea835

URL: https://github.com/llvm/llvm-project/commit/273600ccea914b1178df70dfbe8aa28d4ddea835
DIFF: https://github.com/llvm/llvm-project/commit/273600ccea914b1178df70dfbe8aa28d4ddea835.diff

LOG: [mlir][Linalg] Implement the tiling interface for softmax

Original implementation from Harsh Menon <harsh at nod-labs.com>.

Differential Revision: https://reviews.llvm.org/D153421

Added: 
    mlir/test/Dialect/Linalg/tile-softmax.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 43b86cda281e75..eb68890c8487dc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -19,6 +19,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/TilingInterface.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 
 // Base class for Linalg dialect ops that do not correspond to library calls.
@@ -91,7 +92,12 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
 def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Softmax operator";
   let description = [{
     linalg.softmax computes a numerically stable version of softmax.
@@ -104,6 +110,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 
     This is an aggregate linalg operation that further reduces to a small DAG of
     structured operations.
+
+    Warning: Regarding the tiling capabilities, the implementation doesn't
+    check that the provided dimensions make sense. This is the responsability
+    of the transformation calling the tiling to ensure that the provided
+    sizes for each dimension make sense with respect to the semantic of
+    softmax.
   }];
 
   let arguments = (ins AnyShaped:$input,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1c9a78224902a0..0bda7dbac1bd8d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -46,6 +47,41 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
+/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
+static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
+                                int64_t dim) {
+  auto type = cast<ShapedType>(v.getType());
+  if (!type.isDynamicDim(dim))
+    return builder.getIndexAttr(type.getDimSize(dim));
+
+  return getAsOpFoldResult(
+      TypeSwitch<Type, Value>(v.getType())
+          .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+            return builder.create<tensor::DimOp>(loc, v, dim);
+          })
+          .Case<MemRefType>([&](MemRefType t) -> Value {
+            return builder.create<memref::DimOp>(loc, v, dim);
+          }));
+}
+
+/// Returns a memref.subview or a tensor.extract_slice based on the type of the
+/// `source`.
+static Value getSlice(OpBuilder &b, Location loc, Value source,
+                      ArrayRef<OpFoldResult> offsets,
+                      ArrayRef<OpFoldResult> sizes,
+                      ArrayRef<OpFoldResult> strides) {
+  return TypeSwitch<Type, Value>(source.getType())
+      .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+        return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
+                                                strides);
+      })
+      .Case<MemRefType>([&](MemRefType type) -> Value {
+        return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
+                                           strides);
+      })
+      .Default([&](Type t) { return nullptr; });
+}
+
 //===----------------------------------------------------------------------===//
 // Helper functions
 //===----------------------------------------------------------------------===//
@@ -2182,6 +2218,62 @@ LogicalResult SoftmaxOp::verify() {
   return success();
 }
 
+SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
+  int64_t operandRank = getInputOperandRank();
+  SmallVector<Range> loopBounds(operandRank);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value source = getInput();
+  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, source, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
+                                                 utils::IteratorType::parallel);
+  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
+  return iteratorTypes;
+}
+
+FailureOr<TilingResult>
+SoftmaxOp::getTiledImplementation(OpBuilder &builder,
+                                  ArrayRef<OpFoldResult> offsets,
+                                  ArrayRef<OpFoldResult> sizes) {
+  int64_t rank = getInputOperandRank();
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  SmallVector<OpFoldResult> strides(rank, oneAttr);
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(
+      getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
+  tiledOperands.emplace_back(
+      getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  if (hasTensorSemantics())
+    resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+LogicalResult SoftmaxOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  if (resultNumber == 0) {
+    resultOffsets.assign(offsets.begin(), offsets.end());
+    resultSizes.assign(sizes.begin(), sizes.end());
+    return success();
+  }
+  return failure();
+}
+
 // cast(dynamic) -> static.
 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   return memref::foldMemRefCast(*this);

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2f383473a7a192..781e48a9824a1c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -555,7 +555,13 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
   auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
       rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
       cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
-  assert(succeeded(maybeRankReduced) && "unexpected shape");
+  if (failed(maybeRankReduced)) {
+    diag.attachNote(producerOp->getLoc())
+        << "shape types don't match (missing canonicalization?):\nTiledOp: "
+        << tileAndFuseResult->tiledValues[0]
+        << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
+    return {};
+  }
   rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
 
   // Add new outputs to containing op, if required

diff  --git a/mlir/test/Dialect/Linalg/tile-softmax.mlir b/mlir/test/Dialect/Linalg/tile-softmax.mlir
new file mode 100644
index 00000000000000..f7f15f008553a0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-softmax.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize --split-input-file | FileCheck %s
+
+// Check that we can tile softmax on tensors.
+// The tiling here is 2x3.
+// So the shape used in the inner loop should be 2x3x256, however since 3
+// doesn't divide the second dimension (64), we should see a '?' in the shape.
+// The actual size, used through extract_slice/insert_slice, should come from a
+// `min(64 - current iteration index, 3)`
+
+// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)>
+// CHECK-LABEL:   func.func @softmax(
+// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C64:.*]] = arith.constant 64 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[TENSOR_EMPTY:.*]] = tensor.empty() : tensor<16x64x256xf32>
+// CHECK:           %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C16]] step %[[C2]] iter_args(%[[VAL_9:.*]] = %[[TENSOR_EMPTY]]) -> (tensor<16x64x256xf32>) {
+// CHECK:             %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C64]] step %[[C3]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (tensor<16x64x256xf32>) {
+// CHECK:               %[[VAL_13:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_11]])
+// CHECK:               %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32>
+// CHECK:               %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32>
+// CHECK:               %[[VAL_16:.*]] = linalg.softmax dimension(1) ins(%[[VAL_14]] : tensor<2x?x256xf32>) outs(%[[VAL_15]] : tensor<2x?x256xf32>) -> tensor<2x?x256xf32>
+// CHECK:               %[[VAL_17:.*]] = tensor.insert_slice %[[VAL_16]] into %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<2x?x256xf32> into tensor<16x64x256xf32>
+// CHECK:               scf.yield %[[VAL_17]] : tensor<16x64x256xf32>
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_18:.*]] : tensor<16x64x256xf32>
+// CHECK:           }
+// CHECK:           return %[[VAL_19:.*]] : tensor<16x64x256xf32>
+// CHECK:         }
+func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
+  %0 = tensor.empty() : tensor<16x64x256xf32>
+  %1 = linalg.softmax
+         dimension(1) ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
+  return %1 : tensor<16x64x256xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop:2 = transform.structured.tile %0 [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+}
+
+// -----
+
+// Test the softmax tiling interface with the tile_to_forall_op transform and
+// check that it composes properly with the fuse transform.
+// This should sink the linalg.generic inside the scf.forall and run that
+// generic on 2x4x256 tensors (2==16/8, 4==64/16).
+
+// CHECK: #[[$TIMES2_MAP:.*]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK: #[[$TIMES4_MAP:.*]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL:   func.func @softmax_tile_n_fuse(
+// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK:           %[[VAL_2:.*]] = tensor.empty() : tensor<16x64x256xf32>
+// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<16x64x256xf32>
+// CHECK:           %[[VAL_4:.*]] = scf.forall (%[[VAL_5:.*]], %[[VAL_6:.*]]) in (8, 16) shared_outs(%[[VAL_7:.*]] = %[[VAL_3]]) -> (tensor<16x64x256xf32>) {
+// CHECK:             %[[VAL_8:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]])
+// CHECK:             %[[VAL_9:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]])
+// CHECK:             %[[VAL_10:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]])
+// CHECK:             %[[VAL_11:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]])
+// CHECK:             %[[VAL_12:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]])
+// CHECK:             %[[VAL_13:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]])
+// CHECK:             %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32>
+// CHECK:             %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_2]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32>
+// CHECK:             %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_14]] : tensor<2x4x256xf32>) outs(%[[VAL_15]] : tensor<2x4x256xf32>) {
+// CHECK:             ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32):
+// CHECK:               %[[VAL_19:.*]] = arith.addf %[[VAL_18]], %[[VAL_1]] : f32
+// CHECK:               linalg.yield %[[VAL_19]] : f32
+// CHECK:             } -> tensor<2x4x256xf32>
+// CHECK:             %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32>
+// CHECK:             %[[VAL_21:.*]] = linalg.softmax dimension(1) ins(%[[VAL_22:.*]] : tensor<2x4x256xf32>) outs(%[[VAL_20]] : tensor<2x4x256xf32>) -> tensor<2x4x256xf32>
+// CHECK:             scf.forall.in_parallel {
+// CHECK:               tensor.parallel_insert_slice %[[VAL_21]] into %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<2x4x256xf32> into tensor<16x64x256xf32>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[VAL_23:.*]] : tensor<16x64x256xf32>
+// CHECK:         }
+
+func.func @softmax_tile_n_fuse(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
+  %empty = tensor.empty() : tensor<16x64x256xf32>
+  %cst = arith.constant 1.000000e+00 : f32
+  %eltwise = linalg.generic
+      {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+       iterator_types = ["parallel", "parallel", "parallel"]
+      }
+      ins(%arg0 : tensor<16x64x256xf32>)
+      outs(%empty : tensor<16x64x256xf32>) {
+    ^bb0(%arg2: f32, %arg3: f32):
+      %arg3Plus1 = arith.addf %arg3, %cst : f32
+      linalg.yield %arg3Plus1 : f32
+    } -> tensor<16x64x256xf32>
+
+  %0 = tensor.empty() : tensor<16x64x256xf32>
+  %1 = linalg.softmax
+         dimension(1) ins(%eltwise : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
+  return %1 : tensor<16x64x256xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+  // Tile the root.
+  %forall_op, %tiled_op = transform.structured.tile_to_forall_op %0 num_threads [8, 16]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+  // Fuse all producers.
+  %1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.structured.fuse_into_containing_op %1 into %forall_op
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+// -----
+
+// Same as the previous test but on memrefs.
+
+// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)>
+// CHECK-LABEL:   func.func @softmax_memref(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<16x64x256xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: memref<16x64x256xf32>) {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[C64:.*]] = arith.constant 64 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           scf.for %[[VAL_7:.*]] = %[[C0]] to %[[C16]] step %[[C2]] {
+// CHECK:             scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C64]] step %[[C3]] {
+// CHECK:               %[[VAL_9:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_8]])
+// CHECK:               %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>
+// CHECK:               %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>
+// CHECK:               linalg.softmax dimension(1) ins(%[[VAL_10]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>) outs(%[[VAL_11]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>)
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) {
+  linalg.softmax
+    dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>)
+  return
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop:2 = transform.structured.tile %0 [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+}


        


More information about the Mlir-commits mailing list