[Mlir-commits] [mlir] 0d65000 - [MLIR] Add llvm.mlir.cast op for semantic preserving cast between dialect types.

Tim Shen llvmlistbot at llvm.org
Fri Feb 28 12:21:15 PST 2020


Author: Tim Shen
Date: 2020-02-28T12:20:23-08:00
New Revision: 0d65000e11777b8d2d6aa9f135753209593f2f00

URL: https://github.com/llvm/llvm-project/commit/0d65000e11777b8d2d6aa9f135753209593f2f00
DIFF: https://github.com/llvm/llvm-project/commit/0d65000e11777b8d2d6aa9f135753209593f2f00.diff

LOG: [MLIR] Add llvm.mlir.cast op for semantic preserving cast between dialect types.

Summary: See discussion here: https://llvm.discourse.group/t/rfc-dialect-type-cast-op/538/11

Reviewers: ftynse

Subscribers: bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Differential Revision: https://reviews.llvm.org/D75141

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/Conversion/StandardToLLVM/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index a4d74bf70cbe..601bfbf68926 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -686,6 +686,25 @@ def LLVM_ConstantOp
   let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
 }
 
+def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>,
+                         Results<(outs AnyType:$res)>,
+                         Arguments<(ins AnyType:$in)> {
+  let summary = "Type cast between LLVM dialect and Standard.";
+  let description = [{
+    llvm.mlir.cast op casts between Standard and LLVM dialects. It only changes
+    the dialect, but does not change compile-time or runtime semantics.
+
+    Notice that index type is not supported, as it's Standard-specific.
+
+    Example:
+      llvm.mlir.cast %v : f16 to llvm.half
+      llvm.mlir.cast %v : llvm.float to f32
+      llvm.mlir.cast %v : !llvm<"<2 x float>"> to vector<2xf32>
+  }];
+  let assemblyFormat = "$in attr-dict `:` type($in) `to` type($res)";
+  let verifier = "return ::verify(*this);";
+}
+
 // Operations that correspond to LLVM intrinsics. With MLIR operation set being
 // extendable, there is no reason to introduce a hard boundary between "core"
 // operations and intrinsics. However, we systematically prefix them with

diff  --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 72985ffc639b..96d7e82d6a32 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1807,6 +1807,24 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
   }
 };
 
+struct DialectCastOpLowering
+    : public LLVMLegalizationPattern<LLVM::DialectCastOp> {
+  using LLVMLegalizationPattern<LLVM::DialectCastOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto castOp = cast<LLVM::DialectCastOp>(op);
+    OperandAdaptor<LLVM::DialectCastOp> transformed(operands);
+    if (transformed.in().getType() !=
+        typeConverter.convertType(castOp.getType())) {
+      return matchFailure();
+    }
+    rewriter.replaceOp(op, transformed.in());
+    return matchSuccess();
+  }
+};
+
 // A `dim` is converted to a constant for static sizes and to an access to the
 // size stored in the memref descriptor for dynamic sizes.
 struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
@@ -2772,6 +2790,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       CopySignOpLowering,
       CosOpLowering,
       ConstLLVMOpLowering,
+      DialectCastOpLowering,
       DivFOpLowering,
       ExpOpLowering,
       LogOpLowering,
@@ -2988,6 +3007,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
 mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
     : ConversionTarget(ctx) {
   this->addLegalDialect<LLVM::LLVMDialect>();
+  this->addIllegalOp<LLVM::DialectCastOp>();
 }
 
 std::unique_ptr<OpPassBase<ModuleOp>>

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c1773bbd8120..a8f1dd56e02e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -890,6 +890,45 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
 }
 
+//===----------------------------------------------------------------------===//
+// Verifier for LLVM::DialectCastOp.
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(DialectCastOp op) {
+  auto verifyMLIRCastType = [&op](Type type) -> LogicalResult {
+    if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
+      if (llvmType.isVectorTy())
+        llvmType = llvmType.getVectorElementType();
+      if (llvmType.isIntegerTy() || llvmType.isHalfTy() ||
+          llvmType.isFloatTy() || llvmType.isDoubleTy()) {
+        return success();
+      }
+      return op.emitOpError("type must be non-index integer types, float "
+                            "types, or vector of mentioned types.");
+    }
+    if (auto vectorType = type.dyn_cast<VectorType>()) {
+      if (vectorType.getShape().size() > 1)
+        return op.emitOpError("only 1-d vector is allowed");
+      type = vectorType.getElementType();
+    }
+    if (type.isSignlessIntOrFloat())
+      return success();
+    // Note that memrefs are not supported. We currently don't have a use case
+    // for it, but even if we do, there are challenges:
+    // * if we allow memrefs to cast from/to memref descriptors, then the
+    // semantics of the cast op depends on the implementation detail of the
+    // descriptor.
+    // * if we allow memrefs to cast from/to bare pointers, some users might
+    // alternatively want metadata that only present in the descriptor.
+    //
+    // TODO(timshen): re-evaluate the memref cast design when it's needed.
+    return op.emitOpError("type must be non-index integer types, float types, "
+                          "or vector of mentioned types.");
+  };
+  return failure(failed(verifyMLIRCastType(op.in().getType())) ||
+                 failed(verifyMLIRCastType(op.getType())));
+}
+
 // Parses one of the keywords provided in the list `keywords` and returns the
 // position of the parsed keyword in the list. If none of the keywords from the
 // list is parsed, returns -1.

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 27c249372b15..68aeef8a2e1f 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -910,3 +910,39 @@ func @assume_alignment(%0 : memref<4x4xf16>) {
   assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @mlir_cast_to_llvm
+// CHECK-SAME: %[[ARG:.*]]:
+func @mlir_cast_to_llvm(%0 : vector<2xf16>) -> !llvm<"<2 x half>"> {
+  %1 = llvm.mlir.cast %0 : vector<2xf16> to !llvm<"<2 x half>">
+  // CHECK-NEXT: llvm.return %[[ARG]]
+  return %1 : !llvm<"<2 x half>">
+}
+
+// CHECK-LABEL: func @mlir_cast_from_llvm
+// CHECK-SAME: %[[ARG:.*]]:
+func @mlir_cast_from_llvm(%0 : !llvm<"<2 x half>">) -> vector<2xf16> {
+  %1 = llvm.mlir.cast %0 : !llvm<"<2 x half>"> to vector<2xf16>
+  // CHECK-NEXT: llvm.return %[[ARG]]
+  return %1 : vector<2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @mlir_cast_to_llvm
+// CHECK-SAME: %[[ARG:.*]]:
+func @mlir_cast_to_llvm(%0 : f16) -> !llvm.half {
+  %1 = llvm.mlir.cast %0 : f16 to !llvm.half
+  // CHECK-NEXT: llvm.return %[[ARG]]
+  return %1 : !llvm.half
+}
+
+// CHECK-LABEL: func @mlir_cast_from_llvm
+// CHECK-SAME: %[[ARG:.*]]:
+func @mlir_cast_from_llvm(%0 : !llvm.half) -> f16 {
+  %1 = llvm.mlir.cast %0 : !llvm.half to f16
+  // CHECK-NEXT: llvm.return %[[ARG]]
+  return %1 : f16
+}

diff  --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
index e0b1889e95c8..bb9c2728dcb8 100644
--- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
@@ -1,13 +1,44 @@
-// RUN: mlir-opt %s -verify-diagnostics -split-input-file
+// RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file
 
 #map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
 
 func @invalid_memref_cast(%arg0: memref<?x?xf64>) {
   %c1 = constant 1 : index
   %c0 = constant 0 : index
-  // expected-error at +1: 'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values,
+  // expected-error at +1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}}
   %5 = memref_cast %arg0 : memref<?x?xf64> to memref<?x?xf64, #map1>
   %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
   return
 }
 
+// -----
+
+func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 {
+  // expected-error at +1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}}
+  %1 = llvm.mlir.cast %0 : index to !llvm.i64
+  return %1 : !llvm.i64
+}
+
+// -----
+
+func @mlir_cast_from_llvm(%0 : !llvm.i64) -> index {
+  // expected-error at +1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}}
+  %1 = llvm.mlir.cast %0 : !llvm.i64 to index
+  return %1 : index
+}
+
+// -----
+
+func @mlir_cast_to_llvm_int(%0 : i32) -> !llvm.i64 {
+  // expected-error at +1 {{failed to legalize operation 'llvm.mlir.cast' that was explicitly marked illegal}}
+  %1 = llvm.mlir.cast %0 : i32 to !llvm.i64
+  return %1 : !llvm.i64
+}
+
+// -----
+
+func @mlir_cast_to_llvm_vec(%0 : vector<1x1xf32>) -> !llvm<"<1 x float>"> {
+  // expected-error at +1 {{'llvm.mlir.cast' op only 1-d vector is allowed}}
+  %1 = llvm.mlir.cast %0 : vector<1x1xf32> to !llvm<"<1 x float>">
+  return %1 : !llvm<"<1 x float>">
+}


        


More information about the Mlir-commits mailing list