|
1 | | -//===--- AutodiffClosureSpecialization.swift ---------------------------===// |
| 1 | +//===--- ClosureSpecialization.swift ---------------------------===// |
2 | 2 | // |
3 | 3 | // This source file is part of the Swift.org open source project |
4 | 4 | // |
|
10 | 10 | // |
11 | 11 | //===-----------------------------------------------------------------------===// |
12 | 12 |
|
| 13 | +/// This file contains the closure-specialization optimizations for general and differentiable Swift. |
| 14 | + |
| 15 | +/// General Closure Specialization |
| 16 | +/// ------------------------------------ |
| 17 | +/// TODO: Add description when the functionality is added. |
| 18 | + |
13 | 19 | /// AutoDiff Closure Specialization |
14 | | -/// ---------------------- |
| 20 | +/// ------------------------------- |
15 | 21 | /// This optimization performs closure specialization tailored for the patterns seen in Swift Autodiff. In principle, |
16 | 22 | /// the optimization does the same thing as the existing closure specialization pass. However, it is tailored to the |
17 | 23 | /// patterns of Swift Autodiff. |
@@ -100,12 +106,32 @@ private func log(_ message: @autoclosure () -> String) { |
100 | 106 | } |
101 | 107 |
|
102 | 108 | // =========== Entry point =========== // |
103 | | -let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-specialize") { |
| 109 | +let generalClosureSpecialization = FunctionPass(name: "experimental-swift-based-closure-specialization") { |
| 110 | + (function: Function, context: FunctionPassContext) in |
| 111 | + // TODO: Implement general closure specialization optimization |
| 112 | + print("NOT IMPLEMENTED") |
| 113 | +} |
| 114 | + |
| 115 | +let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-specialization") { |
104 | 116 | (function: Function, context: FunctionPassContext) in |
105 | 117 | // TODO: Pass is a WIP and current implementation is incomplete |
106 | 118 | if !function.isAutodiffVJP { |
107 | 119 | return |
108 | 120 | } |
| 121 | + |
| 122 | + print("Specializing closures in function: \(function.name)") |
| 123 | + print("===============================================") |
| 124 | + var callSites = gatherCallSites(in: function, context) |
| 125 | + |
| 126 | + callSites.forEach { callSite in |
| 127 | + print("PartialApply call site: \(callSite.applySite)") |
| 128 | + print("Passed in closures: ") |
| 129 | + for index in callSite.closureArgDescriptors.indices { |
| 130 | + var closureArgDescriptor = callSite.closureArgDescriptors[index] |
| 131 | + print("\(index+1). \(closureArgDescriptor.closureInfo.closure)") |
| 132 | + } |
| 133 | + } |
| 134 | + print("\n") |
109 | 135 | } |
110 | 136 |
|
111 | 137 | // =========== Top-level functions ========== // |
@@ -276,8 +302,8 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, |
276 | 302 | if !pai.isPullbackInResultOfAutodiffVJP, |
277 | 303 | pai.isPartialApplyOfReabstractionThunk, |
278 | 304 | pai.isSupportedClosure, |
279 | | - pai.callee.type.isNoEscapeFunction, |
280 | | - pai.callee.type.isThickFunction |
| 305 | + pai.arguments[0].type.isNoEscapeFunction, |
| 306 | + pai.arguments[0].type.isThickFunction |
281 | 307 | { |
282 | 308 | rootClosureConversionsAndReabstractions.pushIfNotVisited(contentsOf: pai.uses) |
283 | 309 | possibleMarkDependenceBases.insert(pai) |
@@ -338,6 +364,7 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap: |
338 | 364 | continue |
339 | 365 | } |
340 | 366 |
|
| 367 | + // TODO: Handling generic closures may be possible but is not yet implemented |
341 | 368 | if pai.hasSubstitutions || !pai.calleeIsDynamicFunctionRef || !pai.isPullbackInResultOfAutodiffVJP { |
342 | 369 | continue |
343 | 370 | } |
@@ -507,7 +534,7 @@ private func markConvertedAndReabstractedClosuresAsUsed(rootClosure: Value, conv |
507 | 534 | convertedAndReabstractedClosures.insert(pai) |
508 | 535 | return |
509 | 536 | markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, |
510 | | - convertedAndReabstractedClosure: pai.callee, |
| 537 | + convertedAndReabstractedClosure: pai.arguments[0], |
511 | 538 | convertedAndReabstractedClosures: &convertedAndReabstractedClosures) |
512 | 539 | case let cvt as ConvertFunctionInst: |
513 | 540 | convertedAndReabstractedClosures.insert(cvt) |
|
0 commit comments