From 2b9521275a1bacc394e62d45a3528bf21d4d64cf Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Wed, 17 Dec 2025 11:41:10 -0700 Subject: [PATCH 1/7] initial support for remote evals --- examples/build.gradle | 15 + .../examples/RemoteEvalExample.java | 76 ++ .../java/dev/braintrust/BraintrustUtils.java | 17 + .../braintrust/config/BraintrustConfig.java | 15 +- .../dev/braintrust/devserver/Devserver.java | 1149 +++++++++++++++++ .../dev/braintrust/devserver/EvalRequest.java | 116 ++ .../braintrust/devserver/EvalResponse.java | 62 + .../dev/braintrust/devserver/LRUCache.java | 74 ++ .../dev/braintrust/devserver/RemoteEval.java | 123 ++ .../braintrust/devserver/RequestContext.java | 24 + .../braintrust/trace/BraintrustContext.java | 4 +- src/test/java/dev/braintrust/TestUtils.java | 15 + .../dev/braintrust/devserver/CorsTest.java | 176 +++ .../braintrust/devserver/DevserverTest.java | 427 ++++++ .../devserver/EvalEndpointTest.java | 178 +++ .../devserver/ListEndpointTest.java | 166 +++ 16 files changed, 2633 insertions(+), 4 deletions(-) create mode 100644 examples/src/main/java/dev/braintrust/examples/RemoteEvalExample.java create mode 100644 src/main/java/dev/braintrust/devserver/Devserver.java create mode 100644 src/main/java/dev/braintrust/devserver/EvalRequest.java create mode 100644 src/main/java/dev/braintrust/devserver/EvalResponse.java create mode 100644 src/main/java/dev/braintrust/devserver/LRUCache.java create mode 100644 src/main/java/dev/braintrust/devserver/RemoteEval.java create mode 100644 src/main/java/dev/braintrust/devserver/RequestContext.java create mode 100644 src/test/java/dev/braintrust/TestUtils.java create mode 100644 src/test/java/dev/braintrust/devserver/CorsTest.java create mode 100644 src/test/java/dev/braintrust/devserver/DevserverTest.java create mode 100644 src/test/java/dev/braintrust/devserver/EvalEndpointTest.java create mode 100644 src/test/java/dev/braintrust/devserver/ListEndpointTest.java diff --git a/examples/build.gradle b/examples/build.gradle index 2244681..d1e747c 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -141,3 +141,18 @@ task runSpringAI(type: JavaExec) { suspend = false } } + + +task runRemoteEval(type: JavaExec) { + group = 'Braintrust SDK Examples' + description = 'Run the remote eval example' + classpath = sourceSets.main.runtimeClasspath + mainClass = 'dev.braintrust.examples.RemoteEvalExample' + systemProperty 'org.slf4j.simpleLogger.log.dev.braintrust', braintrustLogLevel + debugOptions { + enabled = true + port = 5566 + server = true + suspend = false + } +} diff --git a/examples/src/main/java/dev/braintrust/examples/RemoteEvalExample.java b/examples/src/main/java/dev/braintrust/examples/RemoteEvalExample.java new file mode 100644 index 0000000..ad83ae3 --- /dev/null +++ b/examples/src/main/java/dev/braintrust/examples/RemoteEvalExample.java @@ -0,0 +1,76 @@ +package dev.braintrust.examples; + +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import dev.braintrust.Braintrust; +import dev.braintrust.devserver.Devserver; +import dev.braintrust.devserver.RemoteEval; +import dev.braintrust.eval.Scorer; +import dev.braintrust.instrumentation.openai.BraintrustOpenAI; +import java.util.List; + +/** Simple Dev Server for Remote Evals */ +public class RemoteEvalExample { + public static void main(String[] args) throws Exception { + var braintrust = Braintrust.get(); + var openTelemetry = braintrust.openTelemetryCreate(); + var openAIClient = BraintrustOpenAI.wrapOpenAI(openTelemetry, OpenAIOkHttpClient.fromEnv()); + + RemoteEval foodTypeEval = + RemoteEval.builder() + .name("food-type-classifier") + .taskFunction( + food -> { + var request = + ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .addSystemMessage("Return a one word answer") + .addUserMessage( + "What kind of food is " + food + "?") + .maxTokens(50L) + .temperature(0.0) + .build(); + var response = + openAIClient.chat().completions().create(request); + return response.choices() + .get(0) + .message() + .content() + .orElse("") + .toLowerCase(); + }) + .scorers( + List.of( + Scorer.of("static_scorer", (expected, result) -> 0.7), + Scorer.of( + "close_enough_match", + (expected, result) -> + expected.trim() + .equalsIgnoreCase( + result.trim()) + ? 1.0 + : 0.0))) + .build(); + + Devserver devserver = + Devserver.builder() + .config(braintrust.config()) + .registerEval(foodTypeEval) + .host("localhost") // set to 0.0.0.0 to bind all interfaces + .port(8301) + .build(); + + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + System.out.println("Shutting down..."); + devserver.stop(); + System.out.flush(); + System.err.flush(); + })); + System.out.println("Starting Braintrust dev server"); + devserver.start(); + } +} diff --git a/src/main/java/dev/braintrust/BraintrustUtils.java b/src/main/java/dev/braintrust/BraintrustUtils.java index b85ff42..b56eeed 100644 --- a/src/main/java/dev/braintrust/BraintrustUtils.java +++ b/src/main/java/dev/braintrust/BraintrustUtils.java @@ -3,6 +3,9 @@ import dev.braintrust.api.BraintrustApiClient; import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import javax.annotation.Nonnull; public class BraintrustUtils { @@ -42,4 +45,18 @@ public String toParentValue() { return type + ":" + id; } } + + public static List parseCsv(String csv) { + if (csv == null || csv.isBlank()) { + return List.of(); + } + + return Arrays.stream(csv.split("\\s*,\\s*")).toList(); + } + + public static List append(List list, T value) { + List result = new ArrayList<>(list); + result.add(value); + return result; + } } diff --git a/src/main/java/dev/braintrust/config/BraintrustConfig.java b/src/main/java/dev/braintrust/config/BraintrustConfig.java index 928abe2..b7602ad 100644 --- a/src/main/java/dev/braintrust/config/BraintrustConfig.java +++ b/src/main/java/dev/braintrust/config/BraintrustConfig.java @@ -51,6 +51,12 @@ public final class BraintrustConfig extends BaseConfig { private final boolean exportSpansInMemoryForUnitTest = getConfig("BRAINTRUST_JAVA_EXPORT_SPANS_IN_MEMORY_FOR_UNIT_TEST", false); + /** CORS origins to allow when running remote eval devserver */ + private final String devserverCorsOriginWhitelistCsv = + getConfig( + "BRAINTRUST_DEVSERVER_CORS_ORIGIN_WHITELIST_CSV", + "https://www.braintrust.dev,https://www.braintrustdata.com,http://localhost:3000"); + public static BraintrustConfig fromEnvironment() { return of(); } @@ -192,8 +198,8 @@ Builder experimentalOtelLogs(boolean value) { return this; } - // hiding visibility. only used for testing - Builder exportSpansInMemoryForUnitTest(boolean value) { + // only used for testing + public Builder exportSpansInMemoryForUnitTest(boolean value) { envOverrides.put( "BRAINTRUST_JAVA_EXPORT_SPANS_IN_MEMORY_FOR_UNIT_TEST", String.valueOf(value)); return this; @@ -209,6 +215,11 @@ public Builder x509TrustManager(X509TrustManager value) { return this; } + public Builder devserverCorsOriginWhitelistCsv(String csv) { + envOverrides.put("BRAINTRUST_DEVSERVER_CORS_ORIGIN_WHITELIST_CSV", csv); + return this; + } + public BraintrustConfig build() { return new BraintrustConfig(envOverrides, sslContext, x509TrustManager); } diff --git a/src/main/java/dev/braintrust/devserver/Devserver.java b/src/main/java/dev/braintrust/devserver/Devserver.java new file mode 100644 index 0000000..680ec7c --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/Devserver.java @@ -0,0 +1,1149 @@ +package dev.braintrust.devserver; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import dev.braintrust.Braintrust; +import dev.braintrust.BraintrustUtils; +import dev.braintrust.Origin; +import dev.braintrust.api.BraintrustApiClient; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Dataset; +import dev.braintrust.eval.DatasetCase; +import dev.braintrust.eval.Score; +import dev.braintrust.trace.BraintrustTracing; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.logs.SdkLoggerProvider; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import io.opentelemetry.sdk.trace.SdkTracerProvider; +import io.opentelemetry.sdk.trace.SdkTracerProviderBuilder; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.regex.Pattern; +import javax.annotation.Nullable; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; + +/** Remote Eval Dev Server */ +@Slf4j +public class Devserver { + private static final Pattern PREVIEW_DOMAIN_PATTERN = + Pattern.compile("^https://[^/]+\\.preview\\.braintrust\\.dev$"); + + // Allowed headers for CORS + private static final String ALLOWED_HEADERS = + String.join( + ", ", + "Content-Type", + "X-Amz-Date", + "Authorization", + "X-Api-Key", + "X-Amz-Security-Token", + "X-Bt-Auth-Token", + "X-Bt-Parent", + "X-Bt-Org-Name", + "X-Bt-Project-Id", + "X-Bt-Stream-Fmt", + "X-Bt-Use-Cache", + "X-Stainless-Os", + "X-Stainless-Lang", + "X-Stainless-Package-Version", + "X-Stainless-Runtime", + "X-Stainless-Runtime-Version", + "X-Stainless-Arch"); + + private static final String EXPOSED_HEADERS = + "x-bt-cursor, x-bt-found-existing-experiment, x-bt-span-id, x-bt-span-export"; + + private static final AttributeKey PARENT = + AttributeKey.stringKey(BraintrustTracing.PARENT_KEY); + + private final List corsOriginWhitelist; + private final BraintrustConfig config; + + @Getter + @Accessors(fluent = true) + private final String host; + + @Getter + @Accessors(fluent = true) + private final int port; + + private final @Nullable String orgName; + private final Map> evals; + private @Nullable HttpServer server; + private final @Nullable Consumer + traceBuilderHook; + private final @Nullable Consumer configBuilderHook; + private static final ObjectMapper JSON_MAPPER = + new ObjectMapper() + .enable( + com.fasterxml.jackson.core.JsonParser.Feature + .INCLUDE_SOURCE_IN_LOCATION); + + // LRU cache for token -> Braintrust mappings (max 32 entries as per api.md) + private final LRUCache authCache = new LRUCache<>(32); + private final LRUCache otelCache = new LRUCache<>(32); + + private Devserver(Builder builder) { + this.config = Objects.requireNonNull(builder.config); + this.host = builder.host; + this.port = builder.port; + this.orgName = builder.orgName; + this.traceBuilderHook = builder.traceBuilderHook; + this.configBuilderHook = builder.configBuilderHook; + Map> evalMap = new HashMap<>(); + for (RemoteEval eval : builder.evals) { + if (evalMap.containsKey(eval.getName())) { + throw new IllegalArgumentException("Duplicate evaluator name: " + eval.getName()); + } + evalMap.put(eval.getName(), eval); + } + this.evals = Collections.unmodifiableMap(evalMap); + if (orgName != null) { + throw new NotSupportedYetException("org name filtering"); + } + this.corsOriginWhitelist = + List.copyOf( + BraintrustUtils.append( + BraintrustUtils.parseCsv(config.devserverCorsOriginWhitelistCsv()), + config.appUrl())); + } + + public static Builder builder() { + return new Builder(); + } + + /** Start the dev server. This method blocks until the server is stopped. */ + public synchronized void start() throws IOException { + if (server != null) { + throw new IllegalStateException("Server is already running"); + } + + server = HttpServer.create(new InetSocketAddress(host, port), 0); + server.setExecutor(Executors.newCachedThreadPool()); + + server.createContext("/", withCors(this::handleHealthCheck)); + server.createContext("/list", withCors(this::handleList)); + server.createContext("/eval", withCors(this::handleEval)); + + server.start(); + log.info("Braintrust dev server started on http://{}:{}", host, port); + log.info("Registered {} evaluator(s): {}", evals.size(), evals.keySet()); + } + + /** Stop the dev server. */ + public synchronized void stop() { + if (server != null) { + server.stop(0); + server = null; + log.info("Braintrust dev server stopped"); + } + } + + private void handleHealthCheck(HttpExchange exchange) throws IOException { + if (!"GET".equals(exchange.getRequestMethod())) { + sendResponse(exchange, 405, "text/plain", "Method Not Allowed"); + return; + } + sendResponse(exchange, 200, "text/plain", "Hello, world!"); + } + + private void handleList(HttpExchange exchange) throws IOException { + if (!"GET".equals(exchange.getRequestMethod())) { + sendResponse(exchange, 405, "text/plain", "Method Not Allowed"); + return; + } + + // Check API key is present + RequestContext context = createRequestContext(exchange); + String apiKey = extractApiKey(exchange, context); + if (apiKey == null) { + sendErrorResponse(exchange, 401, "Missing authentication token"); + return; + } + + try { + // Build the response: Map + Map> response = new LinkedHashMap<>(); + + for (Map.Entry> entry : evals.entrySet()) { + String evalName = entry.getKey(); + RemoteEval eval = entry.getValue(); + + Map metadata = new LinkedHashMap<>(); + + Map> parametersMap = new LinkedHashMap<>(); + for (Map.Entry paramEntry : + eval.getParameters().entrySet()) { + String paramName = paramEntry.getKey(); + RemoteEval.Parameter param = paramEntry.getValue(); + + Map paramMetadata = new LinkedHashMap<>(); + paramMetadata.put("type", param.getType().getValue()); + + if (param.getDescription() != null) { + paramMetadata.put("description", param.getDescription()); + } + + if (param.getDefaultValue() != null) { + paramMetadata.put("default", param.getDefaultValue()); + } + + // Only include schema for data type parameters + if (param.getType() == RemoteEval.ParameterType.DATA + && param.getSchema() != null) { + paramMetadata.put("schema", param.getSchema()); + } + + parametersMap.put(paramName, paramMetadata); + } + metadata.put("parameters", parametersMap); + + // Add scores (list of scorer names) + List> scores = new ArrayList<>(); + for (var scorer : eval.getScorers()) { + Map scoreInfo = new LinkedHashMap<>(); + scoreInfo.put("name", scorer.getName()); + scores.add(scoreInfo); + } + metadata.put("scores", scores); + + response.put(evalName, metadata); + } + + String jsonResponse = JSON_MAPPER.writeValueAsString(response); + sendResponse(exchange, 200, "application/json", jsonResponse); + } catch (Exception e) { + log.error("Error generating /list response", e); + sendResponse(exchange, 500, "text/plain", "Internal Server Error"); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private void handleEval(HttpExchange exchange) throws IOException { + if (!"POST".equals(exchange.getRequestMethod())) { + sendResponse(exchange, 405, "text/plain", "Method Not Allowed"); + return; + } + + // Check authorization and get Braintrust state + RequestContext context = createRequestContext(exchange); + context = getBraintrust(exchange, context); + if (context == null) { + sendErrorResponse(exchange, 401, "Missing required authentication headers"); + return; + } + + try { + InputStream requestBody = exchange.getRequestBody(); + var requestBodyString = new String(requestBody.readAllBytes(), StandardCharsets.UTF_8); + EvalRequest request = JSON_MAPPER.readValue(requestBodyString, EvalRequest.class); + + // Validate evaluator exists + RemoteEval eval = evals.get(request.getName()); + if (eval == null) { + sendResponse( + exchange, 404, "text/plain", "Evaluator not found: " + request.getName()); + return; + } + + // Validate dataset specification + if (request.getData() == null) { + sendResponse(exchange, 400, "text/plain", "Missing 'data' field in request body"); + return; + } + + EvalRequest.DataSpec dataSpec = request.getData(); + boolean hasInlineData = dataSpec.getData() != null && !dataSpec.getData().isEmpty(); + boolean hasByName = + dataSpec.getProjectName() != null && dataSpec.getDatasetName() != null; + boolean hasById = dataSpec.getDatasetId() != null; + + // Ensure exactly one dataset specification method is provided + int specCount = (hasInlineData ? 1 : 0) + (hasByName ? 1 : 0) + (hasById ? 1 : 0); + if (specCount == 0) { + sendResponse( + exchange, + 400, + "text/plain", + "Dataset must be specified using one of: inline data (data.data), by name" + + " (data.project_name + data.dataset_name), or by ID" + + " (data.dataset_id)"); + return; + } + if (specCount > 1) { + sendResponse( + exchange, + 400, + "text/plain", + "Only one dataset specification method should be provided"); + return; + } + + // TODO: support remote scorers + + String datasetDescription = + hasInlineData + ? dataSpec.getData().size() + " inline cases" + : (hasByName + ? "dataset '" + + dataSpec.getProjectName() + + "/" + + dataSpec.getDatasetName() + + "'" + : "dataset ID '" + dataSpec.getDatasetId() + "'"); + log.debug("Executing evaluator '{}' with {}", request.getName(), datasetDescription); + + // Check if streaming is requested + boolean isStreaming = request.getStream() != null && request.getStream(); + + if (isStreaming) { + // SSE streaming response - errors handled inside + log.debug("Starting streaming evaluation for '{}'", request.getName()); + handleStreamingEval(exchange, eval, request, context); + } else { + throw new NotSupportedYetException("non-streaming responses"); + } + } catch (NotSupportedYetException e) { + sendResponse( + exchange, 400, "text/plain", "TODO: feature not supported: " + e.description); + } catch (Exception e) { + log.error("Error executing eval", e); + // Only send error response if we haven't started streaming + // (streaming errors are handled within handleStreamingEval) + sendResponse(exchange, 500, "text/plain", "Internal Server Error: " + e.getMessage()); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private void handleStreamingEval( + HttpExchange exchange, RemoteEval eval, EvalRequest request, RequestContext context) + throws Exception { + // TODO: refactor some of these steps into utility methods (e.g. dataset extraction, span + // attribute setting) + + // Set SSE headers + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.getResponseHeaders().set("Cache-Control", "no-cache"); + exchange.getResponseHeaders().set("Connection", "keep-alive"); + exchange.sendResponseHeaders(200, 0); // 0 = chunked encoding + + try (OutputStream os = exchange.getResponseBody()) { + try { + // Get Braintrust instance from authenticated context + Braintrust braintrust = context.getBraintrust(); + BraintrustApiClient apiClient = braintrust.apiClient(); + + // Determine project name and ID from the authenticated Braintrust instance + var orgAndProject = apiClient.getOrCreateProjectAndOrgInfo(braintrust.config()); + String projectName = orgAndProject.project().name(); + String projectId = orgAndProject.project().id(); + + // Generate experiment name (same logic as non-streaming) + String experimentName = + request.getExperimentName() != null + ? request.getExperimentName() + : eval.getName(); + + String parentSpec = null; + String generation = null; + + // Extract parent spec and generation from request + if (request.getParent() != null && request.getParent() instanceof Map) { + @SuppressWarnings("unchecked") + Map parentMap = (Map) request.getParent(); + String objectType = (String) parentMap.get("object_type"); + String objectId = (String) parentMap.get("object_id"); + + // Extract generation from propagated_event.span_attributes.generation + Object propEventObj = parentMap.get("propagated_event"); + if (propEventObj instanceof Map) { + @SuppressWarnings("unchecked") + Map propEvent = (Map) propEventObj; + Object spanAttrsObj = propEvent.get("span_attributes"); + if (spanAttrsObj instanceof Map) { + @SuppressWarnings("unchecked") + Map spanAttrs = (Map) spanAttrsObj; + generation = (String) spanAttrs.get("generation"); + } + } + + if (objectType != null && objectId != null) { + parentSpec = "playground_id:" + objectId; + } + } + + // Build URLs + String experimentUrl = + BraintrustUtils.createProjectURI( + braintrust.config().appUrl(), orgAndProject) + .toASCIIString() + + "/experiments/" + + experimentName; + String projectUrl = + BraintrustUtils.createProjectURI( + braintrust.config().appUrl(), orgAndProject) + .toASCIIString(); + + // Send start event + // TODO: the browser doesn't understand this event. Should probably just remove it + sendStartEvent( + os, + projectName, + projectId, + experimentName, + null, // experimentId - not created yet + projectUrl, + experimentUrl); + + // Load dataset using one of three methods (same logic as executeEval) + Dataset dataset; + EvalRequest.DataSpec dataSpec = request.getData(); + + if (dataSpec.getData() != null && !dataSpec.getData().isEmpty()) { + // Method 1: Inline data + List cases = new ArrayList<>(); + for (EvalRequest.EvalCaseData caseData : dataSpec.getData()) { + DatasetCase datasetCase = + DatasetCase.of( + caseData.getInput(), + caseData.getExpected(), + caseData.getTags() != null ? caseData.getTags() : List.of(), + caseData.getMetadata() != null + ? caseData.getMetadata() + : Map.of()); + cases.add(datasetCase); + } + dataset = Dataset.of(cases.toArray(new DatasetCase[0])); + } else if (dataSpec.getProjectName() != null && dataSpec.getDatasetName() != null) { + // Method 2: Fetch by project name and dataset name + log.debug( + "Fetching dataset from Braintrust: project={}, dataset={}", + dataSpec.getProjectName(), + dataSpec.getDatasetName()); + dataset = + Dataset.fetchFromBraintrust( + apiClient, + dataSpec.getProjectName(), + dataSpec.getDatasetName(), + null); + } else if (dataSpec.getDatasetId() != null) { + // Method 3: Fetch by dataset ID + log.debug( + "Fetching dataset from Braintrust by ID: {}", dataSpec.getDatasetId()); + var datasetMetadata = apiClient.getDataset(dataSpec.getDatasetId()); + if (datasetMetadata.isEmpty()) { + throw new IllegalArgumentException( + "Dataset not found: " + dataSpec.getDatasetId()); + } + + var project = apiClient.getProject(datasetMetadata.get().projectId()); + if (project.isEmpty()) { + throw new IllegalArgumentException( + "Project not found: " + datasetMetadata.get().projectId()); + } + + String fetchedProjectName = project.get().name(); + String fetchedDatasetName = datasetMetadata.get().name(); + log.debug( + "Resolved dataset ID to project={}, dataset={}", + fetchedProjectName, + fetchedDatasetName); + + dataset = + Dataset.fetchFromBraintrust( + apiClient, fetchedProjectName, fetchedDatasetName, null); + } else { + throw new IllegalStateException("No dataset specification provided"); + } + + // TODO: flush otel upon cache eviction + var otel = + otelCache.getOrCompute(braintrust, () -> createOpenTelemetry(braintrust)); + var tracer = BraintrustTracing.getTracer(otel); + + // Execute task and scorers for each case + Map> scoresByName = new LinkedHashMap<>(); + int[] caseCount = {0}; // Use array for mutability in lambda + final String finalParentSpec = parentSpec; // Make effectively final for lambda + final String finalGeneration = generation; // Make effectively final for lambda + if (finalParentSpec == null) { + throw new RuntimeException("parent required"); + } + + dataset.forEach( + datasetCase -> { + caseCount[0]++; + log.debug("Processing dataset case #{}", caseCount[0]); + + // Build span attributes with exec_counter and generation (eval span) + Map evalSpanAttrs = new LinkedHashMap<>(); + evalSpanAttrs.put("type", "eval"); + evalSpanAttrs.put("name", "eval"); + if (finalGeneration != null) { + evalSpanAttrs.put("generation", finalGeneration); + } + + // Create eval span for this dataset case (matches Eval.java pattern) + // TODO: take another pass through python playground and make sure we're + // setting the same attributes + var evalSpan = + tracer.spanBuilder("eval") + .setNoParent() // each eval case is its own trace + .setSpanKind(SpanKind.CLIENT) + .setAttribute(PARENT, finalParentSpec) + .setAttribute( + "braintrust.span_attributes", + json(evalSpanAttrs)) + .setAttribute( + "braintrust.input_json", + json(Map.of("input", datasetCase.input()))) + .setAttribute( + "braintrust.expected_json", + json(datasetCase.expected())) + .startSpan(); + + // Set parent in baggage for distributed tracing + // Parse parent format "type:id" (e.g., "playground_id:abc123") + io.opentelemetry.context.Context evalContext = + io.opentelemetry.context.Context.current().with(evalSpan); + String[] parentParts = finalParentSpec.split(":", 2); + if (parentParts.length == 2) { + evalContext = + dev.braintrust.trace.BraintrustContext.setParentInBaggage( + evalContext, parentParts[0], parentParts[1]); + } + + if (datasetCase.origin().isPresent()) { + evalSpan.setAttribute( + "braintrust.origin", json(datasetCase.origin().get())); + } + if (!datasetCase.tags().isEmpty()) { + evalSpan.setAttribute( + AttributeKey.stringArrayKey("braintrust.tags"), + datasetCase.tags()); + } + if (!datasetCase.metadata().isEmpty()) { + evalSpan.setAttribute( + "braintrust.metadata", json(datasetCase.metadata())); + } + + // Make the eval context (with span and baggage) current + try (var rootScope = evalContext.makeCurrent()) { + final dev.braintrust.eval.TaskResult taskResult; + { // run task + // Build task span attributes with exec_counter and generation + Map taskSpanAttrs = new LinkedHashMap<>(); + taskSpanAttrs.put("type", "task"); + taskSpanAttrs.put("name", "task"); + if (finalGeneration != null) { + taskSpanAttrs.put("generation", finalGeneration); + } + + var taskSpan = + tracer.spanBuilder("task") + .setAttribute(PARENT, finalParentSpec) + .setAttribute( + "braintrust.span_attributes", + json(taskSpanAttrs)) + .startSpan(); + taskSpan.setAttribute( + "braintrust.input_json", + json(Map.of("input", datasetCase.input()))); + try (var unused = + Context.current().with(taskSpan).makeCurrent()) { + var task = eval.getTask(); + taskResult = task.apply(datasetCase); + // Send progress event for task completion + sendProgressEvent( + os, + evalSpan.getSpanContext().getSpanId(), + datasetCase.origin(), + eval.getName(), + taskResult.result()); + } finally { + taskSpan.end(); + } + taskSpan.setAttribute( + "braintrust.output_json", + json(Map.of("output", taskResult.result()))); + evalSpan.setAttribute( + "braintrust.output_json", + json(Map.of("output", taskResult.result()))); + } + { // run scorers - one score span per scorer + var scorers = eval.getScorers(); + log.debug("Running {} scorers", scorers.size()); + + for (Object scorerObj : scorers) { + dev.braintrust.eval.Scorer scorer = + (dev.braintrust.eval.Scorer) scorerObj; + + // Build score span attributes with scorer name and + // generation + Map scoreSpanAttrs = new LinkedHashMap<>(); + scoreSpanAttrs.put("type", "score"); + scoreSpanAttrs.put("name", scorer.getName()); + if (finalGeneration != null) { + scoreSpanAttrs.put("generation", finalGeneration); + } + + var scoreSpan = + tracer.spanBuilder("score") + .setAttribute(PARENT, finalParentSpec) + .setAttribute( + "braintrust.span_attributes", + json(scoreSpanAttrs)) + .startSpan(); + try (var unused = + Context.current().with(scoreSpan).makeCurrent()) { + List scores = scorer.score(taskResult); + log.debug( + "Scorer '{}' produced {} scores", + scorer.getName(), + scores.size()); + + Map scorerScores = + new LinkedHashMap<>(); + for (Score score : scores) { + scoresByName + .computeIfAbsent( + score.name(), + k -> new ArrayList<>()) + .add(score.value()); + scorerScores.put(score.name(), score.value()); + } + scoreSpan.setAttribute( + "braintrust.output_json", json(scorerScores)); + } finally { + scoreSpan.end(); + } + } + } + } catch (IOException e) { + throw new RuntimeException("Failed to send progress event", e); + } finally { + evalSpan.end(); + } + }); + + // Aggregate scores + Map scoreSummaries = new LinkedHashMap<>(); + for (Map.Entry> entry : scoresByName.entrySet()) { + String scoreName = entry.getKey(); + List values = entry.getValue(); + + double avgScore = + values.stream().mapToDouble(Double::doubleValue).average().orElse(0.0); + + scoreSummaries.put( + scoreName, + EvalResponse.ScoreSummary.builder() + .name(scoreName) + .score(avgScore) + .improvements(0) + .regressions(0) + .build()); + } + + // Send summary event + sendSummaryEvent( + os, + projectName, + projectId, + experimentName, + projectUrl, + experimentUrl, + scoreSummaries); + + // Send done event + sendDoneEvent(os); + + } catch (Exception e) { + // Send error event via SSE + log.error("Error during streaming evaluation", e); + try { + sendSSEEvent( + os, "error", e.getMessage() != null ? e.getMessage() : "Unknown error"); + } catch (IOException ioException) { + log.error("Failed to send error event", ioException); + } + } finally { + try { + os.flush(); + os.close(); + } catch (IOException e) { + log.error("Failed to close output stream", e); + } + } + } + } + + private void sendSSEEvent(OutputStream os, String eventType, String data) throws IOException { + String event = "event: " + eventType + "\n" + "data: " + data + "\n\n"; + os.write(event.getBytes(StandardCharsets.UTF_8)); + } + + private void sendProgressEvent( + OutputStream os, + String spanId, + Optional origin, + String evalName, + Object taskResult) + throws IOException { + Map progressData = new LinkedHashMap<>(); + progressData.put("id", spanId); + progressData.put("object_type", "task"); + + if (origin.isPresent()) { + progressData.put("origin", origin.get()); + } + progressData.put("name", evalName); + progressData.put("format", "code"); + progressData.put("output_type", "completion"); + progressData.put("event", "json_delta"); + progressData.put("data", JSON_MAPPER.writeValueAsString(taskResult)); + + String progressJson = JSON_MAPPER.writeValueAsString(progressData); + sendSSEEvent(os, "progress", progressJson); + } + + private void sendSummaryEvent( + OutputStream os, + String projectName, + String projectId, + String experimentName, + String projectUrl, + String experimentUrl, + Map scoreSummaries) + throws IOException { + Map summary = new LinkedHashMap<>(); + summary.put("projectName", projectName); + summary.put("projectId", projectId); + summary.put("experimentId", null); + summary.put("experimentName", experimentName); + summary.put("projectUrl", projectUrl); + summary.put("experimentUrl", null); + summary.put("comparisonExperimentName", null); + + // Add scores with additional Python-specific fields + Map scoresWithMeta = new LinkedHashMap<>(); + for (Map.Entry entry : scoreSummaries.entrySet()) { + Map scoreData = new LinkedHashMap<>(); + scoreData.put("name", entry.getValue().getName()); + scoreData.put("_longest_score_name", entry.getKey().length()); + scoreData.put("score", entry.getValue().getScore()); + scoreData.put("improvements", entry.getValue().getImprovements()); + scoreData.put("regressions", entry.getValue().getRegressions()); + scoreData.put("diff", null); + scoresWithMeta.put(entry.getKey(), scoreData); + } + summary.put("scores", scoresWithMeta); + summary.put("metrics", Map.of()); + + sendSSEEvent(os, "summary", JSON_MAPPER.writeValueAsString(summary)); + } + + private void sendDoneEvent(OutputStream os) throws IOException { + sendSSEEvent(os, "done", ""); + } + + private void sendStartEvent( + OutputStream os, + String projectName, + String projectId, + String experimentName, + String experimentId, + String projectUrl, + String experimentUrl) + throws IOException { + Map startData = new LinkedHashMap<>(); + startData.put("experimentName", experimentName); + startData.put("projectName", projectName); + startData.put("projectId", projectId); + startData.put("experimentId", experimentId); + startData.put("experimentUrl", experimentUrl); + startData.put("projectUrl", projectUrl); + startData.put("comparisonExperimentName", null); + startData.put("scores", Map.of()); + + sendSSEEvent(os, "start", JSON_MAPPER.writeValueAsString(startData)); + } + + private String json(Object o) { + try { + return JSON_MAPPER.writeValueAsString(o); + } catch (Exception e) { + throw new RuntimeException("Failed to serialize to JSON", e); + } + } + + private void sendResponse( + HttpExchange exchange, int statusCode, String contentType, String body) + throws IOException { + byte[] responseBytes = body.getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", contentType); + exchange.sendResponseHeaders(statusCode, responseBytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(responseBytes); + } + } + + /** + * Check if the origin is whitelisted for CORS. + * + * @param origin The Origin header value + * @return true if the origin is allowed, false otherwise + */ + private boolean isOriginAllowed(@Nullable String origin) { + if (origin == null || origin.isEmpty()) { + return true; // Allow requests without origin (e.g., same-origin) + } + // Check against whitelisted origins + for (String allowedOrigin : corsOriginWhitelist) { + if (allowedOrigin != null && allowedOrigin.equals(origin)) { + return true; + } + } + // Check against preview domain pattern + return PREVIEW_DOMAIN_PATTERN.matcher(origin).matches(); + } + + /** + * Apply CORS headers to the response. + * + * @param exchange The HTTP exchange + */ + private void applyCorsHeaders(HttpExchange exchange) { + String origin = exchange.getRequestHeaders().getFirst("Origin"); + + if (isOriginAllowed(origin)) { + var headers = exchange.getResponseHeaders(); + if (origin != null && !origin.isEmpty()) { + headers.set("Access-Control-Allow-Origin", origin); + } + headers.set("Access-Control-Allow-Credentials", "true"); + headers.set("Access-Control-Expose-Headers", EXPOSED_HEADERS); + } + } + + /** + * Handle CORS preflight requests. + * + * @param exchange The HTTP exchange + */ + private void handlePreflightRequest(HttpExchange exchange) throws IOException { + String origin = exchange.getRequestHeaders().getFirst("Origin"); + + if (!isOriginAllowed(origin)) { + exchange.sendResponseHeaders(403, -1); + return; + } + + var headers = exchange.getResponseHeaders(); + if (origin != null && !origin.isEmpty()) { + headers.set("Access-Control-Allow-Origin", origin); + } + headers.set("Access-Control-Allow-Methods", "GET, PATCH, POST, PUT, DELETE, OPTIONS"); + headers.set("Access-Control-Allow-Headers", ALLOWED_HEADERS); + headers.set("Access-Control-Allow-Credentials", "true"); + headers.set("Access-Control-Max-Age", "86400"); + + // Support for Chrome's Private Network Access + String requestPrivateNetwork = + exchange.getRequestHeaders().getFirst("Access-Control-Request-Private-Network"); + if ("true".equals(requestPrivateNetwork)) { + headers.set("Access-Control-Allow-Private-Network", "true"); + } + + exchange.sendResponseHeaders(204, -1); + } + + /** + * Wrap a handler with CORS support. + * + * @param handler The handler to wrap + * @return A handler that applies CORS headers + */ + private HttpHandler withCors(HttpHandler handler) { + return exchange -> { + // Handle OPTIONS preflight requests + if ("OPTIONS".equals(exchange.getRequestMethod())) { + handlePreflightRequest(exchange); + return; + } + + // Apply CORS headers to all responses + applyCorsHeaders(exchange); + + // Delegate to the actual handler + handler.handle(exchange); + }; + } + + /** + * Extract API key from request headers. + * + *

Checks headers in order of precedence: + * + *

    + *
  1. x-bt-auth-token (preferred) + *
  2. Authorization: Bearer <token> + *
  3. Authorization: <token> + *
+ * + * @param exchange The HTTP exchange + * @param context The request context (unused but for consistency) + * @return The API key, or null if not present + */ + @Nullable + private String extractApiKey(HttpExchange exchange, RequestContext context) { + var headers = exchange.getRequestHeaders(); + + // 1. Check x-bt-auth-token header (preferred) + String token = headers.getFirst("x-bt-auth-token"); + if (token != null && !token.isEmpty()) { + return token; + } + + // 2. Check Authorization header + String authHeader = headers.getFirst("Authorization"); + if (authHeader != null && !authHeader.isEmpty()) { + // Try Bearer format + if (authHeader.startsWith("Bearer ")) { + return authHeader.substring(7).trim(); + } + // Try direct token + return authHeader.trim(); + } + + return null; + } + + /** + * Create a request context with origin. + * + * @param exchange The HTTP exchange + * @return RequestContext with appOrigin + */ + private RequestContext createRequestContext(HttpExchange exchange) { + String origin = exchange.getRequestHeaders().getFirst("Origin"); + if (origin == null) { + origin = ""; + } + + return RequestContext.builder().appOrigin(origin).build(); + } + + /** + * Get Braintrust state for authenticated requests. + * + *

Validates that required headers are present and returns a RequestContext with populated + * Braintrust from cache. + * + *

Required headers: + * + *

    + *
  • API key (x-bt-auth-token or Authorization) + *
  • x-bt-org-name + *
  • x-bt-project-id + *
+ * + *

Cache key format: orgName:projectId:apiKey + * + * @param exchange The HTTP exchange + * @param context The request context + * @return RequestContext with populated state, or null if required headers are missing + */ + @Nullable + private RequestContext getBraintrust(HttpExchange exchange, RequestContext context) { + // Extract API key + String apiKey = extractApiKey(exchange, context); + if (apiKey == null || apiKey.isEmpty()) { + return null; + } + + // Get x-bt-org-name header + String orgName = exchange.getRequestHeaders().getFirst("x-bt-org-name"); + if (orgName == null || orgName.isEmpty()) { + return null; + } + + // Get x-bt-project-id header + String projectId = exchange.getRequestHeaders().getFirst("x-bt-project-id"); + if (projectId == null || projectId.isEmpty()) { + return null; + } + + // Create composite cache key: orgName:projectId:apiKey + String cacheKey = orgName + ":" + projectId + ":" + apiKey; + + // Get from cache or compute if not present + Braintrust braintrust = + authCache.getOrCompute( + cacheKey, + () -> { + // Cache miss - would validate token with Braintrust API here + // TODO: Implement actual token validation with + // loginToState(token, orgName) + log.debug( + "Cached login state for org='{}', projectId='{}' (cache" + + " size={})", + orgName, + projectId, + authCache.size()); + + // Build config with hook if present + var configBuilder = + BraintrustConfig.builder() + .apiKey(apiKey) + .defaultProjectId(projectId) + .apiUrl(config.apiUrl()) + .appUrl(config.appUrl()); + + // Invoke hook if present to allow customization (e.g., enabling + // in-memory span export) + if (configBuilderHook != null) { + configBuilderHook.accept(configBuilder); + } + + return Braintrust.of(configBuilder.build()); + }); + + log.debug( + "Retrieved login state for org='{}', projectId='{}' (cache size={})", + orgName, + projectId, + authCache.size()); + + // Return context with state populated + return RequestContext.builder() + .appOrigin(context.getAppOrigin()) + .token(apiKey) + .braintrust(braintrust) + .build(); + } + + private OpenTelemetry createOpenTelemetry(Braintrust braintrust) { + var tracerBuilder = SdkTracerProvider.builder(); + var loggerBuilder = SdkLoggerProvider.builder(); + var meterBuilder = SdkMeterProvider.builder(); + var contextPropagator = + ContextPropagators.create( + TextMapPropagator.composite( + W3CTraceContextPropagator.getInstance(), + W3CBaggagePropagator.getInstance())); + braintrust.openTelemetryEnable(tracerBuilder, loggerBuilder, meterBuilder); + + // Invoke hook if present to allow customization (e.g., adding span processors) + if (traceBuilderHook != null) { + traceBuilderHook.accept(tracerBuilder); + } + + return OpenTelemetrySdk.builder() + .setTracerProvider(tracerBuilder.build()) + .setLoggerProvider(loggerBuilder.build()) + .setMeterProvider(meterBuilder.build()) + .setPropagators(contextPropagator) + .build(); + } + + /** + * Send an error response with JSON body. + * + * @param exchange The HTTP exchange + * @param statusCode The HTTP status code + * @param message The error message + * @throws IOException if response sending fails + */ + private void sendErrorResponse(HttpExchange exchange, int statusCode, String message) + throws IOException { + Map error = Map.of("error", message); + String json = JSON_MAPPER.writeValueAsString(error); + sendResponse(exchange, statusCode, "application/json", json); + } + + public static class Builder { + private @Nullable BraintrustConfig config = null; + private String host = "localhost"; + private int port = 8300; + private @Nullable String orgName = null; + private List> evals = new ArrayList<>(); + private @Nullable Consumer + traceBuilderHook = null; + private @Nullable Consumer configBuilderHook = null; + + public Devserver build() { + if (evals.isEmpty()) { + throw new IllegalStateException("At least one evaluator must be registered"); + } + if (config == null) { + throw new IllegalStateException("config is required"); + } + return new Devserver(this); + } + + public Builder config(BraintrustConfig config) { + this.config = config; + return this; + } + + public Builder registerEval(RemoteEval eval) { + this.evals.add(eval); + return this; + } + + public Builder host(String host) { + this.host = host; + return this; + } + + public Builder port(int port) { + this.port = port; + return this; + } + + /** hook to run for each open telemetry instance created by the devserver */ + public Builder traceBuilderHook(Consumer traceBuilderHook) { + this.traceBuilderHook = traceBuilderHook; + return this; + } + + /** + * hook to run for each braintrust instance's config created by the devserver. The hook + * receives the BraintrustConfig.Builder before it's built, allowing customization such as + * enabling in-memory span export for testing. + */ + public Builder braintrustConfigBuilderHook( + Consumer configBuilderHook) { + this.configBuilderHook = configBuilderHook; + return this; + } + } + + private static class NotSupportedYetException extends RuntimeException { + private final String description; + + public NotSupportedYetException(String description) { + super("feature not supported yet: " + description); + this.description = description; + } + } +} diff --git a/src/main/java/dev/braintrust/devserver/EvalRequest.java b/src/main/java/dev/braintrust/devserver/EvalRequest.java new file mode 100644 index 0000000..e3a7103 --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/EvalRequest.java @@ -0,0 +1,116 @@ +package dev.braintrust.devserver; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import lombok.Data; + +/** Request body for POST /eval endpoint */ +@Data +public class EvalRequest { + /** Name of the evaluator to run */ + private String name; + + /** Optional parameter overrides */ + @Nullable private Map parameters; + + /** Dataset specification */ + private DataSpec data; + + /** Optional experiment name override */ + @JsonProperty("experiment_name") + @Nullable + private String experimentName; + + /** Optional project ID override */ + @JsonProperty("project_id") + @Nullable + private String projectId; + + /** Optional additional remote scorers */ + @Nullable private List scores; + + /** Optional parent span for tracing (can be string or object) */ + @Nullable private Object parent; + + /** Enable SSE streaming (default: false) */ + @Nullable private Boolean stream; + + /** Dataset specification - supports inline data, by name, or by ID */ + @Data + public static class DataSpec { + /** Inline data array */ + @Nullable private List data; + + /** Project name (for loading by name) */ + @JsonProperty("project_name") + @Nullable + private String projectName; + + /** Dataset name (for loading by name) */ + @JsonProperty("dataset_name") + @Nullable + private String datasetName; + + /** Dataset ID (for loading by ID) */ + @JsonProperty("dataset_id") + @Nullable + private String datasetId; + + /** Optional BTQL filter (can be string or structured query object) */ + @JsonProperty("_internal_btql") + @Nullable + private Object btql; + } + + /** Individual evaluation case data */ + @Data + public static class EvalCaseData { + /** Input for the task */ + private Object input; + + /** Expected output (optional) */ + @Nullable private Object expected; + + /** Metadata (optional) */ + @Nullable private Map metadata; + + /** Tags (optional) */ + @Nullable private List tags; + } + + /** Remote scorer specification */ + @Data + public static class RemoteScorer { + /** Scorer name */ + private String name; + + /** Function ID specification */ + @JsonProperty("function_id") + private FunctionId functionId; + } + + /** Function ID specification (multiple formats supported) */ + @Data + public static class FunctionId { + @JsonProperty("function_id") + @Nullable + private String functionId; + + @Nullable private String version; + @Nullable private String name; + + @JsonProperty("prompt_session_id") + @Nullable + private String promptSessionId; + + @JsonProperty("inline_code") + @Nullable + private String inlineCode; + + @JsonProperty("global_function") + @Nullable + private String globalFunction; + } +} diff --git a/src/main/java/dev/braintrust/devserver/EvalResponse.java b/src/main/java/dev/braintrust/devserver/EvalResponse.java new file mode 100644 index 0000000..b13183c --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/EvalResponse.java @@ -0,0 +1,62 @@ +package dev.braintrust.devserver; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; +import javax.annotation.Nullable; +import lombok.Builder; +import lombok.Data; + +/** Response body for POST /eval endpoint */ +@Data +@Builder +public class EvalResponse { + /** Experiment name */ + @JsonProperty("experimentName") + private String experimentName; + + /** Project name */ + @JsonProperty("projectName") + private String projectName; + + /** Project ID */ + @JsonProperty("projectId") + private String projectId; + + /** Experiment ID */ + @JsonProperty("experimentId") + private String experimentId; + + /** Experiment URL */ + @JsonProperty("experimentUrl") + private String experimentUrl; + + /** Project URL */ + @JsonProperty("projectUrl") + private String projectUrl; + + /** Comparison experiment name (optional) */ + @JsonProperty("comparisonExperimentName") + @Nullable + private String comparisonExperimentName; + + /** Score summaries by scorer name */ + @JsonProperty("scores") + private Map scores; + + /** Summary statistics for a scorer */ + @Data + @Builder + public static class ScoreSummary { + /** Scorer name */ + private String name; + + /** Average score across all cases */ + private double score; + + /** Number of improvements vs baseline */ + private int improvements; + + /** Number of regressions vs baseline */ + private int regressions; + } +} diff --git a/src/main/java/dev/braintrust/devserver/LRUCache.java b/src/main/java/dev/braintrust/devserver/LRUCache.java new file mode 100644 index 0000000..9a5a6d2 --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/LRUCache.java @@ -0,0 +1,74 @@ +package dev.braintrust.devserver; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Simple LRU (Least Recently Used) cache implementation. + * + *

Thread-safe cache with a maximum size. When the cache exceeds its capacity, the least recently + * used entry is evicted. + * + * @param Key type + * @param Value type + */ +@ThreadSafe +public class LRUCache { + private final int maxSize; + private final Map cache; + + public LRUCache(int maxSize) { + this.maxSize = maxSize; + this.cache = + new LinkedHashMap(maxSize, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > LRUCache.this.maxSize; + } + }; + } + + public synchronized void put(K key, V value) { + cache.put(key, value); + } + + @Nullable + public synchronized V get(K key) { + return cache.get(key); + } + + /** + * Get a value from the cache, or compute and cache it if not present. + * + *

This operation is atomic - the supplier function is only called once per key even under + * concurrent access. + * + * @param key The cache key + * @param supplier Function to compute the value if not in cache (takes no args, returns value) + * @return The cached or newly computed value + */ + public synchronized V getOrCompute(K key, Supplier supplier) { + V value = cache.get(key); + if (value == null) { + // Cache miss - compute the value + value = supplier.get(); + cache.put(key, value); + } + return value; + } + + public synchronized boolean containsKey(K key) { + return cache.containsKey(key); + } + + public synchronized void clear() { + cache.clear(); + } + + public synchronized int size() { + return cache.size(); + } +} diff --git a/src/main/java/dev/braintrust/devserver/RemoteEval.java b/src/main/java/dev/braintrust/devserver/RemoteEval.java new file mode 100644 index 0000000..dc8eef1 --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/RemoteEval.java @@ -0,0 +1,123 @@ +package dev.braintrust.devserver; + +import dev.braintrust.eval.Scorer; +import dev.braintrust.eval.Task; +import java.util.*; +import java.util.function.Function; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.Builder; +import lombok.Getter; +import lombok.Singular; + +/** + * Represents a remote evaluator that can be exposed via the dev server. + * + * @param The type of input data for the evaluation + * @param The type of output produced by the task + */ +@Getter +@Builder(builderClassName = "Builder", buildMethodName = "internalBuild") +public class RemoteEval { + /** The name of this evaluator (used as identifier) */ + @Nonnull private final String name; + + /** The task function that performs the evaluation */ + @Nonnull private final Task task; + + /** List of scorers for this evaluator */ + @Singular @Nonnull private final List> scorers; + + /** Optional parameters that can be configured from the UI */ + @Singular @Nonnull private final Map parameters; + + public static class Builder { + /** + * Convenience builder method to create a RemoteEval with a simple task function. + * + * @param taskFn Function that takes input and returns output + * @return this builder + */ + public Builder taskFunction(Function taskFn) { + return task( + datasetCase -> { + var result = taskFn.apply(datasetCase.input()); + return new dev.braintrust.eval.TaskResult<>(result, datasetCase); + }); + } + + /** Build the RemoteEval */ + public RemoteEval build() { + // can add build hooks here later if desired + return internalBuild(); + } + } + + /** Represents a configurable parameter for the evaluator */ + @Getter + @lombok.Builder(builderClassName = "Builder") + public static class Parameter { + /** Type of parameter: "prompt" or "data" */ + @Nonnull private final ParameterType type; + + /** Optional description of the parameter */ + @Nullable private final String description; + + /** Optional default value for the parameter */ + @Nullable private final Object defaultValue; + + /** + * JSON Schema for data type parameters. Only applicable when type is DATA. Should be a Map + * representing a JSON Schema object. + */ + @Nullable private final Map schema; + + public static Parameter promptParameter(String description, Object defaultValue) { + return Parameter.builder() + .type(ParameterType.PROMPT) + .description(description) + .defaultValue(defaultValue) + .build(); + } + + public static Parameter promptParameter(Object defaultValue) { + return promptParameter(null, defaultValue); + } + + public static Parameter dataParameter( + String description, Map schema, Object defaultValue) { + return Parameter.builder() + .type(ParameterType.DATA) + .description(description) + .schema(schema) + .defaultValue(defaultValue) + .build(); + } + + public static Parameter dataParameter(Map schema, Object defaultValue) { + return dataParameter(null, schema, defaultValue); + } + + public static Parameter dataParameter(Map schema) { + return dataParameter(null, schema, null); + } + } + + /** Parameter type enumeration */ + public enum ParameterType { + /** Prompt parameter (for LLM prompts) */ + PROMPT("prompt"), + /** Data parameter (for other configuration data) */ + DATA("data"); + + private final String value; + + ParameterType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } +} diff --git a/src/main/java/dev/braintrust/devserver/RequestContext.java b/src/main/java/dev/braintrust/devserver/RequestContext.java new file mode 100644 index 0000000..4213c6e --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/RequestContext.java @@ -0,0 +1,24 @@ +package dev.braintrust.devserver; + +import dev.braintrust.Braintrust; +import javax.annotation.Nullable; +import lombok.Builder; +import lombok.Getter; + +/** + * Context object attached to each authenticated request. + * + *

Contains the validated origin, extracted authentication token, and cached login state. + */ +@Getter +@Builder +public class RequestContext { + /** Validated origin from CORS */ + private final String appOrigin; + + /** Extracted auth token (if present) */ + @Nullable private final String token; + + /** Cached login state */ + @Nullable private final Braintrust braintrust; +} diff --git a/src/main/java/dev/braintrust/trace/BraintrustContext.java b/src/main/java/dev/braintrust/trace/BraintrustContext.java index 5911796..db3bf9c 100644 --- a/src/main/java/dev/braintrust/trace/BraintrustContext.java +++ b/src/main/java/dev/braintrust/trace/BraintrustContext.java @@ -45,11 +45,11 @@ public static Context ofExperiment(@Nonnull String experimentId, @Nonnull Span s * parent context to flow across process boundaries. * * @param ctx the context to update - * @param parentType the type of parent (e.g., "experiment_id", "project_name") + * @param parentType the type of parent (e.g., "experiment_id", "project_name", "playground_id") * @param parentId the ID of the parent * @return updated context with baggage set */ - static Context setParentInBaggage( + public static Context setParentInBaggage( @Nonnull Context ctx, @Nonnull String parentType, @Nonnull String parentId) { try { String parentValue = (new BraintrustUtils.Parent(parentType, parentId)).toParentValue(); diff --git a/src/test/java/dev/braintrust/TestUtils.java b/src/test/java/dev/braintrust/TestUtils.java new file mode 100644 index 0000000..82124ce --- /dev/null +++ b/src/test/java/dev/braintrust/TestUtils.java @@ -0,0 +1,15 @@ +package dev.braintrust; + +import java.io.IOException; +import java.net.ServerSocket; + +public class TestUtils { + public static int getRandomOpenPort() { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(true); + return socket.getLocalPort(); + } catch (IOException e) { + throw new RuntimeException("Failed to find an available port", e); + } + } +} diff --git a/src/test/java/dev/braintrust/devserver/CorsTest.java b/src/test/java/dev/braintrust/devserver/CorsTest.java new file mode 100644 index 0000000..b6300d0 --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/CorsTest.java @@ -0,0 +1,176 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import dev.braintrust.TestUtils; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Scorer; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class CorsTest { + private static Devserver server; + private static Thread serverThread; + private static final int TEST_PORT = TestUtils.getRandomOpenPort(); + + @BeforeAll + static void setup() throws Exception { + RemoteEval testEval = + RemoteEval.builder() + .name("test-eval") + .taskFunction(String::toUpperCase) + .scorer( + Scorer.of( + "length", + (expected, result) -> (double) result.length() / 10.0)) + .build(); + + server = + Devserver.builder() + .config(BraintrustConfig.of("BRAINTRUST_API_KEY", "bogus")) + .registerEval(testEval) + .host("localhost") + .port(TEST_PORT) + .build(); + + serverThread = + new Thread( + () -> { + try { + server.start(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + serverThread.start(); + Thread.sleep(1000); // Give server time to start + } + + @AfterAll + static void teardown() { + server.stop(); + serverThread.interrupt(); + } + + @Test + void testCorsPreflightRequest() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .method("OPTIONS", HttpRequest.BodyPublishers.noBody()) + .header("Origin", "https://www.braintrust.dev") + .header("Access-Control-Request-Method", "POST") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(204, response.statusCode()); + assertEquals( + "https://www.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + assertEquals( + "true", + response.headers() + .firstValue("Access-Control-Allow-Credentials") + .orElse("") + .toLowerCase()); + assertTrue( + response.headers() + .firstValue("Access-Control-Allow-Methods") + .orElse("") + .toLowerCase() + .contains("post")); + assertTrue( + response.headers() + .firstValue("Access-Control-Allow-Headers") + .orElse("") + .toLowerCase() + .contains("x-bt-auth-token")); + } + + @Test + void testCorsActualRequest() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .GET() + .header("Origin", "https://www.braintrust.dev") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals("Hello, world!", response.body()); + assertEquals( + "https://www.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + assertEquals( + "true", + response.headers() + .firstValue("Access-Control-Allow-Credentials") + .orElse("") + .toLowerCase()); + } + + @Test + void testCorsPreviewDomain() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .GET() + .header("Origin", "https://pr-123.preview.braintrust.dev") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "https://pr-123.preview.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + } + + @Test + void testCorsUnauthorizedOrigin() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .method("OPTIONS", HttpRequest.BodyPublishers.noBody()) + .header("Origin", "https://evil.com") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(403, response.statusCode()); + } + + @Test + void testPrivateNetworkAccess() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .method("OPTIONS", HttpRequest.BodyPublishers.noBody()) + .header("Origin", "https://www.braintrust.dev") + .header("Access-Control-Request-Private-Network", "true") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(204, response.statusCode()); + assertEquals( + "true", + response.headers() + .firstValue("Access-Control-Allow-Private-Network") + .orElse("") + .toLowerCase()); + } +} diff --git a/src/test/java/dev/braintrust/devserver/DevserverTest.java b/src/test/java/dev/braintrust/devserver/DevserverTest.java new file mode 100644 index 0000000..2c8a4da --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/DevserverTest.java @@ -0,0 +1,427 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sun.net.httpserver.HttpServer; +import dev.braintrust.TestHarness; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Scorer; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.*; + +class DevserverTest { + private static Devserver server; + private static Thread serverThread; + private static TestHarness testHarness; + private static HttpServer mockApiServer; + private static io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter devserverSpanExporter; + private static final int TEST_PORT = 18300; + private static final int MOCK_API_PORT = 18301; + private static final String TEST_URL = "http://localhost:" + TEST_PORT; + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + + @BeforeAll + static void setUp() throws Exception { + // Set up mock Braintrust API server + mockApiServer = HttpServer.create(new InetSocketAddress("localhost", MOCK_API_PORT), 0); + + // Mock /v1/project endpoint + mockApiServer.createContext( + "/v1/project", + exchange -> { + String response = + JSON_MAPPER.writeValueAsString( + Map.of( + "id", "test-project-id", + "name", "test-project", + "org_id", "test-org-id", + "created", "2023-01-01T00:00:00Z", + "updated", "2023-01-01T00:00:00Z")); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders( + 200, response.getBytes(StandardCharsets.UTF_8).length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response.getBytes(StandardCharsets.UTF_8)); + } + }); + + // Mock /v1/org endpoint + mockApiServer.createContext( + "/v1/org", + exchange -> { + String response = + JSON_MAPPER.writeValueAsString( + Map.of( + "results", + List.of( + Map.of( + "id", "test-org-id", + "name", "test-org")))); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders( + 200, response.getBytes(StandardCharsets.UTF_8).length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response.getBytes(StandardCharsets.UTF_8)); + } + }); + + // Mock /api/apikey/login endpoint (using snake_case as per API client naming strategy) + mockApiServer.createContext( + "/api/apikey/login", + exchange -> { + String response = + JSON_MAPPER.writeValueAsString( + Map.of( + "org_info", + List.of( + Map.of( + "id", "test-org-id", + "name", "test-org")))); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders( + 200, response.getBytes(StandardCharsets.UTF_8).length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response.getBytes(StandardCharsets.UTF_8)); + } + }); + + mockApiServer.start(); + + // Set up test harness with config pointing to mock API + BraintrustConfig testConfig = + BraintrustConfig.of( + "BRAINTRUST_API_KEY", "test-key", + "BRAINTRUST_API_URL", "http://localhost:" + MOCK_API_PORT, + "BRAINTRUST_APP_URL", "http://localhost:3000", + "BRAINTRUST_DEFAULT_PROJECT_NAME", "test-project", + "BRAINTRUST_JAVA_EXPORT_SPANS_IN_MEMORY_FOR_UNIT_TEST", "true"); + testHarness = TestHarness.setup(testConfig); + + // Create a shared eval for all tests + RemoteEval testEval = + RemoteEval.builder() + .name("food-type-classifier") + .taskFunction( + input -> { + // Create a span inside the task to test baggage propagation + var tracer = dev.braintrust.trace.BraintrustTracing.getTracer(); + var span = tracer.spanBuilder("custom-task-span").startSpan(); + try (var scope = + io.opentelemetry.context.Context.current() + .with(span) + .makeCurrent()) { + // Do some work + return "java-fruit"; + } finally { + span.end(); + } + }) + .scorer(Scorer.of("simple_scorer", (expected, result) -> 0.7)) + .build(); + + // Create in-memory span exporter for devserver-created spans + devserverSpanExporter = io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter.create(); + + server = + Devserver.builder() + .config(testHarness.braintrust().config()) + .registerEval(testEval) + .host("localhost") + .port(TEST_PORT) + .traceBuilderHook( + tracerBuilder -> { + // Add in-memory exporter to capture spans from devserver + tracerBuilder.addSpanProcessor( + io.opentelemetry.sdk.trace.export.SimpleSpanProcessor + .create(devserverSpanExporter)); + }) + .braintrustConfigBuilderHook( + configBuilder -> { + // Enable in-memory span export for testing + configBuilder.exportSpansInMemoryForUnitTest(true); + }) + .build(); + + // Start server in background thread + serverThread = + new Thread( + () -> { + try { + server.start(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + serverThread.start(); + + // Give server time to start + Thread.sleep(1000); + } + + @AfterAll + static void tearDown() { + if (server != null) { + server.stop(); + } + if (serverThread != null) { + serverThread.interrupt(); + } + if (mockApiServer != null) { + mockApiServer.stop(0); + } + } + + @Test + void testHealthCheck() throws Exception { + // Test health check endpoint using the shared devserver + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder().uri(URI.create(TEST_URL + "/")).GET().build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals("Hello, world!", response.body()); + assertEquals("text/plain", response.headers().firstValue("Content-Type").orElse("")); + } + + @Test + void testStreamingEval() throws Exception { + // Create eval request with inline data using EvalRequest types + EvalRequest evalRequest = new EvalRequest(); + evalRequest.setName("food-type-classifier"); + evalRequest.setStream(true); + + // Create inline data + EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); + + EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); + case1.setInput("apple"); + case1.setExpected("fruit"); + + EvalRequest.EvalCaseData case2 = new EvalRequest.EvalCaseData(); + case2.setInput("carrot"); + case2.setExpected("vegetable"); + + dataSpec.setData(List.of(case1, case2)); + evalRequest.setData(dataSpec); + + // Set parent with playground_id and generation + Map parentSpec = + Map.of( + "object_type", "experiment", + "object_id", "test-playground-id-123", + "propagated_event", + Map.of("span_attributes", Map.of("generation", "test-gen-1"))); + evalRequest.setParent(parentSpec); + + String requestBody = JSON_MAPPER.writeValueAsString(evalRequest); + + // Make POST request to /eval with auth headers + HttpURLConnection conn = + (HttpURLConnection) new URI(TEST_URL + "/eval").toURL().openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setRequestProperty("x-bt-auth-token", "test-token-123"); + conn.setRequestProperty("x-bt-project-id", "test-project-id"); + conn.setRequestProperty("x-bt-org-name", "test-org"); + conn.setDoOutput(true); + + // Write request body + conn.getOutputStream().write(requestBody.getBytes(StandardCharsets.UTF_8)); + conn.getOutputStream().flush(); + + // Read SSE response + assertEquals(200, conn.getResponseCode()); + assertEquals("text/event-stream", conn.getHeaderField("Content-Type")); + + BufferedReader reader = + new BufferedReader( + new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8)); + + List> events = new ArrayList<>(); + String line; + String currentEvent = null; + StringBuilder currentData = new StringBuilder(); + + while ((line = reader.readLine()) != null) { + if (line.startsWith("event: ")) { + currentEvent = line.substring(7); + } else if (line.startsWith("data: ")) { + currentData.append(line.substring(6)); + } else if (line.isEmpty() && currentEvent != null) { + // End of event + events.add(Map.of("event", currentEvent, "data", currentData.toString())); + currentEvent = null; + currentData = new StringBuilder(); + } + } + reader.close(); + + // Assert event structure + assertFalse(events.isEmpty(), "Should have received events"); + + // Count events by type + long startCount = events.stream().filter(e -> "start".equals(e.get("event"))).count(); + long progressCount = events.stream().filter(e -> "progress".equals(e.get("event"))).count(); + long summaryCount = events.stream().filter(e -> "summary".equals(e.get("event"))).count(); + long doneCount = events.stream().filter(e -> "done".equals(e.get("event"))).count(); + + // Should have 1 start event, 2 progress events (one per dataset case), 1 summary, 1 done + assertEquals(1, startCount, "Should have 1 start event"); + assertEquals(2, progressCount, "Should have 2 progress events"); + assertEquals(1, summaryCount, "Should have 1 summary event"); + assertEquals(1, doneCount, "Should have 1 done event"); + + // Verify start event is first and has expected structure + assertEquals("start", events.get(0).get("event"), "First event should be start"); + Map startEvent = events.get(0); + JsonNode startData = JSON_MAPPER.readTree(startEvent.get("data")); + + assertEquals("test-project", startData.get("projectName").asText()); + assertTrue(startData.has("projectId")); + assertEquals("food-type-classifier", startData.get("experimentName").asText()); + assertTrue(startData.has("experimentUrl")); + assertTrue(startData.has("projectUrl")); + assertTrue(startData.has("scores")); + assertTrue(startData.get("experimentId").isNull()); + + // Verify progress events match expected structure + List> progressEvents = + events.stream().filter(e -> "progress".equals(e.get("event"))).toList(); + + for (Map progressEvent : progressEvents) { + String dataJson = progressEvent.get("data"); + JsonNode progressData = JSON_MAPPER.readTree(dataJson); + + // Assert expected fields in progress event + assertTrue(progressData.has("id"), "Progress event should have id"); + assertEquals("task", progressData.get("object_type").asText()); + assertEquals("food-type-classifier", progressData.get("name").asText()); + assertEquals("code", progressData.get("format").asText()); + assertEquals("completion", progressData.get("output_type").asText()); + assertEquals("json_delta", progressData.get("event").asText()); + + // Assert data field contains the task result + String taskResultJson = progressData.get("data").asText(); + assertEquals("\"java-fruit\"", taskResultJson); + } + + // Verify summary event + Map summaryEvent = + events.stream() + .filter(e -> "summary".equals(e.get("event"))) + .findFirst() + .orElseThrow(); + JsonNode summaryData = JSON_MAPPER.readTree(summaryEvent.get("data")); + + assertEquals("test-project", summaryData.get("projectName").asText()); + assertTrue(summaryData.has("projectId")); + assertEquals("food-type-classifier", summaryData.get("experimentName").asText()); + + // Verify scores in summary + assertTrue(summaryData.has("scores")); + JsonNode scores = summaryData.get("scores"); + assertTrue(scores.has("simple_scorer")); + JsonNode simpleScorer = scores.get("simple_scorer"); + assertEquals("simple_scorer", simpleScorer.get("name").asText()); + assertEquals(0.7, simpleScorer.get("score").asDouble(), 0.001); + + // Get exported spans from test harness (since devserver uses global tracer) + List exportedSpans = testHarness.awaitExportedSpans(); + assertFalse(exportedSpans.isEmpty(), "Should have exported spans"); + + // We should have 2 eval traces (one per dataset case), each with task, score, and custom + // spans + // Each trace has: 1 eval span, 1 task span, 1 score span, 1 custom-task-span = 4 spans + // per case + // Total: 2 cases * 4 spans = 8 spans + assertEquals(8, exportedSpans.size(), "Should have 8 spans (4 per dataset case)"); + + // Verify span types + long evalSpans = exportedSpans.stream().filter(s -> s.getName().equals("eval")).count(); + long taskSpans = exportedSpans.stream().filter(s -> s.getName().equals("task")).count(); + long scoreSpans = exportedSpans.stream().filter(s -> s.getName().equals("score")).count(); + long customSpans = + exportedSpans.stream().filter(s -> s.getName().equals("custom-task-span")).count(); + + assertEquals(2, evalSpans, "Should have 2 eval spans"); + assertEquals(2, taskSpans, "Should have 2 task spans"); + assertEquals(2, scoreSpans, "Should have 2 score spans"); + assertEquals(2, customSpans, "Should have 2 custom-task-span spans"); + + // Verify each eval span has playground_id parent + for (SpanData span : exportedSpans) { + if (span.getName().equals("eval")) { + String parent = + span.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull(parent, "Eval span should have parent attribute"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Parent should contain playground_id"); + } + } + + // Verify span attributes contain generation + for (SpanData span : exportedSpans) { + String spanAttrsJson = + span.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.span_attributes")); + if (spanAttrsJson != null) { + JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); + if (spanAttrs.has("generation")) { + assertEquals("test-gen-1", spanAttrs.get("generation").asText()); + } + } + } + + // Verify custom span created in task function has proper parent propagation + List customSpansList = + exportedSpans.stream().filter(s -> s.getName().equals("custom-task-span")).toList(); + assertEquals(2, customSpansList.size(), "Should have 2 custom spans"); + + for (SpanData customSpan : customSpansList) { + // Verify it has a parent span (is not a root span) + assertTrue( + customSpan.getParentSpanContext().isValid(), + "Custom span should have a valid parent span context"); + assertNotEquals( + io.opentelemetry.api.trace.SpanId.getInvalid(), + customSpan.getParentSpanId(), + "Custom span should not be a root span (should have parent span ID)"); + + // Verify it has braintrust.parent attribute from baggage propagation + String parent = + customSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull( + parent, "Custom span should have braintrust.parent attribute from baggage"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Custom span parent should contain playground_id from baggage propagation"); + } + } +} diff --git a/src/test/java/dev/braintrust/devserver/EvalEndpointTest.java b/src/test/java/dev/braintrust/devserver/EvalEndpointTest.java new file mode 100644 index 0000000..6004dc5 --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/EvalEndpointTest.java @@ -0,0 +1,178 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.braintrust.TestUtils; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Scorer; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +class EvalEndpointTest { + private static Devserver server; + private static Thread serverThread; + private static final int TEST_PORT = TestUtils.getRandomOpenPort(); + private static final ObjectMapper mapper = new ObjectMapper(); + + @BeforeAll + static void setup() throws Exception { + // Create a test eval + RemoteEval testEval = + RemoteEval.builder() + .name("uppercase-eval") + .taskFunction(String::toUpperCase) + .scorer( + Scorer.of( + "length", + (expected, result) -> + Math.min((double) result.length() / 10.0, 1.0))) + .scorer( + Scorer.of( + "has_hello", + (expected, result) -> result.contains("HELLO") ? 1.0 : 0.0)) + .build(); + + server = + Devserver.builder() + .config(BraintrustConfig.of("BRAINTRUST_API_KEY", "bogus")) + .registerEval(testEval) + .host("localhost") + .port(TEST_PORT) + .build(); + + serverThread = + new Thread( + () -> { + try { + server.start(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + serverThread.start(); + Thread.sleep(1000); // Give server time to start + } + + @AfterAll + static void teardown() { + server.stop(); + serverThread.interrupt(); + } + + @Test + @Disabled + void testEvalEndpointWithInlineData() throws Exception { + // Build request + EvalRequest request = new EvalRequest(); + request.setName("uppercase-eval"); + + EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); + + EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); + case1.setInput("hello"); + case1.setExpected("HELLO"); + + EvalRequest.EvalCaseData case2 = new EvalRequest.EvalCaseData(); + case2.setInput("world"); + case2.setExpected("WORLD"); + + dataSpec.setData(List.of(case1, case2)); + request.setData(dataSpec); + + String requestJson = mapper.writeValueAsString(request); + + HttpClient client = HttpClient.newHttpClient(); + HttpRequest httpRequest = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/eval")) + .POST(HttpRequest.BodyPublishers.ofString(requestJson)) + .header("Content-Type", "application/json") + .build(); + + HttpResponse response = + client.send(httpRequest, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "application/json", response.headers().firstValue("Content-Type").orElse(null)); + + // Parse and validate JSON response + JsonNode root = mapper.readTree(response.body()); + + assertTrue(root.has("experimentName")); + assertTrue(root.has("projectName")); + assertTrue(root.has("projectId")); + assertTrue(root.has("experimentId")); + assertTrue(root.has("experimentUrl")); + assertTrue(root.has("projectUrl")); + assertTrue(root.has("scores")); + + JsonNode scores = root.get("scores"); + assertTrue(scores.has("length")); + assertTrue(scores.has("has_hello")); + + // Check length scorer (average of 5/10 and 5/10 = 0.5) + JsonNode lengthScore = scores.get("length"); + assertEquals("length", lengthScore.get("name").asText()); + assertEquals(0.5, lengthScore.get("score").asDouble(), 0.01); + + // Check has_hello scorer (1.0 for "hello", 0.0 for "world" = 0.5 average) + JsonNode helloScore = scores.get("has_hello"); + assertEquals("has_hello", helloScore.get("name").asText()); + assertEquals(0.5, helloScore.get("score").asDouble(), 0.01); + } + + @Test + void testEvalEndpointEvaluatorNotFound() throws Exception { + EvalRequest request = new EvalRequest(); + request.setName("non-existent-eval"); + + EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); + EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); + case1.setInput("test"); + dataSpec.setData(List.of(case1)); + request.setData(dataSpec); + + String requestJson = mapper.writeValueAsString(request); + + HttpClient client = HttpClient.newHttpClient(); + HttpRequest httpRequest = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/eval")) + .POST(HttpRequest.BodyPublishers.ofString(requestJson)) + .header("Content-Type", "application/json") + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = + client.send(httpRequest, HttpResponse.BodyHandlers.ofString()); + + assertEquals(404, response.statusCode()); + assertTrue(response.body().contains("Evaluator not found")); + } + + @Test + void testEvalEndpointMethodNotAllowed() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/eval")) + .GET() + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(405, response.statusCode()); + } +} diff --git a/src/test/java/dev/braintrust/devserver/ListEndpointTest.java b/src/test/java/dev/braintrust/devserver/ListEndpointTest.java new file mode 100644 index 0000000..71150b6 --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/ListEndpointTest.java @@ -0,0 +1,166 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Scorer; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class ListEndpointTest { + private static Devserver server; + private static Thread serverThread; + private static final int TEST_PORT = 18302; + private static final ObjectMapper mapper = new ObjectMapper(); + + @BeforeAll + static void setup() throws Exception { + // Create a test eval with parameters + RemoteEval testEval = + RemoteEval.builder() + .name("test-classifier") + .taskFunction(input -> input.toUpperCase()) + .scorer(Scorer.of("accuracy", result -> 0.95)) + .scorer( + Scorer.of( + "length", + (expected, result) -> (double) result.length() / 10.0)) + .parameter( + "model", + RemoteEval.Parameter.dataParameter( + "The model to use", + Map.of( + "type", + "string", + "enum", + new String[] {"gpt-4", "gpt-3.5"}), + "gpt-4")) + .parameter( + "temperature", + RemoteEval.Parameter.dataParameter( + "Temperature for sampling", + Map.of("type", "number", "minimum", 0.0, "maximum", 2.0), + 0.7)) + .build(); + + server = + Devserver.builder() + .config(BraintrustConfig.of("BRAINTRUST_API_KEY", "bogus")) + .registerEval(testEval) + .host("localhost") + .port(TEST_PORT) + .build(); + + serverThread = + new Thread( + () -> { + try { + server.start(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + serverThread.start(); + Thread.sleep(1000); // Give server time to start + } + + @AfterAll + static void teardown() { + server.stop(); + serverThread.interrupt(); + } + + @Test + void testListEndpoint() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/list")) + .GET() + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "application/json", response.headers().firstValue("Content-Type").orElse(null)); + + // Parse and validate JSON response + JsonNode root = mapper.readTree(response.body()); + + // Should have one evaluator + assertTrue(root.has("test-classifier")); + + JsonNode eval = root.get("test-classifier"); + + // Check parameters + assertTrue(eval.has("parameters")); + JsonNode parameters = eval.get("parameters"); + + assertTrue(parameters.has("model")); + JsonNode modelParam = parameters.get("model"); + assertEquals("data", modelParam.get("type").asText()); + assertEquals("The model to use", modelParam.get("description").asText()); + assertEquals("gpt-4", modelParam.get("default").asText()); + assertTrue(modelParam.has("schema")); + + assertTrue(parameters.has("temperature")); + JsonNode tempParam = parameters.get("temperature"); + assertEquals("data", tempParam.get("type").asText()); + assertEquals(0.7, tempParam.get("default").asDouble()); + + // Check scores + assertTrue(eval.has("scores")); + JsonNode scores = eval.get("scores"); + assertEquals(2, scores.size()); + + assertEquals("accuracy", scores.get(0).get("name").asText()); + assertEquals("length", scores.get(1).get("name").asText()); + } + + @Test + void testListEndpointMethodNotAllowed() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/list")) + .POST(HttpRequest.BodyPublishers.noBody()) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(405, response.statusCode()); + } + + @Test + void testListEndpointWithCors() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/list")) + .GET() + .header("Origin", "https://www.braintrust.dev") + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "https://www.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + } +} From 7fce7e787d0b82e20104eb3d2d0d01f1b9048ad3 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Fri, 19 Dec 2025 10:05:16 -0700 Subject: [PATCH 2/7] send all devserver traces through global otel --- .../java/dev/braintrust/devserver/Devserver.java | 13 +------------ .../dev/braintrust/devserver/DevserverTest.java | 12 ------------ 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/src/main/java/dev/braintrust/devserver/Devserver.java b/src/main/java/dev/braintrust/devserver/Devserver.java index 680ec7c..8538731 100644 --- a/src/main/java/dev/braintrust/devserver/Devserver.java +++ b/src/main/java/dev/braintrust/devserver/Devserver.java @@ -25,7 +25,6 @@ import io.opentelemetry.sdk.logs.SdkLoggerProvider; import io.opentelemetry.sdk.metrics.SdkMeterProvider; import io.opentelemetry.sdk.trace.SdkTracerProvider; -import io.opentelemetry.sdk.trace.SdkTracerProviderBuilder; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -99,7 +98,6 @@ public class Devserver { // LRU cache for token -> Braintrust mappings (max 32 entries as per api.md) private final LRUCache authCache = new LRUCache<>(32); - private final LRUCache otelCache = new LRUCache<>(32); private Devserver(Builder builder) { this.config = Objects.requireNonNull(builder.config); @@ -474,10 +472,7 @@ private void handleStreamingEval( throw new IllegalStateException("No dataset specification provided"); } - // TODO: flush otel upon cache eviction - var otel = - otelCache.getOrCompute(braintrust, () -> createOpenTelemetry(braintrust)); - var tracer = BraintrustTracing.getTracer(otel); + var tracer = BraintrustTracing.getTracer(); // Execute task and scorers for each case Map> scoresByName = new LinkedHashMap<>(); @@ -1120,12 +1115,6 @@ public Builder port(int port) { return this; } - /** hook to run for each open telemetry instance created by the devserver */ - public Builder traceBuilderHook(Consumer traceBuilderHook) { - this.traceBuilderHook = traceBuilderHook; - return this; - } - /** * hook to run for each braintrust instance's config created by the devserver. The hook * receives the BraintrustConfig.Builder before it's built, allowing customization such as diff --git a/src/test/java/dev/braintrust/devserver/DevserverTest.java b/src/test/java/dev/braintrust/devserver/DevserverTest.java index 2c8a4da..11a2ab7 100644 --- a/src/test/java/dev/braintrust/devserver/DevserverTest.java +++ b/src/test/java/dev/braintrust/devserver/DevserverTest.java @@ -29,7 +29,6 @@ class DevserverTest { private static Thread serverThread; private static TestHarness testHarness; private static HttpServer mockApiServer; - private static io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter devserverSpanExporter; private static final int TEST_PORT = 18300; private static final int MOCK_API_PORT = 18301; private static final String TEST_URL = "http://localhost:" + TEST_PORT; @@ -134,25 +133,14 @@ static void setUp() throws Exception { .scorer(Scorer.of("simple_scorer", (expected, result) -> 0.7)) .build(); - // Create in-memory span exporter for devserver-created spans - devserverSpanExporter = io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter.create(); - server = Devserver.builder() .config(testHarness.braintrust().config()) .registerEval(testEval) .host("localhost") .port(TEST_PORT) - .traceBuilderHook( - tracerBuilder -> { - // Add in-memory exporter to capture spans from devserver - tracerBuilder.addSpanProcessor( - io.opentelemetry.sdk.trace.export.SimpleSpanProcessor - .create(devserverSpanExporter)); - }) .braintrustConfigBuilderHook( configBuilder -> { - // Enable in-memory span export for testing configBuilder.exportSpansInMemoryForUnitTest(true); }) .build(); From 257834c68304388774cf355c9bd017cbfd076405 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Fri, 19 Dec 2025 11:32:56 -0700 Subject: [PATCH 3/7] refactor devserver streaming eval into helper methods --- .../java/dev/braintrust/BraintrustUtils.java | 2 +- .../dev/braintrust/devserver/Devserver.java | 597 +++++++++--------- .../braintrust/devserver/DevserverTest.java | 347 +++++++--- .../devserver/EvalEndpointTest.java | 178 ------ .../devserver/ListEndpointTest.java | 166 ----- 5 files changed, 561 insertions(+), 729 deletions(-) delete mode 100644 src/test/java/dev/braintrust/devserver/EvalEndpointTest.java delete mode 100644 src/test/java/dev/braintrust/devserver/ListEndpointTest.java diff --git a/src/main/java/dev/braintrust/BraintrustUtils.java b/src/main/java/dev/braintrust/BraintrustUtils.java index b56eeed..4e2fa4a 100644 --- a/src/main/java/dev/braintrust/BraintrustUtils.java +++ b/src/main/java/dev/braintrust/BraintrustUtils.java @@ -31,7 +31,7 @@ public static URI createProjectURI( } } - static Parent parseParent(@Nonnull String parentStr) { + public static Parent parseParent(@Nonnull String parentStr) { String[] parts = parentStr.split(":"); if (parts.length != 2) { throw new IllegalArgumentException("Invalid parent format: " + parentStr); diff --git a/src/main/java/dev/braintrust/devserver/Devserver.java b/src/main/java/dev/braintrust/devserver/Devserver.java index 8538731..13d128f 100644 --- a/src/main/java/dev/braintrust/devserver/Devserver.java +++ b/src/main/java/dev/braintrust/devserver/Devserver.java @@ -9,13 +9,13 @@ import dev.braintrust.Origin; import dev.braintrust.api.BraintrustApiClient; import dev.braintrust.config.BraintrustConfig; -import dev.braintrust.eval.Dataset; -import dev.braintrust.eval.DatasetCase; -import dev.braintrust.eval.Score; +import dev.braintrust.eval.*; +import dev.braintrust.trace.BraintrustContext; import dev.braintrust.trace.BraintrustTracing; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanKind; import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; import io.opentelemetry.context.Context; @@ -31,9 +31,11 @@ import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.function.Consumer; import java.util.regex.Pattern; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.Getter; import lombok.experimental.Accessors; @@ -334,9 +336,6 @@ private void handleEval(HttpExchange exchange) throws IOException { private void handleStreamingEval( HttpExchange exchange, RemoteEval eval, EvalRequest request, RequestContext context) throws Exception { - // TODO: refactor some of these steps into utility methods (e.g. dataset extraction, span - // attribute setting) - // Set SSE headers exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); exchange.getResponseHeaders().set("Cache-Control", "no-cache"); @@ -350,294 +349,123 @@ private void handleStreamingEval( BraintrustApiClient apiClient = braintrust.apiClient(); // Determine project name and ID from the authenticated Braintrust instance - var orgAndProject = apiClient.getOrCreateProjectAndOrgInfo(braintrust.config()); - String projectName = orgAndProject.project().name(); - String projectId = orgAndProject.project().id(); - - // Generate experiment name (same logic as non-streaming) - String experimentName = + final var orgAndProject = + apiClient.getOrCreateProjectAndOrgInfo(braintrust.config()); + final var projectName = orgAndProject.project().name(); + final var projectId = orgAndProject.project().id(); + final var experimentName = request.getExperimentName() != null ? request.getExperimentName() : eval.getName(); - - String parentSpec = null; - String generation = null; - - // Extract parent spec and generation from request - if (request.getParent() != null && request.getParent() instanceof Map) { - @SuppressWarnings("unchecked") - Map parentMap = (Map) request.getParent(); - String objectType = (String) parentMap.get("object_type"); - String objectId = (String) parentMap.get("object_id"); - - // Extract generation from propagated_event.span_attributes.generation - Object propEventObj = parentMap.get("propagated_event"); - if (propEventObj instanceof Map) { - @SuppressWarnings("unchecked") - Map propEvent = (Map) propEventObj; - Object spanAttrsObj = propEvent.get("span_attributes"); - if (spanAttrsObj instanceof Map) { - @SuppressWarnings("unchecked") - Map spanAttrs = (Map) spanAttrsObj; - generation = (String) spanAttrs.get("generation"); - } - } - - if (objectType != null && objectId != null) { - parentSpec = "playground_id:" + objectId; - } - } - - // Build URLs - String experimentUrl = + final var experimentUrl = BraintrustUtils.createProjectURI( braintrust.config().appUrl(), orgAndProject) .toASCIIString() + "/experiments/" + experimentName; - String projectUrl = + final var projectUrl = BraintrustUtils.createProjectURI( braintrust.config().appUrl(), orgAndProject) .toASCIIString(); - // Send start event - // TODO: the browser doesn't understand this event. Should probably just remove it - sendStartEvent( - os, - projectName, - projectId, - experimentName, - null, // experimentId - not created yet - projectUrl, - experimentUrl); - - // Load dataset using one of three methods (same logic as executeEval) - Dataset dataset; - EvalRequest.DataSpec dataSpec = request.getData(); - - if (dataSpec.getData() != null && !dataSpec.getData().isEmpty()) { - // Method 1: Inline data - List cases = new ArrayList<>(); - for (EvalRequest.EvalCaseData caseData : dataSpec.getData()) { - DatasetCase datasetCase = - DatasetCase.of( - caseData.getInput(), - caseData.getExpected(), - caseData.getTags() != null ? caseData.getTags() : List.of(), - caseData.getMetadata() != null - ? caseData.getMetadata() - : Map.of()); - cases.add(datasetCase); - } - dataset = Dataset.of(cases.toArray(new DatasetCase[0])); - } else if (dataSpec.getProjectName() != null && dataSpec.getDatasetName() != null) { - // Method 2: Fetch by project name and dataset name - log.debug( - "Fetching dataset from Braintrust: project={}, dataset={}", - dataSpec.getProjectName(), - dataSpec.getDatasetName()); - dataset = - Dataset.fetchFromBraintrust( - apiClient, - dataSpec.getProjectName(), - dataSpec.getDatasetName(), - null); - } else if (dataSpec.getDatasetId() != null) { - // Method 3: Fetch by dataset ID - log.debug( - "Fetching dataset from Braintrust by ID: {}", dataSpec.getDatasetId()); - var datasetMetadata = apiClient.getDataset(dataSpec.getDatasetId()); - if (datasetMetadata.isEmpty()) { - throw new IllegalArgumentException( - "Dataset not found: " + dataSpec.getDatasetId()); - } - - var project = apiClient.getProject(datasetMetadata.get().projectId()); - if (project.isEmpty()) { - throw new IllegalArgumentException( - "Project not found: " + datasetMetadata.get().projectId()); - } - - String fetchedProjectName = project.get().name(); - String fetchedDatasetName = datasetMetadata.get().name(); - log.debug( - "Resolved dataset ID to project={}, dataset={}", - fetchedProjectName, - fetchedDatasetName); - - dataset = - Dataset.fetchFromBraintrust( - apiClient, fetchedProjectName, fetchedDatasetName, null); - } else { - throw new IllegalStateException("No dataset specification provided"); - } - var tracer = BraintrustTracing.getTracer(); // Execute task and scorers for each case - Map> scoresByName = new LinkedHashMap<>(); - int[] caseCount = {0}; // Use array for mutability in lambda - final String finalParentSpec = parentSpec; // Make effectively final for lambda - final String finalGeneration = generation; // Make effectively final for lambda - if (finalParentSpec == null) { - throw new RuntimeException("parent required"); - } - - dataset.forEach( - datasetCase -> { - caseCount[0]++; - log.debug("Processing dataset case #{}", caseCount[0]); - - // Build span attributes with exec_counter and generation (eval span) - Map evalSpanAttrs = new LinkedHashMap<>(); - evalSpanAttrs.put("type", "eval"); - evalSpanAttrs.put("name", "eval"); - if (finalGeneration != null) { - evalSpanAttrs.put("generation", finalGeneration); - } - - // Create eval span for this dataset case (matches Eval.java pattern) - // TODO: take another pass through python playground and make sure we're - // setting the same attributes - var evalSpan = - tracer.spanBuilder("eval") - .setNoParent() // each eval case is its own trace - .setSpanKind(SpanKind.CLIENT) - .setAttribute(PARENT, finalParentSpec) - .setAttribute( - "braintrust.span_attributes", - json(evalSpanAttrs)) - .setAttribute( - "braintrust.input_json", - json(Map.of("input", datasetCase.input()))) - .setAttribute( - "braintrust.expected_json", - json(datasetCase.expected())) - .startSpan(); - - // Set parent in baggage for distributed tracing - // Parse parent format "type:id" (e.g., "playground_id:abc123") - io.opentelemetry.context.Context evalContext = - io.opentelemetry.context.Context.current().with(evalSpan); - String[] parentParts = finalParentSpec.split(":", 2); - if (parentParts.length == 2) { - evalContext = - dev.braintrust.trace.BraintrustContext.setParentInBaggage( - evalContext, parentParts[0], parentParts[1]); - } - - if (datasetCase.origin().isPresent()) { - evalSpan.setAttribute( - "braintrust.origin", json(datasetCase.origin().get())); - } - if (!datasetCase.tags().isEmpty()) { - evalSpan.setAttribute( - AttributeKey.stringArrayKey("braintrust.tags"), - datasetCase.tags()); - } - if (!datasetCase.metadata().isEmpty()) { - evalSpan.setAttribute( - "braintrust.metadata", json(datasetCase.metadata())); - } - - // Make the eval context (with span and baggage) current - try (var rootScope = evalContext.makeCurrent()) { - final dev.braintrust.eval.TaskResult taskResult; - { // run task - // Build task span attributes with exec_counter and generation - Map taskSpanAttrs = new LinkedHashMap<>(); - taskSpanAttrs.put("type", "task"); - taskSpanAttrs.put("name", "task"); - if (finalGeneration != null) { - taskSpanAttrs.put("generation", finalGeneration); - } - - var taskSpan = - tracer.spanBuilder("task") - .setAttribute(PARENT, finalParentSpec) + final Map> scoresByName = new ConcurrentHashMap<>(); + final var parentInfo = extractParentInfo(request); + final var braintrustParent = parentInfo.braintrustParent(); + final var braintrustGeneration = parentInfo.generation(); + + extractDataset(request, apiClient) + .forEach( + datasetCase -> { + var evalSpan = + tracer.spanBuilder("eval") + .setNoParent() + .setSpanKind(SpanKind.CLIENT) .setAttribute( - "braintrust.span_attributes", - json(taskSpanAttrs)) + PARENT, + braintrustParent.toParentValue()) .startSpan(); - taskSpan.setAttribute( - "braintrust.input_json", - json(Map.of("input", datasetCase.input()))); - try (var unused = - Context.current().with(taskSpan).makeCurrent()) { - var task = eval.getTask(); - taskResult = task.apply(datasetCase); - // Send progress event for task completion - sendProgressEvent( - os, - evalSpan.getSpanContext().getSpanId(), - datasetCase.origin(), - eval.getName(), - taskResult.result()); - } finally { - taskSpan.end(); - } - taskSpan.setAttribute( - "braintrust.output_json", - json(Map.of("output", taskResult.result()))); - evalSpan.setAttribute( - "braintrust.output_json", - json(Map.of("output", taskResult.result()))); - } - { // run scorers - one score span per scorer - var scorers = eval.getScorers(); - log.debug("Running {} scorers", scorers.size()); - - for (Object scorerObj : scorers) { - dev.braintrust.eval.Scorer scorer = - (dev.braintrust.eval.Scorer) scorerObj; - - // Build score span attributes with scorer name and - // generation - Map scoreSpanAttrs = new LinkedHashMap<>(); - scoreSpanAttrs.put("type", "score"); - scoreSpanAttrs.put("name", scorer.getName()); - if (finalGeneration != null) { - scoreSpanAttrs.put("generation", finalGeneration); + Context evalContext = Context.current().with(evalSpan); + evalContext = + BraintrustContext.setParentInBaggage( + evalContext, + braintrustParent.type(), + braintrustParent.id()); + // Make the eval context (with span and baggage) current + try (var rootScope = evalContext.makeCurrent()) { + final dev.braintrust.eval.TaskResult taskResult; + { // run task + var taskSpan = tracer.spanBuilder("task").startSpan(); + try (var unused = + Context.current() + .with(taskSpan) + .makeCurrent()) { + var task = eval.getTask(); + taskResult = task.apply(datasetCase); + // Send progress event for task completion + sendProgressEvent( + os, + evalSpan.getSpanContext().getSpanId(), + datasetCase.origin(), + eval.getName(), + taskResult.result()); + setTaskSpanAttributes( + taskSpan, + braintrustParent, + braintrustGeneration, + datasetCase, + taskResult); + } finally { + taskSpan.end(); + } + // setting eval span attributes here because we need the + // task output + setEvalSpanAttributes( + evalSpan, + braintrustParent, + braintrustGeneration, + datasetCase, + taskResult); } - - var scoreSpan = - tracer.spanBuilder("score") - .setAttribute(PARENT, finalParentSpec) - .setAttribute( - "braintrust.span_attributes", - json(scoreSpanAttrs)) - .startSpan(); - try (var unused = - Context.current().with(scoreSpan).makeCurrent()) { - List scores = scorer.score(taskResult); - log.debug( - "Scorer '{}' produced {} scores", - scorer.getName(), - scores.size()); - - Map scorerScores = - new LinkedHashMap<>(); - for (Score score : scores) { - scoresByName - .computeIfAbsent( - score.name(), - k -> new ArrayList<>()) - .add(score.value()); - scorerScores.put(score.name(), score.value()); + // run scorers - one score span per scorer + for (var scorer : (List>) eval.getScorers()) { + var scoreSpan = tracer.spanBuilder("score").startSpan(); + try (var unused = + Context.current() + .with(scoreSpan) + .makeCurrent()) { + List scores = scorer.score(taskResult); + + Map scorerScores = + new LinkedHashMap<>(); + for (Score score : scores) { + scoresByName + .computeIfAbsent( + score.name(), + k -> new ArrayList<>()) + .add(score.value()); + scorerScores.put(score.name(), score.value()); + } + // Set score span attributes before ending span + setScoreSpanAttributes( + scoreSpan, + braintrustParent, + braintrustGeneration, + scorer.getName(), + scorerScores); + } finally { + scoreSpan.end(); } - scoreSpan.setAttribute( - "braintrust.output_json", json(scorerScores)); - } finally { - scoreSpan.end(); } + } catch (IOException e) { + throw new RuntimeException( + "Failed to send progress event", e); + } finally { + evalSpan.end(); } - } - } catch (IOException e) { - throw new RuntimeException("Failed to send progress event", e); - } finally { - evalSpan.end(); - } - }); + }); // Aggregate scores Map scoreSummaries = new LinkedHashMap<>(); @@ -658,7 +486,6 @@ private void handleStreamingEval( .build()); } - // Send summary event sendSummaryEvent( os, projectName, @@ -667,10 +494,7 @@ private void handleStreamingEval( projectUrl, experimentUrl, scoreSummaries); - - // Send done event sendDoneEvent(os); - } catch (Exception e) { // Send error event via SSE log.error("Error during streaming evaluation", e); @@ -680,6 +504,7 @@ private void handleStreamingEval( } catch (IOException ioException) { log.error("Failed to send error event", ioException); } + throw e; } finally { try { os.flush(); @@ -691,6 +516,76 @@ private void handleStreamingEval( } } + private void setEvalSpanAttributes( + Span evalSpan, + BraintrustUtils.Parent braintrustParent, + String braintrustGeneration, + DatasetCase datasetCase, + TaskResult taskResult) { + var spanAttrs = new LinkedHashMap<>(); + spanAttrs.put("type", "eval"); + spanAttrs.put("name", "eval"); + if (braintrustGeneration != null) { + spanAttrs.put("generation", braintrustGeneration); + } + evalSpan.setAttribute(PARENT, braintrustParent.toParentValue()) + .setAttribute("braintrust.span_attributes", json(spanAttrs)) + .setAttribute("braintrust.input_json", json(Map.of("input", datasetCase.input()))) + .setAttribute("braintrust.expected_json", json(datasetCase.expected())); + + if (datasetCase.origin().isPresent()) { + evalSpan.setAttribute("braintrust.origin", json(datasetCase.origin().get())); + } + if (!datasetCase.tags().isEmpty()) { + evalSpan.setAttribute( + AttributeKey.stringArrayKey("braintrust.tags"), datasetCase.tags()); + } + if (!datasetCase.metadata().isEmpty()) { + evalSpan.setAttribute("braintrust.metadata", json(datasetCase.metadata())); + } + evalSpan.setAttribute( + "braintrust.output_json", json(Map.of("output", taskResult.result()))); + } + + private void setTaskSpanAttributes( + Span taskSpan, + BraintrustUtils.Parent braintrustParent, + String braintrustGeneration, + DatasetCase datasetCase, + TaskResult taskResult) { + Map taskSpanAttrs = new LinkedHashMap<>(); + taskSpanAttrs.put("type", "task"); + taskSpanAttrs.put("name", "task"); + if (braintrustGeneration != null) { + taskSpanAttrs.put("generation", braintrustGeneration); + } + + taskSpan.setAttribute(PARENT, braintrustParent.toParentValue()) + .setAttribute("braintrust.span_attributes", json(taskSpanAttrs)) + .setAttribute("braintrust.input_json", json(Map.of("input", datasetCase.input()))) + .setAttribute( + "braintrust.output_json", json(Map.of("output", taskResult.result()))); + } + + private void setScoreSpanAttributes( + Span scoreSpan, + BraintrustUtils.Parent braintrustParent, + String braintrustGeneration, + String scorerName, + Map scorerScores) { + Map scoreSpanAttrs = new LinkedHashMap<>(); + scoreSpanAttrs.put("type", "score"); + scoreSpanAttrs.put("name", scorerName); + if (braintrustGeneration != null) { + scoreSpanAttrs.put("generation", braintrustGeneration); + } + + scoreSpan + .setAttribute(PARENT, braintrustParent.toParentValue()) + .setAttribute("braintrust.span_attributes", json(scoreSpanAttrs)) + .setAttribute("braintrust.output_json", json(scorerScores)); + } + private void sendSSEEvent(OutputStream os, String eventType, String data) throws IOException { String event = "event: " + eventType + "\n" + "data: " + data + "\n\n"; os.write(event.getBytes(StandardCharsets.UTF_8)); @@ -707,9 +602,7 @@ private void sendProgressEvent( progressData.put("id", spanId); progressData.put("object_type", "task"); - if (origin.isPresent()) { - progressData.put("origin", origin.get()); - } + origin.ifPresent(value -> progressData.put("origin", value)); progressData.put("name", evalName); progressData.put("format", "code"); progressData.put("output_type", "completion"); @@ -738,7 +631,6 @@ private void sendSummaryEvent( summary.put("experimentUrl", null); summary.put("comparisonExperimentName", null); - // Add scores with additional Python-specific fields Map scoresWithMeta = new LinkedHashMap<>(); for (Map.Entry entry : scoreSummaries.entrySet()) { Map scoreData = new LinkedHashMap<>(); @@ -760,28 +652,6 @@ private void sendDoneEvent(OutputStream os) throws IOException { sendSSEEvent(os, "done", ""); } - private void sendStartEvent( - OutputStream os, - String projectName, - String projectId, - String experimentName, - String experimentId, - String projectUrl, - String experimentUrl) - throws IOException { - Map startData = new LinkedHashMap<>(); - startData.put("experimentName", experimentName); - startData.put("projectName", projectName); - startData.put("projectId", projectId); - startData.put("experimentId", experimentId); - startData.put("experimentUrl", experimentUrl); - startData.put("projectUrl", projectUrl); - startData.put("comparisonExperimentName", null); - startData.put("scores", Map.of()); - - sendSSEEvent(os, "start", JSON_MAPPER.writeValueAsString(startData)); - } - private String json(Object o) { try { return JSON_MAPPER.writeValueAsString(o); @@ -1075,6 +945,127 @@ private void sendErrorResponse(HttpExchange exchange, int statusCode, String mes sendResponse(exchange, statusCode, "application/json", json); } + /** + * Container for parent information extracted from eval request. + * + * @param braintrustParent The parent specification in "type:id" format (e.g., + * "playground_id:abc123") + * @param generation The generation identifier from the request + */ + private record ParentInfo( + @Nonnull BraintrustUtils.Parent braintrustParent, @Nullable String generation) {} + + /** + * Extracts parent information from the eval request. + * + * @param request The eval request + * @return ParentInfo containing braintrustParent and generation + */ + private static ParentInfo extractParentInfo(EvalRequest request) { + String parentSpec = null; + String generation = null; + + // Extract parent spec and generation from request + if (request.getParent() != null && request.getParent() instanceof Map) { + @SuppressWarnings("unchecked") + Map parentMap = (Map) request.getParent(); + String objectType = (String) parentMap.get("object_type"); + String objectId = (String) parentMap.get("object_id"); + + // Extract generation from propagated_event.span_attributes.generation + Object propEventObj = parentMap.get("propagated_event"); + if (propEventObj instanceof Map) { + @SuppressWarnings("unchecked") + Map propEvent = (Map) propEventObj; + Object spanAttrsObj = propEvent.get("span_attributes"); + if (spanAttrsObj instanceof Map) { + @SuppressWarnings("unchecked") + Map spanAttrs = (Map) spanAttrsObj; + generation = (String) spanAttrs.get("generation"); + } + } + + if (objectType != null && objectId != null) { + parentSpec = "playground_id:" + objectId; + } + } + + if (parentSpec == null) { + throw new IllegalArgumentException("braintrust parent (playground_id) not found"); + } + return new ParentInfo(BraintrustUtils.parseParent(parentSpec), generation); + } + + /** + * Extracts and loads the dataset from the eval request. + * + *

Supports three methods of loading data: + * + *

    + *
  1. Inline data provided in the request + *
  2. Fetch by project name and dataset name + *
  3. Fetch by dataset ID + *
+ * + * @param request The eval request containing dataset specification + * @param apiClient The Braintrust API client for fetching datasets + * @return The loaded dataset + * @throws IllegalStateException if no dataset specification is provided + * @throws IllegalArgumentException if dataset or project is not found + */ + private static Dataset extractDataset( + EvalRequest request, BraintrustApiClient apiClient) { + EvalRequest.DataSpec dataSpec = request.getData(); + + if (dataSpec.getData() != null && !dataSpec.getData().isEmpty()) { + // Method 1: Inline data + List cases = new ArrayList<>(); + for (EvalRequest.EvalCaseData caseData : dataSpec.getData()) { + DatasetCase datasetCase = + DatasetCase.of( + caseData.getInput(), + caseData.getExpected(), + caseData.getTags() != null ? caseData.getTags() : List.of(), + caseData.getMetadata() != null ? caseData.getMetadata() : Map.of()); + cases.add(datasetCase); + } + return Dataset.of(cases.toArray(new DatasetCase[0])); + } else if (dataSpec.getProjectName() != null && dataSpec.getDatasetName() != null) { + // Method 2: Fetch by project name and dataset name + log.debug( + "Fetching dataset from Braintrust: project={}, dataset={}", + dataSpec.getProjectName(), + dataSpec.getDatasetName()); + return Dataset.fetchFromBraintrust( + apiClient, dataSpec.getProjectName(), dataSpec.getDatasetName(), null); + } else if (dataSpec.getDatasetId() != null) { + // Method 3: Fetch by dataset ID + log.debug("Fetching dataset from Braintrust by ID: {}", dataSpec.getDatasetId()); + var datasetMetadata = apiClient.getDataset(dataSpec.getDatasetId()); + if (datasetMetadata.isEmpty()) { + throw new IllegalArgumentException("Dataset not found: " + dataSpec.getDatasetId()); + } + + var project = apiClient.getProject(datasetMetadata.get().projectId()); + if (project.isEmpty()) { + throw new IllegalArgumentException( + "Project not found: " + datasetMetadata.get().projectId()); + } + + String fetchedProjectName = project.get().name(); + String fetchedDatasetName = datasetMetadata.get().name(); + log.debug( + "Resolved dataset ID to project={}, dataset={}", + fetchedProjectName, + fetchedDatasetName); + + return Dataset.fetchFromBraintrust( + apiClient, fetchedProjectName, fetchedDatasetName, null); + } else { + throw new IllegalStateException("No dataset specification provided"); + } + } + public static class Builder { private @Nullable BraintrustConfig config = null; private String host = "localhost"; diff --git a/src/test/java/dev/braintrust/devserver/DevserverTest.java b/src/test/java/dev/braintrust/devserver/DevserverTest.java index 11a2ab7..b1a4065 100644 --- a/src/test/java/dev/braintrust/devserver/DevserverTest.java +++ b/src/test/java/dev/braintrust/devserver/DevserverTest.java @@ -265,34 +265,19 @@ void testStreamingEval() throws Exception { assertFalse(events.isEmpty(), "Should have received events"); // Count events by type - long startCount = events.stream().filter(e -> "start".equals(e.get("event"))).count(); - long progressCount = events.stream().filter(e -> "progress".equals(e.get("event"))).count(); - long summaryCount = events.stream().filter(e -> "summary".equals(e.get("event"))).count(); - long doneCount = events.stream().filter(e -> "done".equals(e.get("event"))).count(); + List> progressEvents = + events.stream().filter(e -> "progress".equals(e.get("event"))).toList(); + List> summaryEvents = + events.stream().filter(e -> "summary".equals(e.get("event"))).toList(); + List> doneEvents = + events.stream().filter(e -> "done".equals(e.get("event"))).toList(); // Should have 1 start event, 2 progress events (one per dataset case), 1 summary, 1 done - assertEquals(1, startCount, "Should have 1 start event"); - assertEquals(2, progressCount, "Should have 2 progress events"); - assertEquals(1, summaryCount, "Should have 1 summary event"); - assertEquals(1, doneCount, "Should have 1 done event"); - - // Verify start event is first and has expected structure - assertEquals("start", events.get(0).get("event"), "First event should be start"); - Map startEvent = events.get(0); - JsonNode startData = JSON_MAPPER.readTree(startEvent.get("data")); - - assertEquals("test-project", startData.get("projectName").asText()); - assertTrue(startData.has("projectId")); - assertEquals("food-type-classifier", startData.get("experimentName").asText()); - assertTrue(startData.has("experimentUrl")); - assertTrue(startData.has("projectUrl")); - assertTrue(startData.has("scores")); - assertTrue(startData.get("experimentId").isNull()); + assertEquals(2, progressEvents.size(), "Should have 2 progress events"); + assertEquals(1, summaryEvents.size(), "Should have 1 summary event"); + assertEquals(1, doneEvents.size(), "Should have 1 done event"); // Verify progress events match expected structure - List> progressEvents = - events.stream().filter(e -> "progress".equals(e.get("event"))).toList(); - for (Map progressEvent : progressEvents) { String dataJson = progressEvent.get("data"); JsonNode progressData = JSON_MAPPER.readTree(dataJson); @@ -310,25 +295,22 @@ void testStreamingEval() throws Exception { assertEquals("\"java-fruit\"", taskResultJson); } - // Verify summary event - Map summaryEvent = - events.stream() - .filter(e -> "summary".equals(e.get("event"))) - .findFirst() - .orElseThrow(); - JsonNode summaryData = JSON_MAPPER.readTree(summaryEvent.get("data")); - - assertEquals("test-project", summaryData.get("projectName").asText()); - assertTrue(summaryData.has("projectId")); - assertEquals("food-type-classifier", summaryData.get("experimentName").asText()); - - // Verify scores in summary - assertTrue(summaryData.has("scores")); - JsonNode scores = summaryData.get("scores"); - assertTrue(scores.has("simple_scorer")); - JsonNode simpleScorer = scores.get("simple_scorer"); - assertEquals("simple_scorer", simpleScorer.get("name").asText()); - assertEquals(0.7, simpleScorer.get("score").asDouble(), 0.001); + { // Verify summary event + Map summaryEvent = summaryEvents.get(0); + JsonNode summaryData = JSON_MAPPER.readTree(summaryEvent.get("data")); + + assertEquals("test-project", summaryData.get("projectName").asText()); + assertTrue(summaryData.has("projectId")); + assertEquals("food-type-classifier", summaryData.get("experimentName").asText()); + + // Verify scores in summary + assertTrue(summaryData.has("scores")); + JsonNode scores = summaryData.get("scores"); + assertTrue(scores.has("simple_scorer")); + JsonNode simpleScorer = scores.get("simple_scorer"); + assertEquals("simple_scorer", simpleScorer.get("name").asText()); + assertEquals(0.7, simpleScorer.get("score").asDouble(), 0.001); + } // Get exported spans from test harness (since devserver uses global tracer) List exportedSpans = testHarness.awaitExportedSpans(); @@ -342,53 +324,145 @@ void testStreamingEval() throws Exception { assertEquals(8, exportedSpans.size(), "Should have 8 spans (4 per dataset case)"); // Verify span types - long evalSpans = exportedSpans.stream().filter(s -> s.getName().equals("eval")).count(); - long taskSpans = exportedSpans.stream().filter(s -> s.getName().equals("task")).count(); - long scoreSpans = exportedSpans.stream().filter(s -> s.getName().equals("score")).count(); - long customSpans = - exportedSpans.stream().filter(s -> s.getName().equals("custom-task-span")).count(); - - assertEquals(2, evalSpans, "Should have 2 eval spans"); - assertEquals(2, taskSpans, "Should have 2 task spans"); - assertEquals(2, scoreSpans, "Should have 2 score spans"); - assertEquals(2, customSpans, "Should have 2 custom-task-span spans"); - - // Verify each eval span has playground_id parent - for (SpanData span : exportedSpans) { - if (span.getName().equals("eval")) { - String parent = - span.getAttributes() - .get( - io.opentelemetry.api.common.AttributeKey.stringKey( - "braintrust.parent")); - assertNotNull(parent, "Eval span should have parent attribute"); - assertTrue( - parent.contains("playground_id:test-playground-id-123"), - "Parent should contain playground_id"); - } + var evalSpans = exportedSpans.stream().filter(s -> s.getName().equals("eval")).toList(); + var taskSpans = exportedSpans.stream().filter(s -> s.getName().equals("task")).toList(); + var scoreSpans = exportedSpans.stream().filter(s -> s.getName().equals("score")).toList(); + var customSpans = + exportedSpans.stream().filter(s -> s.getName().equals("custom-task-span")).toList(); + + assertEquals(2, evalSpans.size(), "Should have 2 eval spans"); + assertEquals(2, taskSpans.size(), "Should have 2 task spans"); + assertEquals(2, scoreSpans.size(), "Should have 2 score spans"); + assertEquals(2, customSpans.size(), "Should have 2 custom-task-span spans"); + + // Verify eval spans have all required attributes + for (SpanData evalSpan : evalSpans) { + // Verify braintrust.parent + String parent = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull(parent, "Eval span should have parent attribute"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Parent should contain playground_id"); + + String spanAttrsJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.span_attributes")); + assertNotNull(spanAttrsJson, "Eval span should have span_attributes"); + JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); + assertEquals("eval", spanAttrs.get("type").asText()); + assertEquals("eval", spanAttrs.get("name").asText()); + assertEquals("test-gen-1", spanAttrs.get("generation").asText()); + + String inputJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.input_json")); + assertNotNull(inputJson, "Eval span should have input_json"); + + String expectedJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.expected_json")); + assertNotNull(expectedJson, "Eval span should have expected_json"); + + String outputJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.output_json")); + assertNotNull(outputJson, "Eval span should have output_json"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertEquals("java-fruit", output.get("output").asText()); } - // Verify span attributes contain generation - for (SpanData span : exportedSpans) { + for (SpanData taskSpan : taskSpans) { + // Verify braintrust.parent + String parent = + taskSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull(parent, "Task span should have parent attribute"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Parent should contain playground_id"); + String spanAttrsJson = - span.getAttributes() + taskSpan.getAttributes() .get( io.opentelemetry.api.common.AttributeKey.stringKey( "braintrust.span_attributes")); - if (spanAttrsJson != null) { - JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); - if (spanAttrs.has("generation")) { - assertEquals("test-gen-1", spanAttrs.get("generation").asText()); - } - } + assertNotNull(spanAttrsJson, "Task span should have span_attributes"); + JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); + assertEquals("task", spanAttrs.get("type").asText()); + assertEquals("task", spanAttrs.get("name").asText()); + assertEquals("test-gen-1", spanAttrs.get("generation").asText()); + + String inputJson = + taskSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.input_json")); + assertNotNull(inputJson, "Task span should have input_json"); + + String outputJson = + taskSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.output_json")); + assertNotNull(outputJson, "Task span should have output_json"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertEquals("java-fruit", output.get("output").asText()); } - // Verify custom span created in task function has proper parent propagation - List customSpansList = - exportedSpans.stream().filter(s -> s.getName().equals("custom-task-span")).toList(); - assertEquals(2, customSpansList.size(), "Should have 2 custom spans"); + for (SpanData scoreSpan : scoreSpans) { + // Verify braintrust.parent + String parent = + scoreSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull(parent, "Score span should have parent attribute"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Parent should contain playground_id"); + + // Verify braintrust.span_attributes + String spanAttrsJson = + scoreSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.span_attributes")); + assertNotNull(spanAttrsJson, "Score span should have span_attributes"); + JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); + assertEquals("score", spanAttrs.get("type").asText()); + assertEquals("simple_scorer", spanAttrs.get("name").asText()); + assertEquals("test-gen-1", spanAttrs.get("generation").asText()); + + // Verify braintrust.output_json contains scores + String outputJson = + scoreSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.output_json")); + assertNotNull(outputJson, "Score span should have output_json"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertTrue(output.has("simple_scorer"), "Output should contain scorer results"); + assertEquals(0.7, output.get("simple_scorer").asDouble(), 0.001); + } - for (SpanData customSpan : customSpansList) { + for (SpanData customSpan : customSpans) { // Verify it has a parent span (is not a root span) assertTrue( customSpan.getParentSpanContext().isValid(), @@ -412,4 +486,115 @@ void testStreamingEval() throws Exception { "Custom span parent should contain playground_id from baggage propagation"); } } + + @Test + void testEvaluatorNotFound() throws Exception { + EvalRequest request = new EvalRequest(); + request.setName("non-existent-eval"); + + EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); + EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); + case1.setInput("test"); + dataSpec.setData(List.of(case1)); + request.setData(dataSpec); + + String requestJson = JSON_MAPPER.writeValueAsString(request); + + HttpClient client = HttpClient.newHttpClient(); + HttpRequest httpRequest = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/eval")) + .POST(HttpRequest.BodyPublishers.ofString(requestJson)) + .header("Content-Type", "application/json") + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = + client.send(httpRequest, HttpResponse.BodyHandlers.ofString()); + + assertEquals(404, response.statusCode()); + assertTrue(response.body().contains("Evaluator not found")); + } + + @Test + void testEvalMethodNotAllowed() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder().uri(URI.create(TEST_URL + "/eval")).GET().build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(405, response.statusCode()); + } + + @Test + void testListEndpoint() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/list")) + .GET() + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "application/json", response.headers().firstValue("Content-Type").orElse(null)); + + // Parse and validate JSON response + JsonNode root = JSON_MAPPER.readTree(response.body()); + + // Should have one evaluator + assertTrue(root.has("food-type-classifier")); + + JsonNode eval = root.get("food-type-classifier"); + + // Check scores + assertTrue(eval.has("scores")); + JsonNode scores = eval.get("scores"); + assertEquals(1, scores.size()); + + assertEquals("simple_scorer", scores.get(0).get("name").asText()); + } + + @Test + void testListMethodNotAllowed() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/list")) + .POST(HttpRequest.BodyPublishers.noBody()) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(405, response.statusCode()); + } + + @Test + void testListEndpointWithCors() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/list")) + .GET() + .header("Origin", "https://www.braintrust.dev") + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "https://www.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + } } diff --git a/src/test/java/dev/braintrust/devserver/EvalEndpointTest.java b/src/test/java/dev/braintrust/devserver/EvalEndpointTest.java deleted file mode 100644 index 6004dc5..0000000 --- a/src/test/java/dev/braintrust/devserver/EvalEndpointTest.java +++ /dev/null @@ -1,178 +0,0 @@ -package dev.braintrust.devserver; - -import static org.junit.jupiter.api.Assertions.*; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import dev.braintrust.TestUtils; -import dev.braintrust.config.BraintrustConfig; -import dev.braintrust.eval.Scorer; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.List; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -class EvalEndpointTest { - private static Devserver server; - private static Thread serverThread; - private static final int TEST_PORT = TestUtils.getRandomOpenPort(); - private static final ObjectMapper mapper = new ObjectMapper(); - - @BeforeAll - static void setup() throws Exception { - // Create a test eval - RemoteEval testEval = - RemoteEval.builder() - .name("uppercase-eval") - .taskFunction(String::toUpperCase) - .scorer( - Scorer.of( - "length", - (expected, result) -> - Math.min((double) result.length() / 10.0, 1.0))) - .scorer( - Scorer.of( - "has_hello", - (expected, result) -> result.contains("HELLO") ? 1.0 : 0.0)) - .build(); - - server = - Devserver.builder() - .config(BraintrustConfig.of("BRAINTRUST_API_KEY", "bogus")) - .registerEval(testEval) - .host("localhost") - .port(TEST_PORT) - .build(); - - serverThread = - new Thread( - () -> { - try { - server.start(); - } catch (Exception e) { - e.printStackTrace(); - } - }); - serverThread.start(); - Thread.sleep(1000); // Give server time to start - } - - @AfterAll - static void teardown() { - server.stop(); - serverThread.interrupt(); - } - - @Test - @Disabled - void testEvalEndpointWithInlineData() throws Exception { - // Build request - EvalRequest request = new EvalRequest(); - request.setName("uppercase-eval"); - - EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); - - EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); - case1.setInput("hello"); - case1.setExpected("HELLO"); - - EvalRequest.EvalCaseData case2 = new EvalRequest.EvalCaseData(); - case2.setInput("world"); - case2.setExpected("WORLD"); - - dataSpec.setData(List.of(case1, case2)); - request.setData(dataSpec); - - String requestJson = mapper.writeValueAsString(request); - - HttpClient client = HttpClient.newHttpClient(); - HttpRequest httpRequest = - HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + TEST_PORT + "/eval")) - .POST(HttpRequest.BodyPublishers.ofString(requestJson)) - .header("Content-Type", "application/json") - .build(); - - HttpResponse response = - client.send(httpRequest, HttpResponse.BodyHandlers.ofString()); - - assertEquals(200, response.statusCode()); - assertEquals( - "application/json", response.headers().firstValue("Content-Type").orElse(null)); - - // Parse and validate JSON response - JsonNode root = mapper.readTree(response.body()); - - assertTrue(root.has("experimentName")); - assertTrue(root.has("projectName")); - assertTrue(root.has("projectId")); - assertTrue(root.has("experimentId")); - assertTrue(root.has("experimentUrl")); - assertTrue(root.has("projectUrl")); - assertTrue(root.has("scores")); - - JsonNode scores = root.get("scores"); - assertTrue(scores.has("length")); - assertTrue(scores.has("has_hello")); - - // Check length scorer (average of 5/10 and 5/10 = 0.5) - JsonNode lengthScore = scores.get("length"); - assertEquals("length", lengthScore.get("name").asText()); - assertEquals(0.5, lengthScore.get("score").asDouble(), 0.01); - - // Check has_hello scorer (1.0 for "hello", 0.0 for "world" = 0.5 average) - JsonNode helloScore = scores.get("has_hello"); - assertEquals("has_hello", helloScore.get("name").asText()); - assertEquals(0.5, helloScore.get("score").asDouble(), 0.01); - } - - @Test - void testEvalEndpointEvaluatorNotFound() throws Exception { - EvalRequest request = new EvalRequest(); - request.setName("non-existent-eval"); - - EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); - EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); - case1.setInput("test"); - dataSpec.setData(List.of(case1)); - request.setData(dataSpec); - - String requestJson = mapper.writeValueAsString(request); - - HttpClient client = HttpClient.newHttpClient(); - HttpRequest httpRequest = - HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + TEST_PORT + "/eval")) - .POST(HttpRequest.BodyPublishers.ofString(requestJson)) - .header("Content-Type", "application/json") - .header("x-bt-auth-token", "test-token") - .header("x-bt-project-id", "test-project-id") - .header("x-bt-org-name", "test-org") - .build(); - - HttpResponse response = - client.send(httpRequest, HttpResponse.BodyHandlers.ofString()); - - assertEquals(404, response.statusCode()); - assertTrue(response.body().contains("Evaluator not found")); - } - - @Test - void testEvalEndpointMethodNotAllowed() throws Exception { - HttpClient client = HttpClient.newHttpClient(); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + TEST_PORT + "/eval")) - .GET() - .build(); - - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - - assertEquals(405, response.statusCode()); - } -} diff --git a/src/test/java/dev/braintrust/devserver/ListEndpointTest.java b/src/test/java/dev/braintrust/devserver/ListEndpointTest.java deleted file mode 100644 index 71150b6..0000000 --- a/src/test/java/dev/braintrust/devserver/ListEndpointTest.java +++ /dev/null @@ -1,166 +0,0 @@ -package dev.braintrust.devserver; - -import static org.junit.jupiter.api.Assertions.*; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import dev.braintrust.config.BraintrustConfig; -import dev.braintrust.eval.Scorer; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.Map; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -class ListEndpointTest { - private static Devserver server; - private static Thread serverThread; - private static final int TEST_PORT = 18302; - private static final ObjectMapper mapper = new ObjectMapper(); - - @BeforeAll - static void setup() throws Exception { - // Create a test eval with parameters - RemoteEval testEval = - RemoteEval.builder() - .name("test-classifier") - .taskFunction(input -> input.toUpperCase()) - .scorer(Scorer.of("accuracy", result -> 0.95)) - .scorer( - Scorer.of( - "length", - (expected, result) -> (double) result.length() / 10.0)) - .parameter( - "model", - RemoteEval.Parameter.dataParameter( - "The model to use", - Map.of( - "type", - "string", - "enum", - new String[] {"gpt-4", "gpt-3.5"}), - "gpt-4")) - .parameter( - "temperature", - RemoteEval.Parameter.dataParameter( - "Temperature for sampling", - Map.of("type", "number", "minimum", 0.0, "maximum", 2.0), - 0.7)) - .build(); - - server = - Devserver.builder() - .config(BraintrustConfig.of("BRAINTRUST_API_KEY", "bogus")) - .registerEval(testEval) - .host("localhost") - .port(TEST_PORT) - .build(); - - serverThread = - new Thread( - () -> { - try { - server.start(); - } catch (Exception e) { - e.printStackTrace(); - } - }); - serverThread.start(); - Thread.sleep(1000); // Give server time to start - } - - @AfterAll - static void teardown() { - server.stop(); - serverThread.interrupt(); - } - - @Test - void testListEndpoint() throws Exception { - HttpClient client = HttpClient.newHttpClient(); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + TEST_PORT + "/list")) - .GET() - .header("x-bt-auth-token", "test-token") - .header("x-bt-project-id", "test-project-id") - .header("x-bt-org-name", "test-org") - .build(); - - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - - assertEquals(200, response.statusCode()); - assertEquals( - "application/json", response.headers().firstValue("Content-Type").orElse(null)); - - // Parse and validate JSON response - JsonNode root = mapper.readTree(response.body()); - - // Should have one evaluator - assertTrue(root.has("test-classifier")); - - JsonNode eval = root.get("test-classifier"); - - // Check parameters - assertTrue(eval.has("parameters")); - JsonNode parameters = eval.get("parameters"); - - assertTrue(parameters.has("model")); - JsonNode modelParam = parameters.get("model"); - assertEquals("data", modelParam.get("type").asText()); - assertEquals("The model to use", modelParam.get("description").asText()); - assertEquals("gpt-4", modelParam.get("default").asText()); - assertTrue(modelParam.has("schema")); - - assertTrue(parameters.has("temperature")); - JsonNode tempParam = parameters.get("temperature"); - assertEquals("data", tempParam.get("type").asText()); - assertEquals(0.7, tempParam.get("default").asDouble()); - - // Check scores - assertTrue(eval.has("scores")); - JsonNode scores = eval.get("scores"); - assertEquals(2, scores.size()); - - assertEquals("accuracy", scores.get(0).get("name").asText()); - assertEquals("length", scores.get(1).get("name").asText()); - } - - @Test - void testListEndpointMethodNotAllowed() throws Exception { - HttpClient client = HttpClient.newHttpClient(); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + TEST_PORT + "/list")) - .POST(HttpRequest.BodyPublishers.noBody()) - .build(); - - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - - assertEquals(405, response.statusCode()); - } - - @Test - void testListEndpointWithCors() throws Exception { - HttpClient client = HttpClient.newHttpClient(); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create("http://localhost:" + TEST_PORT + "/list")) - .GET() - .header("Origin", "https://www.braintrust.dev") - .header("x-bt-auth-token", "test-token") - .header("x-bt-project-id", "test-project-id") - .header("x-bt-org-name", "test-org") - .build(); - - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - - assertEquals(200, response.statusCode()); - assertEquals( - "https://www.braintrust.dev", - response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); - } -} From 29e59d32cb3c673ed27e498376d3345a3c5b9642 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Fri, 19 Dec 2025 14:23:59 -0700 Subject: [PATCH 4/7] LRU cache tests --- .../dev/braintrust/devserver/LRUCache.java | 2 +- .../braintrust/devserver/LRUCacheTest.java | 162 ++++++++++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 src/test/java/dev/braintrust/devserver/LRUCacheTest.java diff --git a/src/main/java/dev/braintrust/devserver/LRUCache.java b/src/main/java/dev/braintrust/devserver/LRUCache.java index 9a5a6d2..f0c9362 100644 --- a/src/main/java/dev/braintrust/devserver/LRUCache.java +++ b/src/main/java/dev/braintrust/devserver/LRUCache.java @@ -16,7 +16,7 @@ * @param Value type */ @ThreadSafe -public class LRUCache { +class LRUCache { private final int maxSize; private final Map cache; diff --git a/src/test/java/dev/braintrust/devserver/LRUCacheTest.java b/src/test/java/dev/braintrust/devserver/LRUCacheTest.java new file mode 100644 index 0000000..56b41d0 --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/LRUCacheTest.java @@ -0,0 +1,162 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +class LRUCacheTest { + + @Test + void testBasicPutAndGet() { + LRUCache cache = new LRUCache<>(3); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + + assertEquals("value1", cache.get("key1")); + assertEquals("value2", cache.get("key2")); + assertNull(cache.get("key3")); + } + + @Test + void testLruEviction() { + LRUCache cache = new LRUCache<>(2); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + assertEquals(2, cache.size()); + + // Adding third item should evict the least recently used (key1) + cache.put("key3", "value3"); + assertEquals(2, cache.size()); + assertNull(cache.get("key1"), "key1 should have been evicted"); + assertEquals("value2", cache.get("key2")); + assertEquals("value3", cache.get("key3")); + } + + @Test + void testLruEvictionWithAccess() { + LRUCache cache = new LRUCache<>(2); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + + // Access key1 to make it more recently used + cache.get("key1"); + + // Adding third item should now evict key2 (least recently used) + cache.put("key3", "value3"); + assertEquals(2, cache.size()); + assertEquals("value1", cache.get("key1"), "key1 should still be present"); + assertNull(cache.get("key2"), "key2 should have been evicted"); + assertEquals("value3", cache.get("key3")); + } + + @Test + void testGetOrComputeCacheHit() { + LRUCache cache = new LRUCache<>(3); + AtomicInteger computeCount = new AtomicInteger(0); + + // First call should compute and cache + String result1 = + cache.getOrCompute( + "key1", + () -> { + computeCount.incrementAndGet(); + return "computed-value"; + }); + assertEquals("computed-value", result1); + assertEquals(1, computeCount.get()); + + // Second call should return cached value without computing + String result2 = + cache.getOrCompute( + "key1", + () -> { + computeCount.incrementAndGet(); + return "should-not-be-called"; + }); + assertEquals("computed-value", result2); + assertEquals(1, computeCount.get(), "Supplier should not have been called on cache hit"); + } + + @Test + void testGetOrComputeCacheMiss() { + LRUCache cache = new LRUCache<>(3); + AtomicInteger computeCount = new AtomicInteger(0); + + String result = + cache.getOrCompute( + "key1", + () -> { + computeCount.incrementAndGet(); + return "value-" + computeCount.get(); + }); + + assertEquals("value-1", result); + assertEquals(1, computeCount.get()); + assertTrue(cache.containsKey("key1")); + assertEquals("value-1", cache.get("key1")); + } + + @Test + void testContainsKey() { + LRUCache cache = new LRUCache<>(3); + + assertFalse(cache.containsKey("key1")); + + cache.put("key1", "value1"); + assertTrue(cache.containsKey("key1")); + assertFalse(cache.containsKey("key2")); + } + + @Test + void testClear() { + LRUCache cache = new LRUCache<>(3); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + assertEquals(2, cache.size()); + + cache.clear(); + assertEquals(0, cache.size()); + assertFalse(cache.containsKey("key1")); + assertFalse(cache.containsKey("key2")); + } + + @Test + void testSize() { + LRUCache cache = new LRUCache<>(5); + + assertEquals(0, cache.size()); + + cache.put("key1", "value1"); + assertEquals(1, cache.size()); + + cache.put("key2", "value2"); + cache.put("key3", "value3"); + assertEquals(3, cache.size()); + + cache.clear(); + assertEquals(0, cache.size()); + } + + @Test + void testMaxSizeEnforcement() { + LRUCache cache = new LRUCache<>(3); + + cache.put(1, "one"); + cache.put(2, "two"); + cache.put(3, "three"); + assertEquals(3, cache.size()); + + // Adding more items should maintain max size + cache.put(4, "four"); + assertEquals(3, cache.size(), "Cache should not exceed max size"); + + cache.put(5, "five"); + cache.put(6, "six"); + assertEquals(3, cache.size(), "Cache should not exceed max size"); + } +} From be41f03799aba677a06a25e58bfdaf99f4c5e007 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Mon, 22 Dec 2025 13:08:47 -0700 Subject: [PATCH 5/7] ensure login works before populating devserver auth cache --- .../braintrust/api/BraintrustApiClient.java | 19 +++++++++++++-- .../dev/braintrust/api/LoginException.java | 23 +++++++++++++++++++ .../dev/braintrust/devserver/Devserver.java | 12 +++++----- 3 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 src/main/java/dev/braintrust/api/LoginException.java diff --git a/src/main/java/dev/braintrust/api/BraintrustApiClient.java b/src/main/java/dev/braintrust/api/BraintrustApiClient.java index a8f8965..20dc278 100644 --- a/src/main/java/dev/braintrust/api/BraintrustApiClient.java +++ b/src/main/java/dev/braintrust/api/BraintrustApiClient.java @@ -26,6 +26,14 @@ * {@link dev.braintrust.eval.Eval} or {@link dev.braintrust.trace.BraintrustTracing} */ public interface BraintrustApiClient { + /** + * Attempt Braintrust login + * + * @return LoginResponse containing organization info + * @throws LoginException if login fails due to invalid credentials or network errors + */ + LoginResponse login() throws LoginException; + /** Creates or gets a project by name. */ Project getOrCreateProject(String projectName); @@ -117,7 +125,8 @@ public Experiment getOrCreateExperiment(CreateExperimentRequest request) { } } - private LoginResponse login() { + @Override + public LoginResponse login() throws LoginException { try { return postAsync( "/api/apikey/login", @@ -125,7 +134,7 @@ private LoginResponse login() { LoginResponse.class) .get(); } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); + throw new LoginException("Failed to login to Braintrust", e); } } @@ -403,6 +412,12 @@ public InMemoryImpl( this.prompts.addAll(prompts); } + @Override + public LoginResponse login() { + return new LoginResponse( + organizationAndProjectInfos.stream().map(o -> o.orgInfo).toList()); + } + @Override public Project getOrCreateProject(String projectName) { // Find existing project by name diff --git a/src/main/java/dev/braintrust/api/LoginException.java b/src/main/java/dev/braintrust/api/LoginException.java new file mode 100644 index 0000000..2d730b6 --- /dev/null +++ b/src/main/java/dev/braintrust/api/LoginException.java @@ -0,0 +1,23 @@ +package dev.braintrust.api; + +import javax.annotation.Nullable; + +/** + * Exception thrown when login to Braintrust fails. + * + *

This is a RuntimeException so it doesn't require explicit handling, but callers can catch it + * specifically if they want to handle login failures differently from other errors. + */ +public class LoginException extends RuntimeException { + public LoginException(String message) { + super(message); + } + + public LoginException(String message, @Nullable Throwable cause) { + super(message, cause); + } + + public LoginException(Throwable cause) { + super(cause); + } +} diff --git a/src/main/java/dev/braintrust/devserver/Devserver.java b/src/main/java/dev/braintrust/devserver/Devserver.java index 13d128f..e4b4e55 100644 --- a/src/main/java/dev/braintrust/devserver/Devserver.java +++ b/src/main/java/dev/braintrust/devserver/Devserver.java @@ -98,7 +98,7 @@ public class Devserver { com.fasterxml.jackson.core.JsonParser.Feature .INCLUDE_SOURCE_IN_LOCATION); - // LRU cache for token -> Braintrust mappings (max 32 entries as per api.md) + // LRU cache for token -> Braintrust mappings private final LRUCache authCache = new LRUCache<>(32); private Devserver(Builder builder) { @@ -504,7 +504,8 @@ private void handleStreamingEval( } catch (IOException ioException) { log.error("Failed to send error event", ioException); } - throw e; + // no need to re-throw. We've already sent 200 because we're streaming and the + // client will see the error event } finally { try { os.flush(); @@ -865,9 +866,6 @@ private RequestContext getBraintrust(HttpExchange exchange, RequestContext conte authCache.getOrCompute( cacheKey, () -> { - // Cache miss - would validate token with Braintrust API here - // TODO: Implement actual token validation with - // loginToState(token, orgName) log.debug( "Cached login state for org='{}', projectId='{}' (cache" + " size={})", @@ -889,7 +887,9 @@ private RequestContext getBraintrust(HttpExchange exchange, RequestContext conte configBuilderHook.accept(configBuilder); } - return Braintrust.of(configBuilder.build()); + var bt = Braintrust.of(configBuilder.build()); + bt.apiClient().login(); + return bt; }); log.debug( From 5d440dd99030806f7b51b98872cd7ad63383c64c Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Mon, 22 Dec 2025 14:29:31 -0700 Subject: [PATCH 6/7] thread-safe dataset processing --- .../java/dev/braintrust/devserver/Devserver.java | 6 +++++- .../java/dev/braintrust/devserver/RemoteEval.java | 12 ++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/main/java/dev/braintrust/devserver/Devserver.java b/src/main/java/dev/braintrust/devserver/Devserver.java index e4b4e55..29d022f 100644 --- a/src/main/java/dev/braintrust/devserver/Devserver.java +++ b/src/main/java/dev/braintrust/devserver/Devserver.java @@ -376,6 +376,8 @@ private void handleStreamingEval( final var braintrustParent = parentInfo.braintrustParent(); final var braintrustGeneration = parentInfo.generation(); + // NOTE: this code is serial but written in a thread-safe manner to support + // concurrent dataset fetching and eval execution extractDataset(request, apiClient) .forEach( datasetCase -> { @@ -589,7 +591,9 @@ private void setScoreSpanAttributes( private void sendSSEEvent(OutputStream os, String eventType, String data) throws IOException { String event = "event: " + eventType + "\n" + "data: " + data + "\n\n"; - os.write(event.getBytes(StandardCharsets.UTF_8)); + synchronized (this) { + os.write(event.getBytes(StandardCharsets.UTF_8)); + } } private void sendProgressEvent( diff --git a/src/main/java/dev/braintrust/devserver/RemoteEval.java b/src/main/java/dev/braintrust/devserver/RemoteEval.java index dc8eef1..75d569b 100644 --- a/src/main/java/dev/braintrust/devserver/RemoteEval.java +++ b/src/main/java/dev/braintrust/devserver/RemoteEval.java @@ -22,10 +22,18 @@ public class RemoteEval { /** The name of this evaluator (used as identifier) */ @Nonnull private final String name; - /** The task function that performs the evaluation */ + /** + * The task function that performs the evaluation + * + *

The task function must be thread safe. + */ @Nonnull private final Task task; - /** List of scorers for this evaluator */ + /** + * List of scorers for this evaluator + * + *

The score function must be thread safe. + */ @Singular @Nonnull private final List> scorers; /** Optional parameters that can be configured from the UI */ From c515ef006481e8e1d850f7f47391e7e8722da8bb Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Mon, 22 Dec 2025 14:33:12 -0700 Subject: [PATCH 7/7] rm dead code --- .../dev/braintrust/devserver/Devserver.java | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/src/main/java/dev/braintrust/devserver/Devserver.java b/src/main/java/dev/braintrust/devserver/Devserver.java index 29d022f..c3b2071 100644 --- a/src/main/java/dev/braintrust/devserver/Devserver.java +++ b/src/main/java/dev/braintrust/devserver/Devserver.java @@ -12,19 +12,10 @@ import dev.braintrust.eval.*; import dev.braintrust.trace.BraintrustContext; import dev.braintrust.trace.BraintrustTracing; -import io.opentelemetry.api.OpenTelemetry; -import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanKind; -import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; import io.opentelemetry.context.Context; -import io.opentelemetry.context.propagation.ContextPropagators; -import io.opentelemetry.context.propagation.TextMapPropagator; -import io.opentelemetry.sdk.OpenTelemetrySdk; -import io.opentelemetry.sdk.logs.SdkLoggerProvider; -import io.opentelemetry.sdk.metrics.SdkMeterProvider; -import io.opentelemetry.sdk.trace.SdkTracerProvider; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -910,30 +901,6 @@ private RequestContext getBraintrust(HttpExchange exchange, RequestContext conte .build(); } - private OpenTelemetry createOpenTelemetry(Braintrust braintrust) { - var tracerBuilder = SdkTracerProvider.builder(); - var loggerBuilder = SdkLoggerProvider.builder(); - var meterBuilder = SdkMeterProvider.builder(); - var contextPropagator = - ContextPropagators.create( - TextMapPropagator.composite( - W3CTraceContextPropagator.getInstance(), - W3CBaggagePropagator.getInstance())); - braintrust.openTelemetryEnable(tracerBuilder, loggerBuilder, meterBuilder); - - // Invoke hook if present to allow customization (e.g., adding span processors) - if (traceBuilderHook != null) { - traceBuilderHook.accept(tracerBuilder); - } - - return OpenTelemetrySdk.builder() - .setTracerProvider(tracerBuilder.build()) - .setLoggerProvider(loggerBuilder.build()) - .setMeterProvider(meterBuilder.build()) - .setPropagators(contextPropagator) - .build(); - } - /** * Send an error response with JSON body. *