diff --git a/Sources/StitchEngine/Actions/GraphCalculate.swift b/Sources/StitchEngine/Actions/GraphCalculate.swift index 765a717..ec4ff4d 100644 --- a/Sources/StitchEngine/Actions/GraphCalculate.swift +++ b/Sources/StitchEngine/Actions/GraphCalculate.swift @@ -16,46 +16,23 @@ extension GraphCalculatable { Set>, // portsToUpdate Bool // shouldResortPreviewLayers ) { - let graphState = self - - var visitedNodes = Set() - var queue = nodeIds - - // Reset state on scheduled cycle nodes - self.topologicalData.nodesForNextGraphStep = .init() - - var portsToUpdate = Set>() + let redEdgeCycles = self.topologicalData.cycleNodesForNextGraphStep - while let nodeId = topologicalData.getNextNodeToCalculate(for: queue, - visitedNodes: visitedNodes) { - queue.remove(nodeId) - - guard !visitedNodes.contains(nodeId) else { - assertInDebug(self.topologicalData.cycleContains(nodeId)) - -#if DEV_DEBUG - // print("GraphState.calculate: scheduling cycle node \(nodeId)") -#endif - - self.topologicalData.nodesForNextGraphStep.insert(nodeId) - continue - } - - visitedNodes.insert(nodeId) - + return self.traverseGraph(from: nodeIds, + redEdgeCycles: redEdgeCycles) { queuedNodeResult, queue in // Retrieve the node afresh everytime, // since an upstream node's changes may have changed its inputs - guard let node = self.getNode(id: nodeId) else { + guard let node = self.getNode(id: queuedNodeResult.nodeId) else { // Not necessarily bad -- can happen if we deleted nodes but those nodes were still scheduled to run. // fatalErrorIfDebug() - continue + return } let existingOutputValues = node.outputsValuesList let outputCoordinates = node.outputCoordinates guard let evalResult = self.calculateNode(node) else { - continue + return } // Update queue with NodeIds for nodes which need re-evaluation @@ -71,9 +48,67 @@ extension GraphCalculatable { // Update queue set with changed downstream nodes queue = changedNodeIds.union(Set(queue)) + } + } + + @MainActor + /// The main graph calculator. Shouldn't be called directly unless you know what you're doing. + public func traverseGraph( + from nodeIds: Set, + redEdgeCycles: Set>, + callback: @escaping @MainActor (QueuedNodeResultType, inout Set) -> () + ) -> ( + Set>, // portsToUpdate + Bool // shouldResortPreviewLayers + ) { + let graphState = self + + var visitedNodes = Set() + var queue = nodeIds + + // Reset state on scheduled cycle nodes + self.topologicalData.nodesForNextGraphStep = .init() + self.topologicalData.cycleNodesForNextGraphStep = .init() + + var portsToUpdate = Set>() + + // First handle red edge cycle nodes from last graph step by update inputs of nodes + for cycleNodeData in redEdgeCycles { + let didChangeInputs = self.updateInputs(inputCoordinate: cycleNodeData.portId, + upstreamOutputValues: cycleNodeData.values, + mediaList: cycleNodeData.mediaList, + upstreamOutputChanged: true, + // empty cycle starting points ensures we update inputs + cycleStartingPoints: .init()) + + // Run red edge cycle node's eval if inputs changed from this update + if didChangeInputs { + queue.insert(cycleNodeData.nodeId) + } + } + + while let queuedNodeResult = topologicalData.getNextNodeToCalculate(for: queue, + visitedNodes: visitedNodes) { + let nodeId = queuedNodeResult.nodeId + queue.remove(nodeId) + + guard !visitedNodes.contains(nodeId) else { + assertInDebug(self.topologicalData.cycleContains(nodeId)) + +#if DEV_DEBUG + // print("GraphState.calculate: scheduling cycle node \(nodeId)") +#endif + + self.topologicalData.nodesForNextGraphStep.insert(nodeId) + continue + } + + visitedNodes.insert(nodeId) + + callback(queuedNodeResult, &queue) // Track changed outputs here, inputs in didInputsUpdate - portsToUpdate.insert(NodePortType.allOutputs(node.id)) + portsToUpdate.insert(NodePortType.allOutputs(nodeId)) } // while let ... diff --git a/Sources/StitchEngine/Actions/HoseFlow.swift b/Sources/StitchEngine/Actions/HoseFlow.swift index d4f2539..4df5181 100644 --- a/Sources/StitchEngine/Actions/HoseFlow.swift +++ b/Sources/StitchEngine/Actions/HoseFlow.swift @@ -68,48 +68,85 @@ extension GraphCalculatable { var changedIds = Set() for inputCoordinate in inputs { - // Get kind of downstream node - if let nodeViewModel = self.getNode(id: inputCoordinate.nodeId), - let inputObserver = nodeViewModel.getInputRowObserver(for: inputCoordinate.portType) { - let inputOldValues = inputObserver.values - - if inputObserver.isPulseNodeType && !upstreamOutputChanged { - // If this is a pulse type input and the upstream output did not change, - // do not set the flowing value into the input. - // (Truthy values are coerced to current graph time, i.e. a pulse; we can only pulse when values actually change.) - continue - // } else if upstreamOutputChanged { - } else { - - guard let existingInputValue = inputOldValues.first else { - continue - } - - // `updateInputs(incomingValues: PortValues, graphTime) -> Bool` - // if we true, then - - // Note: if the input supports directly copying, then these values will not actually be coerced - let flowValuesCoercedToThisInputType = inputObserver.coerce( - theseValues: upstreamOutputValues, - toThisType: existingInputValue, - currentGraphTime: self.currentGraphTime) - - if inputOldValues != flowValuesCoercedToThisInputType { - - inputObserver.setValuesInInput(flowValuesCoercedToThisInputType) - changedIds.insert(inputObserver.id) - - // Update downstream observers - if let mediaList = mediaList, - let mediaObservers = nodeViewModel.getMediaObservers(port: inputCoordinate) { - nodeViewModel.updateInputMedia(inputCoordinate: inputCoordinate, - mediaList: mediaList) - } - } - } + let didChangeInputs = self.updateInputs(inputCoordinate: inputCoordinate, + upstreamOutputValues: upstreamOutputValues, + mediaList: mediaList, + upstreamOutputChanged: upstreamOutputChanged, + cycleStartingPoints: self.topologicalData.cycleStartingPoints) + + if didChangeInputs { + changedIds.insert(inputCoordinate) } } // for inputCoordinate in ... return changedIds } + + @MainActor + /// Returns: true if inputs had changed. + func updateInputs(inputCoordinate: Self.Node.InputRow.RowID, + upstreamOutputValues: [Self.Node.PortData], + mediaList: [Self.Node.EvalResult.MediaType?]?, + upstreamOutputChanged: Bool, + cycleStartingPoints: Set) -> Bool { + // Get kind of downstream node + guard let nodeViewModel = self.getNode(id: inputCoordinate.nodeId), + let inputObserver = nodeViewModel.getInputRowObserver(for: inputCoordinate.portType) else { + return .init() + } + + let inputOldValues = inputObserver.values + var didChange = false + + if inputObserver.isPulseNodeType && !upstreamOutputChanged { + // If this is a pulse type input and the upstream output did not change, + // do not set the flowing value into the input. + // (Truthy values are coerced to current graph time, i.e. a pulse; we can only pulse when values actually change.) + return .init() + // } else if upstreamOutputChanged { + } else { + let isCycleRedEdge = cycleStartingPoints.contains(inputCoordinate.nodeId) + + guard let existingInputValue = inputOldValues.first else { + return .init() + } + + // `updateInputs(incomingValues: PortValues, graphTime) -> Bool` + // if we true, then + + // Note: if the input supports directly copying, then these values will not actually be coerced + let flowValuesCoercedToThisInputType = inputObserver.coerce( + theseValues: upstreamOutputValues, + toThisType: existingInputValue, + currentGraphTime: self.currentGraphTime, + isCycleRedEdge: isCycleRedEdge) + + if inputOldValues != flowValuesCoercedToThisInputType { + // Catch cycle red edge case if not yet tracked + if isCycleRedEdge { + self.topologicalData.cycleNodesForNextGraphStep.insert( + .init(portId: inputCoordinate, + values: flowValuesCoercedToThisInputType, + mediaList: mediaList) + ) + + // Skip updating this node on this cycle + return false + } + + + inputObserver.setValuesInInput(flowValuesCoercedToThisInputType) + didChange = true + + // Update downstream observers + if let mediaList = mediaList, + let mediaObservers = nodeViewModel.getMediaObservers(port: inputCoordinate) { + nodeViewModel.updateInputMedia(inputCoordinate: inputCoordinate, + mediaList: mediaList) + } + } + } + + return didChange + } } diff --git a/Sources/StitchEngine/Actions/TopologicalDataRefresh.swift b/Sources/StitchEngine/Actions/TopologicalDataRefresh.swift index a42034e..341af2e 100644 --- a/Sources/StitchEngine/Actions/TopologicalDataRefresh.swift +++ b/Sources/StitchEngine/Actions/TopologicalDataRefresh.swift @@ -17,6 +17,7 @@ extension GraphCalculatable { // All nodes excludes group nodes let allNodes = Array(self.nodes.values.filter { !$0.isGroupNode }) + let allNodeIds: Set = Set(allNodes.map(\.id)) let connections = self.createConnections() // Maps connections by node instead of by coordinate @@ -35,8 +36,9 @@ extension GraphCalculatable { self.nodesForNextGraphStep = prevIdsToCalculate // Gets a set of all node cycles - self.topologicalData.nodeCycles = TopologicalData + let nodeCycles = TopologicalData .findAllCycles(downstreamNodesMap: downstream) + self.topologicalData.nodeCycles = nodeCycles self.topologicalData.shallowDownstreamNodes = downstream @@ -47,5 +49,31 @@ extension GraphCalculatable { self.topologicalData._allMustRunNodes = self.topologicalData.nodesToAlwaysRun .union(self.topologicalData.nodesScheduledToRun) .union(self.topologicalData.keyboardNodes) + + // Traverses full graph to identify cycle nodes that need to establish a red edge + var cycleStartingPoints = Set() + let _ = self.traverseGraph(from: allNodeIds, + // can be empty as we're just trying to traverse + redEdgeCycles: .init()) { queuedNodeResult, queue in + switch queuedNodeResult { + case .cycle(let nextCycleNodeResult): + guard let allNodesInCycle = nodeCycles.first(where: { + $0.contains(nextCycleNodeResult.nodeId) + }) else { + fatalErrorIfDebug() + return + } + + // Only add red edge if this cycle not yet tracked + if cycleStartingPoints.intersection(allNodesInCycle).isEmpty { + cycleStartingPoints.insert(nextCycleNodeResult.nodeId) + } + + default: + return + } + } + + self.topologicalData.cycleStartingPoints = cycleStartingPoints } } diff --git a/Sources/StitchEngine/Data/GraphCalculatable.swift b/Sources/StitchEngine/Data/GraphCalculatable.swift index e2954e9..739fe3b 100644 --- a/Sources/StitchEngine/Data/GraphCalculatable.swift +++ b/Sources/StitchEngine/Data/GraphCalculatable.swift @@ -75,9 +75,11 @@ extension GraphCalculatable { } @MainActor public var nodesToRunOnGraphStep: Set { - get { - self.topologicalData.nodesToRunOnGraphStep - } + self.topologicalData.nodesToRunOnGraphStep + } + + @MainActor public var hasUnprocessedCycleNodes: Bool { + !self.topologicalData.cycleNodesForNextGraphStep.isEmpty } @MainActor public func setNodesForNextGraphStep(_ nodeIds: Set) { diff --git a/Sources/StitchEngine/Data/Node/NodeRowCalculatable.swift b/Sources/StitchEngine/Data/Node/NodeRowCalculatable.swift index 9f96237..7866be4 100644 --- a/Sources/StitchEngine/Data/Node/NodeRowCalculatable.swift +++ b/Sources/StitchEngine/Data/Node/NodeRowCalculatable.swift @@ -32,7 +32,8 @@ public protocol InputNodeRowCalculatable: NodeRowCalculatable { @MainActor func coerce(theseValues: [PortData], toThisType: PortData, - currentGraphTime: TimeInterval) -> [PortData] + currentGraphTime: TimeInterval, + isCycleRedEdge: Bool) -> [PortData] @MainActor func didInputsUpdate(newValues: [PortData], oldValues: [PortData]) diff --git a/Sources/StitchEngine/GraphQueueLogic.swift b/Sources/StitchEngine/GraphQueueLogic.swift index f26f9e3..bcab093 100644 --- a/Sources/StitchEngine/GraphQueueLogic.swift +++ b/Sources/StitchEngine/GraphQueueLogic.swift @@ -24,7 +24,7 @@ extension GraphTopologicalData { } @MainActor func getNextNodeToCalculate(for nodeIds: Set, - visitedNodes: Set) -> Node.ID? { + visitedNodes: Set) -> QueuedNodeResultType? { if let nextNodeId = self.memoizedQueues[nodeIds] { return nextNodeId } @@ -44,7 +44,7 @@ extension GraphTopologicalData { /// from the other nodes. Therefore we can pick any node matching that criteria. /// The only exception is if the remaining nodes contain a cycle at its root. @MainActor private func findNextNodeInGraphCalc(queuedNodeIds: Set, - visitedNodes: Set) -> Node.ID? { + visitedNodes: Set) -> QueuedNodeResultType? { if queuedNodeIds.isEmpty { // print("TopologicalData.findNextNodeInGraphCalc: none") return nil @@ -54,15 +54,17 @@ extension GraphTopologicalData { // preventing us from finding a natural root node. guard let nextNode = self.findNextNodeInDAG(queuedNodeIds: queuedNodeIds, visitedNodes: visitedNodes) else { - let cycleNode = self.findNextNodeInCycle(queuedNodeIds: queuedNodeIds, - visitedNodes: visitedNodes) + guard let cycleNode = self.findNextNodeInCycle(queuedNodeIds: queuedNodeIds, + visitedNodes: visitedNodes) else { + return nil + } // print("TopologicalData.findNextNodeInGraphCalc: cycle node \(cycleNode)") - return cycleNode + return .cycle(cycleNode) } // print("TopologicalData.findNextNodeInGraphCalc: non-cycle node \(nextNode)") - return nextNode + return .noncycle(nextNode) } @MainActor private func findNextNodeInDAG(queuedNodeIds: Set, @@ -87,7 +89,7 @@ extension GraphTopologicalData { /// 2. Tiebreakers resort to leveraging nodes in a cycle containing a direct connection to a node that was already visited /// in a graph. This ensures cycle calculation starts at the "red edge". @MainActor private func findNextNodeInCycle(queuedNodeIds: Set, - visitedNodes: Set) -> Node.ID? { + visitedNodes: Set) -> NextCycleNodeResult? { // There are cycle nodes involved if the above condition isn't hit. We'll prioritize // the cycle node with the fewest matching upstream parents to other nodes in the set. // Doing so guarantees the upstream-most cycle will be computed first. @@ -108,6 +110,16 @@ extension GraphTopologicalData { // Another sorting to guarantee consistent results .sorted() + let upstreamNodeCounts = sortedCycleNodes.map { + self.getAllUpstreamNodes(from: $0).intersection(queuedNodeIds).count + } + + guard let smallestCount = upstreamNodeCounts.first else { + return nil + } + + let hasConflict = upstreamNodeCounts.filter { $0 == smallestCount }.count > 1 + // Prioritize nodes in cycle which contain direct upstream parent to visited node // in this cycle (aka the "red edge") let cycleNode = sortedCycleNodes @@ -121,7 +133,32 @@ extension GraphTopologicalData { // print("TopologicalData.findNextNodeInGraphCalc: cycle node detected at \(cycleNode)") - return cycleNode + guard let cycleNode = cycleNode else { + return nil + } + + return .init(nodeId: cycleNode, + hasConflict: hasConflict) } - +} + +public enum QueuedNodeResultType { + case noncycle(Node.ID) + case cycle(NextCycleNodeResult) +} + +extension QueuedNodeResultType { + var nodeId: Node.ID { + switch self { + case .noncycle(let id): + return id + case .cycle(let nextCycleNodeResult): + return nextCycleNodeResult.nodeId + } + } +} + +public struct NextCycleNodeResult { + let nodeId: Node.ID + let hasConflict: Bool } diff --git a/Sources/StitchEngine/TopologicalData.swift b/Sources/StitchEngine/TopologicalData.swift index bd33370..dff1a68 100644 --- a/Sources/StitchEngine/TopologicalData.swift +++ b/Sources/StitchEngine/TopologicalData.swift @@ -7,6 +7,23 @@ import Foundation +public struct CycleNodeUpdateData { + let portId: Node.InputRow.RowID + let values: [Node.PortData] + let mediaList: [Node.EvalResult.MediaType?]? +} + +extension CycleNodeUpdateData: Hashable { + public func hash(into hasher: inout Hasher) { + hasher.combine(self.nodeId) + hasher.combine(self.portId) + } + + var nodeId: Node.ID { + self.portId.nodeId + } +} + public final class GraphTopologicalData: Sendable { public typealias InputRowId = Node.InputRow.RowID public typealias OutputRowId = Node.OutputRow.RowID @@ -23,6 +40,8 @@ public final class GraphTopologicalData: Sendable { @MainActor var nodesForNextGraphStep = Set() + @MainActor var cycleNodesForNextGraphStep = Set>() + // Maps NodeID to set of connected input coordinates @MainActor var shallowDownstreamNodes = ShallowDownstreamNodesDict() @@ -33,10 +52,12 @@ public final class GraphTopologicalData: Sendable { @MainActor var connections = Connections() + @MainActor var cycleStartingPoints = Set() + @MainActor var _allMustRunNodes = Set() // Memoizes a sorted list of nodes to eval given some set of node IDs - @MainActor var memoizedQueues = [Set: Node.ID]() + @MainActor var memoizedQueues = [Set: QueuedNodeResultType]() // Nodes that scheduled themselves via their node eval, // e.g. Classic Animation node