diff --git a/src/strands/strands_api.js b/src/strands/strands_api.js index fb88413749..825f4a1728 100644 --- a/src/strands/strands_api.js +++ b/src/strands/strands_api.js @@ -227,7 +227,9 @@ function createHookArguments(strandsContext, parameters){ if(isStructType(param.type.typeName)) { const structTypeInfo = structType(param); const { id, dimension } = build.structInstanceNode(strandsContext, structTypeInfo, param.name, []); - const structNode = createStrandsNode(id, dimension, strandsContext); + const structNode = createStrandsNode(id, dimension, strandsContext).withStructProperties( + structTypeInfo.properties.map(prop => prop.name) + ); for (let i = 0; i < structTypeInfo.properties.length; i++) { const propertyType = structTypeInfo.properties[i]; Object.defineProperty(structNode, propertyType.name, { @@ -327,12 +329,43 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) { const { cfg, dag } = strandsContext; for (const hookType of hookTypes) { - const hookImplementation = function(hookUserCallback) { - const entryBlockID = CFG.createBasicBlock(cfg, BlockType.FUNCTION); + const hook = function(hookUserCallback) { + const args = setupHook(); + hook.result = hookUserCallback(...args); + finishHook(); + } + + let entryBlockID; + function setupHook() { + entryBlockID = CFG.createBasicBlock(cfg, BlockType.FUNCTION); CFG.addEdge(cfg, cfg.currentBlock, entryBlockID); CFG.pushBlock(cfg, entryBlockID); const args = createHookArguments(strandsContext, hookType.parameters); - const userReturned = hookUserCallback(...args); + if (args.length === 1 && hookType.parameters[0].type.properties) { + for (const key of args[0].structProperties || []) { + Object.defineProperty(hook, key, { + get() { + return args[0][key]; + }, + set(val) { + args[0][key] = val; + }, + enumerable: true, + }); + } + if (hookType.returnType?.typeName === hookType.parameters[0].type.typeName) { + hook.result = args[0]; + } + } else { + for (let i = 0; i < args.length; i++) { + hook[hookType.parameters[i].name] = args[i]; + } + } + return args; + }; + + function finishHook() { + const userReturned = hook.result; const expectedReturnType = hookType.returnType; let rootNodeID = null; if(isStructType(expectedReturnType.typeName)) { @@ -385,10 +418,12 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) { shaderContext: hookInfo?.shaderContext, // 'vertex' or 'fragment' }); CFG.popBlock(cfg); - } + }; + hook.begin = setupHook; + hook.end = finishHook; strandsContext.windowOverrides[hookType.name] = window[hookType.name]; strandsContext.fnOverrides[hookType.name] = fn[hookType.name]; - window[hookType.name] = hookImplementation; - fn[hookType.name] = hookImplementation; + window[hookType.name] = hook; + fn[hookType.name] = hook; } } diff --git a/src/strands/strands_node.js b/src/strands/strands_node.js index 0901355aff..f8ff752eca 100644 --- a/src/strands/strands_node.js +++ b/src/strands/strands_node.js @@ -7,6 +7,7 @@ export class StrandsNode { this.id = id; this.strandsContext = strandsContext; this.dimension = dimension; + this.structProperties = null; // Store original identifier for varying variables const dag = this.strandsContext.dag; @@ -17,6 +18,10 @@ export class StrandsNode { this._originalDimension = nodeData.dimension; } } + withStructProperties(properties) { + this.structProperties = properties; + return this; + } copy() { return createStrandsNode(this.id, this.dimension, this.strandsContext); } @@ -30,8 +35,8 @@ export class StrandsNode { newValueID = value.id; } else { const newVal = primitiveConstructorNode( - this.strandsContext, - { baseType, dimension: this.dimension }, + this.strandsContext, + { baseType, dimension: this.dimension }, value ); newValueID = newVal.id; @@ -85,8 +90,8 @@ export class StrandsNode { newValueID = value.id; } else { const newVal = primitiveConstructorNode( - this.strandsContext, - { baseType, dimension: this.dimension }, + this.strandsContext, + { baseType, dimension: this.dimension }, value ); newValueID = newVal.id; @@ -159,4 +164,4 @@ export function createStrandsNode(id, dimension, strandsContext, onRebind) { new StrandsNode(id, dimension, strandsContext), swizzleTrap(id, dimension, strandsContext, onRebind) ); -} \ No newline at end of file +} diff --git a/test/unit/webgl/p5.Shader.js b/test/unit/webgl/p5.Shader.js index c69893019a..b9d7807c1b 100644 --- a/test/unit/webgl/p5.Shader.js +++ b/test/unit/webgl/p5.Shader.js @@ -1398,5 +1398,53 @@ suite('p5.Shader', function() { }); } }); + + test('Can use begin/end API for hooks with result', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + myp5.getColor.begin(); + myp5.getColor.result = [1.0, 0.5, 0.0, 1.0]; + myp5.getColor.end(); + }, { myp5 }); + + // Create a simple scene to filter + myp5.background(0, 0, 255); // Blue background + + // Apply the filter + myp5.filter(testShader); + + // Check that the filter was applied (should be orange) + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 255, 5); // Red channel should be 255 + assert.approximately(pixelColor[1], 127, 5); // Green channel should be ~127 + assert.approximately(pixelColor[2], 0, 5); // Blue channel should be 0 + }); + + test('Can use begin/end API for hooks modifying inputs', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseMaterialShader().modify(() => { + myp5.getPixelInputs.begin(); + debugger + myp5.getPixelInputs.color = [1.0, 0.5, 0.0, 1.0]; + myp5.getPixelInputs.end(); + }, { myp5 }); + + // Create a simple scene to filter + myp5.background(0, 0, 255); // Blue background + + // Draw a fullscreen rectangle + myp5.noStroke(); + myp5.fill('red') + myp5.shader(testShader); + myp5.plane(myp5.width, myp5.height); + + // Check that the filter was applied (should be orange) + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 255, 5); // Red channel should be 255 + assert.approximately(pixelColor[1], 127, 5); // Green channel should be ~127 + assert.approximately(pixelColor[2], 0, 5); // Blue channel should be 0 + }); }); });