[Mlir-commits] [mlir] 44a8897 - [mlir][linalg] Fold ExtractSliceOps during tiling.

Tobias Gysi llvmlistbot at llvm.org
Tue Sep 14 04:51:05 PDT 2021


Author: Tobias Gysi
Date: 2021-09-14T11:43:52Z
New Revision: 44a889778ceeb6bcb11702f5c940306905a3821e

URL: https://github.com/llvm/llvm-project/commit/44a889778ceeb6bcb11702f5c940306905a3821e
DIFF: https://github.com/llvm/llvm-project/commit/44a889778ceeb6bcb11702f5c940306905a3821e.diff

LOG: [mlir][linalg] Fold ExtractSliceOps during tiling.

Add the makeComposedExtractSliceOp method that creates an ExtractSliceOp and folds chains of ExtractSliceOps by computing the sum of their offsets and by multiplying their strides.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D109601

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/tile-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 937330c90bf29..fd0a0befcbc29 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -56,6 +56,25 @@ SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b);
 /// Otherwise return nullptr.
 IntegerAttr getSmallestBoundingIndex(Value size);
 
+/// Create an ExtractSliceOp and, if `source` is defined by an ExtractSliceOp,
+/// fold it by adding the offsets.
+///
+/// Example:
+/// ```
+///   %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to
+///                                                        tensor<3x32xf32>
+///   %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to
+///                                                    tensor<3x4xf32>
+/// ```
+/// folds into:
+/// ```
+///   %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to
+///                                                         tensor<3x4xf32>
+/// ```
+tensor::ExtractSliceOp makeComposedExtractSliceOp(
+    OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides);
+
 //===----------------------------------------------------------------------===//
 // Fusion utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
index fd668ba5d87a6..1d2adc62d2714 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
@@ -72,6 +72,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
   }
 };
 
+/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
+/// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
+/// Other attribute types are not supported.
+Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
+                                      OpFoldResult ofr);
+
 /// Helper struct to build simple arithmetic quantities with minimal type
 /// inference support.
 struct ArithBuilder {

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 1305cfdf464a4..296956d7425eb 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/LoopUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "linalg-utils"
@@ -194,6 +195,48 @@ IntegerAttr getSmallestBoundingIndex(Value size) {
   return nullptr;
 }
 
+tensor::ExtractSliceOp makeComposedExtractSliceOp(
+    OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
+  assert(source && "expect source to be nonzero");
+
+  // Do not fold if the producer is not an ExtractSliceOp.
+  auto producerOp = source.getDefiningOp<tensor::ExtractSliceOp>();
+  if (!producerOp)
+    return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
+                                            strides);
+
+  // Do not fold if the producer is rank reducing or if there are any non-unit
+  // strides. Supporting non-unit strides complicates the offset computation
+  // since the consumer offsets need to be multiplied by the producer strides.
+  // TODO: support non-unit strides once there are use cases.
+  SmallVector<OpFoldResult> allStrides = producerOp.getMixedStrides();
+  allStrides.append(strides.begin(), strides.end());
+  bool hasNonUnitStride = any_of(allStrides, [](OpFoldResult ofr) {
+    return getConstantIntValue(ofr) != static_cast<int64_t>(1);
+  });
+  if (hasNonUnitStride ||
+      producerOp.getSourceType().getRank() !=
+          producerOp.getResult().getType().cast<ShapedType>().getRank())
+    return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
+                                            strides);
+
+  // Fold the producer by adding the offests and extracting the slice directly
+  // from the producer source tensor.
+  SmallVector<OpFoldResult> foldedOffsets(offsets.begin(), offsets.end());
+  AffineExpr dim1, dim2;
+  bindDims(b.getContext(), dim1, dim2);
+  for (auto en : enumerate(producerOp.getMixedOffsets())) {
+    SmallVector<Value> offsetValues = {
+        getValueOrCreateConstantIndexOp(b, loc, foldedOffsets[en.index()]),
+        getValueOrCreateConstantIndexOp(b, loc, en.value())};
+    foldedOffsets[en.index()] =
+        makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues).getResult();
+  }
+  return b.create<tensor::ExtractSliceOp>(loc, producerOp.source(),
+                                          foldedOffsets, sizes, strides);
+}
+
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
@@ -603,15 +646,18 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
     strides.push_back(builder.getIndexAttr(1));
   }
 
-  Operation *sliceOp = shapedType.isa<MemRefType>()
-                           ? builder
-                                 .create<memref::SubViewOp>(
-                                     loc, valueToTile, offsets, sizes, strides)
-                                 .getOperation()
-                           : builder
-                                 .create<tensor::ExtractSliceOp>(
-                                     loc, valueToTile, offsets, sizes, strides)
-                                 .getOperation();
+  auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
+                      .Case([&](MemRefType) {
+                        return builder.create<memref::SubViewOp>(
+                            loc, valueToTile, offsets, sizes, strides);
+                      })
+                      .Case([&](RankedTensorType) {
+                        return makeComposedExtractSliceOp(
+                            builder, loc, valueToTile, offsets, sizes, strides);
+                      })
+                      .Default([](ShapedType) -> Operation * {
+                        llvm_unreachable("Unexpected shaped type");
+                      });
   return sliceOp->getResult(0);
 }
 

diff  --git a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
index 30ef0937cda73..3f66738cc78f6 100644
--- a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
@@ -49,6 +49,15 @@ void mlir::getPositionsOfShapeOne(
   }
 }
 
+Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
+                                            OpFoldResult ofr) {
+  if (auto value = ofr.dyn_cast<Value>())
+    return value;
+  auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
+  assert(attr && "expect the op fold result casts to an integer attribute");
+  return b.create<ConstantIndexOp>(loc, attr.getValue().getSExtValue());
+}
+
 Value ArithBuilder::_and(Value lhs, Value rhs) {
   return b.create<AndOp>(loc, lhs, rhs);
 }

diff  --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 2cd04668a475c..c0619ec0627ab 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -130,3 +130,49 @@ func @generic_op_tensors(
 // TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]])
 // TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]])
 // TLOOP-SAME: distribution["block_x", "block_y", "none"] {
+
+// -----
+
+//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (d0 + 3)>
+//  CHECK-DAG:  #[[MAP2:.*]] = affine_map<(d0) -> (d0 + 4)>
+
+//      CHECK:  fold_extract_slice
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<?x128xf32>
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<?x42xf32>
+func @fold_extract_slice(
+  %arg0 : tensor<?x128xf32>, %arg1 : tensor<?x42xf32>, %arg2 : tensor<?x42x?xf32>) -> tensor<?x42xf32> {
+
+  //      CHECK:    %[[C0:.*]] = constant 0
+  %c0 = constant 0 : index
+
+  //      CHECK:    %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+  %0 = tensor.dim %arg1, %c0 : tensor<?x42xf32>
+  %1 = tensor.extract_slice %arg0[3, 4] [%0, 42] [1, 1] : tensor<?x128xf32> to tensor<?x42xf32>
+
+  //      CHECK:    scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      CHECK:      scf.for %[[IV1:[0-9a-zA-Z]*]] =
+
+  // Fold the existing extract slice op into the one created by the tiling.
+  //      CHECK:        %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]]
+  //      CHECK:        %[[OFF0:.*]] = affine.apply #[[MAP1]](%[[IV0]]
+  //      CHECK:        %[[OFF1:.*]] = affine.apply #[[MAP2]](%[[IV1]]
+  //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+  // CHECK-SAME:                                          %[[OFF0]], %[[OFF1]]
+  // CHECK-SAME:                                          %[[SIZE0]], 3
+  // CHECK-SAME:                                          1, 1
+  //      CHECK:        {{.*}} = linalg.generic {{.*}} ins(%[[T0]]
+  %2 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1)>],
+     iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%1, %arg2 : tensor<?x42xf32>, tensor<?x42x?xf32>)
+    outs(%arg1 : tensor<?x42xf32>) {
+    ^bb0(%arg3 : f32, %arg4: f32, %arg5: f32):
+      %5 = addf %arg3, %arg5 : f32
+      linalg.yield %5 : f32
+    } -> tensor<?x42xf32>
+  return %2 : tensor<?x42xf32>
+}
+


        


More information about the Mlir-commits mailing list