[Mlir-commits] [mlir] 486a2ca - Add tensor.bitcast op to Tensor dialect
Rob Suderman
llvmlistbot at llvm.org
Tue May 2 13:28:01 PDT 2023
Author: Trevor Morris
Date: 2023-05-02T13:20:13-07:00
New Revision: 486a2ca57a27451177c389b485ecb56f88100649
URL: https://github.com/llvm/llvm-project/commit/486a2ca57a27451177c389b485ecb56f88100649
DIFF: https://github.com/llvm/llvm-project/commit/486a2ca57a27451177c389b485ecb56f88100649.diff
LOG: Add tensor.bitcast op to Tensor dialect
Add tensor.bitcast operator to bitcast between two tensors of compatible shape
and same bit width. This can be use to reinterpret an unsigned integer as a
signed integer or vice versa.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D149608
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index db979c07722ae..e0e97e7efdcc5 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -53,6 +53,35 @@ class Tensor_OpWithOffsetSizesAndStrides<string mnemonic,
}];
}
+//===----------------------------------------------------------------------===//
+// BitcastOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_BitcastOp : Tensor_Op<"bitcast", [
+ DeclareOpInterfaceMethods<CastOpInterface>,
+ Pure
+ ]> {
+ let summary = "tensor bitcast operation";
+ let description = [{
+ Bitcast a tensor from one type to another type of equivalent element width.
+ If both are ranked, then the rank should be the same and static dimensions
+ should match.
+
+ Example:
+
+ ```mlir
+ // Bitcast from unsigned to signed or signless integer.
+ %2 = tensor.bitcast %1 : tensor<4xui32> to tensor<4xi32>
+ ```
+ }];
+
+ let arguments = (ins AnyTensor:$source);
+ let results = (outs AnyTensor:$dest);
+ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8e2461f499b2a..74a21771d034e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -162,6 +162,53 @@ static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
return droppedDims;
}
+//===----------------------------------------------------------------------===//
+// BitcastOp
+//===----------------------------------------------------------------------===//
+
+bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ Type a = inputs.front(), b = outputs.front();
+ auto aT = dyn_cast<TensorType>(a);
+ auto bT = dyn_cast<TensorType>(b);
+ if (!aT || !bT)
+ return false;
+
+ if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
+ return false;
+
+ return succeeded(verifyCompatibleShape(aT, bT));
+}
+
+namespace {
+
+/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
+/// operation.
+struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
+ using OpRewritePattern<BitcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
+ PatternRewriter &rewriter) const final {
+ auto tensorBitcastOperand =
+ tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
+ if (!tensorBitcastOperand)
+ return failure();
+
+ auto resultType = cast<TensorType>(tensorBitcast.getType());
+ rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
+ tensorBitcastOperand.getOperand());
+ return success();
+ }
+};
+
+} // namespace
+
+void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ChainedTensorBitcast>(context);
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0a42e2bb3a5c9..4a757500920d5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1,5 +1,28 @@
// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+// CHECK-LABEL: @tensor_bitcast_chain_ok
+// CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
+func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
+ // CHECK-NEXT: %[[RES:.*]] = tensor.bitcast %[[IN]] : tensor<2xi32> to tensor<2xf32>
+ %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32>
+ %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_bitcast_chain_nop
+// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
+func.func @tensor_bitcast_chain_nop(%input: tensor<4xi32>) -> tensor<4xi32> {
+ %0 = tensor.bitcast %input : tensor<4xi32> to tensor<4xui32>
+ %1 = tensor.bitcast %0 : tensor<4xui32> to tensor<4xi32>
+ // CHECK-NEXT: return %[[IN]]
+ return %1 : tensor<4xi32>
+}
+
+// -----
+
// Checks that NOP casts are removed.
// CHECK-LABEL: cast_values
func.func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
More information about the Mlir-commits
mailing list