[Mlir-commits] [mlir] 864236d - [mlir][arith] Support wide integer constant emulation
Jakub Kuderski
llvmlistbot at llvm.org
Thu Sep 8 21:07:48 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-09T00:04:06-04:00
New Revision: 864236d1c1fa61d1cfa95e0d59effc59a96cb06d
URL: https://github.com/llvm/llvm-project/commit/864236d1c1fa61d1cfa95e0d59effc59a96cb06d
DIFF: https://github.com/llvm/llvm-project/commit/864236d1c1fa61d1cfa95e0d59effc59a96cb06d.diff
LOG: [mlir][arith] Support wide integer constant emulation
Reviewed By: antiagainst, Mogball
Differential Revision: https://reviews.llvm.org/D133136
Added:
Modified:
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
index 94e321ba1ad7..2d24513e2734 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
@@ -23,7 +24,80 @@ namespace mlir::arith {
using namespace mlir;
+// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
+// Treats `value` as a 2*N bits-wide integer.
+// The bottom bits are returned in the first pair element, while the top bits in
+// the second one.
+static std::pair<APInt, APInt> getHalves(const APInt &value,
+ unsigned newBitWidth) {
+ APInt low = value.extractBits(newBitWidth, 0);
+ APInt high = value.extractBits(newBitWidth, newBitWidth);
+ return {std::move(low), std::move(high)};
+}
+
namespace {
+//===----------------------------------------------------------------------===//
+// ConvertConstant
+//===----------------------------------------------------------------------===//
+
+struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OpAdaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type oldType = op.getType();
+ auto newType = getTypeConverter()->convertType(oldType).cast<VectorType>();
+ unsigned newBitWidth = newType.getElementTypeBitWidth();
+ Attribute oldValue = op.getValueAttr();
+
+ if (auto intAttr = oldValue.dyn_cast<IntegerAttr>()) {
+ auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
+ auto newAttr = DenseElementsAttr::get(newType, {low, high});
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
+ return success();
+ }
+
+ if (auto splatAttr = oldValue.dyn_cast<SplatElementsAttr>()) {
+ auto [low, high] =
+ getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
+ int64_t numSplatElems = splatAttr.getNumElements();
+ SmallVector<APInt> values;
+ values.reserve(numSplatElems * 2);
+ for (int64_t i = 0; i < numSplatElems; ++i) {
+ values.push_back(low);
+ values.push_back(high);
+ }
+
+ auto attr = DenseElementsAttr::get(newType, values);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
+ return success();
+ }
+
+ if (auto elemsAttr = oldValue.dyn_cast<DenseElementsAttr>()) {
+ int64_t numElems = elemsAttr.getNumElements();
+ SmallVector<APInt> values;
+ values.reserve(numElems * 2);
+ for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
+ auto [low, high] = getHalves(origVal, newBitWidth);
+ values.push_back(std::move(low));
+ values.push_back(std::move(high));
+ }
+
+ auto attr = DenseElementsAttr::get(newType, values);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
+ return success();
+ }
+
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "unhandled constant attribute");
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
struct EmulateWideIntPass final
: arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
@@ -42,7 +116,11 @@ struct EmulateWideIntPass final
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
});
- target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+ target.addDynamicallyLegalOp<
+ // `func.*` ops
+ func::CallOp, func::ReturnOp,
+ // `arith.*` ops
+ arith::ConstantOp>(
[&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
RewritePatternSet patterns(ctx);
@@ -54,6 +132,10 @@ struct EmulateWideIntPass final
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
arith::WideIntEmulationConverter::WideIntEmulationConverter(
unsigned widestIntSupportedByTarget)
: maxIntWidth(widestIntSupportedByTarget) {
@@ -117,4 +199,7 @@ void arith::populateWideIntEmulationPatterns(
typeConverter);
populateCallOpTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
+
+ // Populate `arith.*` conversion patterns.
+ patterns.add<ConvertConstant>(typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
index aafeb5b3a1cd..3a1a6c7b2d1f 100644
--- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
@@ -49,3 +49,29 @@ func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
%res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
return %res : vector<4xi64>
}
+
+// CHECK-LABEL: func @constant_scalar
+// CHECK-SAME: () -> vector<2xi32>
+// CHECK-NEXT: [[C0:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[C1:%.+]] = arith.constant dense<[0, 1]> : vector<2xi32>
+// CHECK-NEXT: [[C2:%.+]] = arith.constant dense<[-7, -1]> : vector<2xi32>
+// CHECK-NEXT: return [[C0]] : vector<2xi32>
+func.func @constant_scalar() -> i64 {
+ %c0 = arith.constant 0 : i64
+ %c1 = arith.constant 4294967296 : i64
+ %c2 = arith.constant -7 : i64
+ return %c0 : i64
+}
+
+// CHECK-LABEL: func @constant_vector
+// CHECK-SAME: () -> vector<3x2xi32>
+// CHECK-NEXT: [[C0:%.+]] = arith.constant dense
+// CHECK-SAME{LITERAL}: <[[0, 1], [0, 1], [0, 1]]> : vector<3x2xi32>
+// CHECK-NEXT: [[C1:%.+]] = arith.constant dense
+// CHECK-SAME{LITERAL}: <[[0, 0], [1, 0], [-2, -1]]> : vector<3x2xi32>
+// CHECK-NEXT: return [[C0]] : vector<3x2xi32>
+func.func @constant_vector() -> vector<3xi64> {
+ %c0 = arith.constant dense<4294967296> : vector<3xi64>
+ %c1 = arith.constant dense<[0, 1, -2]> : vector<3xi64>
+ return %c0 : vector<3xi64>
+}
More information about the Mlir-commits
mailing list