[Mlir-commits] [mlir] [mlir][memref] Fix type conversion in emulate-wide-int and emulate-narrow-type (PR #112214)
Longsheng Mou
llvmlistbot at llvm.org
Mon Oct 14 16:00:21 PDT 2024
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/112214
>From 7337741acf9d88fe3eec28069d9f51a8e95e6ca6 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 14 Oct 2024 22:13:52 +0800
Subject: [PATCH] [mlir][memref] Fix type conversion in emulate-wide-int and
emulate-narrow-type
This PR follows with #112104, using `nullptr` to indicate that type conversion
failed and no fallback conversion should be attempted.
---
.../Arith/Transforms/EmulateNarrowType.cpp | 4 +--
.../MemRef/Transforms/EmulateNarrowType.cpp | 17 ++++++-----
.../MemRef/Transforms/EmulateWideInt.cpp | 2 +-
.../emulate-narrow-type-unsupported.mlir | 25 ++++++++++++++++
.../Dialect/MemRef/emulate-narrow-type.mlir | 29 ++-----------------
.../MemRef/emulate-wide-int-unsupported.mlir | 28 ++++++++++++++++++
6 files changed, 67 insertions(+), 38 deletions(-)
create mode 100644 mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir
create mode 100644 mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
index 4be0e06fe2a5e5..fddd7c51bfbc87 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
@@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
addConversion([this](FunctionType ty) -> std::optional<Type> {
SmallVector<Type> inputs;
if (failed(convertTypes(ty.getInputs(), inputs)))
- return std::nullopt;
+ return nullptr;
SmallVector<Type> results;
if (failed(convertTypes(ty.getResults(), results)))
- return std::nullopt;
+ return nullptr;
return FunctionType::get(ty.getContext(), inputs, results);
});
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9efea066a03c85..28f9061d9873b7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -169,8 +169,9 @@ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
std::is_same<OpTy, memref::AllocaOp>(),
"expected only memref::AllocOp or memref::AllocaOp");
auto currentType = cast<MemRefType>(op.getMemref().getType());
- auto newResultType = dyn_cast<MemRefType>(
- this->getTypeConverter()->convertType(op.getType()));
+ auto newResultType =
+ this->getTypeConverter()->template convertType<MemRefType>(
+ op.getType());
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -378,7 +379,7 @@ struct ConvertMemRefReinterpretCast final
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType newTy =
- dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ getTypeConverter()->convertType<MemRefType>(op.getType());
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -466,8 +467,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
LogicalResult
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy = dyn_cast<MemRefType>(
- getTypeConverter()->convertType(subViewOp.getType()));
+ MemRefType newTy =
+ getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
if (!newTy) {
return rewriter.notifyMatchFailure(
subViewOp->getLoc(),
@@ -632,14 +633,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(ty, strides, offset)))
- return std::nullopt;
+ return nullptr;
if (!strides.empty() && strides.back() != 1)
- return std::nullopt;
+ return nullptr;
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
intTy.getSignedness());
if (!newElemTy)
- return std::nullopt;
+ return nullptr;
StridedLayoutAttr layoutAttr;
// If the offset is 0, we do not need a strided layout as the stride is
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index bc4535f97acf04..49b71625291db9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions(
Type newElemTy = typeConverter.convertType(intTy);
if (!newElemTy)
- return std::nullopt;
+ return nullptr;
return ty.cloneWith(std::nullopt, newElemTy);
});
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir
new file mode 100644
index 00000000000000..024144337a31fb
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --verify-diagnostics --split-input-file %s
+
+func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %arr = memref.alloc() : memref<40x40xi4>
+ // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
+ %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
+ %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
+ return %ld : i4
+}
+
+// -----
+
+func.func @alloc_non_contiguous() {
+ // expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+ %arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
+ return
+}
+
+// -----
+
+// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 540da239fced08..498f5d768e7358 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
// Expect no conversions.
func.func @memref_i8() -> i8 {
@@ -203,18 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
// -----
-
-func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
- %c0 = arith.constant 0 : index
- %arr = memref.alloc() : memref<40x40xi4>
- // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
- %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
- %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
- return %ld : i4
-}
-
-// -----
-
func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
@@ -540,16 +528,3 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>)
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
// CHECK32: return
-
-// -----
-
-!colMajor = memref<8x8xi4, strided<[1, 8]>>
-func.func @copy_distinct_layouts(%idx : index) -> i4 {
- %c0 = arith.constant 0 : index
- %arr = memref.alloc() : memref<8x8xi4>
- %arr2 = memref.alloc() : !colMajor
- // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
- memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
- %ld = memref.load %arr2[%c0, %c0] : !colMajor
- return %ld : i4
-}
diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir
new file mode 100644
index 00000000000000..228e9a0bff7bcf
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" \
+// RUN: --split-input-file --verify-diagnostics %s
+
+// Make sure we do not crash on unsupported types.
+
+func.func @alloc_i128() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+ %m = memref.alloc() : memref<4xi128, 1>
+ return
+}
+
+// -----
+
+func.func @load_i128(%m: memref<4xi128, 1>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}}
+ %v = memref.load %m[%c0] : memref<4xi128, 1>
+ return
+}
+
+// -----
+
+func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}}
+ memref.store %c1, %m[%c0] : memref<4xi128, 1>
+ return
+}
More information about the Mlir-commits
mailing list