[Mlir-commits] [mlir] [mlir] Add sub-byte type emulation support for `memref.collapse_shape` (PR #89962)
Diego Caballero
llvmlistbot at llvm.org
Wed Apr 24 10:53:33 PDT 2024
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/89962
This PR adds support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).
>From c15f9d4ad6cd7626e8cba071c6cad1c935c5c875 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 24 Apr 2024 16:59:52 +0000
Subject: [PATCH] [mlir] Add sub-byte type emulation support for
`memref.collapse_shape`
This PR add support for `memref.collapse_shape` to sub-byte type
emulation. The `memref.collapse_shape` becomes a no-opt given that we
are flattening the memref as part of the emulation (i.e., we are
collapsing all the dimensions).
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 32 +++++++++++++++++--
.../Dialect/MemRef/emulate-narrow-type.mlir | 20 ++++++++++++
2 files changed, 49 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 4449733f0daf06..77c108aab48070 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <type_traits>
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCollapseShape
+//===----------------------------------------------------------------------===//
+
+/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// there is no dimension to collapse any further.
+struct ConvertMemRefCollapseShape final
+ : OpConversionPattern<memref::CollapseShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value srcVal = adaptor.getSrc();
+ auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+ if (!newTy)
+ return failure();
+
+ if (newTy.getRank() != 1)
+ return failure();
+
+ rewriter.replaceOp(collapseShapeOp, srcVal);
+ return success();
+ }
+};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
- ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+ ConvertMemRefAllocation<memref::AllocaOp>,
+ ConvertMemRefCollapseShape, ConvertMemRefLoad,
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..435dcc944778db 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
// CHECK32: return
+
+// -----
+
+func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
+ %arr = memref.alloc() : memref<32x8x128xi4>
+ %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
+ %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
+ return %1 : i4
+}
+
+// CHECK-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+// CHECK-NOT: memref.collapse_shape
+// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+// CHECK32-NOT: memref.collapse_shape
+// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+
More information about the Mlir-commits
mailing list