[Mlir-commits] [mlir] 2e5fe72 - [MLIR][Linalg] Use `DenseI64ArrayAttr` in `InterchangeOp` (NFC)
Lorenzo Chelini
llvmlistbot at llvm.org
Fri Dec 16 07:37:38 PST 2022
Author: Lorenzo Chelini
Date: 2022-12-16T16:37:33+01:00
New Revision: 2e5fe721724446265d1ea48267b6a34d33fca14b
URL: https://github.com/llvm/llvm-project/commit/2e5fe721724446265d1ea48267b6a34d33fca14b
DIFF: https://github.com/llvm/llvm-project/commit/2e5fe721724446265d1ea48267b6a34d33fca14b.diff
LOG: [MLIR][Linalg] Use `DenseI64ArrayAttr` in `InterchangeOp` (NFC)
Use op separator to improve code navigation.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D139917
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-interchange.mlir
mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
mlir/test/Dialect/Linalg/transform-patterns.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1cac6b83a1e9..1cb321d76d03 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -19,6 +19,10 @@ include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"
+//===----------------------------------------------------------------------===//
+// DecomposeOp
+//===----------------------------------------------------------------------===//
+
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -48,6 +52,10 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
}];
}
+//===----------------------------------------------------------------------===//
+// FuseOp
+//===----------------------------------------------------------------------===//
+
def FuseOp : Op<Transform_Dialect, "structured.fuse",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -67,6 +75,10 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// FuseIntoContainingOp
+//===----------------------------------------------------------------------===//
+
def FuseIntoContainingOp :
Op<Transform_Dialect, "structured.fuse_into_containing_op",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -120,6 +132,10 @@ def FuseIntoContainingOp :
];
}
+//===----------------------------------------------------------------------===//
+// GeneralizeOp
+//===----------------------------------------------------------------------===//
+
def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -149,6 +165,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
}];
}
+//===----------------------------------------------------------------------===//
+// InterchangeOp
+//===----------------------------------------------------------------------===//
+
def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -169,10 +189,14 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let arguments =
(ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
+ ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">,
+ [DenseArrayNonNegative<DenseI64ArrayAttr>]>:$iterator_interchange);
let results = (outs PDL_Operation:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat = [{
+ $target
+ (`iterator_interchange` `=` $iterator_interchange^)? attr-dict
+ }];
let hasVerifier = 1;
let extraClassDeclaration = [{
@@ -183,6 +207,10 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}
+//===----------------------------------------------------------------------===//
+// MatchOp
+//===----------------------------------------------------------------------===//
+
def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match",
[
I32EnumAttrCase<"LinalgOp", 0>,
@@ -245,6 +273,10 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
}];
}
+//===----------------------------------------------------------------------===//
+// MultiTileSizesOp
+//===----------------------------------------------------------------------===//
+
def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface, TransformEachOpTrait]> {
@@ -309,6 +341,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
}];
}
+//===----------------------------------------------------------------------===//
+// PadOp
+//===----------------------------------------------------------------------===//
+
def PadOp : Op<Transform_Dialect, "structured.pad",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -349,6 +385,10 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
}];
}
+//===----------------------------------------------------------------------===//
+// PromoteOp
+//===----------------------------------------------------------------------===//
+
def PromoteOp : Op<Transform_Dialect, "structured.promote",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -388,6 +428,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
}];
}
+//===----------------------------------------------------------------------===//
+// ReplaceOp
+//===----------------------------------------------------------------------===//
+
def ReplaceOp : Op<Transform_Dialect, "structured.replace",
[IsolatedFromAbove, DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>] # GraphRegionNoTerminator.traits> {
@@ -410,6 +454,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ScalarizeOp
+//===----------------------------------------------------------------------===//
+
def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -449,6 +497,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
}];
}
+//===----------------------------------------------------------------------===//
+// SplitOp
+//===----------------------------------------------------------------------===//
+
def SplitOp : Op<Transform_Dialect, "structured.split",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -481,6 +533,10 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
let hasCustomAssemblyFormat = 1;
}
+//===----------------------------------------------------------------------===//
+// SplitReductionOp
+//===----------------------------------------------------------------------===//
+
def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
@@ -649,6 +705,10 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
}];
}
+//===----------------------------------------------------------------------===//
+// TileReductionUsingScfOp
+//===----------------------------------------------------------------------===//
+
def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
@@ -748,6 +808,10 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
}];
}
+//===----------------------------------------------------------------------===//
+// TileReductionUsingForeachThreadOp
+//===----------------------------------------------------------------------===//
+
def TileReductionUsingForeachThreadOp :
Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
@@ -853,6 +917,10 @@ def TileReductionUsingForeachThreadOp :
}
+//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -910,6 +978,10 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
}];
}
+//===----------------------------------------------------------------------===//
+// TileToForeachThreadOp
+//===----------------------------------------------------------------------===//
+
def TileToForeachThreadOp :
Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
[AttrSizedOperandSegments,
@@ -1023,6 +1095,10 @@ def TileToForeachThreadOp :
}];
}
+//===----------------------------------------------------------------------===//
+// TileToScfForOp
+//===----------------------------------------------------------------------===//
+
def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -1080,6 +1156,10 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
}];
}
+//===----------------------------------------------------------------------===//
+// VectorizeOp
+//===----------------------------------------------------------------------===//
+
def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c8995e609e77..3138268e57b1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -34,16 +34,6 @@ using namespace mlir::transform;
#define DEBUG_TYPE "linalg-transforms"
-/// Extracts a vector of unsigned from an array attribute. Asserts if the
-/// attribute contains values other than intergers. May truncate.
-static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
- SmallVector<unsigned> result;
- result.reserve(attr.size());
- for (APInt value : attr.getAsValueRange<IntegerAttr>())
- result.push_back(value.getZExtValue());
- return result;
-}
-
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
/// function that returns the "main" result or failure. Returns failure if the
@@ -604,8 +594,7 @@ DiagnosedSilenceableFailure
transform::InterchangeOp::applyToOne(linalg::GenericOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- SmallVector<unsigned> interchangeVector =
- extractUIntArray(getIteratorInterchange());
+ ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
// Exit early if no transformation is needed.
if (interchangeVector.empty()) {
results.push_back(target);
@@ -613,7 +602,9 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
}
TrivialPatternRewriter rewriter(target->getContext());
FailureOr<GenericOp> res =
- interchangeGenericOp(rewriter, target, interchangeVector);
+ interchangeGenericOp(rewriter, target,
+ SmallVector<unsigned>(interchangeVector.begin(),
+ interchangeVector.end()));
if (failed(res))
return DiagnosedSilenceableFailure::definiteFailure();
results.push_back(res->getOperation());
@@ -621,9 +612,8 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
}
LogicalResult transform::InterchangeOp::verify() {
- SmallVector<unsigned> permutation =
- extractUIntArray(getIteratorInterchange());
- auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
+ ArrayRef<int64_t> permutation = getIteratorInterchange();
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
permutation.begin(), permutation.end())) {
return emitOpError()
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index e3607b2f0b96..402e80b0af6b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -257,7 +257,7 @@ LogicalResult transform::AlternativesOp::verify() {
}
//===----------------------------------------------------------------------===//
-// ForeachOp
+// CastOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
diff --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
index 0f3a9fc0d2a3..3b480d70f096 100644
--- a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
@@ -21,7 +21,7 @@ func.func @interchange_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+ transform.structured.interchange %0 iterator_interchange = [1, 0]
}
// -----
@@ -36,5 +36,5 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-error @below {{transform applied to the wrong op kind}}
- transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+ transform.structured.interchange %0 iterator_interchange = [1, 0]
}
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index 01bb8e8dcd4d..e21b21a8fa40 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -2,8 +2,8 @@
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
- // expected-error at below {{expects iterator_interchange to be a permutation, found [1, 1]}}
- transform.structured.interchange %arg0 {iterator_interchange = [1, 1]}
+ // expected-error at below {{'transform.structured.interchange' op expects iterator_interchange to be a permutation, found 1, 1}}
+ transform.structured.interchange %arg0 iterator_interchange = [1, 1]
}
// -----
@@ -37,3 +37,11 @@ transform.sequence failures(propagate) {
// expected-error at below {{expects transpose_paddings to be a permutation, found [1, 1]}}
transform.structured.pad %arg0 {transpose_paddings=[[1, 1]]}
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+ // expected-error at below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}}
+ transform.structured.interchange %arg0 iterator_interchange = [-3, 1]
+}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 482cbc786d48..65ff4d62809a 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]}
+ transform.structured.interchange %0 iterator_interchange = [1, 2, 0]
}
// CHECK-LABEL: func @permute_generic
More information about the Mlir-commits
mailing list