diff --git a/examples/build.gradle b/examples/build.gradle index 2244681..d1e747c 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -141,3 +141,18 @@ task runSpringAI(type: JavaExec) { suspend = false } } + + +task runRemoteEval(type: JavaExec) { + group = 'Braintrust SDK Examples' + description = 'Run the remote eval example' + classpath = sourceSets.main.runtimeClasspath + mainClass = 'dev.braintrust.examples.RemoteEvalExample' + systemProperty 'org.slf4j.simpleLogger.log.dev.braintrust', braintrustLogLevel + debugOptions { + enabled = true + port = 5566 + server = true + suspend = false + } +} diff --git a/examples/src/main/java/dev/braintrust/examples/RemoteEvalExample.java b/examples/src/main/java/dev/braintrust/examples/RemoteEvalExample.java new file mode 100644 index 0000000..ad83ae3 --- /dev/null +++ b/examples/src/main/java/dev/braintrust/examples/RemoteEvalExample.java @@ -0,0 +1,76 @@ +package dev.braintrust.examples; + +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import dev.braintrust.Braintrust; +import dev.braintrust.devserver.Devserver; +import dev.braintrust.devserver.RemoteEval; +import dev.braintrust.eval.Scorer; +import dev.braintrust.instrumentation.openai.BraintrustOpenAI; +import java.util.List; + +/** Simple Dev Server for Remote Evals */ +public class RemoteEvalExample { + public static void main(String[] args) throws Exception { + var braintrust = Braintrust.get(); + var openTelemetry = braintrust.openTelemetryCreate(); + var openAIClient = BraintrustOpenAI.wrapOpenAI(openTelemetry, OpenAIOkHttpClient.fromEnv()); + + RemoteEval foodTypeEval = + RemoteEval.builder() + .name("food-type-classifier") + .taskFunction( + food -> { + var request = + ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .addSystemMessage("Return a one word answer") + .addUserMessage( + "What kind of food is " + food + "?") + .maxTokens(50L) + .temperature(0.0) + .build(); + var response = + openAIClient.chat().completions().create(request); + return response.choices() + .get(0) + .message() + .content() + .orElse("") + .toLowerCase(); + }) + .scorers( + List.of( + Scorer.of("static_scorer", (expected, result) -> 0.7), + Scorer.of( + "close_enough_match", + (expected, result) -> + expected.trim() + .equalsIgnoreCase( + result.trim()) + ? 1.0 + : 0.0))) + .build(); + + Devserver devserver = + Devserver.builder() + .config(braintrust.config()) + .registerEval(foodTypeEval) + .host("localhost") // set to 0.0.0.0 to bind all interfaces + .port(8301) + .build(); + + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + System.out.println("Shutting down..."); + devserver.stop(); + System.out.flush(); + System.err.flush(); + })); + System.out.println("Starting Braintrust dev server"); + devserver.start(); + } +} diff --git a/src/main/java/dev/braintrust/BraintrustUtils.java b/src/main/java/dev/braintrust/BraintrustUtils.java index b85ff42..4e2fa4a 100644 --- a/src/main/java/dev/braintrust/BraintrustUtils.java +++ b/src/main/java/dev/braintrust/BraintrustUtils.java @@ -3,6 +3,9 @@ import dev.braintrust.api.BraintrustApiClient; import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import javax.annotation.Nonnull; public class BraintrustUtils { @@ -28,7 +31,7 @@ public static URI createProjectURI( } } - static Parent parseParent(@Nonnull String parentStr) { + public static Parent parseParent(@Nonnull String parentStr) { String[] parts = parentStr.split(":"); if (parts.length != 2) { throw new IllegalArgumentException("Invalid parent format: " + parentStr); @@ -42,4 +45,18 @@ public String toParentValue() { return type + ":" + id; } } + + public static List parseCsv(String csv) { + if (csv == null || csv.isBlank()) { + return List.of(); + } + + return Arrays.stream(csv.split("\\s*,\\s*")).toList(); + } + + public static List append(List list, T value) { + List result = new ArrayList<>(list); + result.add(value); + return result; + } } diff --git a/src/main/java/dev/braintrust/api/BraintrustApiClient.java b/src/main/java/dev/braintrust/api/BraintrustApiClient.java index a8f8965..20dc278 100644 --- a/src/main/java/dev/braintrust/api/BraintrustApiClient.java +++ b/src/main/java/dev/braintrust/api/BraintrustApiClient.java @@ -26,6 +26,14 @@ * {@link dev.braintrust.eval.Eval} or {@link dev.braintrust.trace.BraintrustTracing} */ public interface BraintrustApiClient { + /** + * Attempt Braintrust login + * + * @return LoginResponse containing organization info + * @throws LoginException if login fails due to invalid credentials or network errors + */ + LoginResponse login() throws LoginException; + /** Creates or gets a project by name. */ Project getOrCreateProject(String projectName); @@ -117,7 +125,8 @@ public Experiment getOrCreateExperiment(CreateExperimentRequest request) { } } - private LoginResponse login() { + @Override + public LoginResponse login() throws LoginException { try { return postAsync( "/api/apikey/login", @@ -125,7 +134,7 @@ private LoginResponse login() { LoginResponse.class) .get(); } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); + throw new LoginException("Failed to login to Braintrust", e); } } @@ -403,6 +412,12 @@ public InMemoryImpl( this.prompts.addAll(prompts); } + @Override + public LoginResponse login() { + return new LoginResponse( + organizationAndProjectInfos.stream().map(o -> o.orgInfo).toList()); + } + @Override public Project getOrCreateProject(String projectName) { // Find existing project by name diff --git a/src/main/java/dev/braintrust/api/LoginException.java b/src/main/java/dev/braintrust/api/LoginException.java new file mode 100644 index 0000000..2d730b6 --- /dev/null +++ b/src/main/java/dev/braintrust/api/LoginException.java @@ -0,0 +1,23 @@ +package dev.braintrust.api; + +import javax.annotation.Nullable; + +/** + * Exception thrown when login to Braintrust fails. + * + *

This is a RuntimeException so it doesn't require explicit handling, but callers can catch it + * specifically if they want to handle login failures differently from other errors. + */ +public class LoginException extends RuntimeException { + public LoginException(String message) { + super(message); + } + + public LoginException(String message, @Nullable Throwable cause) { + super(message, cause); + } + + public LoginException(Throwable cause) { + super(cause); + } +} diff --git a/src/main/java/dev/braintrust/config/BraintrustConfig.java b/src/main/java/dev/braintrust/config/BraintrustConfig.java index 928abe2..b7602ad 100644 --- a/src/main/java/dev/braintrust/config/BraintrustConfig.java +++ b/src/main/java/dev/braintrust/config/BraintrustConfig.java @@ -51,6 +51,12 @@ public final class BraintrustConfig extends BaseConfig { private final boolean exportSpansInMemoryForUnitTest = getConfig("BRAINTRUST_JAVA_EXPORT_SPANS_IN_MEMORY_FOR_UNIT_TEST", false); + /** CORS origins to allow when running remote eval devserver */ + private final String devserverCorsOriginWhitelistCsv = + getConfig( + "BRAINTRUST_DEVSERVER_CORS_ORIGIN_WHITELIST_CSV", + "https://www.braintrust.dev,https://www.braintrustdata.com,http://localhost:3000"); + public static BraintrustConfig fromEnvironment() { return of(); } @@ -192,8 +198,8 @@ Builder experimentalOtelLogs(boolean value) { return this; } - // hiding visibility. only used for testing - Builder exportSpansInMemoryForUnitTest(boolean value) { + // only used for testing + public Builder exportSpansInMemoryForUnitTest(boolean value) { envOverrides.put( "BRAINTRUST_JAVA_EXPORT_SPANS_IN_MEMORY_FOR_UNIT_TEST", String.valueOf(value)); return this; @@ -209,6 +215,11 @@ public Builder x509TrustManager(X509TrustManager value) { return this; } + public Builder devserverCorsOriginWhitelistCsv(String csv) { + envOverrides.put("BRAINTRUST_DEVSERVER_CORS_ORIGIN_WHITELIST_CSV", csv); + return this; + } + public BraintrustConfig build() { return new BraintrustConfig(envOverrides, sslContext, x509TrustManager); } diff --git a/src/main/java/dev/braintrust/devserver/Devserver.java b/src/main/java/dev/braintrust/devserver/Devserver.java new file mode 100644 index 0000000..c3b2071 --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/Devserver.java @@ -0,0 +1,1100 @@ +package dev.braintrust.devserver; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import dev.braintrust.Braintrust; +import dev.braintrust.BraintrustUtils; +import dev.braintrust.Origin; +import dev.braintrust.api.BraintrustApiClient; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.*; +import dev.braintrust.trace.BraintrustContext; +import dev.braintrust.trace.BraintrustTracing; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.context.Context; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.regex.Pattern; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; + +/** Remote Eval Dev Server */ +@Slf4j +public class Devserver { + private static final Pattern PREVIEW_DOMAIN_PATTERN = + Pattern.compile("^https://[^/]+\\.preview\\.braintrust\\.dev$"); + + // Allowed headers for CORS + private static final String ALLOWED_HEADERS = + String.join( + ", ", + "Content-Type", + "X-Amz-Date", + "Authorization", + "X-Api-Key", + "X-Amz-Security-Token", + "X-Bt-Auth-Token", + "X-Bt-Parent", + "X-Bt-Org-Name", + "X-Bt-Project-Id", + "X-Bt-Stream-Fmt", + "X-Bt-Use-Cache", + "X-Stainless-Os", + "X-Stainless-Lang", + "X-Stainless-Package-Version", + "X-Stainless-Runtime", + "X-Stainless-Runtime-Version", + "X-Stainless-Arch"); + + private static final String EXPOSED_HEADERS = + "x-bt-cursor, x-bt-found-existing-experiment, x-bt-span-id, x-bt-span-export"; + + private static final AttributeKey PARENT = + AttributeKey.stringKey(BraintrustTracing.PARENT_KEY); + + private final List corsOriginWhitelist; + private final BraintrustConfig config; + + @Getter + @Accessors(fluent = true) + private final String host; + + @Getter + @Accessors(fluent = true) + private final int port; + + private final @Nullable String orgName; + private final Map> evals; + private @Nullable HttpServer server; + private final @Nullable Consumer + traceBuilderHook; + private final @Nullable Consumer configBuilderHook; + private static final ObjectMapper JSON_MAPPER = + new ObjectMapper() + .enable( + com.fasterxml.jackson.core.JsonParser.Feature + .INCLUDE_SOURCE_IN_LOCATION); + + // LRU cache for token -> Braintrust mappings + private final LRUCache authCache = new LRUCache<>(32); + + private Devserver(Builder builder) { + this.config = Objects.requireNonNull(builder.config); + this.host = builder.host; + this.port = builder.port; + this.orgName = builder.orgName; + this.traceBuilderHook = builder.traceBuilderHook; + this.configBuilderHook = builder.configBuilderHook; + Map> evalMap = new HashMap<>(); + for (RemoteEval eval : builder.evals) { + if (evalMap.containsKey(eval.getName())) { + throw new IllegalArgumentException("Duplicate evaluator name: " + eval.getName()); + } + evalMap.put(eval.getName(), eval); + } + this.evals = Collections.unmodifiableMap(evalMap); + if (orgName != null) { + throw new NotSupportedYetException("org name filtering"); + } + this.corsOriginWhitelist = + List.copyOf( + BraintrustUtils.append( + BraintrustUtils.parseCsv(config.devserverCorsOriginWhitelistCsv()), + config.appUrl())); + } + + public static Builder builder() { + return new Builder(); + } + + /** Start the dev server. This method blocks until the server is stopped. */ + public synchronized void start() throws IOException { + if (server != null) { + throw new IllegalStateException("Server is already running"); + } + + server = HttpServer.create(new InetSocketAddress(host, port), 0); + server.setExecutor(Executors.newCachedThreadPool()); + + server.createContext("/", withCors(this::handleHealthCheck)); + server.createContext("/list", withCors(this::handleList)); + server.createContext("/eval", withCors(this::handleEval)); + + server.start(); + log.info("Braintrust dev server started on http://{}:{}", host, port); + log.info("Registered {} evaluator(s): {}", evals.size(), evals.keySet()); + } + + /** Stop the dev server. */ + public synchronized void stop() { + if (server != null) { + server.stop(0); + server = null; + log.info("Braintrust dev server stopped"); + } + } + + private void handleHealthCheck(HttpExchange exchange) throws IOException { + if (!"GET".equals(exchange.getRequestMethod())) { + sendResponse(exchange, 405, "text/plain", "Method Not Allowed"); + return; + } + sendResponse(exchange, 200, "text/plain", "Hello, world!"); + } + + private void handleList(HttpExchange exchange) throws IOException { + if (!"GET".equals(exchange.getRequestMethod())) { + sendResponse(exchange, 405, "text/plain", "Method Not Allowed"); + return; + } + + // Check API key is present + RequestContext context = createRequestContext(exchange); + String apiKey = extractApiKey(exchange, context); + if (apiKey == null) { + sendErrorResponse(exchange, 401, "Missing authentication token"); + return; + } + + try { + // Build the response: Map + Map> response = new LinkedHashMap<>(); + + for (Map.Entry> entry : evals.entrySet()) { + String evalName = entry.getKey(); + RemoteEval eval = entry.getValue(); + + Map metadata = new LinkedHashMap<>(); + + Map> parametersMap = new LinkedHashMap<>(); + for (Map.Entry paramEntry : + eval.getParameters().entrySet()) { + String paramName = paramEntry.getKey(); + RemoteEval.Parameter param = paramEntry.getValue(); + + Map paramMetadata = new LinkedHashMap<>(); + paramMetadata.put("type", param.getType().getValue()); + + if (param.getDescription() != null) { + paramMetadata.put("description", param.getDescription()); + } + + if (param.getDefaultValue() != null) { + paramMetadata.put("default", param.getDefaultValue()); + } + + // Only include schema for data type parameters + if (param.getType() == RemoteEval.ParameterType.DATA + && param.getSchema() != null) { + paramMetadata.put("schema", param.getSchema()); + } + + parametersMap.put(paramName, paramMetadata); + } + metadata.put("parameters", parametersMap); + + // Add scores (list of scorer names) + List> scores = new ArrayList<>(); + for (var scorer : eval.getScorers()) { + Map scoreInfo = new LinkedHashMap<>(); + scoreInfo.put("name", scorer.getName()); + scores.add(scoreInfo); + } + metadata.put("scores", scores); + + response.put(evalName, metadata); + } + + String jsonResponse = JSON_MAPPER.writeValueAsString(response); + sendResponse(exchange, 200, "application/json", jsonResponse); + } catch (Exception e) { + log.error("Error generating /list response", e); + sendResponse(exchange, 500, "text/plain", "Internal Server Error"); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private void handleEval(HttpExchange exchange) throws IOException { + if (!"POST".equals(exchange.getRequestMethod())) { + sendResponse(exchange, 405, "text/plain", "Method Not Allowed"); + return; + } + + // Check authorization and get Braintrust state + RequestContext context = createRequestContext(exchange); + context = getBraintrust(exchange, context); + if (context == null) { + sendErrorResponse(exchange, 401, "Missing required authentication headers"); + return; + } + + try { + InputStream requestBody = exchange.getRequestBody(); + var requestBodyString = new String(requestBody.readAllBytes(), StandardCharsets.UTF_8); + EvalRequest request = JSON_MAPPER.readValue(requestBodyString, EvalRequest.class); + + // Validate evaluator exists + RemoteEval eval = evals.get(request.getName()); + if (eval == null) { + sendResponse( + exchange, 404, "text/plain", "Evaluator not found: " + request.getName()); + return; + } + + // Validate dataset specification + if (request.getData() == null) { + sendResponse(exchange, 400, "text/plain", "Missing 'data' field in request body"); + return; + } + + EvalRequest.DataSpec dataSpec = request.getData(); + boolean hasInlineData = dataSpec.getData() != null && !dataSpec.getData().isEmpty(); + boolean hasByName = + dataSpec.getProjectName() != null && dataSpec.getDatasetName() != null; + boolean hasById = dataSpec.getDatasetId() != null; + + // Ensure exactly one dataset specification method is provided + int specCount = (hasInlineData ? 1 : 0) + (hasByName ? 1 : 0) + (hasById ? 1 : 0); + if (specCount == 0) { + sendResponse( + exchange, + 400, + "text/plain", + "Dataset must be specified using one of: inline data (data.data), by name" + + " (data.project_name + data.dataset_name), or by ID" + + " (data.dataset_id)"); + return; + } + if (specCount > 1) { + sendResponse( + exchange, + 400, + "text/plain", + "Only one dataset specification method should be provided"); + return; + } + + // TODO: support remote scorers + + String datasetDescription = + hasInlineData + ? dataSpec.getData().size() + " inline cases" + : (hasByName + ? "dataset '" + + dataSpec.getProjectName() + + "/" + + dataSpec.getDatasetName() + + "'" + : "dataset ID '" + dataSpec.getDatasetId() + "'"); + log.debug("Executing evaluator '{}' with {}", request.getName(), datasetDescription); + + // Check if streaming is requested + boolean isStreaming = request.getStream() != null && request.getStream(); + + if (isStreaming) { + // SSE streaming response - errors handled inside + log.debug("Starting streaming evaluation for '{}'", request.getName()); + handleStreamingEval(exchange, eval, request, context); + } else { + throw new NotSupportedYetException("non-streaming responses"); + } + } catch (NotSupportedYetException e) { + sendResponse( + exchange, 400, "text/plain", "TODO: feature not supported: " + e.description); + } catch (Exception e) { + log.error("Error executing eval", e); + // Only send error response if we haven't started streaming + // (streaming errors are handled within handleStreamingEval) + sendResponse(exchange, 500, "text/plain", "Internal Server Error: " + e.getMessage()); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private void handleStreamingEval( + HttpExchange exchange, RemoteEval eval, EvalRequest request, RequestContext context) + throws Exception { + // Set SSE headers + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.getResponseHeaders().set("Cache-Control", "no-cache"); + exchange.getResponseHeaders().set("Connection", "keep-alive"); + exchange.sendResponseHeaders(200, 0); // 0 = chunked encoding + + try (OutputStream os = exchange.getResponseBody()) { + try { + // Get Braintrust instance from authenticated context + Braintrust braintrust = context.getBraintrust(); + BraintrustApiClient apiClient = braintrust.apiClient(); + + // Determine project name and ID from the authenticated Braintrust instance + final var orgAndProject = + apiClient.getOrCreateProjectAndOrgInfo(braintrust.config()); + final var projectName = orgAndProject.project().name(); + final var projectId = orgAndProject.project().id(); + final var experimentName = + request.getExperimentName() != null + ? request.getExperimentName() + : eval.getName(); + final var experimentUrl = + BraintrustUtils.createProjectURI( + braintrust.config().appUrl(), orgAndProject) + .toASCIIString() + + "/experiments/" + + experimentName; + final var projectUrl = + BraintrustUtils.createProjectURI( + braintrust.config().appUrl(), orgAndProject) + .toASCIIString(); + + var tracer = BraintrustTracing.getTracer(); + + // Execute task and scorers for each case + final Map> scoresByName = new ConcurrentHashMap<>(); + final var parentInfo = extractParentInfo(request); + final var braintrustParent = parentInfo.braintrustParent(); + final var braintrustGeneration = parentInfo.generation(); + + // NOTE: this code is serial but written in a thread-safe manner to support + // concurrent dataset fetching and eval execution + extractDataset(request, apiClient) + .forEach( + datasetCase -> { + var evalSpan = + tracer.spanBuilder("eval") + .setNoParent() + .setSpanKind(SpanKind.CLIENT) + .setAttribute( + PARENT, + braintrustParent.toParentValue()) + .startSpan(); + Context evalContext = Context.current().with(evalSpan); + evalContext = + BraintrustContext.setParentInBaggage( + evalContext, + braintrustParent.type(), + braintrustParent.id()); + // Make the eval context (with span and baggage) current + try (var rootScope = evalContext.makeCurrent()) { + final dev.braintrust.eval.TaskResult taskResult; + { // run task + var taskSpan = tracer.spanBuilder("task").startSpan(); + try (var unused = + Context.current() + .with(taskSpan) + .makeCurrent()) { + var task = eval.getTask(); + taskResult = task.apply(datasetCase); + // Send progress event for task completion + sendProgressEvent( + os, + evalSpan.getSpanContext().getSpanId(), + datasetCase.origin(), + eval.getName(), + taskResult.result()); + setTaskSpanAttributes( + taskSpan, + braintrustParent, + braintrustGeneration, + datasetCase, + taskResult); + } finally { + taskSpan.end(); + } + // setting eval span attributes here because we need the + // task output + setEvalSpanAttributes( + evalSpan, + braintrustParent, + braintrustGeneration, + datasetCase, + taskResult); + } + // run scorers - one score span per scorer + for (var scorer : (List>) eval.getScorers()) { + var scoreSpan = tracer.spanBuilder("score").startSpan(); + try (var unused = + Context.current() + .with(scoreSpan) + .makeCurrent()) { + List scores = scorer.score(taskResult); + + Map scorerScores = + new LinkedHashMap<>(); + for (Score score : scores) { + scoresByName + .computeIfAbsent( + score.name(), + k -> new ArrayList<>()) + .add(score.value()); + scorerScores.put(score.name(), score.value()); + } + // Set score span attributes before ending span + setScoreSpanAttributes( + scoreSpan, + braintrustParent, + braintrustGeneration, + scorer.getName(), + scorerScores); + } finally { + scoreSpan.end(); + } + } + } catch (IOException e) { + throw new RuntimeException( + "Failed to send progress event", e); + } finally { + evalSpan.end(); + } + }); + + // Aggregate scores + Map scoreSummaries = new LinkedHashMap<>(); + for (Map.Entry> entry : scoresByName.entrySet()) { + String scoreName = entry.getKey(); + List values = entry.getValue(); + + double avgScore = + values.stream().mapToDouble(Double::doubleValue).average().orElse(0.0); + + scoreSummaries.put( + scoreName, + EvalResponse.ScoreSummary.builder() + .name(scoreName) + .score(avgScore) + .improvements(0) + .regressions(0) + .build()); + } + + sendSummaryEvent( + os, + projectName, + projectId, + experimentName, + projectUrl, + experimentUrl, + scoreSummaries); + sendDoneEvent(os); + } catch (Exception e) { + // Send error event via SSE + log.error("Error during streaming evaluation", e); + try { + sendSSEEvent( + os, "error", e.getMessage() != null ? e.getMessage() : "Unknown error"); + } catch (IOException ioException) { + log.error("Failed to send error event", ioException); + } + // no need to re-throw. We've already sent 200 because we're streaming and the + // client will see the error event + } finally { + try { + os.flush(); + os.close(); + } catch (IOException e) { + log.error("Failed to close output stream", e); + } + } + } + } + + private void setEvalSpanAttributes( + Span evalSpan, + BraintrustUtils.Parent braintrustParent, + String braintrustGeneration, + DatasetCase datasetCase, + TaskResult taskResult) { + var spanAttrs = new LinkedHashMap<>(); + spanAttrs.put("type", "eval"); + spanAttrs.put("name", "eval"); + if (braintrustGeneration != null) { + spanAttrs.put("generation", braintrustGeneration); + } + evalSpan.setAttribute(PARENT, braintrustParent.toParentValue()) + .setAttribute("braintrust.span_attributes", json(spanAttrs)) + .setAttribute("braintrust.input_json", json(Map.of("input", datasetCase.input()))) + .setAttribute("braintrust.expected_json", json(datasetCase.expected())); + + if (datasetCase.origin().isPresent()) { + evalSpan.setAttribute("braintrust.origin", json(datasetCase.origin().get())); + } + if (!datasetCase.tags().isEmpty()) { + evalSpan.setAttribute( + AttributeKey.stringArrayKey("braintrust.tags"), datasetCase.tags()); + } + if (!datasetCase.metadata().isEmpty()) { + evalSpan.setAttribute("braintrust.metadata", json(datasetCase.metadata())); + } + evalSpan.setAttribute( + "braintrust.output_json", json(Map.of("output", taskResult.result()))); + } + + private void setTaskSpanAttributes( + Span taskSpan, + BraintrustUtils.Parent braintrustParent, + String braintrustGeneration, + DatasetCase datasetCase, + TaskResult taskResult) { + Map taskSpanAttrs = new LinkedHashMap<>(); + taskSpanAttrs.put("type", "task"); + taskSpanAttrs.put("name", "task"); + if (braintrustGeneration != null) { + taskSpanAttrs.put("generation", braintrustGeneration); + } + + taskSpan.setAttribute(PARENT, braintrustParent.toParentValue()) + .setAttribute("braintrust.span_attributes", json(taskSpanAttrs)) + .setAttribute("braintrust.input_json", json(Map.of("input", datasetCase.input()))) + .setAttribute( + "braintrust.output_json", json(Map.of("output", taskResult.result()))); + } + + private void setScoreSpanAttributes( + Span scoreSpan, + BraintrustUtils.Parent braintrustParent, + String braintrustGeneration, + String scorerName, + Map scorerScores) { + Map scoreSpanAttrs = new LinkedHashMap<>(); + scoreSpanAttrs.put("type", "score"); + scoreSpanAttrs.put("name", scorerName); + if (braintrustGeneration != null) { + scoreSpanAttrs.put("generation", braintrustGeneration); + } + + scoreSpan + .setAttribute(PARENT, braintrustParent.toParentValue()) + .setAttribute("braintrust.span_attributes", json(scoreSpanAttrs)) + .setAttribute("braintrust.output_json", json(scorerScores)); + } + + private void sendSSEEvent(OutputStream os, String eventType, String data) throws IOException { + String event = "event: " + eventType + "\n" + "data: " + data + "\n\n"; + synchronized (this) { + os.write(event.getBytes(StandardCharsets.UTF_8)); + } + } + + private void sendProgressEvent( + OutputStream os, + String spanId, + Optional origin, + String evalName, + Object taskResult) + throws IOException { + Map progressData = new LinkedHashMap<>(); + progressData.put("id", spanId); + progressData.put("object_type", "task"); + + origin.ifPresent(value -> progressData.put("origin", value)); + progressData.put("name", evalName); + progressData.put("format", "code"); + progressData.put("output_type", "completion"); + progressData.put("event", "json_delta"); + progressData.put("data", JSON_MAPPER.writeValueAsString(taskResult)); + + String progressJson = JSON_MAPPER.writeValueAsString(progressData); + sendSSEEvent(os, "progress", progressJson); + } + + private void sendSummaryEvent( + OutputStream os, + String projectName, + String projectId, + String experimentName, + String projectUrl, + String experimentUrl, + Map scoreSummaries) + throws IOException { + Map summary = new LinkedHashMap<>(); + summary.put("projectName", projectName); + summary.put("projectId", projectId); + summary.put("experimentId", null); + summary.put("experimentName", experimentName); + summary.put("projectUrl", projectUrl); + summary.put("experimentUrl", null); + summary.put("comparisonExperimentName", null); + + Map scoresWithMeta = new LinkedHashMap<>(); + for (Map.Entry entry : scoreSummaries.entrySet()) { + Map scoreData = new LinkedHashMap<>(); + scoreData.put("name", entry.getValue().getName()); + scoreData.put("_longest_score_name", entry.getKey().length()); + scoreData.put("score", entry.getValue().getScore()); + scoreData.put("improvements", entry.getValue().getImprovements()); + scoreData.put("regressions", entry.getValue().getRegressions()); + scoreData.put("diff", null); + scoresWithMeta.put(entry.getKey(), scoreData); + } + summary.put("scores", scoresWithMeta); + summary.put("metrics", Map.of()); + + sendSSEEvent(os, "summary", JSON_MAPPER.writeValueAsString(summary)); + } + + private void sendDoneEvent(OutputStream os) throws IOException { + sendSSEEvent(os, "done", ""); + } + + private String json(Object o) { + try { + return JSON_MAPPER.writeValueAsString(o); + } catch (Exception e) { + throw new RuntimeException("Failed to serialize to JSON", e); + } + } + + private void sendResponse( + HttpExchange exchange, int statusCode, String contentType, String body) + throws IOException { + byte[] responseBytes = body.getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", contentType); + exchange.sendResponseHeaders(statusCode, responseBytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(responseBytes); + } + } + + /** + * Check if the origin is whitelisted for CORS. + * + * @param origin The Origin header value + * @return true if the origin is allowed, false otherwise + */ + private boolean isOriginAllowed(@Nullable String origin) { + if (origin == null || origin.isEmpty()) { + return true; // Allow requests without origin (e.g., same-origin) + } + // Check against whitelisted origins + for (String allowedOrigin : corsOriginWhitelist) { + if (allowedOrigin != null && allowedOrigin.equals(origin)) { + return true; + } + } + // Check against preview domain pattern + return PREVIEW_DOMAIN_PATTERN.matcher(origin).matches(); + } + + /** + * Apply CORS headers to the response. + * + * @param exchange The HTTP exchange + */ + private void applyCorsHeaders(HttpExchange exchange) { + String origin = exchange.getRequestHeaders().getFirst("Origin"); + + if (isOriginAllowed(origin)) { + var headers = exchange.getResponseHeaders(); + if (origin != null && !origin.isEmpty()) { + headers.set("Access-Control-Allow-Origin", origin); + } + headers.set("Access-Control-Allow-Credentials", "true"); + headers.set("Access-Control-Expose-Headers", EXPOSED_HEADERS); + } + } + + /** + * Handle CORS preflight requests. + * + * @param exchange The HTTP exchange + */ + private void handlePreflightRequest(HttpExchange exchange) throws IOException { + String origin = exchange.getRequestHeaders().getFirst("Origin"); + + if (!isOriginAllowed(origin)) { + exchange.sendResponseHeaders(403, -1); + return; + } + + var headers = exchange.getResponseHeaders(); + if (origin != null && !origin.isEmpty()) { + headers.set("Access-Control-Allow-Origin", origin); + } + headers.set("Access-Control-Allow-Methods", "GET, PATCH, POST, PUT, DELETE, OPTIONS"); + headers.set("Access-Control-Allow-Headers", ALLOWED_HEADERS); + headers.set("Access-Control-Allow-Credentials", "true"); + headers.set("Access-Control-Max-Age", "86400"); + + // Support for Chrome's Private Network Access + String requestPrivateNetwork = + exchange.getRequestHeaders().getFirst("Access-Control-Request-Private-Network"); + if ("true".equals(requestPrivateNetwork)) { + headers.set("Access-Control-Allow-Private-Network", "true"); + } + + exchange.sendResponseHeaders(204, -1); + } + + /** + * Wrap a handler with CORS support. + * + * @param handler The handler to wrap + * @return A handler that applies CORS headers + */ + private HttpHandler withCors(HttpHandler handler) { + return exchange -> { + // Handle OPTIONS preflight requests + if ("OPTIONS".equals(exchange.getRequestMethod())) { + handlePreflightRequest(exchange); + return; + } + + // Apply CORS headers to all responses + applyCorsHeaders(exchange); + + // Delegate to the actual handler + handler.handle(exchange); + }; + } + + /** + * Extract API key from request headers. + * + *

Checks headers in order of precedence: + * + *

    + *
  1. x-bt-auth-token (preferred) + *
  2. Authorization: Bearer <token> + *
  3. Authorization: <token> + *
+ * + * @param exchange The HTTP exchange + * @param context The request context (unused but for consistency) + * @return The API key, or null if not present + */ + @Nullable + private String extractApiKey(HttpExchange exchange, RequestContext context) { + var headers = exchange.getRequestHeaders(); + + // 1. Check x-bt-auth-token header (preferred) + String token = headers.getFirst("x-bt-auth-token"); + if (token != null && !token.isEmpty()) { + return token; + } + + // 2. Check Authorization header + String authHeader = headers.getFirst("Authorization"); + if (authHeader != null && !authHeader.isEmpty()) { + // Try Bearer format + if (authHeader.startsWith("Bearer ")) { + return authHeader.substring(7).trim(); + } + // Try direct token + return authHeader.trim(); + } + + return null; + } + + /** + * Create a request context with origin. + * + * @param exchange The HTTP exchange + * @return RequestContext with appOrigin + */ + private RequestContext createRequestContext(HttpExchange exchange) { + String origin = exchange.getRequestHeaders().getFirst("Origin"); + if (origin == null) { + origin = ""; + } + + return RequestContext.builder().appOrigin(origin).build(); + } + + /** + * Get Braintrust state for authenticated requests. + * + *

Validates that required headers are present and returns a RequestContext with populated + * Braintrust from cache. + * + *

Required headers: + * + *

    + *
  • API key (x-bt-auth-token or Authorization) + *
  • x-bt-org-name + *
  • x-bt-project-id + *
+ * + *

Cache key format: orgName:projectId:apiKey + * + * @param exchange The HTTP exchange + * @param context The request context + * @return RequestContext with populated state, or null if required headers are missing + */ + @Nullable + private RequestContext getBraintrust(HttpExchange exchange, RequestContext context) { + // Extract API key + String apiKey = extractApiKey(exchange, context); + if (apiKey == null || apiKey.isEmpty()) { + return null; + } + + // Get x-bt-org-name header + String orgName = exchange.getRequestHeaders().getFirst("x-bt-org-name"); + if (orgName == null || orgName.isEmpty()) { + return null; + } + + // Get x-bt-project-id header + String projectId = exchange.getRequestHeaders().getFirst("x-bt-project-id"); + if (projectId == null || projectId.isEmpty()) { + return null; + } + + // Create composite cache key: orgName:projectId:apiKey + String cacheKey = orgName + ":" + projectId + ":" + apiKey; + + // Get from cache or compute if not present + Braintrust braintrust = + authCache.getOrCompute( + cacheKey, + () -> { + log.debug( + "Cached login state for org='{}', projectId='{}' (cache" + + " size={})", + orgName, + projectId, + authCache.size()); + + // Build config with hook if present + var configBuilder = + BraintrustConfig.builder() + .apiKey(apiKey) + .defaultProjectId(projectId) + .apiUrl(config.apiUrl()) + .appUrl(config.appUrl()); + + // Invoke hook if present to allow customization (e.g., enabling + // in-memory span export) + if (configBuilderHook != null) { + configBuilderHook.accept(configBuilder); + } + + var bt = Braintrust.of(configBuilder.build()); + bt.apiClient().login(); + return bt; + }); + + log.debug( + "Retrieved login state for org='{}', projectId='{}' (cache size={})", + orgName, + projectId, + authCache.size()); + + // Return context with state populated + return RequestContext.builder() + .appOrigin(context.getAppOrigin()) + .token(apiKey) + .braintrust(braintrust) + .build(); + } + + /** + * Send an error response with JSON body. + * + * @param exchange The HTTP exchange + * @param statusCode The HTTP status code + * @param message The error message + * @throws IOException if response sending fails + */ + private void sendErrorResponse(HttpExchange exchange, int statusCode, String message) + throws IOException { + Map error = Map.of("error", message); + String json = JSON_MAPPER.writeValueAsString(error); + sendResponse(exchange, statusCode, "application/json", json); + } + + /** + * Container for parent information extracted from eval request. + * + * @param braintrustParent The parent specification in "type:id" format (e.g., + * "playground_id:abc123") + * @param generation The generation identifier from the request + */ + private record ParentInfo( + @Nonnull BraintrustUtils.Parent braintrustParent, @Nullable String generation) {} + + /** + * Extracts parent information from the eval request. + * + * @param request The eval request + * @return ParentInfo containing braintrustParent and generation + */ + private static ParentInfo extractParentInfo(EvalRequest request) { + String parentSpec = null; + String generation = null; + + // Extract parent spec and generation from request + if (request.getParent() != null && request.getParent() instanceof Map) { + @SuppressWarnings("unchecked") + Map parentMap = (Map) request.getParent(); + String objectType = (String) parentMap.get("object_type"); + String objectId = (String) parentMap.get("object_id"); + + // Extract generation from propagated_event.span_attributes.generation + Object propEventObj = parentMap.get("propagated_event"); + if (propEventObj instanceof Map) { + @SuppressWarnings("unchecked") + Map propEvent = (Map) propEventObj; + Object spanAttrsObj = propEvent.get("span_attributes"); + if (spanAttrsObj instanceof Map) { + @SuppressWarnings("unchecked") + Map spanAttrs = (Map) spanAttrsObj; + generation = (String) spanAttrs.get("generation"); + } + } + + if (objectType != null && objectId != null) { + parentSpec = "playground_id:" + objectId; + } + } + + if (parentSpec == null) { + throw new IllegalArgumentException("braintrust parent (playground_id) not found"); + } + return new ParentInfo(BraintrustUtils.parseParent(parentSpec), generation); + } + + /** + * Extracts and loads the dataset from the eval request. + * + *

Supports three methods of loading data: + * + *

    + *
  1. Inline data provided in the request + *
  2. Fetch by project name and dataset name + *
  3. Fetch by dataset ID + *
+ * + * @param request The eval request containing dataset specification + * @param apiClient The Braintrust API client for fetching datasets + * @return The loaded dataset + * @throws IllegalStateException if no dataset specification is provided + * @throws IllegalArgumentException if dataset or project is not found + */ + private static Dataset extractDataset( + EvalRequest request, BraintrustApiClient apiClient) { + EvalRequest.DataSpec dataSpec = request.getData(); + + if (dataSpec.getData() != null && !dataSpec.getData().isEmpty()) { + // Method 1: Inline data + List cases = new ArrayList<>(); + for (EvalRequest.EvalCaseData caseData : dataSpec.getData()) { + DatasetCase datasetCase = + DatasetCase.of( + caseData.getInput(), + caseData.getExpected(), + caseData.getTags() != null ? caseData.getTags() : List.of(), + caseData.getMetadata() != null ? caseData.getMetadata() : Map.of()); + cases.add(datasetCase); + } + return Dataset.of(cases.toArray(new DatasetCase[0])); + } else if (dataSpec.getProjectName() != null && dataSpec.getDatasetName() != null) { + // Method 2: Fetch by project name and dataset name + log.debug( + "Fetching dataset from Braintrust: project={}, dataset={}", + dataSpec.getProjectName(), + dataSpec.getDatasetName()); + return Dataset.fetchFromBraintrust( + apiClient, dataSpec.getProjectName(), dataSpec.getDatasetName(), null); + } else if (dataSpec.getDatasetId() != null) { + // Method 3: Fetch by dataset ID + log.debug("Fetching dataset from Braintrust by ID: {}", dataSpec.getDatasetId()); + var datasetMetadata = apiClient.getDataset(dataSpec.getDatasetId()); + if (datasetMetadata.isEmpty()) { + throw new IllegalArgumentException("Dataset not found: " + dataSpec.getDatasetId()); + } + + var project = apiClient.getProject(datasetMetadata.get().projectId()); + if (project.isEmpty()) { + throw new IllegalArgumentException( + "Project not found: " + datasetMetadata.get().projectId()); + } + + String fetchedProjectName = project.get().name(); + String fetchedDatasetName = datasetMetadata.get().name(); + log.debug( + "Resolved dataset ID to project={}, dataset={}", + fetchedProjectName, + fetchedDatasetName); + + return Dataset.fetchFromBraintrust( + apiClient, fetchedProjectName, fetchedDatasetName, null); + } else { + throw new IllegalStateException("No dataset specification provided"); + } + } + + public static class Builder { + private @Nullable BraintrustConfig config = null; + private String host = "localhost"; + private int port = 8300; + private @Nullable String orgName = null; + private List> evals = new ArrayList<>(); + private @Nullable Consumer + traceBuilderHook = null; + private @Nullable Consumer configBuilderHook = null; + + public Devserver build() { + if (evals.isEmpty()) { + throw new IllegalStateException("At least one evaluator must be registered"); + } + if (config == null) { + throw new IllegalStateException("config is required"); + } + return new Devserver(this); + } + + public Builder config(BraintrustConfig config) { + this.config = config; + return this; + } + + public Builder registerEval(RemoteEval eval) { + this.evals.add(eval); + return this; + } + + public Builder host(String host) { + this.host = host; + return this; + } + + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * hook to run for each braintrust instance's config created by the devserver. The hook + * receives the BraintrustConfig.Builder before it's built, allowing customization such as + * enabling in-memory span export for testing. + */ + public Builder braintrustConfigBuilderHook( + Consumer configBuilderHook) { + this.configBuilderHook = configBuilderHook; + return this; + } + } + + private static class NotSupportedYetException extends RuntimeException { + private final String description; + + public NotSupportedYetException(String description) { + super("feature not supported yet: " + description); + this.description = description; + } + } +} diff --git a/src/main/java/dev/braintrust/devserver/EvalRequest.java b/src/main/java/dev/braintrust/devserver/EvalRequest.java new file mode 100644 index 0000000..e3a7103 --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/EvalRequest.java @@ -0,0 +1,116 @@ +package dev.braintrust.devserver; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import lombok.Data; + +/** Request body for POST /eval endpoint */ +@Data +public class EvalRequest { + /** Name of the evaluator to run */ + private String name; + + /** Optional parameter overrides */ + @Nullable private Map parameters; + + /** Dataset specification */ + private DataSpec data; + + /** Optional experiment name override */ + @JsonProperty("experiment_name") + @Nullable + private String experimentName; + + /** Optional project ID override */ + @JsonProperty("project_id") + @Nullable + private String projectId; + + /** Optional additional remote scorers */ + @Nullable private List scores; + + /** Optional parent span for tracing (can be string or object) */ + @Nullable private Object parent; + + /** Enable SSE streaming (default: false) */ + @Nullable private Boolean stream; + + /** Dataset specification - supports inline data, by name, or by ID */ + @Data + public static class DataSpec { + /** Inline data array */ + @Nullable private List data; + + /** Project name (for loading by name) */ + @JsonProperty("project_name") + @Nullable + private String projectName; + + /** Dataset name (for loading by name) */ + @JsonProperty("dataset_name") + @Nullable + private String datasetName; + + /** Dataset ID (for loading by ID) */ + @JsonProperty("dataset_id") + @Nullable + private String datasetId; + + /** Optional BTQL filter (can be string or structured query object) */ + @JsonProperty("_internal_btql") + @Nullable + private Object btql; + } + + /** Individual evaluation case data */ + @Data + public static class EvalCaseData { + /** Input for the task */ + private Object input; + + /** Expected output (optional) */ + @Nullable private Object expected; + + /** Metadata (optional) */ + @Nullable private Map metadata; + + /** Tags (optional) */ + @Nullable private List tags; + } + + /** Remote scorer specification */ + @Data + public static class RemoteScorer { + /** Scorer name */ + private String name; + + /** Function ID specification */ + @JsonProperty("function_id") + private FunctionId functionId; + } + + /** Function ID specification (multiple formats supported) */ + @Data + public static class FunctionId { + @JsonProperty("function_id") + @Nullable + private String functionId; + + @Nullable private String version; + @Nullable private String name; + + @JsonProperty("prompt_session_id") + @Nullable + private String promptSessionId; + + @JsonProperty("inline_code") + @Nullable + private String inlineCode; + + @JsonProperty("global_function") + @Nullable + private String globalFunction; + } +} diff --git a/src/main/java/dev/braintrust/devserver/EvalResponse.java b/src/main/java/dev/braintrust/devserver/EvalResponse.java new file mode 100644 index 0000000..b13183c --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/EvalResponse.java @@ -0,0 +1,62 @@ +package dev.braintrust.devserver; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; +import javax.annotation.Nullable; +import lombok.Builder; +import lombok.Data; + +/** Response body for POST /eval endpoint */ +@Data +@Builder +public class EvalResponse { + /** Experiment name */ + @JsonProperty("experimentName") + private String experimentName; + + /** Project name */ + @JsonProperty("projectName") + private String projectName; + + /** Project ID */ + @JsonProperty("projectId") + private String projectId; + + /** Experiment ID */ + @JsonProperty("experimentId") + private String experimentId; + + /** Experiment URL */ + @JsonProperty("experimentUrl") + private String experimentUrl; + + /** Project URL */ + @JsonProperty("projectUrl") + private String projectUrl; + + /** Comparison experiment name (optional) */ + @JsonProperty("comparisonExperimentName") + @Nullable + private String comparisonExperimentName; + + /** Score summaries by scorer name */ + @JsonProperty("scores") + private Map scores; + + /** Summary statistics for a scorer */ + @Data + @Builder + public static class ScoreSummary { + /** Scorer name */ + private String name; + + /** Average score across all cases */ + private double score; + + /** Number of improvements vs baseline */ + private int improvements; + + /** Number of regressions vs baseline */ + private int regressions; + } +} diff --git a/src/main/java/dev/braintrust/devserver/LRUCache.java b/src/main/java/dev/braintrust/devserver/LRUCache.java new file mode 100644 index 0000000..f0c9362 --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/LRUCache.java @@ -0,0 +1,74 @@ +package dev.braintrust.devserver; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Simple LRU (Least Recently Used) cache implementation. + * + *

Thread-safe cache with a maximum size. When the cache exceeds its capacity, the least recently + * used entry is evicted. + * + * @param Key type + * @param Value type + */ +@ThreadSafe +class LRUCache { + private final int maxSize; + private final Map cache; + + public LRUCache(int maxSize) { + this.maxSize = maxSize; + this.cache = + new LinkedHashMap(maxSize, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > LRUCache.this.maxSize; + } + }; + } + + public synchronized void put(K key, V value) { + cache.put(key, value); + } + + @Nullable + public synchronized V get(K key) { + return cache.get(key); + } + + /** + * Get a value from the cache, or compute and cache it if not present. + * + *

This operation is atomic - the supplier function is only called once per key even under + * concurrent access. + * + * @param key The cache key + * @param supplier Function to compute the value if not in cache (takes no args, returns value) + * @return The cached or newly computed value + */ + public synchronized V getOrCompute(K key, Supplier supplier) { + V value = cache.get(key); + if (value == null) { + // Cache miss - compute the value + value = supplier.get(); + cache.put(key, value); + } + return value; + } + + public synchronized boolean containsKey(K key) { + return cache.containsKey(key); + } + + public synchronized void clear() { + cache.clear(); + } + + public synchronized int size() { + return cache.size(); + } +} diff --git a/src/main/java/dev/braintrust/devserver/RemoteEval.java b/src/main/java/dev/braintrust/devserver/RemoteEval.java new file mode 100644 index 0000000..75d569b --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/RemoteEval.java @@ -0,0 +1,131 @@ +package dev.braintrust.devserver; + +import dev.braintrust.eval.Scorer; +import dev.braintrust.eval.Task; +import java.util.*; +import java.util.function.Function; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.Builder; +import lombok.Getter; +import lombok.Singular; + +/** + * Represents a remote evaluator that can be exposed via the dev server. + * + * @param The type of input data for the evaluation + * @param The type of output produced by the task + */ +@Getter +@Builder(builderClassName = "Builder", buildMethodName = "internalBuild") +public class RemoteEval { + /** The name of this evaluator (used as identifier) */ + @Nonnull private final String name; + + /** + * The task function that performs the evaluation + * + *

The task function must be thread safe. + */ + @Nonnull private final Task task; + + /** + * List of scorers for this evaluator + * + *

The score function must be thread safe. + */ + @Singular @Nonnull private final List> scorers; + + /** Optional parameters that can be configured from the UI */ + @Singular @Nonnull private final Map parameters; + + public static class Builder { + /** + * Convenience builder method to create a RemoteEval with a simple task function. + * + * @param taskFn Function that takes input and returns output + * @return this builder + */ + public Builder taskFunction(Function taskFn) { + return task( + datasetCase -> { + var result = taskFn.apply(datasetCase.input()); + return new dev.braintrust.eval.TaskResult<>(result, datasetCase); + }); + } + + /** Build the RemoteEval */ + public RemoteEval build() { + // can add build hooks here later if desired + return internalBuild(); + } + } + + /** Represents a configurable parameter for the evaluator */ + @Getter + @lombok.Builder(builderClassName = "Builder") + public static class Parameter { + /** Type of parameter: "prompt" or "data" */ + @Nonnull private final ParameterType type; + + /** Optional description of the parameter */ + @Nullable private final String description; + + /** Optional default value for the parameter */ + @Nullable private final Object defaultValue; + + /** + * JSON Schema for data type parameters. Only applicable when type is DATA. Should be a Map + * representing a JSON Schema object. + */ + @Nullable private final Map schema; + + public static Parameter promptParameter(String description, Object defaultValue) { + return Parameter.builder() + .type(ParameterType.PROMPT) + .description(description) + .defaultValue(defaultValue) + .build(); + } + + public static Parameter promptParameter(Object defaultValue) { + return promptParameter(null, defaultValue); + } + + public static Parameter dataParameter( + String description, Map schema, Object defaultValue) { + return Parameter.builder() + .type(ParameterType.DATA) + .description(description) + .schema(schema) + .defaultValue(defaultValue) + .build(); + } + + public static Parameter dataParameter(Map schema, Object defaultValue) { + return dataParameter(null, schema, defaultValue); + } + + public static Parameter dataParameter(Map schema) { + return dataParameter(null, schema, null); + } + } + + /** Parameter type enumeration */ + public enum ParameterType { + /** Prompt parameter (for LLM prompts) */ + PROMPT("prompt"), + /** Data parameter (for other configuration data) */ + DATA("data"); + + private final String value; + + ParameterType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } +} diff --git a/src/main/java/dev/braintrust/devserver/RequestContext.java b/src/main/java/dev/braintrust/devserver/RequestContext.java new file mode 100644 index 0000000..4213c6e --- /dev/null +++ b/src/main/java/dev/braintrust/devserver/RequestContext.java @@ -0,0 +1,24 @@ +package dev.braintrust.devserver; + +import dev.braintrust.Braintrust; +import javax.annotation.Nullable; +import lombok.Builder; +import lombok.Getter; + +/** + * Context object attached to each authenticated request. + * + *

Contains the validated origin, extracted authentication token, and cached login state. + */ +@Getter +@Builder +public class RequestContext { + /** Validated origin from CORS */ + private final String appOrigin; + + /** Extracted auth token (if present) */ + @Nullable private final String token; + + /** Cached login state */ + @Nullable private final Braintrust braintrust; +} diff --git a/src/main/java/dev/braintrust/trace/BraintrustContext.java b/src/main/java/dev/braintrust/trace/BraintrustContext.java index 5911796..db3bf9c 100644 --- a/src/main/java/dev/braintrust/trace/BraintrustContext.java +++ b/src/main/java/dev/braintrust/trace/BraintrustContext.java @@ -45,11 +45,11 @@ public static Context ofExperiment(@Nonnull String experimentId, @Nonnull Span s * parent context to flow across process boundaries. * * @param ctx the context to update - * @param parentType the type of parent (e.g., "experiment_id", "project_name") + * @param parentType the type of parent (e.g., "experiment_id", "project_name", "playground_id") * @param parentId the ID of the parent * @return updated context with baggage set */ - static Context setParentInBaggage( + public static Context setParentInBaggage( @Nonnull Context ctx, @Nonnull String parentType, @Nonnull String parentId) { try { String parentValue = (new BraintrustUtils.Parent(parentType, parentId)).toParentValue(); diff --git a/src/test/java/dev/braintrust/TestUtils.java b/src/test/java/dev/braintrust/TestUtils.java new file mode 100644 index 0000000..82124ce --- /dev/null +++ b/src/test/java/dev/braintrust/TestUtils.java @@ -0,0 +1,15 @@ +package dev.braintrust; + +import java.io.IOException; +import java.net.ServerSocket; + +public class TestUtils { + public static int getRandomOpenPort() { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(true); + return socket.getLocalPort(); + } catch (IOException e) { + throw new RuntimeException("Failed to find an available port", e); + } + } +} diff --git a/src/test/java/dev/braintrust/devserver/CorsTest.java b/src/test/java/dev/braintrust/devserver/CorsTest.java new file mode 100644 index 0000000..b6300d0 --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/CorsTest.java @@ -0,0 +1,176 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import dev.braintrust.TestUtils; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Scorer; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class CorsTest { + private static Devserver server; + private static Thread serverThread; + private static final int TEST_PORT = TestUtils.getRandomOpenPort(); + + @BeforeAll + static void setup() throws Exception { + RemoteEval testEval = + RemoteEval.builder() + .name("test-eval") + .taskFunction(String::toUpperCase) + .scorer( + Scorer.of( + "length", + (expected, result) -> (double) result.length() / 10.0)) + .build(); + + server = + Devserver.builder() + .config(BraintrustConfig.of("BRAINTRUST_API_KEY", "bogus")) + .registerEval(testEval) + .host("localhost") + .port(TEST_PORT) + .build(); + + serverThread = + new Thread( + () -> { + try { + server.start(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + serverThread.start(); + Thread.sleep(1000); // Give server time to start + } + + @AfterAll + static void teardown() { + server.stop(); + serverThread.interrupt(); + } + + @Test + void testCorsPreflightRequest() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .method("OPTIONS", HttpRequest.BodyPublishers.noBody()) + .header("Origin", "https://www.braintrust.dev") + .header("Access-Control-Request-Method", "POST") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(204, response.statusCode()); + assertEquals( + "https://www.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + assertEquals( + "true", + response.headers() + .firstValue("Access-Control-Allow-Credentials") + .orElse("") + .toLowerCase()); + assertTrue( + response.headers() + .firstValue("Access-Control-Allow-Methods") + .orElse("") + .toLowerCase() + .contains("post")); + assertTrue( + response.headers() + .firstValue("Access-Control-Allow-Headers") + .orElse("") + .toLowerCase() + .contains("x-bt-auth-token")); + } + + @Test + void testCorsActualRequest() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .GET() + .header("Origin", "https://www.braintrust.dev") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals("Hello, world!", response.body()); + assertEquals( + "https://www.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + assertEquals( + "true", + response.headers() + .firstValue("Access-Control-Allow-Credentials") + .orElse("") + .toLowerCase()); + } + + @Test + void testCorsPreviewDomain() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .GET() + .header("Origin", "https://pr-123.preview.braintrust.dev") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "https://pr-123.preview.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + } + + @Test + void testCorsUnauthorizedOrigin() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .method("OPTIONS", HttpRequest.BodyPublishers.noBody()) + .header("Origin", "https://evil.com") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(403, response.statusCode()); + } + + @Test + void testPrivateNetworkAccess() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + TEST_PORT + "/")) + .method("OPTIONS", HttpRequest.BodyPublishers.noBody()) + .header("Origin", "https://www.braintrust.dev") + .header("Access-Control-Request-Private-Network", "true") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(204, response.statusCode()); + assertEquals( + "true", + response.headers() + .firstValue("Access-Control-Allow-Private-Network") + .orElse("") + .toLowerCase()); + } +} diff --git a/src/test/java/dev/braintrust/devserver/DevserverTest.java b/src/test/java/dev/braintrust/devserver/DevserverTest.java new file mode 100644 index 0000000..b1a4065 --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/DevserverTest.java @@ -0,0 +1,600 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sun.net.httpserver.HttpServer; +import dev.braintrust.TestHarness; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Scorer; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.*; + +class DevserverTest { + private static Devserver server; + private static Thread serverThread; + private static TestHarness testHarness; + private static HttpServer mockApiServer; + private static final int TEST_PORT = 18300; + private static final int MOCK_API_PORT = 18301; + private static final String TEST_URL = "http://localhost:" + TEST_PORT; + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + + @BeforeAll + static void setUp() throws Exception { + // Set up mock Braintrust API server + mockApiServer = HttpServer.create(new InetSocketAddress("localhost", MOCK_API_PORT), 0); + + // Mock /v1/project endpoint + mockApiServer.createContext( + "/v1/project", + exchange -> { + String response = + JSON_MAPPER.writeValueAsString( + Map.of( + "id", "test-project-id", + "name", "test-project", + "org_id", "test-org-id", + "created", "2023-01-01T00:00:00Z", + "updated", "2023-01-01T00:00:00Z")); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders( + 200, response.getBytes(StandardCharsets.UTF_8).length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response.getBytes(StandardCharsets.UTF_8)); + } + }); + + // Mock /v1/org endpoint + mockApiServer.createContext( + "/v1/org", + exchange -> { + String response = + JSON_MAPPER.writeValueAsString( + Map.of( + "results", + List.of( + Map.of( + "id", "test-org-id", + "name", "test-org")))); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders( + 200, response.getBytes(StandardCharsets.UTF_8).length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response.getBytes(StandardCharsets.UTF_8)); + } + }); + + // Mock /api/apikey/login endpoint (using snake_case as per API client naming strategy) + mockApiServer.createContext( + "/api/apikey/login", + exchange -> { + String response = + JSON_MAPPER.writeValueAsString( + Map.of( + "org_info", + List.of( + Map.of( + "id", "test-org-id", + "name", "test-org")))); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders( + 200, response.getBytes(StandardCharsets.UTF_8).length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response.getBytes(StandardCharsets.UTF_8)); + } + }); + + mockApiServer.start(); + + // Set up test harness with config pointing to mock API + BraintrustConfig testConfig = + BraintrustConfig.of( + "BRAINTRUST_API_KEY", "test-key", + "BRAINTRUST_API_URL", "http://localhost:" + MOCK_API_PORT, + "BRAINTRUST_APP_URL", "http://localhost:3000", + "BRAINTRUST_DEFAULT_PROJECT_NAME", "test-project", + "BRAINTRUST_JAVA_EXPORT_SPANS_IN_MEMORY_FOR_UNIT_TEST", "true"); + testHarness = TestHarness.setup(testConfig); + + // Create a shared eval for all tests + RemoteEval testEval = + RemoteEval.builder() + .name("food-type-classifier") + .taskFunction( + input -> { + // Create a span inside the task to test baggage propagation + var tracer = dev.braintrust.trace.BraintrustTracing.getTracer(); + var span = tracer.spanBuilder("custom-task-span").startSpan(); + try (var scope = + io.opentelemetry.context.Context.current() + .with(span) + .makeCurrent()) { + // Do some work + return "java-fruit"; + } finally { + span.end(); + } + }) + .scorer(Scorer.of("simple_scorer", (expected, result) -> 0.7)) + .build(); + + server = + Devserver.builder() + .config(testHarness.braintrust().config()) + .registerEval(testEval) + .host("localhost") + .port(TEST_PORT) + .braintrustConfigBuilderHook( + configBuilder -> { + configBuilder.exportSpansInMemoryForUnitTest(true); + }) + .build(); + + // Start server in background thread + serverThread = + new Thread( + () -> { + try { + server.start(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + serverThread.start(); + + // Give server time to start + Thread.sleep(1000); + } + + @AfterAll + static void tearDown() { + if (server != null) { + server.stop(); + } + if (serverThread != null) { + serverThread.interrupt(); + } + if (mockApiServer != null) { + mockApiServer.stop(0); + } + } + + @Test + void testHealthCheck() throws Exception { + // Test health check endpoint using the shared devserver + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder().uri(URI.create(TEST_URL + "/")).GET().build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals("Hello, world!", response.body()); + assertEquals("text/plain", response.headers().firstValue("Content-Type").orElse("")); + } + + @Test + void testStreamingEval() throws Exception { + // Create eval request with inline data using EvalRequest types + EvalRequest evalRequest = new EvalRequest(); + evalRequest.setName("food-type-classifier"); + evalRequest.setStream(true); + + // Create inline data + EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); + + EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); + case1.setInput("apple"); + case1.setExpected("fruit"); + + EvalRequest.EvalCaseData case2 = new EvalRequest.EvalCaseData(); + case2.setInput("carrot"); + case2.setExpected("vegetable"); + + dataSpec.setData(List.of(case1, case2)); + evalRequest.setData(dataSpec); + + // Set parent with playground_id and generation + Map parentSpec = + Map.of( + "object_type", "experiment", + "object_id", "test-playground-id-123", + "propagated_event", + Map.of("span_attributes", Map.of("generation", "test-gen-1"))); + evalRequest.setParent(parentSpec); + + String requestBody = JSON_MAPPER.writeValueAsString(evalRequest); + + // Make POST request to /eval with auth headers + HttpURLConnection conn = + (HttpURLConnection) new URI(TEST_URL + "/eval").toURL().openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setRequestProperty("x-bt-auth-token", "test-token-123"); + conn.setRequestProperty("x-bt-project-id", "test-project-id"); + conn.setRequestProperty("x-bt-org-name", "test-org"); + conn.setDoOutput(true); + + // Write request body + conn.getOutputStream().write(requestBody.getBytes(StandardCharsets.UTF_8)); + conn.getOutputStream().flush(); + + // Read SSE response + assertEquals(200, conn.getResponseCode()); + assertEquals("text/event-stream", conn.getHeaderField("Content-Type")); + + BufferedReader reader = + new BufferedReader( + new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8)); + + List> events = new ArrayList<>(); + String line; + String currentEvent = null; + StringBuilder currentData = new StringBuilder(); + + while ((line = reader.readLine()) != null) { + if (line.startsWith("event: ")) { + currentEvent = line.substring(7); + } else if (line.startsWith("data: ")) { + currentData.append(line.substring(6)); + } else if (line.isEmpty() && currentEvent != null) { + // End of event + events.add(Map.of("event", currentEvent, "data", currentData.toString())); + currentEvent = null; + currentData = new StringBuilder(); + } + } + reader.close(); + + // Assert event structure + assertFalse(events.isEmpty(), "Should have received events"); + + // Count events by type + List> progressEvents = + events.stream().filter(e -> "progress".equals(e.get("event"))).toList(); + List> summaryEvents = + events.stream().filter(e -> "summary".equals(e.get("event"))).toList(); + List> doneEvents = + events.stream().filter(e -> "done".equals(e.get("event"))).toList(); + + // Should have 1 start event, 2 progress events (one per dataset case), 1 summary, 1 done + assertEquals(2, progressEvents.size(), "Should have 2 progress events"); + assertEquals(1, summaryEvents.size(), "Should have 1 summary event"); + assertEquals(1, doneEvents.size(), "Should have 1 done event"); + + // Verify progress events match expected structure + for (Map progressEvent : progressEvents) { + String dataJson = progressEvent.get("data"); + JsonNode progressData = JSON_MAPPER.readTree(dataJson); + + // Assert expected fields in progress event + assertTrue(progressData.has("id"), "Progress event should have id"); + assertEquals("task", progressData.get("object_type").asText()); + assertEquals("food-type-classifier", progressData.get("name").asText()); + assertEquals("code", progressData.get("format").asText()); + assertEquals("completion", progressData.get("output_type").asText()); + assertEquals("json_delta", progressData.get("event").asText()); + + // Assert data field contains the task result + String taskResultJson = progressData.get("data").asText(); + assertEquals("\"java-fruit\"", taskResultJson); + } + + { // Verify summary event + Map summaryEvent = summaryEvents.get(0); + JsonNode summaryData = JSON_MAPPER.readTree(summaryEvent.get("data")); + + assertEquals("test-project", summaryData.get("projectName").asText()); + assertTrue(summaryData.has("projectId")); + assertEquals("food-type-classifier", summaryData.get("experimentName").asText()); + + // Verify scores in summary + assertTrue(summaryData.has("scores")); + JsonNode scores = summaryData.get("scores"); + assertTrue(scores.has("simple_scorer")); + JsonNode simpleScorer = scores.get("simple_scorer"); + assertEquals("simple_scorer", simpleScorer.get("name").asText()); + assertEquals(0.7, simpleScorer.get("score").asDouble(), 0.001); + } + + // Get exported spans from test harness (since devserver uses global tracer) + List exportedSpans = testHarness.awaitExportedSpans(); + assertFalse(exportedSpans.isEmpty(), "Should have exported spans"); + + // We should have 2 eval traces (one per dataset case), each with task, score, and custom + // spans + // Each trace has: 1 eval span, 1 task span, 1 score span, 1 custom-task-span = 4 spans + // per case + // Total: 2 cases * 4 spans = 8 spans + assertEquals(8, exportedSpans.size(), "Should have 8 spans (4 per dataset case)"); + + // Verify span types + var evalSpans = exportedSpans.stream().filter(s -> s.getName().equals("eval")).toList(); + var taskSpans = exportedSpans.stream().filter(s -> s.getName().equals("task")).toList(); + var scoreSpans = exportedSpans.stream().filter(s -> s.getName().equals("score")).toList(); + var customSpans = + exportedSpans.stream().filter(s -> s.getName().equals("custom-task-span")).toList(); + + assertEquals(2, evalSpans.size(), "Should have 2 eval spans"); + assertEquals(2, taskSpans.size(), "Should have 2 task spans"); + assertEquals(2, scoreSpans.size(), "Should have 2 score spans"); + assertEquals(2, customSpans.size(), "Should have 2 custom-task-span spans"); + + // Verify eval spans have all required attributes + for (SpanData evalSpan : evalSpans) { + // Verify braintrust.parent + String parent = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull(parent, "Eval span should have parent attribute"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Parent should contain playground_id"); + + String spanAttrsJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.span_attributes")); + assertNotNull(spanAttrsJson, "Eval span should have span_attributes"); + JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); + assertEquals("eval", spanAttrs.get("type").asText()); + assertEquals("eval", spanAttrs.get("name").asText()); + assertEquals("test-gen-1", spanAttrs.get("generation").asText()); + + String inputJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.input_json")); + assertNotNull(inputJson, "Eval span should have input_json"); + + String expectedJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.expected_json")); + assertNotNull(expectedJson, "Eval span should have expected_json"); + + String outputJson = + evalSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.output_json")); + assertNotNull(outputJson, "Eval span should have output_json"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertEquals("java-fruit", output.get("output").asText()); + } + + for (SpanData taskSpan : taskSpans) { + // Verify braintrust.parent + String parent = + taskSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull(parent, "Task span should have parent attribute"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Parent should contain playground_id"); + + String spanAttrsJson = + taskSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.span_attributes")); + assertNotNull(spanAttrsJson, "Task span should have span_attributes"); + JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); + assertEquals("task", spanAttrs.get("type").asText()); + assertEquals("task", spanAttrs.get("name").asText()); + assertEquals("test-gen-1", spanAttrs.get("generation").asText()); + + String inputJson = + taskSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.input_json")); + assertNotNull(inputJson, "Task span should have input_json"); + + String outputJson = + taskSpan.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.output_json")); + assertNotNull(outputJson, "Task span should have output_json"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertEquals("java-fruit", output.get("output").asText()); + } + + for (SpanData scoreSpan : scoreSpans) { + // Verify braintrust.parent + String parent = + scoreSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull(parent, "Score span should have parent attribute"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Parent should contain playground_id"); + + // Verify braintrust.span_attributes + String spanAttrsJson = + scoreSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.span_attributes")); + assertNotNull(spanAttrsJson, "Score span should have span_attributes"); + JsonNode spanAttrs = JSON_MAPPER.readTree(spanAttrsJson); + assertEquals("score", spanAttrs.get("type").asText()); + assertEquals("simple_scorer", spanAttrs.get("name").asText()); + assertEquals("test-gen-1", spanAttrs.get("generation").asText()); + + // Verify braintrust.output_json contains scores + String outputJson = + scoreSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.output_json")); + assertNotNull(outputJson, "Score span should have output_json"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertTrue(output.has("simple_scorer"), "Output should contain scorer results"); + assertEquals(0.7, output.get("simple_scorer").asDouble(), 0.001); + } + + for (SpanData customSpan : customSpans) { + // Verify it has a parent span (is not a root span) + assertTrue( + customSpan.getParentSpanContext().isValid(), + "Custom span should have a valid parent span context"); + assertNotEquals( + io.opentelemetry.api.trace.SpanId.getInvalid(), + customSpan.getParentSpanId(), + "Custom span should not be a root span (should have parent span ID)"); + + // Verify it has braintrust.parent attribute from baggage propagation + String parent = + customSpan + .getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.parent")); + assertNotNull( + parent, "Custom span should have braintrust.parent attribute from baggage"); + assertTrue( + parent.contains("playground_id:test-playground-id-123"), + "Custom span parent should contain playground_id from baggage propagation"); + } + } + + @Test + void testEvaluatorNotFound() throws Exception { + EvalRequest request = new EvalRequest(); + request.setName("non-existent-eval"); + + EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); + EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); + case1.setInput("test"); + dataSpec.setData(List.of(case1)); + request.setData(dataSpec); + + String requestJson = JSON_MAPPER.writeValueAsString(request); + + HttpClient client = HttpClient.newHttpClient(); + HttpRequest httpRequest = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/eval")) + .POST(HttpRequest.BodyPublishers.ofString(requestJson)) + .header("Content-Type", "application/json") + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = + client.send(httpRequest, HttpResponse.BodyHandlers.ofString()); + + assertEquals(404, response.statusCode()); + assertTrue(response.body().contains("Evaluator not found")); + } + + @Test + void testEvalMethodNotAllowed() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder().uri(URI.create(TEST_URL + "/eval")).GET().build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(405, response.statusCode()); + } + + @Test + void testListEndpoint() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/list")) + .GET() + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "application/json", response.headers().firstValue("Content-Type").orElse(null)); + + // Parse and validate JSON response + JsonNode root = JSON_MAPPER.readTree(response.body()); + + // Should have one evaluator + assertTrue(root.has("food-type-classifier")); + + JsonNode eval = root.get("food-type-classifier"); + + // Check scores + assertTrue(eval.has("scores")); + JsonNode scores = eval.get("scores"); + assertEquals(1, scores.size()); + + assertEquals("simple_scorer", scores.get(0).get("name").asText()); + } + + @Test + void testListMethodNotAllowed() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/list")) + .POST(HttpRequest.BodyPublishers.noBody()) + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(405, response.statusCode()); + } + + @Test + void testListEndpointWithCors() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(TEST_URL + "/list")) + .GET() + .header("Origin", "https://www.braintrust.dev") + .header("x-bt-auth-token", "test-token") + .header("x-bt-project-id", "test-project-id") + .header("x-bt-org-name", "test-org") + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode()); + assertEquals( + "https://www.braintrust.dev", + response.headers().firstValue("Access-Control-Allow-Origin").orElse(null)); + } +} diff --git a/src/test/java/dev/braintrust/devserver/LRUCacheTest.java b/src/test/java/dev/braintrust/devserver/LRUCacheTest.java new file mode 100644 index 0000000..56b41d0 --- /dev/null +++ b/src/test/java/dev/braintrust/devserver/LRUCacheTest.java @@ -0,0 +1,162 @@ +package dev.braintrust.devserver; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +class LRUCacheTest { + + @Test + void testBasicPutAndGet() { + LRUCache cache = new LRUCache<>(3); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + + assertEquals("value1", cache.get("key1")); + assertEquals("value2", cache.get("key2")); + assertNull(cache.get("key3")); + } + + @Test + void testLruEviction() { + LRUCache cache = new LRUCache<>(2); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + assertEquals(2, cache.size()); + + // Adding third item should evict the least recently used (key1) + cache.put("key3", "value3"); + assertEquals(2, cache.size()); + assertNull(cache.get("key1"), "key1 should have been evicted"); + assertEquals("value2", cache.get("key2")); + assertEquals("value3", cache.get("key3")); + } + + @Test + void testLruEvictionWithAccess() { + LRUCache cache = new LRUCache<>(2); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + + // Access key1 to make it more recently used + cache.get("key1"); + + // Adding third item should now evict key2 (least recently used) + cache.put("key3", "value3"); + assertEquals(2, cache.size()); + assertEquals("value1", cache.get("key1"), "key1 should still be present"); + assertNull(cache.get("key2"), "key2 should have been evicted"); + assertEquals("value3", cache.get("key3")); + } + + @Test + void testGetOrComputeCacheHit() { + LRUCache cache = new LRUCache<>(3); + AtomicInteger computeCount = new AtomicInteger(0); + + // First call should compute and cache + String result1 = + cache.getOrCompute( + "key1", + () -> { + computeCount.incrementAndGet(); + return "computed-value"; + }); + assertEquals("computed-value", result1); + assertEquals(1, computeCount.get()); + + // Second call should return cached value without computing + String result2 = + cache.getOrCompute( + "key1", + () -> { + computeCount.incrementAndGet(); + return "should-not-be-called"; + }); + assertEquals("computed-value", result2); + assertEquals(1, computeCount.get(), "Supplier should not have been called on cache hit"); + } + + @Test + void testGetOrComputeCacheMiss() { + LRUCache cache = new LRUCache<>(3); + AtomicInteger computeCount = new AtomicInteger(0); + + String result = + cache.getOrCompute( + "key1", + () -> { + computeCount.incrementAndGet(); + return "value-" + computeCount.get(); + }); + + assertEquals("value-1", result); + assertEquals(1, computeCount.get()); + assertTrue(cache.containsKey("key1")); + assertEquals("value-1", cache.get("key1")); + } + + @Test + void testContainsKey() { + LRUCache cache = new LRUCache<>(3); + + assertFalse(cache.containsKey("key1")); + + cache.put("key1", "value1"); + assertTrue(cache.containsKey("key1")); + assertFalse(cache.containsKey("key2")); + } + + @Test + void testClear() { + LRUCache cache = new LRUCache<>(3); + + cache.put("key1", "value1"); + cache.put("key2", "value2"); + assertEquals(2, cache.size()); + + cache.clear(); + assertEquals(0, cache.size()); + assertFalse(cache.containsKey("key1")); + assertFalse(cache.containsKey("key2")); + } + + @Test + void testSize() { + LRUCache cache = new LRUCache<>(5); + + assertEquals(0, cache.size()); + + cache.put("key1", "value1"); + assertEquals(1, cache.size()); + + cache.put("key2", "value2"); + cache.put("key3", "value3"); + assertEquals(3, cache.size()); + + cache.clear(); + assertEquals(0, cache.size()); + } + + @Test + void testMaxSizeEnforcement() { + LRUCache cache = new LRUCache<>(3); + + cache.put(1, "one"); + cache.put(2, "two"); + cache.put(3, "three"); + assertEquals(3, cache.size()); + + // Adding more items should maintain max size + cache.put(4, "four"); + assertEquals(3, cache.size(), "Cache should not exceed max size"); + + cache.put(5, "five"); + cache.put(6, "six"); + assertEquals(3, cache.size(), "Cache should not exceed max size"); + } +}