diff --git a/build.gradle b/build.gradle index f722964..ad1b078 100644 --- a/build.gradle +++ b/build.gradle @@ -46,6 +46,7 @@ ext { jacksonVersion = '2.16.1' junitVersion = '5.11.4' slf4jVersion = '2.0.17' + langchainVersion = '1.9.1' } dependencies { @@ -86,6 +87,14 @@ dependencies { // Google GenAI Instrumentation compileOnly "com.google.genai:google-genai:1.20.0" testImplementation "com.google.genai:google-genai:1.20.0" + + // LangChain4j Instrumentation + compileOnly "dev.langchain4j:langchain4j:${langchainVersion}" + compileOnly "dev.langchain4j:langchain4j-http-client:${langchainVersion}" + compileOnly "dev.langchain4j:langchain4j-open-ai:${langchainVersion}" + testImplementation "dev.langchain4j:langchain4j:${langchainVersion}" + testImplementation "dev.langchain4j:langchain4j-http-client:${langchainVersion}" + testImplementation "dev.langchain4j:langchain4j-open-ai:${langchainVersion}" } /** diff --git a/examples/build.gradle b/examples/build.gradle index d1e747c..f4e9c20 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -35,6 +35,9 @@ dependencies { implementation('org.springframework.boot:spring-boot-starter:3.4.1') { exclude group: 'org.springframework.boot', module: 'spring-boot-starter-logging' } + // to run langchain4j examples + implementation 'dev.langchain4j:langchain4j:1.9.1' + implementation 'dev.langchain4j:langchain4j-open-ai:1.9.1' } application { @@ -142,7 +145,6 @@ task runSpringAI(type: JavaExec) { } } - task runRemoteEval(type: JavaExec) { group = 'Braintrust SDK Examples' description = 'Run the remote eval example' @@ -156,3 +158,17 @@ task runRemoteEval(type: JavaExec) { suspend = false } } + +task runLangchain(type: JavaExec) { + group = 'Braintrust SDK Examples' + description = 'Run the LangChain4j instrumentation example. NOTE: this requires OPENAI_API_KEY to be exported and will make a small call to openai, using your tokens' + classpath = sourceSets.main.runtimeClasspath + mainClass = 'dev.braintrust.examples.LangchainExample' + systemProperty 'org.slf4j.simpleLogger.log.dev.braintrust', braintrustLogLevel + debugOptions { + enabled = true + port = 5566 + server = true + suspend = false + } +} diff --git a/examples/src/main/java/dev/braintrust/examples/LangchainExample.java b/examples/src/main/java/dev/braintrust/examples/LangchainExample.java new file mode 100644 index 0000000..117272f --- /dev/null +++ b/examples/src/main/java/dev/braintrust/examples/LangchainExample.java @@ -0,0 +1,54 @@ +package dev.braintrust.examples; + +import dev.braintrust.Braintrust; +import dev.braintrust.instrumentation.langchain.BraintrustLangchain; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; + +/** Basic OTel + LangChain4j instrumentation example */ +public class LangchainExample { + + public static void main(String[] args) throws Exception { + if (null == System.getenv("OPENAI_API_KEY")) { + System.err.println( + "\nWARNING envar OPENAI_API_KEY not found. This example will likely fail.\n"); + } + var braintrust = Braintrust.get(); + var openTelemetry = braintrust.openTelemetryCreate(); + + ChatModel model = + BraintrustLangchain.wrap( + openTelemetry, + OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName("gpt-4o-mini") + .temperature(0.0)); + + var rootSpan = + openTelemetry + .getTracer("my-instrumentation") + .spanBuilder("langchain4j-instrumentation-example") + .startSpan(); + try (var ignored = rootSpan.makeCurrent()) { + chatExample(model); + } finally { + rootSpan.end(); + } + var url = + braintrust.projectUri() + + "/logs?r=%s&s=%s" + .formatted( + rootSpan.getSpanContext().getTraceId(), + rootSpan.getSpanContext().getSpanId()); + System.out.println( + "\n\n Example complete! View your data in Braintrust: %s\n".formatted(url)); + } + + private static void chatExample(ChatModel model) { + var message = UserMessage.from("What is the capital of France?"); + var response = model.chat(message); + System.out.println( + "\n~~~ LANGCHAIN4J CHAT RESPONSE: %s\n".formatted(response.aiMessage().text())); + } +} diff --git a/src/main/java/dev/braintrust/instrumentation/langchain/BraintrustLangchain.java b/src/main/java/dev/braintrust/instrumentation/langchain/BraintrustLangchain.java new file mode 100644 index 0000000..a3aafa8 --- /dev/null +++ b/src/main/java/dev/braintrust/instrumentation/langchain/BraintrustLangchain.java @@ -0,0 +1,66 @@ +package dev.braintrust.instrumentation.langchain; + +import dev.langchain4j.http.client.HttpClientBuilder; +import dev.langchain4j.http.client.HttpClientBuilderLoader; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import io.opentelemetry.api.OpenTelemetry; +import lombok.extern.slf4j.Slf4j; + +/** Braintrust LangChain4j client instrumentation. */ +@Slf4j +public final class BraintrustLangchain { + /** Instrument langchain openai chat model with braintrust traces */ + public static OpenAiChatModel wrap( + OpenTelemetry otel, OpenAiChatModel.OpenAiChatModelBuilder builder) { + try { + HttpClientBuilder underlyingHttpClient = getPrivateField(builder, "httpClientBuilder"); + if (underlyingHttpClient == null) { + underlyingHttpClient = HttpClientBuilderLoader.loadHttpClientBuilder(); + } + HttpClientBuilder wrappedHttpClient = + wrap(otel, underlyingHttpClient, new Options("openai")); + return builder.httpClientBuilder(wrappedHttpClient).build(); + } catch (Exception e) { + log.warn( + "Braintrust instrumentation could not be applied to OpenAiChatModel builder", + e); + return builder.build(); + } + } + + /** Instrument langchain openai chat model with braintrust traces */ + public static OpenAiStreamingChatModel wrap( + OpenTelemetry otel, OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder) { + try { + HttpClientBuilder underlyingHttpClient = getPrivateField(builder, "httpClientBuilder"); + if (underlyingHttpClient == null) { + underlyingHttpClient = HttpClientBuilderLoader.loadHttpClientBuilder(); + } + HttpClientBuilder wrappedHttpClient = + wrap(otel, underlyingHttpClient, new Options("openai")); + return builder.httpClientBuilder(wrappedHttpClient).build(); + } catch (Exception e) { + log.warn( + "Braintrust instrumentation could not be applied to OpenAiStreamingChatModel" + + " builder", + e); + return builder.build(); + } + } + + private static HttpClientBuilder wrap( + OpenTelemetry otel, HttpClientBuilder builder, Options options) { + return new WrappedHttpClientBuilder(otel, builder, options); + } + + public record Options(String providerName) {} + + @SuppressWarnings("unchecked") + private static T getPrivateField(Object obj, String fieldName) + throws ReflectiveOperationException { + java.lang.reflect.Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return (T) field.get(obj); + } +} diff --git a/src/main/java/dev/braintrust/instrumentation/langchain/WrappedHttpClient.java b/src/main/java/dev/braintrust/instrumentation/langchain/WrappedHttpClient.java new file mode 100644 index 0000000..d9215f7 --- /dev/null +++ b/src/main/java/dev/braintrust/instrumentation/langchain/WrappedHttpClient.java @@ -0,0 +1,387 @@ +package dev.braintrust.instrumentation.langchain; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.braintrust.trace.BraintrustTracing; +import dev.langchain4j.exception.HttpException; +import dev.langchain4j.http.client.HttpClient; +import dev.langchain4j.http.client.HttpRequest; +import dev.langchain4j.http.client.SuccessfulHttpResponse; +import dev.langchain4j.http.client.sse.ServerSentEvent; +import dev.langchain4j.http.client.sse.ServerSentEventContext; +import dev.langchain4j.http.client.sse.ServerSentEventListener; +import dev.langchain4j.http.client.sse.ServerSentEventParser; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Scope; +import java.util.HashMap; +import java.util.Map; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +class WrappedHttpClient implements HttpClient { + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + + private final Tracer tracer; + private final HttpClient underlying; + private final BraintrustLangchain.Options options; + + public WrappedHttpClient( + OpenTelemetry openTelemetry, + HttpClient underlying, + BraintrustLangchain.Options options) { + this.tracer = BraintrustTracing.getTracer(openTelemetry); + this.underlying = underlying; + this.options = options; + } + + @Override + public SuccessfulHttpResponse execute(HttpRequest request) + throws HttpException, RuntimeException { + ProviderInfo providerInfo = + new ProviderInfo(options.providerName(), extractEndpoint(request)); + Span span = startNewSpan(getSpanName(providerInfo)); + try (Scope scope = span.makeCurrent()) { + tagSpan(span, request, providerInfo); + final long startTime = System.nanoTime(); + var response = underlying.execute(request); + final long endTime = System.nanoTime(); + double timeToFirstToken = (endTime - startTime) / 1_000_000_000.0; + tagSpan(span, response, providerInfo, timeToFirstToken); + return response; + } catch (Throwable t) { + tagSpan(span, t); + throw t; + } finally { + span.end(); + } + } + + @Override + public void execute(HttpRequest request, ServerSentEventListener listener) { + if (listener instanceof WrappedServerSentEventListener) { + // we've already applied instrumentation + underlying.execute(request, listener); + return; + } + ProviderInfo providerInfo = + new ProviderInfo(options.providerName(), extractEndpoint(request)); + Span span = startNewSpan(getSpanName(providerInfo)); + try (Scope scope = span.makeCurrent()) { + tagSpan(span, request, providerInfo); + underlying.execute( + request, new WrappedServerSentEventListener(listener, span, providerInfo)); + } catch (Throwable t) { + // unlikely to happen, but just in case + tagSpan(span, t); + span.end(); + throw t; + } + } + + @Override + public void execute( + HttpRequest request, ServerSentEventParser parser, ServerSentEventListener listener) { + if (listener instanceof WrappedServerSentEventListener) { + // we've already applied instrumentation + underlying.execute(request, parser, listener); + return; + } + ProviderInfo providerInfo = + new ProviderInfo(options.providerName(), extractEndpoint(request)); + Span span = startNewSpan(getSpanName(providerInfo)); + try { + tagSpan(span, request, providerInfo); + underlying.execute( + request, + parser, + new WrappedServerSentEventListener(listener, span, providerInfo)); + } catch (Throwable t) { + // unlikely to happen, but just in case + tagSpan(span, t); + span.end(); + throw t; + } + } + + /** Extract endpoint path from the request URL. */ + private static String extractEndpoint(HttpRequest request) { + try { + java.net.URI uri = new java.net.URI(request.url()); + return uri.getPath(); + } catch (Exception e) { + log.debug("Failed to parse URL: {}", request.url(), e); + return ""; + } + } + + /** Get span name based on the provider and endpoint. */ + private static String getSpanName(ProviderInfo info) { + if (info.endpoint.contains("/chat/completions") + || info.endpoint.contains("/v1/completions")) { + return "Chat Completion"; + } else if (info.endpoint.contains("/embeddings")) { + return "Embeddings"; + } else if (info.endpoint.contains("/messages")) { + return "Messages"; + } + return info.endpoint(); + } + + private Span startNewSpan(String spanName) { + return tracer.spanBuilder(spanName).setSpanKind(SpanKind.CLIENT).startSpan(); + } + + /** Tag span with request data: input messages, model, provider. */ + private static void tagSpan(Span span, HttpRequest request, ProviderInfo providerInfo) { + try { + span.setAttribute("braintrust.span_attributes", json(Map.of("type", "llm"))); + + // Build metadata map + Map metadata = new HashMap<>(); + metadata.put("provider", providerInfo.provider); + + // Parse request body to extract model and messages + String body = request.body(); + if (body != null && !body.isEmpty()) { + JsonNode requestJson = JSON_MAPPER.readTree(body); + + // Extract model + if (requestJson.has("model")) { + String model = requestJson.get("model").asText(); + metadata.put("model", model); + } + + // Extract messages array for input + if (requestJson.has("messages")) { + String messagesJson = json(requestJson.get("messages")); + span.setAttribute("braintrust.input_json", messagesJson); + } + } + + // Serialize metadata as JSON + span.setAttribute("braintrust.metadata", json(metadata)); + } catch (Exception e) { + log.debug("Failed to parse request for span tagging", e); + } + } + + /** Tag span with response data: output messages, usage metrics. */ + private static void tagSpan( + Span span, + SuccessfulHttpResponse response, + ProviderInfo providerInfo, + double timeToFirstToken) { + try { + // Build metrics map + Map metrics = new HashMap<>(); + metrics.put("time_to_first_token", timeToFirstToken); + + String body = response.body(); + if (body != null && !body.isEmpty()) { + JsonNode responseJson = JSON_MAPPER.readTree(body); + + // Extract choices array for output + if (responseJson.has("choices")) { + String choicesJson = json(responseJson.get("choices")); + span.setAttribute("braintrust.output_json", choicesJson); + } + + // Extract usage metrics if present + if (responseJson.has("usage")) { + JsonNode usage = responseJson.get("usage"); + if (usage.has("prompt_tokens")) { + metrics.put("prompt_tokens", usage.get("prompt_tokens").asLong()); + } + if (usage.has("completion_tokens")) { + metrics.put("completion_tokens", usage.get("completion_tokens").asLong()); + } + if (usage.has("total_tokens")) { + metrics.put("tokens", usage.get("total_tokens").asLong()); + } + } + } + + span.setAttribute("braintrust.metrics", json(metrics)); + } catch (Exception e) { + log.debug("Failed to parse response for span tagging", e); + } + } + + /** Tag span with error information. */ + private static void tagSpan(Span span, Throwable t) { + span.setStatus(StatusCode.ERROR, t.getMessage()); + span.recordException(t); + } + + @SneakyThrows + private static String json(Object o) { + return JSON_MAPPER.writeValueAsString(o); + } + + /** + * Wraps a ServerSentEventListener to properly end the span when streaming completes or errors. + * Also buffers streaming chunks to extract usage data. + */ + private static class WrappedServerSentEventListener implements ServerSentEventListener { + private final ServerSentEventListener delegate; + private final Span span; + private final ProviderInfo providerInfo; + private final StringBuilder outputBuffer = new StringBuilder(); + private long firstTokenTime = 0; + private final long startTime; + private JsonNode usageData = null; + + WrappedServerSentEventListener( + ServerSentEventListener delegate, Span span, ProviderInfo providerInfo) { + this.delegate = delegate; + this.span = span; + this.providerInfo = providerInfo; + this.startTime = System.nanoTime(); + } + + @Override + public void onOpen(SuccessfulHttpResponse response) { + delegate.onOpen(response); + } + + @Override + public void onEvent(ServerSentEvent event, ServerSentEventContext context) { + instrumentEvent(event); + delegate.onEvent(event, context); + } + + @Override + public void onEvent(ServerSentEvent event) { + instrumentEvent(event); + delegate.onEvent(event); + } + + private void instrumentEvent(ServerSentEvent event) { + String data = event.data(); + if (data == null || data.isEmpty() || "[DONE]".equals(data)) { + return; + } + + // Track time to first token + if (firstTokenTime == 0) { + firstTokenTime = System.nanoTime(); + } + + // Buffer the data for final processing + try { + JsonNode chunk = JSON_MAPPER.readTree(data); + + // For streaming, we accumulate deltas into the complete message + // Just track if we have any content + if (chunk.has("choices") && chunk.get("choices").size() > 0) { + JsonNode choice = chunk.get("choices").get(0); + if (choice.has("delta")) { + JsonNode delta = choice.get("delta"); + if (delta.has("content")) { + String content = delta.get("content").asText(); + outputBuffer.append(content); + } + } + } + + // Extract usage data if present (usually in the last chunk) + if (chunk.has("usage")) { + usageData = chunk.get("usage"); + } + } catch (Exception e) { + log.debug("Failed to parse streaming event: {}", data, e); + } + } + + @Override + public void onError(Throwable error) { + try { + delegate.onError(error); + } finally { + tagSpan(span, error); + finalizeSpan(); + span.end(); + } + } + + @Override + public void onClose() { + try { + delegate.onClose(); + } finally { + finalizeSpan(); + span.end(); + } + } + + private void finalizeSpan() { + // Build metrics map for streaming + Map metrics = new HashMap<>(); + + // Add time to first token if we have it + if (firstTokenTime > 0) { + double timeToFirstToken = (firstTokenTime - startTime) / 1_000_000_000.0; + metrics.put("time_to_first_token", timeToFirstToken); + } + + // Reconstruct output as a choices array for streaming + // Format: [{"index": 0, "finish_reason": "stop", "message": {"role": "assistant", + // "content": "..."}}] + if (outputBuffer.length() > 0) { + try { + // Create a proper choice object matching OpenAI API format + var choiceBuilder = JSON_MAPPER.createObjectNode(); + choiceBuilder.put("index", 0); + choiceBuilder.put("finish_reason", "stop"); + + var messageNode = JSON_MAPPER.createObjectNode(); + messageNode.put("role", "assistant"); + messageNode.put("content", outputBuffer.toString()); + + choiceBuilder.set("message", messageNode); + + var choicesArray = JSON_MAPPER.createArrayNode(); + choicesArray.add(choiceBuilder); + + span.setAttribute("braintrust.output_json", choicesArray.toString()); + } catch (Exception e) { + log.debug("Failed to reconstruct streaming output", e); + } + } + + // Set usage metrics if we collected them + if (usageData != null) { + try { + if (usageData.has("prompt_tokens")) { + metrics.put("prompt_tokens", usageData.get("prompt_tokens").asLong()); + } + if (usageData.has("completion_tokens")) { + metrics.put( + "completion_tokens", usageData.get("completion_tokens").asLong()); + } + if (usageData.has("total_tokens")) { + metrics.put("tokens", usageData.get("total_tokens").asLong()); + } + } catch (Exception e) { + log.debug("Failed to extract usage metrics from streaming data", e); + } + } + + // Serialize metrics as JSON + try { + if (!metrics.isEmpty()) { + span.setAttribute("braintrust.metrics", json(metrics)); + } + } catch (Exception e) { + log.debug("Failed to serialize metrics", e); + } + } + } + + private record ProviderInfo(String provider, String endpoint) {} +} diff --git a/src/main/java/dev/braintrust/instrumentation/langchain/WrappedHttpClientBuilder.java b/src/main/java/dev/braintrust/instrumentation/langchain/WrappedHttpClientBuilder.java new file mode 100644 index 0000000..a78aa39 --- /dev/null +++ b/src/main/java/dev/braintrust/instrumentation/langchain/WrappedHttpClientBuilder.java @@ -0,0 +1,48 @@ +package dev.braintrust.instrumentation.langchain; + +import dev.langchain4j.http.client.HttpClient; +import dev.langchain4j.http.client.HttpClientBuilder; +import io.opentelemetry.api.OpenTelemetry; +import java.time.Duration; + +class WrappedHttpClientBuilder implements HttpClientBuilder { + private final OpenTelemetry openTelemetry; + private final HttpClientBuilder underlying; + private final BraintrustLangchain.Options options; + + public WrappedHttpClientBuilder( + OpenTelemetry openTelemetry, + HttpClientBuilder underlying, + BraintrustLangchain.Options options) { + this.openTelemetry = openTelemetry; + this.underlying = underlying; + this.options = options; + } + + @Override + public Duration connectTimeout() { + return underlying.connectTimeout(); + } + + @Override + public HttpClientBuilder connectTimeout(Duration timeout) { + underlying.connectTimeout(timeout); + return this; + } + + @Override + public Duration readTimeout() { + return underlying.readTimeout(); + } + + @Override + public HttpClientBuilder readTimeout(Duration timeout) { + underlying.readTimeout(timeout); + return this; + } + + @Override + public HttpClient build() { + return new WrappedHttpClient(openTelemetry, underlying.build(), options); + } +} diff --git a/src/test/java/dev/braintrust/TestHarness.java b/src/test/java/dev/braintrust/TestHarness.java index 3597e12..aa139b7 100644 --- a/src/test/java/dev/braintrust/TestHarness.java +++ b/src/test/java/dev/braintrust/TestHarness.java @@ -1,6 +1,7 @@ package dev.braintrust; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import dev.braintrust.api.BraintrustApiClient; import dev.braintrust.config.BraintrustConfig; @@ -22,6 +23,7 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nonnull; import lombok.Getter; +import lombok.SneakyThrows; import lombok.experimental.Accessors; public class TestHarness { @@ -111,6 +113,30 @@ public List awaitExportedSpans() { return spanExporter.getFinishedSpanItems(); } + /** + * flush all pending spans and return all spans which have been exported so far + * + *

repeat the process until the number of exported spans equals or exceeds `minSpanCount` + */ + @SneakyThrows + public List awaitExportedSpans(int minSpanCount) { + var spans = awaitExportedSpans(); + int attempts = 0; + while (spans.size() < minSpanCount) { + attempts++; + if (attempts > 30) { + fail( + String.format( + "Timeout waiting for spans: expected at least %d spans, but got %d" + + " after %d attempts", + minSpanCount, spans.size(), attempts)); + } + Thread.sleep(1000); + spans = awaitExportedSpans(); + } + return spans; + } + private static BraintrustApiClient.InMemoryImpl createApiClient() { var orgInfo = new dev.braintrust.api.BraintrustApiClient.OrganizationInfo( diff --git a/src/test/java/dev/braintrust/instrumentation/langchain/BraintrustLangchainTest.java b/src/test/java/dev/braintrust/instrumentation/langchain/BraintrustLangchainTest.java new file mode 100644 index 0000000..9d9c983 --- /dev/null +++ b/src/test/java/dev/braintrust/instrumentation/langchain/BraintrustLangchainTest.java @@ -0,0 +1,305 @@ +package dev.braintrust.instrumentation.langchain; + +import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.junit5.WireMockExtension; +import dev.braintrust.TestHarness; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import io.opentelemetry.api.common.AttributeKey; +import java.util.concurrent.CompletableFuture; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class BraintrustLangchainTest { + + @RegisterExtension + static WireMockExtension wireMock = + WireMockExtension.newInstance().options(wireMockConfig().dynamicPort()).build(); + + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + + private TestHarness testHarness; + + @BeforeEach + void beforeEach() { + testHarness = TestHarness.setup(); + wireMock.resetAll(); + } + + @Test + @SneakyThrows + void testSyncChatCompletion() { + // Mock the OpenAI API response + wireMock.stubFor( + post(urlEqualTo("/v1/chat/completions")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "id": "chatcmpl-test123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The capital of France is Paris." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 20, + "completion_tokens": 8, + "total_tokens": 28 + } + } + """))); + + // Create LangChain4j client with Braintrust instrumentation + ChatModel model = + BraintrustLangchain.wrap( + testHarness.openTelemetry(), + OpenAiChatModel.builder() + .apiKey("test-api-key") + .baseUrl("http://localhost:" + wireMock.getPort() + "/v1") + .modelName("gpt-4o-mini") + .temperature(0.0)); + + // Execute chat request + var message = UserMessage.from("What is the capital of France?"); + var response = model.chat(message); + + // Verify the response + assertNotNull(response); + assertEquals("The capital of France is Paris.", response.aiMessage().text()); + wireMock.verify(1, postRequestedFor(urlEqualTo("/v1/chat/completions"))); + + // Verify spans were exported + var spans = testHarness.awaitExportedSpans(); + assertEquals(1, spans.size(), "Expected one span for sync chat completion"); + var span = spans.get(0); + + // Verify span name + assertEquals("Chat Completion", span.getName(), "Span name should be 'Chat Completion'"); + + // Verify span attributes + var attributes = span.getAttributes(); + var braintrustSpanAttributesJson = + attributes.get(AttributeKey.stringKey("braintrust.span_attributes")); + + // Verify span type + JsonNode spanAttributes = JSON_MAPPER.readTree(braintrustSpanAttributesJson); + assertEquals("llm", spanAttributes.get("type").asText(), "Span type should be 'llm'"); + + // Verify metadata + String metadataJson = attributes.get(AttributeKey.stringKey("braintrust.metadata")); + assertNotNull(metadataJson, "Metadata should be present"); + JsonNode metadata = JSON_MAPPER.readTree(metadataJson); + assertEquals("openai", metadata.get("provider").asText(), "Provider should be 'openai'"); + assertEquals( + "gpt-4o-mini", metadata.get("model").asText(), "Model should be 'gpt-4o-mini'"); + + // Verify metrics + String metricsJson = attributes.get(AttributeKey.stringKey("braintrust.metrics")); + assertNotNull(metricsJson, "Metrics should be present"); + JsonNode metrics = JSON_MAPPER.readTree(metricsJson); + assertEquals(28, metrics.get("tokens").asLong(), "Total tokens should be 28"); + assertEquals(20, metrics.get("prompt_tokens").asLong(), "Prompt tokens should be 20"); + assertEquals(8, metrics.get("completion_tokens").asLong(), "Completion tokens should be 8"); + assertTrue( + metrics.has("time_to_first_token"), "Metrics should contain time_to_first_token"); + assertTrue( + metrics.get("time_to_first_token").isNumber(), + "time_to_first_token should be a number"); + + // Verify input + String inputJson = attributes.get(AttributeKey.stringKey("braintrust.input_json")); + assertNotNull(inputJson, "Input should be present"); + JsonNode input = JSON_MAPPER.readTree(inputJson); + assertTrue(input.isArray(), "Input should be an array"); + assertTrue(input.size() > 0, "Input array should not be empty"); + assertTrue( + input.get(0).get("content").asText().contains("What is the capital of France"), + "Input should contain the user message"); + + // Verify output + String outputJson = attributes.get(AttributeKey.stringKey("braintrust.output_json")); + assertNotNull(outputJson, "Output should be present"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertTrue(output.isArray(), "Output should be an array"); + assertTrue(output.size() > 0, "Output array should not be empty"); + assertTrue( + output.get(0) + .get("message") + .get("content") + .asText() + .contains("The capital of France is Paris"), + "Output should contain the assistant response"); + } + + @Test + @SneakyThrows + void testStreamingChatCompletion() { + // Mock the OpenAI API streaming response + String streamingResponse = + """ + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":"The"},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":" capital"},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":" of"},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":" France"},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":" Paris"},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":"."},"finish_reason":null}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: {"id":"chatcmpl-test123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o-mini","choices":[],"usage":{"prompt_tokens":20,"completion_tokens":8,"total_tokens":28}} + + data: [DONE] + + """; + + wireMock.stubFor( + post(urlEqualTo("/v1/chat/completions")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "text/event-stream") + .withBody(streamingResponse))); + + // Create LangChain4j streaming client with Braintrust instrumentation + StreamingChatModel model = + BraintrustLangchain.wrap( + testHarness.openTelemetry(), + OpenAiStreamingChatModel.builder() + .apiKey("test-api-key") + .baseUrl("http://localhost:" + wireMock.getPort() + "/v1") + .modelName("gpt-4o-mini") + .temperature(0.0)); + + // Execute streaming chat request + var future = new CompletableFuture(); + var responseBuilder = new StringBuilder(); + + model.chat( + "What is the capital of France?", + new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String token) { + responseBuilder.append(token); + } + + @Override + public void onCompleteResponse(ChatResponse response) { + future.complete(response); + } + + @Override + public void onError(Throwable error) { + future.completeExceptionally(error); + } + }); + + // Wait for completion + var response = future.get(); + + // Verify the response + assertNotNull(response); + assertEquals("The capital of France is Paris.", responseBuilder.toString()); + wireMock.verify(1, postRequestedFor(urlEqualTo("/v1/chat/completions"))); + + // Verify spans were exported + var spans = testHarness.awaitExportedSpans(1); + assertEquals(1, spans.size(), "Expected one span for streaming chat completion"); + var span = spans.get(0); + + // Verify span name + assertEquals("Chat Completion", span.getName(), "Span name should be 'Chat Completion'"); + + // Verify span attributes + var attributes = span.getAttributes(); + + var braintrustSpanAttributesJson = + attributes.get(AttributeKey.stringKey("braintrust.span_attributes")); + + // Verify span type + JsonNode spanAttributes = JSON_MAPPER.readTree(braintrustSpanAttributesJson); + assertEquals("llm", spanAttributes.get("type").asText(), "Span type should be 'llm'"); + + // Verify metadata + String metadataJson = attributes.get(AttributeKey.stringKey("braintrust.metadata")); + assertNotNull(metadataJson, "Metadata should be present"); + JsonNode metadata = JSON_MAPPER.readTree(metadataJson); + assertEquals("openai", metadata.get("provider").asText(), "Provider should be 'openai'"); + assertEquals( + "gpt-4o-mini", metadata.get("model").asText(), "Model should be 'gpt-4o-mini'"); + + // Verify metrics for streaming + String metricsJson = attributes.get(AttributeKey.stringKey("braintrust.metrics")); + assertNotNull(metricsJson, "Metrics should be present"); + JsonNode metrics = JSON_MAPPER.readTree(metricsJson); + assertEquals(28, metrics.get("tokens").asLong(), "Total tokens should be 28"); + assertEquals(20, metrics.get("prompt_tokens").asLong(), "Prompt tokens should be 20"); + assertEquals(8, metrics.get("completion_tokens").asLong(), "Completion tokens should be 8"); + assertTrue( + metrics.has("time_to_first_token"), + "Metrics should contain time_to_first_token for streaming"); + assertTrue( + metrics.get("time_to_first_token").isNumber(), + "time_to_first_token should be a number"); + + // Verify input + String inputJson = attributes.get(AttributeKey.stringKey("braintrust.input_json")); + assertNotNull(inputJson, "Input should be present"); + JsonNode input = JSON_MAPPER.readTree(inputJson); + assertTrue(input.isArray(), "Input should be an array"); + assertTrue(input.size() > 0, "Input array should not be empty"); + assertTrue( + input.get(0).get("content").asText().contains("What is the capital of France"), + "Input should contain the user message"); + + // Verify output (streaming reconstructs the output) + String outputJson = attributes.get(AttributeKey.stringKey("braintrust.output_json")); + assertNotNull(outputJson, "Output should be present"); + JsonNode output = JSON_MAPPER.readTree(outputJson); + assertTrue(output.isArray(), "Output should be an array"); + assertTrue(output.size() > 0, "Output array should not be empty"); + JsonNode choice = output.get(0); + assertTrue( + choice.get("message") + .get("content") + .asText() + .contains("The capital of France is Paris"), + "Output should contain the complete streamed response"); + assertEquals( + "stop", + choice.get("finish_reason").asText(), + "Output should have finish_reason 'stop'"); + } +}