Skip to content

Commit 533da39

Browse files
committed
Server: accept json inputs + return bas64 image
1 parent 8529431 commit 533da39

File tree

2 files changed

+204
-46
lines changed

2 files changed

+204
-46
lines changed

examples/server/b64.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
//FROM
3+
//https://stackoverflow.com/a/34571089/5155484
4+
5+
static const std::string b = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";//=
6+
static std::string base64_encode(const std::string &in) {
7+
std::string out;
8+
9+
int val=0, valb=-6;
10+
for (uint8_t c : in) {
11+
val = (val<<8) + c;
12+
valb += 8;
13+
while (valb>=0) {
14+
out.push_back(b[(val>>valb)&0x3F]);
15+
valb-=6;
16+
}
17+
}
18+
if (valb>-6) out.push_back(b[((val<<8)>>(valb+8))&0x3F]);
19+
while (out.size()%4) out.push_back('=');
20+
return out;
21+
}
22+
23+
24+
static std::string base64_decode(const std::string &in) {
25+
26+
std::string out;
27+
28+
std::vector<int> T(256,-1);
29+
for (int i=0; i<64; i++) T[b[i]] = i;
30+
31+
int val=0, valb=-8;
32+
for (uint8_t c : in) {
33+
if (T[c] == -1) break;
34+
val = (val<<6) + T[c];
35+
valb += 6;
36+
if (valb>=0) {
37+
out.push_back(char((val>>valb)&0xFF));
38+
valb-=8;
39+
}
40+
}
41+
return out;
42+
}

examples/server/main.cpp

Lines changed: 162 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
#include <vector>
88

99
// #include "preprocessing.hpp"
10+
#include "b64.cpp"
1011
#include "flux.hpp"
12+
#include "json.hpp"
1113
#include "stable-diffusion.h"
1214

1315
#define STB_IMAGE_IMPLEMENTATION
@@ -49,7 +51,6 @@ const char* schedule_str[] = {
4951
"ays",
5052
};
5153

52-
5354
enum SDMode {
5455
TXT2IMG,
5556
IMG2IMG,
@@ -86,7 +87,6 @@ struct SDParams {
8687
int height = 512;
8788
int batch_count = 1;
8889

89-
9090
sample_method_t sample_method = EULER_A;
9191
schedule_t schedule = DEFAULT;
9292
int sample_steps = 20;
@@ -100,9 +100,9 @@ struct SDParams {
100100
bool vae_on_cpu = false;
101101
bool color = false;
102102

103-
//server things
104-
int port = 8080;
105-
std::string host = "127.0.0.1";
103+
// server things
104+
int port = 8080;
105+
std::string host = "127.0.0.1";
106106
};
107107

108108
void print_params(SDParams params) {
@@ -227,7 +227,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
227227
break;
228228
}
229229
params.vae_path = argv[i];
230-
// TODO Tiny AE
230+
// TODO Tiny AE
231231
} else if (arg == "--type") {
232232
if (++i >= argc) {
233233
invalid_arg = true;
@@ -565,27 +565,113 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
565565
fflush(out_stream);
566566
}
567567

568-
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
568+
static void log_server_request(const httplib::Request& req, const httplib::Response& res) {
569569
printf("request: %s %s (%s)\n", req.method.c_str(), req.path.c_str(), req.body.c_str());
570570
}
571571

572+
void parseJsonPrompt(std::string json_str, SDParams* params) {
573+
using namespace nlohmann;
574+
json payload = json::parse(json_str);
575+
// if no exception, the request is a json object
576+
// now we try to get the new param values from the payload object
577+
// const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
578+
try {
579+
std::string prompt = payload["prompt"];
580+
params->prompt = prompt;
581+
} catch (...) {
582+
}
583+
try {
584+
std::string negative_prompt = payload["negative_prompt"];
585+
params->negative_prompt = negative_prompt;
586+
} catch (...) {
587+
}
588+
try {
589+
int clip_skip = payload["clip_skip"];
590+
params->clip_skip = clip_skip;
591+
} catch (...) {
592+
}
593+
try {
594+
float cfg_scale = payload["cfg_scale"];
595+
params->cfg_scale = cfg_scale;
596+
} catch (...) {
597+
}
598+
try {
599+
float guidance = payload["guidance"];
600+
params->guidance = guidance;
601+
} catch (...) {
602+
}
603+
try {
604+
int width = payload["width"];
605+
params->width = width;
606+
} catch (...) {
607+
}
608+
try {
609+
int height = payload["height"];
610+
params->height = height;
611+
} catch (...) {
612+
}
613+
try {
614+
std::string sample_method = payload["sample_method"];
615+
// TODO map to enum value
616+
LOG_WARN("sample_method is not supported yet\n");
617+
} catch (...) {
618+
}
619+
try {
620+
int sample_steps = payload["sample_steps"];
621+
params->sample_steps = sample_steps;
622+
} catch (...) {
623+
}
624+
try {
625+
int64_t seed = payload["seed"];
626+
params->seed = seed;
627+
} catch (...) {
628+
}
629+
try {
630+
int batch_count = payload["batch_count"];
631+
params->batch_count = batch_count;
632+
} catch (...) {
633+
}
634+
635+
try {
636+
std::string control_cond = payload["control_cond"];
637+
// TODO map to enum value
638+
LOG_WARN("control_cond is not supported yet\n");
639+
} catch (...) {
640+
}
641+
try {
642+
float control_strength = payload["control_strength"];
643+
} catch (...) {
644+
}
645+
try {
646+
float style_strength = payload["style_strength"];
647+
} catch (...) {
648+
}
649+
try {
650+
bool normalize_input = payload["normalize_input"];
651+
params->normalize_input = normalize_input;
652+
} catch (...) {
653+
}
654+
try {
655+
std::string input_id_images_path = payload["input_id_images_path"];
656+
// TODO replace with b64 image maybe?
657+
} catch (...) {
658+
}
659+
}
572660

573661
int main(int argc, const char* argv[]) {
574662
SDParams params;
575663

576664
parse_args(argc, argv, params);
577665

578-
579666
sd_set_log_callback(sd_log_cb, (void*)&params);
580667

581668
if (params.verbose) {
582669
print_params(params);
583670
printf("%s", sd_get_system_info());
584671
}
585672

673+
bool vae_decode_only = true;
586674

587-
bool vae_decode_only = true;
588-
589675
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
590676
params.clip_l_path.c_str(),
591677
params.t5xxl_path.c_str(),
@@ -614,33 +700,48 @@ int main(int argc, const char* argv[]) {
614700

615701
int n_prompts = 0;
616702

617-
const auto txt2imgRequest = [&sd_ctx, &params, &n_prompts](const httplib::Request & req, httplib::Response & res) {
618-
//TODO: proper payloads
619-
std::string prompt = req.body;
620-
if(!prompt.empty()){
621-
params.prompt = prompt;
622-
}else{
623-
params.seed+=1;
703+
const auto txt2imgRequest = [&sd_ctx, &params, &n_prompts](const httplib::Request& req, httplib::Response& res) {
704+
LOG_INFO("raw body is: %s\n", req.body.c_str());
705+
// parse req.body as json using jsoncpp
706+
using json = nlohmann::json;
707+
708+
try {
709+
std::string json_str = req.body;
710+
parseJsonPrompt(json_str, &params);
711+
} catch (json::parse_error& e) {
712+
// assume the request is just a prompt
713+
LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());
714+
std::string prompt = req.body;
715+
if (!prompt.empty()) {
716+
params.prompt = prompt;
717+
} else {
718+
params.seed += 1;
719+
}
720+
} catch (...) {
721+
// Handle any other type of exception
722+
LOG_ERROR("An unexpected error occurred\n");
624723
}
724+
LOG_INFO("prompt is: %s\n", params.prompt.c_str());
725+
625726
{
626727
sd_image_t* results;
627728
results = txt2img(sd_ctx,
628-
params.prompt.c_str(),
629-
params.negative_prompt.c_str(),
630-
params.clip_skip,
631-
params.cfg_scale,
632-
params.guidance,
633-
params.width,
634-
params.height,
635-
params.sample_method,
636-
params.sample_steps,
637-
params.seed,
638-
params.batch_count,
639-
NULL,
640-
1,
641-
params.style_ratio,
642-
params.normalize_input,
643-
"");
729+
params.prompt.c_str(),
730+
params.negative_prompt.c_str(),
731+
params.clip_skip,
732+
params.cfg_scale,
733+
params.guidance,
734+
params.width,
735+
params.height,
736+
params.sample_method,
737+
params.sample_steps,
738+
params.seed,
739+
params.batch_count,
740+
NULL,
741+
1,
742+
params.style_ratio,
743+
params.normalize_input,
744+
"");
644745

645746
if (results == NULL) {
646747
printf("generate failed\n");
@@ -650,52 +751,67 @@ int main(int argc, const char* argv[]) {
650751

651752
size_t last = params.output_path.find_last_of(".");
652753
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
754+
json images_json = json::array();
653755
for (int i = 0; i < params.batch_count; i++) {
654756
if (results[i].data == NULL) {
655757
continue;
656758
}
657-
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1 + n_prompts*params.batch_count) + ".png" : dummy_name + ".png";
759+
// TODO allow disable save to disk
760+
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1 + n_prompts * params.batch_count) + ".png" : dummy_name + ".png";
658761
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
659-
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
762+
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
660763
printf("save result image to '%s'\n", final_image_path.c_str());
661-
// Todo: return base64 encoded image via websocket?
764+
// Todo: return base64 encoded image via httplib::Response& res
765+
766+
int len;
767+
unsigned char* png = stbi_write_png_to_mem((const unsigned char*)results[i].data, 0, results[i].width, results[i].height, results[i].channel, &len, NULL);
768+
769+
std::string data_str(png, png + len);
770+
std::string encoded_img = base64_encode(data_str);
771+
772+
images_json.push_back({{"width", results[i].width},
773+
{"height", results[i].height},
774+
{"channel", results[i].channel},
775+
{"data", encoded_img},
776+
{"encoding", "png"}});
777+
662778
free(results[i].data);
663779
results[i].data = NULL;
664780
}
665781
free(results);
666782
n_prompts++;
783+
res.set_content(images_json.dump(), "application/json");
667784
}
668785
return 0;
669786
};
670787

671-
672788
std::unique_ptr<httplib::Server> svr;
673789
svr.reset(new httplib::Server());
674790
svr->set_default_headers({{"Server", "sd.cpp"}});
675791
// CORS preflight
676-
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
792+
svr->Options(R"(.*)", [](const httplib::Request&, httplib::Response& res) {
677793
// Access-Control-Allow-Origin is already set by middleware
678794
res.set_header("Access-Control-Allow-Credentials", "true");
679-
res.set_header("Access-Control-Allow-Methods", "POST");
680-
res.set_header("Access-Control-Allow-Headers", "*");
681-
return res.set_content("", "text/html"); // blank response, no data
795+
res.set_header("Access-Control-Allow-Methods", "POST");
796+
res.set_header("Access-Control-Allow-Headers", "*");
797+
return res.set_content("", "text/html"); // blank response, no data
682798
});
683799
svr->set_logger(log_server_request);
684800

685801
svr->Post("/txt2img", txt2imgRequest);
686802

687-
688803
// bind HTTP listen port, run the HTTP server in a thread
689804
if (!svr->bind_to_port(params.host, params.port)) {
690-
//TODO: Error message
805+
// TODO: Error message
691806
return 1;
692-
}
807+
}
693808
std::thread t([&]() { svr->listen_after_bind(); });
694809
svr->wait_until_ready();
695810

696-
printf("Server listening at %s:%d\n",params.host.c_str(),params.port);
811+
printf("Server listening at %s:%d\n", params.host.c_str(), params.port);
697812

698-
while(1);
813+
while (1)
814+
;
699815

700816
free_sd_ctx(sd_ctx);
701817

0 commit comments

Comments
 (0)