Skip to content

Commit 84bf3e8

Browse files
committed
Add stricter checks for Java platform SAM compatibility
1 parent ff7e01c commit 84bf3e8

File tree

5 files changed

+325
-38
lines changed

5 files changed

+325
-38
lines changed

compiler/src/dotty/tools/dotc/config/JavaPlatform.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ class JavaPlatform extends Platform {
5050
cls.superClass == defn.ObjectClass &&
5151
cls.directlyInheritedTraits.forall(_.is(NoInits)) &&
5252
!ExplicitOuter.needsOuterIfReferenced(cls) &&
53-
cls.typeRef.fields.isEmpty // Superaccessors already show up as abstract methods here, so no test necessary
53+
// Superaccessors already show up as abstract methods here, so no test necessary
54+
cls.typeRef.fields.isEmpty &&
55+
// Check if the SAM can be implemented via LambdaMetaFactory
56+
TypeErasure.samNotNeededExpansion(cls)
5457

5558
/** We could get away with excluding BoxedBooleanClass for the
5659
* purpose of equality testing since it need not compare equal

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end SourceLanguage
7474
* only for isInstanceOf, asInstanceOf: PolyType, TypeParamRef, TypeBounds
7575
*
7676
*/
77-
object TypeErasure {
77+
object TypeErasure:
7878

7979
private def erasureDependsOnArgs(sym: Symbol)(using Context) =
8080
sym == defn.ArrayClass || sym == defn.PairClass || sym.isDerivedValueClass
@@ -586,7 +586,102 @@ object TypeErasure {
586586
defn.FunctionType(n = info.nonErasedParamCount)
587587
}
588588
erasure(functionType(applyInfo))
589-
}
589+
590+
/** Check if LambdaMetaFactory can handle signature adaptation between two method types.
591+
*
592+
* LMF has limitations on what type adaptations it can perform automatically.
593+
* This method checks whether manual bridging is needed for params and/or result.
594+
*
595+
* The adaptation rules are:
596+
* - For parameters: primitives and value classes cannot be auto-adapted by LMF
597+
* because the Scala spec requires null to be "unboxed" to the default value,
598+
* but LMF throws `NullPointerException` instead.
599+
* - For results: value classes and Unit cannot be auto-adapted by LMF.
600+
* Non-Unit primitives can be auto-adapted since LMF only needs to box (not unbox).
601+
* - LMF cannot auto-adapt between Object and Array types.
602+
*
603+
* @param implParamTypes Parameter types of the implementation method
604+
* @param implResultType Result type of the implementation method
605+
* @param samParamTypes Parameter types of the SAM method
606+
* @param samResultType Result type of the SAM method
607+
*
608+
* @return (paramNeeded, resultNeeded) indicating what needs bridging
609+
*/
610+
def additionalAdaptationNeeded(
611+
implParamTypes: List[Type],
612+
implResultType: Type,
613+
samParamTypes: List[Type],
614+
samResultType: Type
615+
)(using Context): (paramNeeded: Boolean, resultNeeded: Boolean) =
616+
def sameClass(tp1: Type, tp2: Type) = tp1.classSymbol == tp2.classSymbol
617+
618+
/** Can the implementation parameter type `tp` be auto-adapted to a different
619+
* parameter type in the SAM?
620+
*
621+
* For derived value classes, we always need to do the bridging manually.
622+
* For primitives, we cannot rely on auto-adaptation on the JVM because
623+
* the Scala spec requires null to be "unboxed" to the default value of
624+
* the value class, but the adaptation performed by LambdaMetaFactory
625+
* will throw a `NullPointerException` instead.
626+
*/
627+
def autoAdaptedParam(tp: Type) = !tp.isErasedValueType && !tp.isPrimitiveValueType
628+
629+
/** Can the implementation result type be auto-adapted to a different result
630+
* type in the SAM?
631+
*
632+
* For derived value classes, it's the same story as for parameters.
633+
* For non-Unit primitives, we can actually rely on the `LambdaMetaFactory`
634+
* adaptation, because it only needs to box, not unbox, so no special
635+
* handling of null is required.
636+
*/
637+
def autoAdaptedResult(tp: Type) =
638+
!tp.isErasedValueType && !(tp.classSymbol eq defn.UnitClass)
639+
640+
val paramAdaptationNeeded =
641+
implParamTypes.lazyZip(samParamTypes).exists((implType, samType) =>
642+
!sameClass(implType, samType) && (!autoAdaptedParam(implType)
643+
// LambdaMetaFactory cannot auto-adapt between Object and Array types
644+
|| samType.isInstanceOf[JavaArrayType]))
645+
646+
val resultAdaptationNeeded =
647+
!sameClass(implResultType, samResultType) && !autoAdaptedResult(implResultType)
648+
649+
(paramAdaptationNeeded, resultAdaptationNeeded)
650+
end additionalAdaptationNeeded
651+
652+
/** Check if LambdaMetaFactory can handle the SAM method's required signature adaptation.
653+
*
654+
* When a SAM method overrides other methods, the erased signatures must be compatible
655+
* to be qualifies as a valid functional interface on JVM.
656+
* This method returns true if all overridden methods have compatible erased signatures
657+
* that LMF can auto-adapt (or don't need adaptation).
658+
*
659+
* When this returns true, the SAM class does not need to be expanded.
660+
*
661+
* @param cls The SAM class to check
662+
* @return true if LMF can handle the required adaptation
663+
*/
664+
def samNotNeededExpansion(cls: ClassSymbol)(using Context): Boolean = cls.typeRef.possibleSamMethods match
665+
case Seq(samMeth) =>
666+
val samMethSym = samMeth.symbol
667+
val erasedSamInfo = transformInfo(samMethSym, samMeth.info)
668+
669+
val (erasedSamParamTypes, erasedSamResultType) = erasedSamInfo match
670+
case mt: MethodType => (mt.paramInfos, mt.resultType)
671+
case _ => return false
672+
673+
samMethSym.allOverriddenSymbols.forall { overridden =>
674+
val erasedOverriddenInfo = transformInfo(overridden, overridden.info)
675+
erasedOverriddenInfo match
676+
case mt: MethodType =>
677+
val (paramNeeded, resultNeeded) =
678+
additionalAdaptationNeeded(erasedSamParamTypes, erasedSamResultType, mt.paramInfos, mt.resultType)
679+
!(paramNeeded || resultNeeded)
680+
case _ => true
681+
}
682+
case _ => false
683+
end samNotNeededExpansion
684+
end TypeErasure
590685

591686
import TypeErasure.*
592687

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -453,41 +453,9 @@ object Erasure {
453453
val samParamTypes = sam.paramInfos
454454
val samResultType = sam.resultType
455455

456-
/** Can the implementation parameter type `tp` be auto-adapted to a different
457-
* parameter type in the SAM?
458-
*
459-
* For derived value classes, we always need to do the bridging manually.
460-
* For primitives, we cannot rely on auto-adaptation on the JVM because
461-
* the Scala spec requires null to be "unboxed" to the default value of
462-
* the value class, but the adaptation performed by LambdaMetaFactory
463-
* will throw a `NullPointerException` instead. See `lambda-null.scala`
464-
* for test cases.
465-
*
466-
* @see [LambdaMetaFactory](https://docs.oracle.com/en/java/javase/17/docs/api/java.base/java/lang/invoke/LambdaMetafactory.html)
467-
*/
468-
def autoAdaptedParam(tp: Type) =
469-
!tp.isErasedValueType && !tp.isPrimitiveValueType
470-
471-
/** Can the implementation result type be auto-adapted to a different result
472-
* type in the SAM?
473-
*
474-
* For derived value classes, it's the same story as for parameters.
475-
* For non-Unit primitives, we can actually rely on the `LambdaMetaFactory`
476-
* adaptation, because it only needs to box, not unbox, so no special
477-
* handling of null is required.
478-
*/
479-
def autoAdaptedResult =
480-
!implResultType.isErasedValueType && !implReturnsUnit
481-
482-
def sameClass(tp1: Type, tp2: Type) = tp1.classSymbol == tp2.classSymbol
483-
484-
val paramAdaptationNeeded =
485-
implParamTypes.lazyZip(samParamTypes).exists((implType, samType) =>
486-
!sameClass(implType, samType) && (!autoAdaptedParam(implType)
487-
// LambdaMetaFactory cannot auto-adapt between Object and Array types
488-
|| samType.isInstanceOf[JavaArrayType]))
489-
val resultAdaptationNeeded =
490-
!sameClass(implResultType, samResultType) && !autoAdaptedResult
456+
// Check if bridging is needed using the common function from TypeErasure
457+
val (paramAdaptationNeeded, resultAdaptationNeeded) =
458+
additionalAdaptationNeeded(implParamTypes, implResultType, samParamTypes, samResultType)
491459

492460
if paramAdaptationNeeded || resultAdaptationNeeded then
493461
// Instead of instantiating `scala.FunctionN`, see if we can instantiate

tests/run/i24573.check

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
1
2+
2
3+
3
4+
11
5+
12
6+
13
7+
14
8+
15
9+
16
10+
17
11+
18
12+
19
13+
20
14+
21
15+
22
16+
23
17+
24
18+
31
19+
32
20+
33
21+
34
22+
41
23+
42
24+
43
25+
44
26+
45
27+
46
28+
51
29+
52
30+
53
31+
55
32+
56
33+
57
34+
61
35+
62
36+
63
37+
64
38+
71
39+
72
40+
75
41+
76
42+
81
43+
82

tests/run/i24573.scala

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
trait ConTU[-T] extends (T => Unit):
2+
def apply(t: T): Unit
3+
4+
trait ConTI[-T] extends (T => Int):
5+
def apply(t: T): Int
6+
7+
trait ConTS[-T] extends (T => String):
8+
def apply(t: T): String
9+
10+
trait ConIR[+R] extends (Int => R):
11+
def apply(t: Int): R
12+
13+
trait ConSR[+R] extends (String => R):
14+
def apply(t: String): R
15+
16+
trait ConUR[+R] extends (() => R):
17+
def apply(): R
18+
19+
trait ConII extends (Int => Int):
20+
def apply(t: Int): Int
21+
22+
trait ConSI extends (String => Int):
23+
def apply(t: String): Int
24+
25+
trait ConIS extends (Int => String):
26+
def apply(t: Int): String
27+
28+
trait ConUU extends (() => Unit):
29+
def apply(): Unit
30+
31+
trait F1[-T, +R]:
32+
def apply(t: T): R
33+
34+
trait SFTU[-T] extends F1[T, Unit]:
35+
def apply(t: T): Unit
36+
37+
trait SFTI[-T] extends F1[T, Int]:
38+
def apply(t: T): Int
39+
40+
trait SFTS[-T] extends F1[T, String]:
41+
def apply(t: T): String
42+
43+
trait SFIR [+R] extends F1[Int, R]:
44+
def apply(t: Int): R
45+
46+
trait SFSR [+R] extends F1[String, R]:
47+
def apply(t: String): R
48+
49+
trait SFII extends F1[Int, Int]:
50+
def apply(t: Int): Int
51+
52+
trait SFSI extends F1[String, Int]:
53+
def apply(t: String): Int
54+
55+
trait SFIS extends F1[Int, String]:
56+
def apply(t: Int): String
57+
58+
trait SFIU extends F1[Int, Unit]:
59+
def apply(t: Int): Unit
60+
61+
trait F1U[-T]:
62+
def apply(t: T): Unit
63+
64+
trait SF2T[-T] extends F1U[T]:
65+
def apply(t: T): Unit
66+
67+
trait SF2I extends F1U[Int]:
68+
def apply(t: Int): Unit
69+
70+
trait SF2S extends F1U[String]:
71+
def apply(t: String): Unit
72+
73+
object Test:
74+
def main(args: Array[String]): Unit =
75+
val fIU: (Int => Unit) = (x: Int) => println(x) // closure by JFunction1
76+
fIU(1)
77+
78+
val fIS: (Int => String) = (x: Int) => x.toString // closure
79+
println(fIS(2))
80+
81+
val fUI: (() => Int) = () => 3 // closure
82+
println(fUI())
83+
84+
val conITU: ConTU[Int] = (x: Int) => println(x) // expanded
85+
conITU(11)
86+
val conITI: ConTI[Int] = (x: Int) => x // closure
87+
println(conITI(12))
88+
val conITS: ConTS[Int] = (x: Int) => x.toString // closure
89+
println(conITS(13))
90+
val conSTS: ConTS[String] = (x: String) => x // closure
91+
println(conSTS("14"))
92+
93+
val conIRS: ConIR[String] = (x: Int) => x.toString // expanded
94+
println(conIRS(15))
95+
val conIRI: ConIR[Int] = (x: Int) => x // expanded
96+
println(conIRI(16))
97+
val conIRU: ConIR[Unit] = (x: Int) => println(x) // expanded
98+
conIRU(17)
99+
100+
val conSRI: ConSR[Int] = (x: String) => x.toInt // closure
101+
println(conSRI("18"))
102+
val conURI: ConUR[Int] = () => 19 // closure
103+
println(conURI())
104+
val conURU: ConUR[Unit] = () => println("20") // closure
105+
conURU()
106+
107+
val conII: ConII = (x: Int) => x // expanded
108+
println(conII(21))
109+
val conSI: ConSI = (x: String) => x.toInt // closure
110+
println(conSI("22"))
111+
val conIS: ConIS = (x: Int) => x.toString // expanded
112+
println(conIS(23))
113+
val conUU: ConUU = () => println("24") // expanded
114+
conUU()
115+
116+
val ffIU: F1[Int, Unit] = (x: Int) => println(x) // closure
117+
ffIU(31)
118+
val ffIS: F1[Int, String] = (x: Int) => x.toString // closure
119+
println(ffIS(32))
120+
val ffSU: F1[String, Unit] = (x: String) => println(x) // closure
121+
ffSU("33")
122+
val ffSI: F1[String, Int] = (x: String) => x.toInt // closure
123+
println(ffSI("34"))
124+
125+
val sfITU: SFTU[Int] = (x: Int) => println(x) // expanded
126+
sfITU(41)
127+
val sfSTU: SFTU[String] = (x: String) => println(x) // expanded
128+
sfSTU("42")
129+
130+
val sfITI: SFTI[Int] = (x: Int) => x // closure
131+
println(sfITI(43))
132+
val sfSTI: SFTI[String] = (x: String) => x.toInt // closure
133+
println(sfSTI("44"))
134+
135+
val sfITS: SFTS[Int] = (x: Int) => x.toString // closure
136+
println(sfITS(45))
137+
val sfSTS: SFTS[String] = (x: String) => x // closure
138+
println(sfSTS("46"))
139+
140+
val sfIRI: SFIR[Int] = (x: Int) => x // expanded
141+
println(sfIRI(51))
142+
val sfIRS: SFIR[String] = (x: Int) => x.toString // expanded
143+
println(sfIRS(52))
144+
val sfIRU: SFIR[Unit] = (x: Int) => println(x) // expanded
145+
sfIRU(53)
146+
147+
val sfSRI: SFSR[Int] = (x: String) => x.toInt // closure
148+
println(sfSRI("55"))
149+
val sfSRS: SFSR[String] = (x: String) => x // closure
150+
println(sfSRS("56"))
151+
val sfSRU: SFSR[Unit] = (x: String) => println(x) // closure
152+
sfSRU("57")
153+
154+
val sfII: SFII = (x: Int) => x // expanded
155+
println(sfII(61))
156+
val sfSI: SFSI = (x: String) => x.toInt // closure
157+
println(sfSI("62"))
158+
val sfIS: SFIS = (x: Int) => x.toString // expanded
159+
println(sfIS(63))
160+
val sfIU: SFIU = (x: Int) => println(x) // expanded
161+
sfIU(64)
162+
163+
val f2ITU: F1U[Int] = (x: Int) => println(x) // closure
164+
f2ITU(71)
165+
val f2STU: F1U[String] = (x: String) => println(x) // closure
166+
f2STU("72")
167+
168+
val sf2IT: SF2T[Int] = (x: Int) => println(x) // closure
169+
sf2IT(75)
170+
val sf2ST: SF2T[String] = (x: String) => println(x) // closure
171+
sf2ST("76")
172+
173+
val sf2I: SF2I = (x: Int) => println(x) // expanded
174+
sf2I(81)
175+
val sf2S: SF2S = (x: String) => println(x) // closure
176+
sf2S("82")
177+
178+
end Test

0 commit comments

Comments
 (0)