[Mlir-commits] [mlir] 93f9922 - [mlir][linalg] adding operation to access the iteration index of enclosing linalg ops.
Tobias Gysi
llvmlistbot at llvm.org
Mon Apr 12 06:39:46 PDT 2021
Author: Tobias Gysi
Date: 2021-04-12T13:37:17Z
New Revision: 93f9922d65f8a89b8a9299eeab61511ce8baa3bc
URL: https://github.com/llvm/llvm-project/commit/93f9922d65f8a89b8a9299eeab61511ce8baa3bc
DIFF: https://github.com/llvm/llvm-project/commit/93f9922d65f8a89b8a9299eeab61511ce8baa3bc.diff
LOG: [mlir][linalg] adding operation to access the iteration index of enclosing linalg ops.
The `linalg.index` operation provides access to the iteration indexes of immediately enclosing linalg operations. It takes a dimension `dim` attribute and returns the iteration index in the given dimension. Having `linalg.index` allows us to unify `linalg.generic` and `linalg.indexed_generic` and also enables index access in named operations.
Differential Revision: https://reviews.llvm.org/D100292
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
mlir/test/Dialect/Linalg/fusion-tensor.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/loop-order.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 95e008aacc451..6c3e86c6d2f14 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1026,7 +1026,16 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return $_op.getLibraryCallName();
}]
>,
-
+ InterfaceMethod<
+ /*desc=*/[{
+ Return whether the op accesses the iteration indices.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasIndexSemantics",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/""
+ >,
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index b0a93f36ab758..931eada61ac4a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -618,5 +618,51 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let hasFolder = 1;
}
+def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>,
+ Arguments<(ins Confined<I64Attr, [IntMinValue<0>]>:$dim)>,
+ Results<(outs Index:$result)> {
+ let summary = "linalg index operation";
+ let description = [{
+ The `linalg.index` operation returns the iteration index of the immediately
+ enclosing linalg structured operation for the iteration dimension `dim`. The
+ `dim` attribute specifies the position of the accessed dimension in the
+ indexing map domain.
+
+ Example:
+
+ ```mlir
+ #map = affine_map<(i, j) -> (i, j)>
+ linalg.generic {indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%I, %J : memref<?x?xindex>, memref<?x?xindex>) {
+ ^bb0(%arg0 : index, %arg1 : index):
+ // Access the outer iteration dimension i
+ %i = linalg.index 0 : index
+ // Access the inner iteration dimension j
+ %j = linalg.index 1 : index
+ linalg.yield %i, %j : index, index
+ }
+ ```
+
+ This may lower to IR resembling:
+
+ ```mlir
+ %0 = dim %I, %c0 : memref<?x?xindex>
+ %1 = dim %I, %c1 : memref<?x?xindex>
+ scf.for %i = %c0 to %0 step %c1 {
+ scf.for %j = %c0 to %1 step %c1 {
+ store %i, %I[%i, %j] : memref<?x?xindex>
+ store %j, %J[%i, %j] : memref<?x?xindex>
+ }
+ }
+ ```
+ }];
+ let builders = [
+ OpBuilder<(ins "int64_t":$dim),
+ [{ build($_builder, $_state, $_builder.getIndexType(), dim); }]>
+ ];
+
+ let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
+}
#endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f80d6759172c1..b158cec529d99 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -35,6 +35,14 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
}
+ // Return whether the op accesses the iteration indices.
+ bool hasIndexSemantics() {
+ Operation *op = this->getOperation();
+ if(op->getNumRegions() == 0 || op->getRegion(0).empty())
+ return false;
+ return !op->getRegion(0).front().getOps<IndexOp>().empty();
+ }
+
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 5260c503cce92..d385b46848df9 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -106,6 +106,10 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
return failure();
+ // TODO: remove once index ops are supported.
+ if (op.hasIndexSemantics())
+ return failure();
+
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
if (!libraryCallName)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d8b512cdeea06..e93e13cb7e192 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2047,6 +2047,21 @@ LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
return foldMemRefCast(*this);
}
+//===----------------------------------------------------------------------===//
+// IndexOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(IndexOp op) {
+ auto linalgOp = dyn_cast<LinalgOp>(op->getParentOp());
+ if (!linalgOp)
+ return op.emitOpError("expected parent op with LinalgOp interface");
+ if (linalgOp.getNumLoops() <= op.dim())
+ return op.emitOpError("expected dim (")
+ << op.dim() << ") to be lower than the number of loops ("
+ << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
+ return success();
+}
+
/////// Operations corresponding to library calls defined with Tablegen ////////
template <typename LinalgPoolingOp>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index b3af82c80ceea..0059178d30e5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -177,6 +177,10 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
+ // TODO: remove once index ops are supported.
+ if (op.hasIndexSemantics())
+ return failure();
+
SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
if (indexingMaps.empty())
return failure();
@@ -321,6 +325,10 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
+ // TODO: remove once index ops are supported.
+ if (op.hasIndexSemantics())
+ return failure();
+
if (!op.hasTensorSemantics())
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 34eac4bdfcaaf..713de7b22c4b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -28,6 +28,10 @@ using namespace mlir::linalg;
/// Implementation of fusion of generic ops and indexed_generic ops.
static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx) {
+ // TODO: remove once index ops are supported.
+ if (producer.hasIndexSemantics() || consumer.hasIndexSemantics())
+ return false;
+
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
return false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index a6e296b0ea112..8b2b0cde8f9a5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -527,7 +527,9 @@ class LinalgRewritePattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (!isa<LinalgOp>(op))
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ // TODO: remove hasIndexSemantics check once index ops are supported.
+ if (!linalgOp || linalgOp.hasIndexSemantics())
return failure();
if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e7095a9f0b34e..00d6a3a1b95f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -246,7 +246,8 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
+ // TODO: remove hasIndexSemantics check once index ops are supported.
+ if (!linalgOp || linalgOp.hasIndexSemantics())
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
@@ -314,7 +315,8 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
+ // TODO: remove hasIndexSemantics check once index ops are supported.
+ if (!linalgOp || linalgOp.hasIndexSemantics())
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
@@ -407,7 +409,8 @@ mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
+ // TODO: remove hasIndexSemantics check once index ops are supported.
+ if (!linalgOp || linalgOp.hasIndexSemantics())
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
@@ -465,7 +468,8 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
+ // TODO: remove hasIndexSemantics check once index ops are supported.
+ if (!linalgOp || linalgOp.hasIndexSemantics())
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 10562d68a9e0a..e2c069ea7f477 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -402,6 +402,9 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
+ // TODO: remove once index ops are supported.
+ if (linalgOp.hasIndexSemantics())
+ return failure();
if (isElementwise(op))
return success();
return success(isaContractionOpInterface(linalgOp));
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5c8866662d2ba..e65a408a0a0e2 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -566,3 +566,19 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
// CHECK: return %[[RESULT_RESHAPE]]
+
+// -----
+
+// CHECK: #{{.+}} = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @index_op
+func @index_op(%arg0: memref<1x8xindex>) {
+ linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%arg0 : memref<1x8xindex>) {
+ ^bb0(%arg1: index): // no predecessors
+ %0 = linalg.index 1 : index
+ linalg.yield %0 : index
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
index 76e9148cdff0c..2cb7ef1274b8d 100644
--- a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
@@ -188,3 +188,53 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
// CHECK: {{.*}} = index_cast [[j_new]] : index to i32
// CHECK: linalg.generic
// CHECK: addf
+
+// -----
+
+#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#id_2d = affine_map<(d0, d1) -> (d0, d1)>
+#pointwise_2d_trait = {
+ indexing_maps = [#id_2d],
+ iterator_types = ["parallel", "parallel"]
+}
+func @index_op(%A: memref<?x?xindex>,
+ %B: memref<?x?xindex>) {
+ linalg.generic #pointwise_2d_trait
+ outs(%B : memref<?x?xindex>) {
+ ^bb0(%arg6: index): // no predecessors
+ %2 = constant 0 : index
+ linalg.yield %2 : index
+ }
+ %c1 = constant 1 : index
+ %c0 = constant 0 : index
+ %c25 = constant 25 : index
+ %c10 = constant 10 : index
+ %0 = memref.dim %A, %c0 : memref<?x?xindex>
+ %1 = memref.dim %A, %c1 : memref<?x?xindex>
+ %2 = memref.dim %B, %c0 : memref<?x?xindex>
+ %3 = memref.dim %B, %c1 : memref<?x?xindex>
+ scf.for %arg2 = %c0 to %0 step %c10 {
+ scf.for %arg3 = %c0 to %1 step %c25 {
+ %4 = memref.subview %A[%arg2, %arg3][%c10, %c25][%c1, %c1] :
+ memref<?x?xindex> to memref<?x?xindex, #map>
+ %5 = memref.subview %B[%arg2, %arg3][%c10, %c25][%c1, %c1] :
+ memref<?x?xindex> to memref<?x?xindex, #map>
+ linalg.generic {
+ indexing_maps = [#id_2d, #id_2d],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%4 : memref<?x?xindex, #map>)
+ outs(%5 : memref<?x?xindex, #map>) {
+ ^bb0(%arg6: index, %arg7: index):
+ %6 = linalg.index 0 : index
+ linalg.yield %6 : index
+ }
+ }
+ }
+ return
+}
+// CHECK-LABEL: func @index_op
+// CHECK: linalg.generic
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK-NOT: scf.for
+// CHECK: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index b0a006398c991..7983fe19a95a3 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -652,3 +652,29 @@ func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tenso
// CHECK-NEXT: return %[[R]] : tensor<1x8xi32>
return %1 : tensor<1x8xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @index_op(
+// CHECK-COUNT-2: linalg.generic
+func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8xindex> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%arg0 : tensor<1x8xindex>) {
+ ^bb0(%a: index): // no predecessors
+ %2 = linalg.index 1 : index
+ linalg.yield %2 : index
+ } -> tensor<1x8xindex>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : tensor<1x8xindex>)
+ outs(%arg1 : tensor<1x8xindex>) {
+ ^bb0(%a: index, %b: index): // no predecessors
+ %2 = linalg.index 0 : index
+ %3 = addi %2, %a : index
+ linalg.yield %3 : index
+ } -> tensor<1x8xindex>
+ return %1 : tensor<1x8xindex>
+}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index bdaf0ea351aa9..bb17c9bf1399c 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -24,6 +24,41 @@ func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// -----
+func @index_parent() {
+ // expected-error @+1 {{op expected parent op with LinalgOp interface}}
+ linalg.index 0 : index
+}
+
+// -----
+
+func @index_dim_lower_than_number_of_loops(%arg0: memref<f32>) {
+ // expected-error @+6 {{op expected dim (2) to be lower than the number of loops (0) of the enclosing LinalgOp}}
+ linalg.generic {
+ indexing_maps = [ affine_map<() -> ()> ],
+ iterator_types = []}
+ outs(%arg0 : memref<f32>) {
+ ^bb(%0: f32):
+ linalg.index 2 : index
+ linalg.yield %0 : f32
+ }
+}
+
+// -----
+
+func @index_dim_negative(%arg0: memref<f32>) {
+ // expected-error @+6 {{op attribute 'dim' failed to satisfy constraint: 64-bit signless integer attribute whose minimum value is 0}}
+ linalg.generic {
+ indexing_maps = [ affine_map<() -> ()> ],
+ iterator_types = []}
+ outs(%arg0 : memref<f32>) {
+ ^bb(%0: f32):
+ linalg.index -1 : index
+ linalg.yield %0 : f32
+ }
+}
+
+// -----
+
func @generic_no_region(%arg0: memref<f32>) {
// expected-error @+5 {{expected '{' to begin a region}}
linalg.generic {
diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir
index d1ff47977c351..968ffdc5e7478 100644
--- a/mlir/test/Dialect/Linalg/loop-order.mlir
+++ b/mlir/test/Dialect/Linalg/loop-order.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=LOOP %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=PARALLEL %s
-// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=AFFINE %s
+// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s
func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32>
@@ -22,3 +22,24 @@ func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
// AFFINE: affine.for %{{.*}} = 0 to 2
// AFFINE: affine.for %{{.*}} = 0 to 3
+// -----
+
+func @index_op(%arg0: memref<4x8xindex>) {
+ linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%arg0 : memref<4x8xindex>) {
+ ^bb0(%arg1: index): // no predecessors
+ %0 = linalg.index 1 : index
+ linalg.yield %0 : index
+ }
+ return
+}
+// LOOP-LABEL: @index_op
+// LOOP: linalg.generic
+
+// PARALLEL-LABEL: @index_op
+// PARALLEL: linalg.generic
+
+// AFFINE-LABEL: @index_op
+// AFFINE: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 084b8a339c0de..d03161541e2f6 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -525,6 +525,9 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%a: vector<3x4xi4>, %b: f32) :
+ %0 = linalg.index 0 : index
+ %1 = linalg.index 1 : index
+ %2 = linalg.index 2 : index
linalg.yield %b : f32
}
return
@@ -538,6 +541,9 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: attrs = {foo = 1 : i64} {
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
+// CHECK: %{{.*}} = linalg.index 0 : index
+// CHECK: %{{.*}} = linalg.index 1 : index
+// CHECK: %{{.*}} = linalg.index 2 : index
// CHECK: linalg.yield %{{.*}} : f32
func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index c761bd6cd57e0..d8b20904f751a 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -377,3 +377,18 @@ func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memre
// TILE-234: for
// TILE-234-NOT: for
// TILE-234: linalg.generic
+
+// TILE-2-LABEL: func @index_op
+// TILE-2-NOT: for
+// TILE-2: linalg.generic
+func @index_op(%arg0: memref<?x?xindex>) {
+ linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%arg0 : memref<?x?xindex>) {
+ ^bb0(%arg1: index): // no predecessors
+ %0 = linalg.index 1 : index
+ linalg.yield %0 : index
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3daab904d0323..faaadcf94a5c0 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -469,3 +469,19 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
} : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
return %0 : tensor<6x?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: @index_op
+// CHECK: linalg.generic
+func @index_op(%arg0: memref<4x8xindex>) {
+ linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%arg0 : memref<4x8xindex>) {
+ ^bb0(%arg1: index): // no predecessors
+ %0 = linalg.index 1 : index
+ linalg.yield %0 : index
+ }
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 3ef6ed5e4b4ba..6931f07a8ab4c 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -126,6 +126,10 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Save original Linalg ops, we only want to make a pass over those.
SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) {
+ // TODO: remove hasIndexSemantics check once index ops are supported.
+ if (op.hasIndexSemantics())
+ return;
+
// TODO: support multi-results.
if (op->getNumResults() <= 1)
linalgOps.push_back(op);
More information about the Mlir-commits
mailing list