[Mlir-commits] [mlir] 6e2b267 - Promote transpose from linalg to standard dialect

Benjamin Kramer llvmlistbot at llvm.org
Mon Oct 5 02:13:31 PDT 2020


Author: Benjamin Kramer
Date: 2020-10-05T10:58:20+02:00
New Revision: 6e2b267d1c85ce0de0e91eb446831607896a0f2b

URL: https://github.com/llvm/llvm-project/commit/6e2b267d1c85ce0de0e91eb446831607896a0f2b
DIFF: https://github.com/llvm/llvm-project/commit/6e2b267d1c85ce0de0e91eb446831607896a0f2b.diff

LOG: Promote transpose from linalg to standard dialect

While affine maps are part of the builtin memref type, there is very
limited support for manipulating them in the standard dialect. Add
transpose to the set of ops to complement the existing view/subview ops.
This is a metadata transformation that encodes the transpose into the
strides of a memref.

I'm planning to use this when lowering operations on strided memrefs,
using the transpose to remove the stride without adding a dependency on
linalg dialect.

Differential Revision: https://reviews.llvm.org/D88651

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/llvm.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/standard.mlir
    mlir/test/Dialect/Standard/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index 140197b16815..c6681a93e53e 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -554,9 +554,9 @@ are:
 
     * `std.view`,
     * `std.subview`,
+    * `std.transpose`.
     * `linalg.range`,
     * `linalg.slice`,
-    * `linalg.transpose`.
     * `linalg.reshape`,
 
 Future ops are added on a per-need basis but should include:

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index d74e59145705..5b29154e0a03 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -287,36 +287,6 @@ def Linalg_SliceOp : Linalg_Op<"slice", [
   let hasFolder = 1;
 }
 
-def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
-    Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>,
-    Results<(outs AnyStridedMemRef)> {
-  let summary = "`transpose` produces a new strided memref (metadata-only)";
-  let description = [{
-    The `linalg.transpose` op produces a strided memref whose sizes and strides
-    are a permutation of the original `view`. This is a pure metadata
-    transformation.
-
-    Example:
-
-    ```mlir
-    %1 = linalg.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, stride_spec>
-    ```
-  }];
-
-  let builders = [OpBuilder<
-    "OpBuilder &b, OperationState &result, Value view, "
-    "AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
-
-  let verifier = [{ return ::verify(*this); }];
-
-  let extraClassDeclaration = [{
-    static StringRef getPermutationAttrName() { return "permutation"; }
-    ShapedType getShapedType() { return view().getType().cast<ShapedType>(); }
-  }];
-
-  let hasFolder = 1;
-}
-
 def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, Terminator]>,
     Arguments<(ins Variadic<AnyType>:$values)> {
   let summary = "Linalg yield operation";

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index c62be7571aad..4a014cb7060c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3428,6 +3428,38 @@ def TensorStoreOp : Std_Op<"tensor_store",
   let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+
+def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
+    Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
+    Results<(outs AnyStridedMemRef)> {
+  let summary = "`transpose` produces a new strided memref (metadata-only)";
+  let description = [{
+    The `transpose` op produces a strided memref whose sizes and strides
+    are a permutation of the original `in` memref. This is purely a metadata
+    transformation.
+
+    Example:
+
+    ```mlir
+    %1 = transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+    ```
+  }];
+
+  let builders = [OpBuilder<
+    "OpBuilder &b, OperationState &result, Value in, "
+    "AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
+
+  let extraClassDeclaration = [{
+    static StringRef getPermutationAttrName() { return "permutation"; }
+    ShapedType getShapedType() { return in().getType().cast<ShapedType>(); }
+  }];
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TruncateIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index f38eabb9465d..4f83297ee031 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -284,57 +284,6 @@ class SliceOpConversion : public ConvertToLLVMPattern {
   }
 };
 
-/// Conversion pattern that transforms a linalg.transpose op into:
-///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
-///   2. A load of the ViewDescriptor from the pointer allocated in 1.
-///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
-///      and stride. Size and stride are permutations of the original values.
-///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
-/// The linalg.transpose op is replaced by the alloca'ed pointer.
-class TransposeOpConversion : public ConvertToLLVMPattern {
-public:
-  explicit TransposeOpConversion(MLIRContext *context,
-                                 LLVMTypeConverter &lowering_)
-      : ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
-                             lowering_) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Initialize the common boilerplate and alloca at the top of the FuncOp.
-    edsc::ScopedContext context(rewriter, op->getLoc());
-    TransposeOpAdaptor adaptor(operands);
-    BaseViewConversionHelper baseDesc(adaptor.view());
-
-    auto transposeOp = cast<TransposeOp>(op);
-    // No permutation, early exit.
-    if (transposeOp.permutation().isIdentity())
-      return rewriter.replaceOp(op, {baseDesc}), success();
-
-    BaseViewConversionHelper desc(
-        typeConverter.convertType(transposeOp.getShapedType()));
-
-    // Copy the base and aligned pointers from the old descriptor to the new
-    // one.
-    desc.setAllocatedPtr(baseDesc.allocatedPtr());
-    desc.setAlignedPtr(baseDesc.alignedPtr());
-
-    // Copy the offset pointer from the old descriptor to the new one.
-    desc.setOffset(baseDesc.offset());
-
-    // Iterate over the dimensions and apply size/stride permutation.
-    for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
-      int sourcePos = en.index();
-      int targetPos = en.value().cast<AffineDimExpr>().getPosition();
-      desc.setSize(targetPos, baseDesc.size(sourcePos));
-      desc.setStride(targetPos, baseDesc.stride(sourcePos));
-    }
-
-    rewriter.replaceOp(op, {desc});
-    return success();
-  }
-};
-
 // YieldOp produces and LLVM::ReturnOp.
 class YieldOpConversion : public ConvertToLLVMPattern {
 public:
@@ -356,7 +305,7 @@ void mlir::populateLinalgToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
     MLIRContext *ctx) {
   patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
-                  TransposeOpConversion, YieldOpConversion>(ctx, converter);
+                  YieldOpConversion>(ctx, converter);
 
   // Populate the type conversions for the linalg types.
   converter.addConversion(

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 29b5f9cc996e..ffb56138a795 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -206,12 +206,12 @@ class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
     // If either inputPerm or outputPerm are non-identities, insert transposes.
     auto inputPerm = op.inputPermutation();
     if (inputPerm.hasValue() && !inputPerm->isIdentity())
-      in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
-                                                AffineMapAttr::get(*inputPerm));
+      in = rewriter.create<TransposeOp>(op.getLoc(), in,
+                                        AffineMapAttr::get(*inputPerm));
     auto outputPerm = op.outputPermutation();
     if (outputPerm.hasValue() && !outputPerm->isIdentity())
-      out = rewriter.create<linalg::TransposeOp>(
-          op.getLoc(), out, AffineMapAttr::get(*outputPerm));
+      out = rewriter.create<TransposeOp>(op.getLoc(), out,
+                                         AffineMapAttr::get(*outputPerm));
 
     // If nothing was transposed, fail and let the conversion kick in.
     if (in == op.input() && out == op.output())
@@ -270,7 +270,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
   target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
-  target.addLegalOp<linalg::TransposeOp, linalg::ReshapeOp, linalg::RangeOp>();
+  target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
   OwningRewritePatternList patterns;
   populateLinalgToStandardConversionPatterns(patterns, &getContext());
   if (failed(applyFullConversion(module, target, patterns)))

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 37d0c940aa26..731eab0c28df 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -3011,6 +3011,57 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
   }
 };
 
+/// Conversion pattern that transforms a transpose op into:
+///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
+///   2. A load of the ViewDescriptor from the pointer allocated in 1.
+///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
+///      and stride. Size and stride are permutations of the original values.
+///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
+/// The transpose op is replaced by the alloca'ed pointer.
+class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
+public:
+  using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    TransposeOpAdaptor adaptor(operands);
+    MemRefDescriptor viewMemRef(adaptor.in());
+
+    auto transposeOp = cast<TransposeOp>(op);
+    // No permutation, early exit.
+    if (transposeOp.permutation().isIdentity())
+      return rewriter.replaceOp(op, {viewMemRef}), success();
+
+    auto targetMemRef = MemRefDescriptor::undef(
+        rewriter, loc, typeConverter.convertType(transposeOp.getShapedType()));
+
+    // Copy the base and aligned pointers from the old descriptor to the new
+    // one.
+    targetMemRef.setAllocatedPtr(rewriter, loc,
+                                 viewMemRef.allocatedPtr(rewriter, loc));
+    targetMemRef.setAlignedPtr(rewriter, loc,
+                               viewMemRef.alignedPtr(rewriter, loc));
+
+    // Copy the offset pointer from the old descriptor to the new one.
+    targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
+
+    // Iterate over the dimensions and apply size/stride permutation.
+    for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
+      int sourcePos = en.index();
+      int targetPos = en.value().cast<AffineDimExpr>().getPosition();
+      targetMemRef.setSize(rewriter, loc, targetPos,
+                           viewMemRef.size(rewriter, loc, sourcePos));
+      targetMemRef.setStride(rewriter, loc, targetPos,
+                             viewMemRef.stride(rewriter, loc, sourcePos));
+    }
+
+    rewriter.replaceOp(op, {targetMemRef});
+    return success();
+  }
+};
+
 /// Conversion pattern that transforms an op into:
 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
@@ -3425,6 +3476,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
       RankOpLowering,
       StoreOpLowering,
       SubViewOpLowering,
+      TransposeOpLowering,
       ViewOpLowering,
       AllocOpLowering>(converter);
   // clang-format on

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ca2260836d9f..e9cdb3391f4a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -973,86 +973,6 @@ static LogicalResult verify(SliceOp op) {
 
 Value SliceOp::getViewSource() { return view(); }
 
-//===----------------------------------------------------------------------===//
-// TransposeOp
-//===----------------------------------------------------------------------===//
-
-static MemRefType inferTransposeResultType(MemRefType memRefType,
-                                           AffineMap permutationMap) {
-  auto rank = memRefType.getRank();
-  auto originalSizes = memRefType.getShape();
-  // Compute permuted sizes.
-  SmallVector<int64_t, 4> sizes(rank, 0);
-  for (auto en : llvm::enumerate(permutationMap.getResults()))
-    sizes[en.index()] =
-        originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
-
-  // Compute permuted strides.
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  auto res = getStridesAndOffset(memRefType, strides, offset);
-  assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
-  (void)res;
-  auto map =
-      makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
-  map = permutationMap ? map.compose(permutationMap) : map;
-  return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
-}
-
-void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result,
-                                      Value view, AffineMapAttr permutation,
-                                      ArrayRef<NamedAttribute> attrs) {
-  auto permutationMap = permutation.getValue();
-  assert(permutationMap);
-
-  auto memRefType = view.getType().cast<MemRefType>();
-  // Compute result type.
-  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
-
-  build(b, result, resultType, view, attrs);
-  result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
-}
-
-static void print(OpAsmPrinter &p, TransposeOp op) {
-  p << op.getOperationName() << " " << op.view() << " " << op.permutation();
-  p.printOptionalAttrDict(op.getAttrs(),
-                          {TransposeOp::getPermutationAttrName()});
-  p << " : " << op.view().getType() << " to " << op.getType();
-}
-
-static ParseResult parseTransposeOp(OpAsmParser &parser,
-                                    OperationState &result) {
-  OpAsmParser::OperandType view;
-  AffineMap permutation;
-  MemRefType srcType, dstType;
-  if (parser.parseOperand(view) || parser.parseAffineMap(permutation) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(srcType) ||
-      parser.resolveOperand(view, srcType, result.operands) ||
-      parser.parseKeywordType("to", dstType) ||
-      parser.addTypeToList(dstType, result.types))
-    return failure();
-
-  result.addAttribute(TransposeOp::getPermutationAttrName(),
-                      AffineMapAttr::get(permutation));
-  return success();
-}
-
-static LogicalResult verify(TransposeOp op) {
-  if (!op.permutation().isPermutation())
-    return op.emitOpError("expected a permutation map");
-  if (op.permutation().getNumDims() != op.getShapedType().getRank())
-    return op.emitOpError(
-        "expected a permutation map of same rank as the view");
-
-  auto srcType = op.view().getType().cast<MemRefType>();
-  auto dstType = op.getType().cast<MemRefType>();
-  if (dstType != inferTransposeResultType(srcType, op.permutation()))
-    return op.emitOpError("output type ")
-           << dstType << " does not match transposed input type " << srcType;
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
@@ -1359,11 +1279,6 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
 OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
   return foldReshapeOp(*this, operands);
 }
-OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
-  return {};
-}
 
 //===----------------------------------------------------------------------===//
 // Auto-generated Linalg named ops.

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 09600963be0e..a4d739135aea 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3491,6 +3491,96 @@ static Type getTensorTypeFromMemRefType(Type type) {
   return NoneType::get(type.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+
+/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+static MemRefType inferTransposeResultType(MemRefType memRefType,
+                                           AffineMap permutationMap) {
+  auto rank = memRefType.getRank();
+  auto originalSizes = memRefType.getShape();
+  // Compute permuted sizes.
+  SmallVector<int64_t, 4> sizes(rank, 0);
+  for (auto en : llvm::enumerate(permutationMap.getResults()))
+    sizes[en.index()] =
+        originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
+
+  // Compute permuted strides.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto res = getStridesAndOffset(memRefType, strides, offset);
+  assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
+  (void)res;
+  auto map =
+      makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
+  map = permutationMap ? map.compose(permutationMap) : map;
+  return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
+}
+
+void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
+                        AffineMapAttr permutation,
+                        ArrayRef<NamedAttribute> attrs) {
+  auto permutationMap = permutation.getValue();
+  assert(permutationMap);
+
+  auto memRefType = in.getType().cast<MemRefType>();
+  // Compute result type.
+  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
+
+  build(b, result, resultType, in, attrs);
+  result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
+}
+
+// transpose $in $permutation attr-dict : type($in) `to` type(results)
+static void print(OpAsmPrinter &p, TransposeOp op) {
+  p << "transpose " << op.in() << " " << op.permutation();
+  p.printOptionalAttrDict(op.getAttrs(),
+                          {TransposeOp::getPermutationAttrName()});
+  p << " : " << op.in().getType() << " to " << op.getType();
+}
+
+static ParseResult parseTransposeOp(OpAsmParser &parser,
+                                    OperationState &result) {
+  OpAsmParser::OperandType in;
+  AffineMap permutation;
+  MemRefType srcType, dstType;
+  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(srcType) ||
+      parser.resolveOperand(in, srcType, result.operands) ||
+      parser.parseKeywordType("to", dstType) ||
+      parser.addTypeToList(dstType, result.types))
+    return failure();
+
+  result.addAttribute(TransposeOp::getPermutationAttrName(),
+                      AffineMapAttr::get(permutation));
+  return success();
+}
+
+static LogicalResult verify(TransposeOp op) {
+  if (!op.permutation().isPermutation())
+    return op.emitOpError("expected a permutation map");
+  if (op.permutation().getNumDims() != op.getShapedType().getRank())
+    return op.emitOpError(
+        "expected a permutation map of same rank as the input");
+
+  auto srcType = op.in().getType().cast<MemRefType>();
+  auto dstType = op.getType().cast<MemRefType>();
+  auto transposedType = inferTransposeResultType(srcType, op.permutation());
+  if (dstType != transposedType)
+    return op.emitOpError("output type ")
+           << dstType << " does not match transposed input type " << srcType
+           << ", " << transposedType;
+  return success();
+}
+
+OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TruncateIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 663595ce161c..672ad4058309 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -673,7 +673,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
 
 /// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
 static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
-  auto transposeOp = extractOp.vector().getDefiningOp<TransposeOp>();
+  auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
   if (!transposeOp)
     return failure();
 
@@ -2521,7 +2521,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 // Eliminates transpose operations, which produce values identical to their
 // input values. This happens when the dimensions of the input vector remain in
 // their original order after the transpose operation.
-OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
   SmallVector<int64_t, 4> transp;
   getTransp(transp);
 
@@ -2535,7 +2535,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
   return vector();
 }
 
-static LogicalResult verify(TransposeOp op) {
+static LogicalResult verify(vector::TransposeOp op) {
   VectorType vectorType = op.getVectorType();
   VectorType resultType = op.getResultType();
   int64_t rank = resultType.getRank();
@@ -2563,14 +2563,14 @@ static LogicalResult verify(TransposeOp op) {
 namespace {
 
 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
-class TransposeFolder final : public OpRewritePattern<TransposeOp> {
+class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
 public:
-  using OpRewritePattern<TransposeOp>::OpRewritePattern;
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    // Wrapper around TransposeOp::getTransp() for cleaner code.
-    auto getPermutation = [](TransposeOp transpose) {
+    // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
+    auto getPermutation = [](vector::TransposeOp transpose) {
       SmallVector<int64_t, 4> permutation;
       transpose.getTransp(permutation);
       return permutation;
@@ -2586,15 +2586,15 @@ class TransposeFolder final : public OpRewritePattern<TransposeOp> {
     };
 
     // Return if the input of 'transposeOp' is not defined by another transpose.
-    TransposeOp parentTransposeOp =
-        transposeOp.vector().getDefiningOp<TransposeOp>();
+    vector::TransposeOp parentTransposeOp =
+        transposeOp.vector().getDefiningOp<vector::TransposeOp>();
     if (!parentTransposeOp)
       return failure();
 
     SmallVector<int64_t, 4> permutation = composePermutations(
         getPermutation(parentTransposeOp), getPermutation(transposeOp));
     // Replace 'transposeOp' with a new transpose operation.
-    rewriter.replaceOpWithNewOp<TransposeOp>(
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(
         transposeOp, transposeOp.getResult().getType(),
         parentTransposeOp.vector(),
         vector::getVectorSubscriptAttr(rewriter, permutation));
@@ -2604,12 +2604,12 @@ class TransposeFolder final : public OpRewritePattern<TransposeOp> {
 
 } // end anonymous namespace
 
-void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                              MLIRContext *context) {
+void vector::TransposeOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<TransposeFolder>(context);
 }
 
-void TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
+void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
   populateFromInt64AttrArray(transp(), results);
 }
 

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index c7363085817e..71a35f6ccf0a 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -114,3 +114,20 @@ func @assert_test_function(%arg : i1) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @transpose
+//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:   llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
+func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+  %0 = transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 004bf9260a82..dcfafdc4d27a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -33,27 +33,6 @@ func @store_number_of_indices(%v : memref<f32>) {
 
 // -----
 
-func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
-  // expected-error @+1 {{expected a permutation map}}
-  linalg.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
-}
-
-// -----
-
-func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
-  // expected-error @+1 {{expected a permutation map of same rank as the view}}
-  linalg.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
-}
-
-// -----
-
-func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
-  // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
-  linalg.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
-}
-
-// -----
-
 func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
   // expected-error @+1 {{op expected parent op with LinalgOp interface}}
   linalg.yield %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>

diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index c8031824d630..9303a7aa6b31 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -69,22 +69,6 @@ func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?,
 //       CHECK:   llvm.insertvalue %{{.*}}[3, 0] : !llvm.struct<(ptr<double>, ptr<double>, i64, array<1 x i64>, array<1 x i64>)>
 //       CHECK:   llvm.insertvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<double>, ptr<double>, i64, array<1 x i64>, array<1 x i64>)>
 
-func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
-  return
-}
-// CHECK-LABEL: func @transpose
-//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:   llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:   llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:   llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
-
 func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
   // Reshapes that expand a contiguous tensor with some 1's.
   %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 5960d5525f44..868cabb5eff3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -126,11 +126,11 @@ func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
 // CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
 
 func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
+  %0 = transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
   return
 }
 // CHECK-LABEL: func @transpose
-//       CHECK:   linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
+//       CHECK:   transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
 //  CHECK-SAME:      memref<?x?x?xf32, #[[$strided3D]]> to memref<?x?x?xf32, #[[$strided3DT]]>
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
index 14b4e2a01c30..eee2ca1d1a1c 100644
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ b/mlir/test/Dialect/Linalg/standard.mlir
@@ -55,9 +55,9 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
 // CHECK-LABEL: func @copy_transpose(
 //  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>,
 //  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>) {
-//       CHECK:   %[[t0:.*]] = linalg.transpose %[[arg0]]
+//       CHECK:   %[[t0:.*]] = transpose %[[arg0]]
 //  CHECK-SAME:     (d0, d1, d2) -> (d0, d2, d1) : memref<?x?x?xf32, #[[$map1]]>
-//       CHECK:   %[[t1:.*]] = linalg.transpose %[[arg1]]
+//       CHECK:   %[[t1:.*]] = transpose %[[arg1]]
 //  CHECK-SAME:     (d0, d1, d2) -> (d2, d1, d0) : memref<?x?x?xf32, #[[$map1]]>
 //       CHECK:   %[[o0:.*]] = memref_cast %[[t0]] :
 //  CHECK-SAME:     memref<?x?x?xf32, #[[$map2]]> to memref<?x?x?xf32, #[[$map8]]>

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 7f9c564e74f3..72fe5c227578 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -81,3 +81,24 @@ func @dynamic_tensor_from_elements(%m : index, %n : index)
   } : tensor<?x3x?xf32>
   return %tnsr : tensor<?x3x?xf32>
 }
+
+// -----
+
+func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
+  // expected-error @+1 {{expected a permutation map}}
+  transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+}
+
+// -----
+
+func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
+  // expected-error @+1 {{expected a permutation map of same rank as the input}}
+  transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+}
+
+// -----
+
+func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
+  // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+  transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+}


        


More information about the Mlir-commits mailing list