[Mlir-commits] [mlir] a0e0d30 - [mlir][Linalg] Print both types for linalg.transpose

Benjamin Kramer llvmlistbot at llvm.org
Fri Sep 11 02:18:38 PDT 2020


Author: Benjamin Kramer
Date: 2020-09-11T11:16:51+02:00
New Revision: a0e0d30a29841fe6cc854f3949f12bb523814d7a

URL: https://github.com/llvm/llvm-project/commit/a0e0d30a29841fe6cc854f3949f12bb523814d7a
DIFF: https://github.com/llvm/llvm-project/commit/a0e0d30a29841fe6cc854f3949f12bb523814d7a.diff

LOG: [mlir][Linalg] Print both types for linalg.transpose

Previously only the input type was printed, and the parser applied it to
both input and output, creating an invalid transpose. Print and parse
both types, and verify that they match.

Differential Revision: https://reviews.llvm.org/D87462

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/llvm.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1366e920039b..a7855e6327b2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -300,7 +300,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
     Example:
 
     ```mlir
-    %1 = linalg.transpose %0 (i, j) -> (j, i) : memref<?x?xf32, stride_spec>
+    %1 = linalg.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, stride_spec>
     ```
   }];
 
@@ -308,13 +308,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
     "OpBuilder &b, OperationState &result, Value view, "
     "AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
 
-  let verifier = [{
-    if (!permutation().isPermutation())
-      return emitOpError("expected a permutation map");
-    if (permutation().getNumDims() != getShapedType().getRank())
-      return emitOpError("expected a permutation map of same rank as the view");
-    return success();
-  }];
+  let verifier = [{ return ::verify(*this); }];
 
   let extraClassDeclaration = [{
     static StringRef getPermutationAttrName() { return "permutation"; }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fcead984dfe5..77eb64489477 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -846,13 +846,9 @@ Value SliceOp::getViewSource() { return view(); }
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
-void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result,
-                                      Value view, AffineMapAttr permutation,
-                                      ArrayRef<NamedAttribute> attrs) {
-  auto permutationMap = permutation.getValue();
-  assert(permutationMap);
 
-  auto memRefType = view.getType().cast<MemRefType>();
+static MemRefType inferTransposeResultType(MemRefType memRefType,
+                                           AffineMap permutationMap) {
   auto rank = memRefType.getRank();
   auto originalSizes = memRefType.getShape();
   // Compute permuted sizes.
@@ -867,11 +863,21 @@ void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result,
   auto res = getStridesAndOffset(memRefType, strides, offset);
   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
   (void)res;
-  auto map = makeStridedLinearLayoutMap(strides, offset, b.getContext());
+  auto map =
+      makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
   map = permutationMap ? map.compose(permutationMap) : map;
+  return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
+}
+
+void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result,
+                                      Value view, AffineMapAttr permutation,
+                                      ArrayRef<NamedAttribute> attrs) {
+  auto permutationMap = permutation.getValue();
+  assert(permutationMap);
+
+  auto memRefType = view.getType().cast<MemRefType>();
   // Compute result type.
-  MemRefType resultType =
-      MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
+  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
 
   build(b, result, resultType, view, attrs);
   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
@@ -881,19 +887,20 @@ static void print(OpAsmPrinter &p, TransposeOp op) {
   p << op.getOperationName() << " " << op.view() << " " << op.permutation();
   p.printOptionalAttrDict(op.getAttrs(),
                           {TransposeOp::getPermutationAttrName()});
-  p << " : " << op.view().getType();
+  p << " : " << op.view().getType() << " to " << op.getType();
 }
 
 static ParseResult parseTransposeOp(OpAsmParser &parser,
                                     OperationState &result) {
   OpAsmParser::OperandType view;
   AffineMap permutation;
-  MemRefType type;
+  MemRefType srcType, dstType;
   if (parser.parseOperand(view) || parser.parseAffineMap(permutation) ||
       parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(view, type, result.operands) ||
-      parser.addTypeToList(type, result.types))
+      parser.parseColonType(srcType) ||
+      parser.resolveOperand(view, srcType, result.operands) ||
+      parser.parseKeywordType("to", dstType) ||
+      parser.addTypeToList(dstType, result.types))
     return failure();
 
   result.addAttribute(TransposeOp::getPermutationAttrName(),
@@ -901,6 +908,21 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
   return success();
 }
 
+static LogicalResult verify(TransposeOp op) {
+  if (!op.permutation().isPermutation())
+    return op.emitOpError("expected a permutation map");
+  if (op.permutation().getNumDims() != op.getShapedType().getRank())
+    return op.emitOpError(
+        "expected a permutation map of same rank as the view");
+
+  auto srcType = op.view().getType().cast<MemRefType>();
+  auto dstType = op.getType().cast<MemRefType>();
+  if (dstType != inferTransposeResultType(srcType, op.permutation()))
+    return op.emitOpError("output type ")
+           << dstType << " does not match transposed input type " << srcType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index ca59ecd387ec..c631c47099b0 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -35,14 +35,21 @@ func @store_number_of_indices(%v : memref<f32>) {
 
 func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map}}
-  linalg.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  linalg.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----
 
 func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map of same rank as the view}}
-  linalg.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  linalg.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+}
+
+// -----
+
+func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
+  // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+  linalg.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 02693e5d1be4..c8031824d630 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -70,7 +70,7 @@ func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?,
 //       CHECK:   llvm.insertvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<double>, ptr<double>, i64, array<1 x i64>, array<1 x i64>)>
 
 func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
   return
 }
 // CHECK-LABEL: func @transpose

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 269664324697..404c978fa61b 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -123,14 +123,15 @@ func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
 // -----
 
 // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
+// CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
 
 func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
   return
 }
 // CHECK-LABEL: func @transpose
 //       CHECK:   linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
-//  CHECK-SAME:      memref<?x?x?xf32, #[[$strided3D]]>
+//  CHECK-SAME:      memref<?x?x?xf32, #[[$strided3D]]> to memref<?x?x?xf32, #[[$strided3DT]]>
 
 // -----
 


        


More information about the Mlir-commits mailing list