[Mlir-commits] [mlir] 2e5fe72 - [MLIR][Linalg] Use `DenseI64ArrayAttr` in `InterchangeOp` (NFC)

Lorenzo Chelini llvmlistbot at llvm.org
Fri Dec 16 07:37:38 PST 2022


Author: Lorenzo Chelini
Date: 2022-12-16T16:37:33+01:00
New Revision: 2e5fe721724446265d1ea48267b6a34d33fca14b

URL: https://github.com/llvm/llvm-project/commit/2e5fe721724446265d1ea48267b6a34d33fca14b
DIFF: https://github.com/llvm/llvm-project/commit/2e5fe721724446265d1ea48267b6a34d33fca14b.diff

LOG: [MLIR][Linalg] Use `DenseI64ArrayAttr` in `InterchangeOp` (NFC)

Use op separator to improve code navigation.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D139917

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-interchange.mlir
    mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1cac6b83a1e9..1cb321d76d03 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -19,6 +19,10 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/RegionKindInterface.td"
 
+//===----------------------------------------------------------------------===//
+// DecomposeOp
+//===----------------------------------------------------------------------===//
+
 def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -48,6 +52,10 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// FuseOp
+//===----------------------------------------------------------------------===//
+
 def FuseOp : Op<Transform_Dialect, "structured.fuse",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -67,6 +75,10 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// FuseIntoContainingOp
+//===----------------------------------------------------------------------===//
+
 def FuseIntoContainingOp :
     Op<Transform_Dialect, "structured.fuse_into_containing_op",
       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -120,6 +132,10 @@ def FuseIntoContainingOp :
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// GeneralizeOp
+//===----------------------------------------------------------------------===//
+
 def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -149,6 +165,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// InterchangeOp
+//===----------------------------------------------------------------------===//
+
 def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
     TransformOpInterface, TransformEachOpTrait]> {
@@ -169,10 +189,14 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
 
   let arguments =
     (ins PDL_Operation:$target,
-         DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
+         ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">,
+                      [DenseArrayNonNegative<DenseI64ArrayAttr>]>:$iterator_interchange);
   let results = (outs PDL_Operation:$transformed);
 
-  let assemblyFormat = "$target attr-dict";
+  let assemblyFormat = [{ 
+    $target 
+    (`iterator_interchange` `=` $iterator_interchange^)? attr-dict
+  }];
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
@@ -183,6 +207,10 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MatchOp
+//===----------------------------------------------------------------------===//
+
 def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match",
     [
       I32EnumAttrCase<"LinalgOp", 0>,
@@ -245,6 +273,10 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MultiTileSizesOp
+//===----------------------------------------------------------------------===//
+
 def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -309,6 +341,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PadOp
+//===----------------------------------------------------------------------===//
+
 def PadOp : Op<Transform_Dialect, "structured.pad",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -349,6 +385,10 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PromoteOp
+//===----------------------------------------------------------------------===//
+
 def PromoteOp : Op<Transform_Dialect, "structured.promote",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
     TransformOpInterface, TransformEachOpTrait]> {
@@ -388,6 +428,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ReplaceOp
+//===----------------------------------------------------------------------===//
+
 def ReplaceOp : Op<Transform_Dialect, "structured.replace",
     [IsolatedFromAbove, DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>] # GraphRegionNoTerminator.traits> {
@@ -410,6 +454,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ScalarizeOp
+//===----------------------------------------------------------------------===//
+
 def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -449,6 +497,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SplitOp
+//===----------------------------------------------------------------------===//
+
 def SplitOp : Op<Transform_Dialect, "structured.split",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -481,6 +533,10 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SplitReductionOp
+//===----------------------------------------------------------------------===//
+
 def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
         TransformEachOpTrait, TransformOpInterface]> {
@@ -649,6 +705,10 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileReductionUsingScfOp
+//===----------------------------------------------------------------------===//
+
 def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
         TransformEachOpTrait, TransformOpInterface]> {
@@ -748,6 +808,10 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileReductionUsingForeachThreadOp
+//===----------------------------------------------------------------------===//
+
 def TileReductionUsingForeachThreadOp :
   Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
@@ -853,6 +917,10 @@ def TileReductionUsingForeachThreadOp :
 
 }
 
+//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
 def TileOp : Op<Transform_Dialect, "structured.tile",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -910,6 +978,10 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileToForeachThreadOp
+//===----------------------------------------------------------------------===//
+
 def TileToForeachThreadOp :
     Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
       [AttrSizedOperandSegments,
@@ -1023,6 +1095,10 @@ def TileToForeachThreadOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileToScfForOp
+//===----------------------------------------------------------------------===//
+
 def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -1080,6 +1156,10 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorizeOp
+//===----------------------------------------------------------------------===//
+
 def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformEachOpTrait, TransformOpInterface]> {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c8995e609e77..3138268e57b1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -34,16 +34,6 @@ using namespace mlir::transform;
 
 #define DEBUG_TYPE "linalg-transforms"
 
-/// Extracts a vector of unsigned from an array attribute. Asserts if the
-/// attribute contains values other than intergers. May truncate.
-static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
-  SmallVector<unsigned> result;
-  result.reserve(attr.size());
-  for (APInt value : attr.getAsValueRange<IntegerAttr>())
-    result.push_back(value.getZExtValue());
-  return result;
-}
-
 /// Attempts to apply the pattern specified as template argument to the given
 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
 /// function that returns the "main" result or failure. Returns failure if the
@@ -604,8 +594,7 @@ DiagnosedSilenceableFailure
 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
                                      SmallVectorImpl<Operation *> &results,
                                      transform::TransformState &state) {
-  SmallVector<unsigned> interchangeVector =
-      extractUIntArray(getIteratorInterchange());
+  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
   // Exit early if no transformation is needed.
   if (interchangeVector.empty()) {
     results.push_back(target);
@@ -613,7 +602,9 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
   }
   TrivialPatternRewriter rewriter(target->getContext());
   FailureOr<GenericOp> res =
-      interchangeGenericOp(rewriter, target, interchangeVector);
+      interchangeGenericOp(rewriter, target,
+                           SmallVector<unsigned>(interchangeVector.begin(),
+                                                 interchangeVector.end()));
   if (failed(res))
     return DiagnosedSilenceableFailure::definiteFailure();
   results.push_back(res->getOperation());
@@ -621,9 +612,8 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
 }
 
 LogicalResult transform::InterchangeOp::verify() {
-  SmallVector<unsigned> permutation =
-      extractUIntArray(getIteratorInterchange());
-  auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
+  ArrayRef<int64_t> permutation = getIteratorInterchange();
+  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
   if (!std::is_permutation(sequence.begin(), sequence.end(),
                            permutation.begin(), permutation.end())) {
     return emitOpError()

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index e3607b2f0b96..402e80b0af6b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -257,7 +257,7 @@ LogicalResult transform::AlternativesOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// ForeachOp
+// CastOp
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure

diff  --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
index 0f3a9fc0d2a3..3b480d70f096 100644
--- a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
@@ -21,7 +21,7 @@ func.func @interchange_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+  transform.structured.interchange %0 iterator_interchange = [1, 0] 
 }
 
 // -----
@@ -36,5 +36,5 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
   // expected-error @below {{transform applied to the wrong op kind}}
-  transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+  transform.structured.interchange %0 iterator_interchange = [1, 0]
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index 01bb8e8dcd4d..e21b21a8fa40 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -2,8 +2,8 @@
 
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
-  // expected-error at below {{expects iterator_interchange to be a permutation, found [1, 1]}}
-  transform.structured.interchange %arg0 {iterator_interchange = [1, 1]}
+  // expected-error at below {{'transform.structured.interchange' op expects iterator_interchange to be a permutation, found 1, 1}}
+  transform.structured.interchange %arg0 iterator_interchange = [1, 1] 
 }
 
 // -----
@@ -37,3 +37,11 @@ transform.sequence failures(propagate) {
   // expected-error at below {{expects transpose_paddings to be a permutation, found [1, 1]}}
   transform.structured.pad %arg0 {transpose_paddings=[[1, 1]]}
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  // expected-error at below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}}
+  transform.structured.interchange %arg0 iterator_interchange = [-3, 1]
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 482cbc786d48..65ff4d62809a 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]}
+  transform.structured.interchange %0 iterator_interchange = [1, 2, 0]
 }
 
 // CHECK-LABEL:  func @permute_generic


        


More information about the Mlir-commits mailing list