diff --git a/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp b/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp index dc6f4532..cec3f673 100644 --- a/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp +++ b/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp @@ -57,6 +57,28 @@ struct ArithAddIToNeuraAdd : public OpRewritePattern { } }; +struct ArithCmpFToNeuraFCmp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + mlir::arith::CmpFPredicate predicate_enum = op.getPredicate(); + + StringRef predicate_str = arith::stringifyCmpFPredicate(predicate_enum); + + StringAttr predicate_attr = rewriter.getStringAttr(predicate_str); +// Converts arith CmpFOp to Neura FCmpOp. + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), lhs, rhs, predicate_attr); + + return success(); + } +}; + struct ArithFAddToNeuraFAdd : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -338,7 +360,7 @@ struct LowerArithToNeuraPass mlir::neura::arith2neura::populateWithGenerated(patterns); patterns.add< ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant, - ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel, + ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithCmpFToNeuraFCmp, ArithSelectToNeuraSel, ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast, ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub, ArithMulIToNeuraMul, diff --git a/test/arith2neura/cmpf.mlir b/test/arith2neura/cmpf.mlir new file mode 100644 index 00000000..27b25eca --- /dev/null +++ b/test/arith2neura/cmpf.mlir @@ -0,0 +1,12 @@ +module { + func.func @test_cmpf(%arg0: f32, %arg1: f32) -> i1 { + %0 = arith.cmpf ogt, %arg0, %arg1 : f32 + return %0 : i1 + } +} + +// RUN: mlir-neura-opt --assign-accelerator --lower-arith-to-neura %s | FileCheck %s --check-prefix=OPT + +// CHECK-LABEL: func.func @test_cmpf( +// OPT: %{{.*}} = "neura.fcmp" +// OPT: cmpType = "ogt"