|
2 | 2 |
|
3 | 3 | import org.beehive.gpullama3.aot.AOT; |
4 | 4 | import org.beehive.gpullama3.auxiliary.LastRunMetrics; |
5 | | -import org.beehive.gpullama3.core.model.tensor.FloatTensor; |
6 | | -import org.beehive.gpullama3.inference.sampler.CategoricalSampler; |
7 | 5 | import org.beehive.gpullama3.inference.sampler.Sampler; |
8 | | -import org.beehive.gpullama3.inference.sampler.ToppSampler; |
9 | 6 | import org.beehive.gpullama3.model.Model; |
10 | 7 | import org.beehive.gpullama3.model.loader.ModelLoader; |
11 | | -import org.beehive.gpullama3.tornadovm.FloatArrayUtils; |
12 | | -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
13 | 8 |
|
14 | 9 | import java.io.IOException; |
15 | | -import java.util.random.RandomGenerator; |
16 | | -import java.util.random.RandomGeneratorFactory; |
17 | 10 |
|
18 | 11 | import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler; |
| 12 | +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; |
| 13 | + |
19 | 14 | public class LlamaApp { |
20 | 15 | // Configuration flags for hardware acceleration and optimizations |
21 | 16 | public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration |
22 | | - public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation |
23 | 17 | public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode |
24 | 18 |
|
25 | | - |
26 | | - /** |
27 | | - * Loads the language model based on the given options. |
28 | | - * <p> |
29 | | - * If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader. |
30 | | - * </p> |
31 | | - * |
32 | | - * @param options |
33 | | - * the parsed CLI options containing model path and max token limit |
34 | | - * @return the loaded {@link Model} instance |
35 | | - * @throws IOException |
36 | | - * if the model fails to load |
37 | | - * @throws IllegalStateException |
38 | | - * if AOT loading is enabled but the preloaded model is unavailable |
39 | | - */ |
40 | | - private static Model loadModel(Options options) throws IOException { |
41 | | - if (USE_AOT) { |
42 | | - Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); |
43 | | - if (model == null) { |
44 | | - throw new IllegalStateException("Failed to load precompiled AOT model."); |
45 | | - } |
46 | | - return model; |
47 | | - } |
48 | | - return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm()); |
49 | | - } |
50 | | - |
51 | | - private static Sampler createSampler(Model model, Options options) { |
52 | | - return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); |
53 | | - } |
54 | | - |
55 | 19 | private static void runSingleInstruction(Model model, Sampler sampler, Options options) { |
56 | 20 | String response = model.runInstructOnce(sampler, options); |
57 | 21 | System.out.println(response); |
|
0 commit comments