[Mlir-commits] [mlir] [mlir] Add sub-byte type emulation support for `memref.collapse_shape` (PR #89962)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 24 10:54:02 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
This PR adds support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).
---
Full diff: https://github.com/llvm/llvm-project/pull/89962.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+29-3)
- (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+20)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 4449733f0daf06..77c108aab48070 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <type_traits>
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCollapseShape
+//===----------------------------------------------------------------------===//
+
+/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// there is no dimension to collapse any further.
+struct ConvertMemRefCollapseShape final
+ : OpConversionPattern<memref::CollapseShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value srcVal = adaptor.getSrc();
+ auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+ if (!newTy)
+ return failure();
+
+ if (newTy.getRank() != 1)
+ return failure();
+
+ rewriter.replaceOp(collapseShapeOp, srcVal);
+ return success();
+ }
+};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
- ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+ ConvertMemRefAllocation<memref::AllocaOp>,
+ ConvertMemRefCollapseShape, ConvertMemRefLoad,
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..435dcc944778db 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
// CHECK32: return
+
+// -----
+
+func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
+ %arr = memref.alloc() : memref<32x8x128xi4>
+ %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
+ %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
+ return %1 : i4
+}
+
+// CHECK-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+// CHECK-NOT: memref.collapse_shape
+// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+// CHECK32-NOT: memref.collapse_shape
+// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/89962
More information about the Mlir-commits
mailing list