99use Doctrine \DBAL \Driver \PDO \SQLite \Driver as PdoSQLiteDriver ;
1010use Doctrine \DBAL \Driver \PgSQL \Driver as PgSQLDriver ;
1111use Doctrine \DBAL \Driver \SQLite3 \Driver as SQLite3Driver ;
12+ use Doctrine \DBAL \Types \Type as DbalType ;
1213use Doctrine \ORM \EntityManagerInterface ;
1314use Doctrine \ORM \Mapping \ClassMetadata ;
1415use Doctrine \ORM \Query ;
3839use PHPStan \Type \IntersectionType ;
3940use PHPStan \Type \MixedType ;
4041use PHPStan \Type \NeverType ;
41- use PHPStan \Type \NullType ;
4242use PHPStan \Type \ObjectType ;
4343use PHPStan \Type \StringType ;
4444use PHPStan \Type \Type ;
4545use PHPStan \Type \TypeCombinator ;
4646use PHPStan \Type \TypeTraverser ;
47- use PHPStan \Type \TypeUtils ;
4847use PHPStan \Type \UnionType ;
4948use function array_key_exists ;
5049use function array_map ;
50+ use function array_values ;
5151use function assert ;
5252use function class_exists ;
5353use function count ;
@@ -414,7 +414,7 @@ public function walkFunction($function): string
414414 return $ this ->marshalType ($ this ->inferSumFunction ($ function ));
415415
416416 case $ function instanceof AST \Functions \CountFunction:
417- return $ this ->marshalType (new IntegerType ( )); // TypedExpression condition will overwrite this anyway
417+ return $ this ->marshalType (IntegerRangeType:: fromInterval ( 0 , null ));
418418
419419 case $ function instanceof AST \Functions \AbsFunction:
420420 // mysql sqlite pdo_pgsql pgsql
@@ -431,10 +431,25 @@ public function walkFunction($function): string
431431
432432 $ exprType = $ this ->unmarshalType ($ this ->walkSimpleArithmeticExpression ($ function ->simpleArithmeticExpression ));
433433 $ exprType = $ this ->generalizeLiteralType ($ exprType , false );
434+ $ exprTypeNoNull = TypeCombinator::removeNull ($ exprType );
435+ $ nullable = TypeCombinator::containsNull ($ exprType );
436+
437+ if ($ exprTypeNoNull ->isInteger ()->yes ()) {
438+ $ positiveInt = TypeCombinator::containsNull ($ exprType )
439+ ? TypeCombinator::addNull (IntegerRangeType::fromInterval (0 , null ))
440+ : IntegerRangeType::fromInterval (0 , null );
441+ return $ this ->marshalType ($ positiveInt );
442+ }
434443
435- // TODO invalid usages
444+ if ($ exprTypeNoNull ->isFloat ()->yes () || $ exprTypeNoNull ->isNumericString ()->yes ()) {
445+ return $ this ->marshalType ($ exprType ); // retains underlying type
446+ }
436447
437- return $ this ->marshalType ($ exprType ); // retains underlying type
448+ if ($ exprTypeNoNull ->isString ()->yes ()) {
449+ return $ this ->marshalType ($ this ->createFloat ($ nullable ));
450+ }
451+
452+ return $ this ->marshalType (new MixedType ());
438453
439454 case $ function instanceof AST \Functions \BitAndFunction:
440455 case $ function instanceof AST \Functions \BitOrFunction:
@@ -549,6 +564,16 @@ public function walkFunction($function): string
549564 $ secondExprType = $ this ->unmarshalType ($ this ->walkSimpleArithmeticExpression ($ function ->secondSimpleArithmeticExpression ));
550565
551566 $ type = $ firstExprType ;
567+ $ typeNoNull = TypeCombinator::removeNull ($ type );
568+
569+ // TODO simplify?
570+
571+ if ($ typeNoNull ->isInteger ()->yes ()) {
572+ $ type = TypeCombinator::containsNull ($ type )
573+ ? TypeCombinator::addNull (IntegerRangeType::fromInterval (0 , null ))
574+ : IntegerRangeType::fromInterval (0 , null );
575+ }
576+
552577 if (TypeCombinator::containsNull ($ firstExprType ) || TypeCombinator::containsNull ($ secondExprType )) {
553578 $ type = TypeCombinator::addNull ($ type );
554579 }
@@ -738,7 +763,7 @@ private function inferAvgFunction(AST\Functions\AvgFunction $function): Type
738763 private function inferSumFunction (AST \Functions \SumFunction $ function ): Type
739764 {
740765 // mysql sqlite pdo_pgsql pgsql
741- // col_float => float float string float
766+ // col_float => float float string float
742767 // col_decimal => string float string string
743768 // col_int => int int int int
744769 // col_bigint => int int int int
@@ -802,6 +827,18 @@ private function createNumericString(bool $nullable): Type
802827 return $ nullable ? TypeCombinator::addNull ($ numericString ) : $ numericString ;
803828 }
804829
830+ /**
831+ * @param list<Type> $allowedTypes
832+ */
833+ private function containsOnlyTypes (
834+ Type $ checkedType ,
835+ array $ allowedTypes
836+ ): bool
837+ {
838+ $ allowedType = TypeCombinator::union (...$ allowedTypes );
839+ return $ allowedType ->isSuperTypeOf ($ checkedType )->yes ();
840+ }
841+
805842 /**
806843 * E.g. to ensure SUM(1) is inferred as int, not 1
807844 */
@@ -1074,7 +1111,10 @@ public function walkSelectExpression($selectExpression): string
10741111 $ type = $ this ->unmarshalType ($ expr ->dispatch ($ this ));
10751112
10761113 if ($ expr instanceof TypedExpression) {
1077- $ type = $ this ->resolveDoctrineType ($ expr ->getReturnType ()->getName (), null , TypeCombinator::containsNull ($ type )); // TODO test nullability
1114+ $ type = TypeCombinator::intersect ( // e.g. count is typed as int, but we infer int<0, max>
1115+ $ type ,
1116+ $ this ->resolveDoctrineType (DbalType::lookupName ($ expr ->getReturnType ()), null , TypeCombinator::containsNull ($ type ))
1117+ );
10781118 } else {
10791119 // Expressions default to Doctrine's StringType, whose
10801120 // convertToPHPValue() is a no-op. So the actual type depends on
@@ -1467,14 +1507,15 @@ public function walkSimpleArithmeticExpression($simpleArithmeticExpr): string
14671507 // Skip '+' or '-'
14681508 continue ;
14691509 }
1510+
14701511 $ type = $ this ->unmarshalType ($ this ->walkArithmeticPrimary ($ term ));
1471- $ types [] = TypeUtils::generalizeType ($ type , GeneralizePrecision::lessSpecific ());
1512+ if ($ term instanceof AST \Literal) {
1513+ $ type = $ type ->generalize (GeneralizePrecision::lessSpecific ()); // make '1' string, not numeric-string
1514+ }
1515+ $ types [] = $ type ;
14721516 }
14731517
1474- $ type = TypeCombinator::union (...$ types );
1475- $ type = $ this ->toNumericOrNull ($ type );
1476-
1477- return $ this ->marshalType ($ type );
1518+ return $ this ->marshalType ($ this ->inferPlusMinusTimesType ($ types ));
14781519 }
14791520
14801521 /**
@@ -1487,20 +1528,177 @@ public function walkArithmeticTerm($term): string
14871528 }
14881529
14891530 $ types = [];
1531+ $ operators = [];
14901532
14911533 foreach ($ term ->arithmeticFactors as $ factor ) {
14921534 if (!$ factor instanceof AST \Node) {
1493- // Skip '*' or '/'
1494- continue ;
1535+ assert (is_string ($ factor ));
1536+ $ operators [$ factor ] = $ factor ;
1537+ continue ; // Skip '*' or '/'
14951538 }
1496- $ type = $ this -> unmarshalType ( $ this -> walkArithmeticPrimary ( $ factor ));
1497- $ types [] = TypeUtils:: generalizeType ( $ type , GeneralizePrecision:: lessSpecific ( ));
1539+
1540+ $ types [] = $ this -> unmarshalType ( $ this -> walkArithmeticPrimary ( $ factor ));
14981541 }
14991542
1500- $ type = TypeCombinator::union (...$ types );
1501- $ type = $ this ->toNumericOrNull ($ type );
1543+ if (array_values ($ operators ) === ['* ' ]) {
1544+ return $ this ->marshalType ($ this ->inferPlusMinusTimesType ($ types ));
1545+ }
15021546
1503- return $ this ->marshalType ($ type );
1547+ return $ this ->marshalType ($ this ->inferDivisionType ($ types ));
1548+ }
1549+
1550+ /**
1551+ * @param list<Type> $termTypes
1552+ */
1553+ private function inferPlusMinusTimesType (array $ termTypes ): Type
1554+ {
1555+ // mysql sqlite pdo_pgsql pgsql
1556+ // col_float float float string float
1557+ // col_decimal string float string string
1558+ // col_int int int int int
1559+ // col_bigint int int int int
1560+ // col_bool int int bool bool
1561+ //
1562+ // col_int + col_int int int int int
1563+ // col_int + col_float float float string float
1564+ // col_float + col_float float float string float
1565+ // col_float + col_decimal float float string float
1566+ // col_int + col_decimal string float string string
1567+ // col_decimal + col_decimal string float string string
1568+ // col_string + col_string float int x x
1569+ // col_int + col_string float int x x
1570+ // col_bool + col_bool int int x x
1571+ // col_int + col_bool int int x x
1572+ // col_float + col_string float float x x
1573+ // col_decimal + col_string float float x x
1574+ // col_float + col_bool float float x x
1575+ // col_decimal + col_bool string float x x
1576+
1577+ $ driver = $ this ->em ->getConnection ()->getDriver ();
1578+ $ types = [];
1579+
1580+ foreach ($ termTypes as $ termType ) {
1581+ $ types [] = $ this ->generalizeLiteralType ($ termType , false );
1582+ }
1583+
1584+ $ union = TypeCombinator::union (...$ types );
1585+ $ nullable = TypeCombinator::containsNull ($ union );
1586+ $ unionWithoutNull = TypeCombinator::removeNull ($ union );
1587+
1588+ if ($ unionWithoutNull ->isInteger ()->yes ()) {
1589+ return $ this ->createInteger ($ nullable );
1590+ }
1591+
1592+ if ($ driver instanceof PdoPgSQLDriver) {
1593+ return $ this ->createNumericString ($ nullable );
1594+ }
1595+
1596+ if ($ driver instanceof SQLite3Driver || $ driver instanceof PdoSqliteDriver) {
1597+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new FloatType ()])) {
1598+ return $ this ->createFloat ($ nullable );
1599+ }
1600+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new StringType ()])) {
1601+ return $ this ->createInteger ($ nullable );
1602+ }
1603+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new FloatType (), new StringType ()])) {
1604+ return $ this ->createFloat ($ nullable );
1605+ }
1606+ }
1607+
1608+ if ($ driver instanceof MysqliDriver || $ driver instanceof PdoMysqlDriver || $ driver instanceof PgSQLDriver) {
1609+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new FloatType ()])) {
1610+ return $ this ->createFloat ($ nullable );
1611+ }
1612+
1613+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), $ this ->createNumericString (false )])) {
1614+ return $ this ->createNumericString ($ nullable );
1615+ }
1616+
1617+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new StringType ()])) {
1618+ return $ this ->createFloat ($ nullable );
1619+ }
1620+
1621+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new FloatType (), new StringType ()])) {
1622+ return $ this ->createFloat ($ nullable );
1623+ }
1624+ }
1625+
1626+ // TODO all 3?
1627+ // TODO string
1628+ // TODO string with number in it?
1629+
1630+ // postgre fails and other drivers are unknown
1631+ return new MixedType ();
1632+ }
1633+
1634+ /**
1635+ * @param list<Type> $termTypes
1636+ */
1637+ private function inferDivisionType (array $ termTypes ): Type
1638+ {
1639+ // mysql sqlite pdo_pgsql pgsql
1640+ // col_float => float float string float
1641+ // col_decimal => string float string string
1642+ // col_int => int int int int
1643+ // col_bigint => int int int int
1644+ //
1645+ // col_int / col_int string int int int
1646+ // col_int / col_float float float string float
1647+ // col_float / col_float float float string float
1648+ // col_float / col_decimal float float string float
1649+ // col_int / col_decimal string float string string
1650+ // col_decimal / col_decimal string float string string
1651+ // col_string / col_string null null x x
1652+ // col_int / col_string null null x x
1653+ // col_bool / col_bool string int x x
1654+ // col_int / col_bool string int x x
1655+ // col_float / col_string null null x x
1656+ // col_decimal / col_string null null x x
1657+ // col_float / col_bool float float x x
1658+ // col_decimal / col_bool string float x x
1659+
1660+ $ driver = $ this ->em ->getConnection ()->getDriver ();
1661+ $ types = [];
1662+
1663+ foreach ($ termTypes as $ termType ) {
1664+ $ types [] = $ this ->generalizeLiteralType ($ termType , false );
1665+ }
1666+
1667+ $ union = TypeCombinator::union (...$ types );
1668+ $ nullable = TypeCombinator::containsNull ($ union );
1669+ $ unionWithoutNull = TypeCombinator::removeNull ($ union );
1670+
1671+ if ($ unionWithoutNull ->isInteger ()->yes ()) {
1672+ if ($ driver instanceof MysqliDriver || $ driver instanceof PdoMysqlDriver) {
1673+ return $ this ->createNumericString ($ nullable );
1674+ } elseif ($ driver instanceof PdoPgSQLDriver || $ driver instanceof PgSQLDriver || $ driver instanceof SQLite3Driver || $ driver instanceof PdoSqliteDriver) {
1675+ return $ this ->createInteger ($ nullable );
1676+ }
1677+
1678+ return new MixedType ();
1679+ }
1680+
1681+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new FloatType (), $ this ->createNumericString (false )])) {
1682+ if ($ driver instanceof PdoPgSQLDriver) {
1683+ return $ this ->createNumericString ($ nullable );
1684+ }
1685+ if ($ driver instanceof SQLite3Driver || $ driver instanceof PdoSqliteDriver) {
1686+ return $ this ->createFloat ($ nullable );
1687+ }
1688+ if ($ driver instanceof MysqliDriver || $ driver instanceof PdoMysqlDriver || $ driver instanceof PgSQLDriver) {
1689+ return TypeCombinator::union ( // float vs decimal
1690+ $ this ->createNumericString ($ nullable ),
1691+ $ this ->createFloat ($ nullable )
1692+ );
1693+ }
1694+ }
1695+
1696+ // incompatible types, not trying to be precise here, very chaotic behaviour + postgre fails
1697+ return TypeCombinator::union (
1698+ $ this ->createNumericString (true ),
1699+ $ this ->createFloat (true ),
1700+ $ this ->createInteger (true )
1701+ );
15041702 }
15051703
15061704 /**
@@ -1659,25 +1857,6 @@ private function resolveDatabaseInternalType(string $typeName, ?string $enumType
16591857 return $ type ;
16601858 }
16611859
1662- private function toNumericOrNull (Type $ type ): Type
1663- {
1664- return TypeTraverser::map ($ type , static function (Type $ type , callable $ traverse ): Type {
1665- if ($ type instanceof UnionType || $ type instanceof IntersectionType) {
1666- return $ traverse ($ type );
1667- }
1668- if ($ type instanceof NullType || $ type instanceof IntegerType) {
1669- return $ type ;
1670- }
1671- if ($ type instanceof BooleanType) {
1672- return $ type ->toInteger ();
1673- }
1674- return TypeCombinator::union (
1675- $ type ->toFloat (),
1676- $ type ->toInteger ()
1677- );
1678- });
1679- }
1680-
16811860 /**
16821861 * Returns whether the query has aggregate function and no group by clause
16831862 *
0 commit comments