[Mlir-commits] [mlir] 5883c4b - [MLIR] Fix standard -> LLVM conversion to fail for unsupported memref element type.

Rahul Joshi llvmlistbot at llvm.org
Thu Nov 12 17:06:36 PST 2020


Author: Rahul Joshi
Date: 2020-11-12T17:06:05-08:00
New Revision: 5883c4b4705e7f93e71d58c893f4bcfa4b52e0ad

URL: https://github.com/llvm/llvm-project/commit/5883c4b4705e7f93e71d58c893f4bcfa4b52e0ad
DIFF: https://github.com/llvm/llvm-project/commit/5883c4b4705e7f93e71d58c893f4bcfa4b52e0ad.diff

LOG: [MLIR] Fix standard -> LLVM conversion to fail for unsupported memref element type.

- Move isSupportedMemRefType() to ConvertToLLVMPatterns and check if the
  memref element type is supported there.

Differential Revision: https://reviews.llvm.org/D91374

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index e7aa9d5ae516..04f884987fb7 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -522,6 +522,9 @@ class ConvertToLLVMPattern : public ConversionPattern {
                              ArrayRef<int64_t> strides, int64_t offset,
                              ConversionPatternRewriter &rewriter) const;
 
+  /// Returns if the givem memref type is supported.
+  bool isSupportedMemRefType(MemRefType type) const;
+
   Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
                    ValueRange indices,
                    ConversionPatternRewriter &rewriter) const;

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 57c26c4e83c9..6807f8311e7c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1094,11 +1094,20 @@ Value ConvertToLLVMPattern::getDataPtr(
                               offset, rewriter);
 }
 
+// Check if the MemRefType `type` is supported by the lowering. We currently
+// only support memrefs with identity maps.
+bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
+  if (!typeConverter.convertType(type.getElementType()))
+    return false;
+  return type.getAffineMaps().empty() ||
+         llvm::all_of(type.getAffineMaps(),
+                      [](AffineMap map) { return map.isIdentity(); });
+}
+
 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
   auto elementType = type.getElementType();
-  auto structElementType = typeConverter.convertType(elementType);
-  return structElementType.cast<LLVM::LLVMType>().getPointerTo(
-      type.getMemorySpace());
+  auto structElementType = unwrap(typeConverter.convertType(elementType));
+  return structElementType.getPointerTo(type.getMemorySpace());
 }
 
 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
@@ -1912,14 +1921,6 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
   }
 };
 
-// Check if the MemRefType `type` is supported by the lowering. We currently
-// only support memrefs with identity maps.
-static bool isSupportedMemRefType(MemRefType type) {
-  return type.getAffineMaps().empty() ||
-         llvm::all_of(type.getAffineMaps(),
-                      [](AffineMap map) { return map.isIdentity(); });
-}
-
 /// Lowering for AllocOp and AllocaOp.
 struct AllocLikeOpLowering : public ConvertToLLVMPattern {
   using ConvertToLLVMPattern::createIndexConstant;
@@ -3070,6 +3071,7 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
 template <typename Derived>
 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
+  using ConvertOpToLLVMPattern<Derived>::isSupportedMemRefType;
   using Base = LoadStoreOpLowering<Derived>;
 
   LogicalResult match(Operation *op) const override {


        


More information about the Mlir-commits mailing list