From 225bf6be9399be6dc05c15aaed8c62b229b60176 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Mon, 25 Nov 2024 10:04:17 -0500 Subject: [PATCH] Remove binary state from high-level API and use Jinja templates (#3147) Signed-off-by: Jared Van Bortel Signed-off-by: Adam Treat Co-authored-by: Adam Treat --- .gitmodules | 3 + .../include/gpt4all-backend/llmodel.h | 68 +- .../include/gpt4all-backend/llmodel_c.h | 44 +- gpt4all-backend/src/llamamodel.cpp | 98 +- gpt4all-backend/src/llamamodel_impl.h | 19 +- gpt4all-backend/src/llmodel.cpp | 6 + gpt4all-backend/src/llmodel_c.cpp | 87 +- gpt4all-backend/src/llmodel_shared.cpp | 320 ++--- .../docs/gpt4all_desktop/chat_templates.md | 206 ++++ gpt4all-bindings/python/gpt4all/_pyllmodel.py | 183 ++- gpt4all-bindings/python/gpt4all/gpt4all.py | 307 ++--- gpt4all-bindings/python/mkdocs.yml | 1 + gpt4all-bindings/python/setup.py | 3 +- gpt4all-chat/CMakeLists.txt | 6 +- gpt4all-chat/deps/CMakeLists.txt | 2 + gpt4all-chat/deps/Jinja2Cpp | 1 + gpt4all-chat/icons/edit.svg | 4 +- gpt4all-chat/metadata/models3.json | 66 +- gpt4all-chat/qml/ApplicationSettings.qml | 19 +- gpt4all-chat/qml/ChatItemView.qml | 261 ++-- gpt4all-chat/qml/ChatMessageButton.qml | 20 + gpt4all-chat/qml/ChatView.qml | 146 ++- gpt4all-chat/qml/ConfirmationDialog.qml | 59 + gpt4all-chat/qml/LocalDocsSettings.qml | 2 +- gpt4all-chat/qml/ModelSettings.qml | 234 +++- gpt4all-chat/qml/MySettingsButton.qml | 2 + gpt4all-chat/qml/MySettingsLabel.qml | 43 +- gpt4all-chat/qml/MySettingsTab.qml | 14 +- gpt4all-chat/qml/MyTextArea.qml | 15 +- gpt4all-chat/qml/MyToolButton.qml | 2 + gpt4all-chat/qml/SwitchModelDialog.qml | 46 - gpt4all-chat/qml/Theme.qml | 49 +- gpt4all-chat/src/chat.cpp | 104 +- gpt4all-chat/src/chat.h | 24 +- gpt4all-chat/src/chatapi.cpp | 173 ++- gpt4all-chat/src/chatapi.h | 78 +- gpt4all-chat/src/chatlistmodel.cpp | 18 +- gpt4all-chat/src/chatllm.cpp | 1091 ++++++++--------- gpt4all-chat/src/chatllm.h | 98 +- gpt4all-chat/src/chatmodel.h | 514 ++++++-- gpt4all-chat/src/jinja_helpers.cpp | 111 ++ gpt4all-chat/src/jinja_helpers.h | 116 ++ gpt4all-chat/src/jinja_helpers.inl | 17 + gpt4all-chat/src/main.cpp | 21 +- gpt4all-chat/src/modellist.cpp | 256 ++-- gpt4all-chat/src/modellist.h | 86 +- gpt4all-chat/src/mysettings.cpp | 209 +++- gpt4all-chat/src/mysettings.h | 67 +- gpt4all-chat/src/network.cpp | 18 +- gpt4all-chat/src/server.cpp | 214 ++-- gpt4all-chat/src/server.h | 2 +- gpt4all-chat/src/utils.h | 40 +- gpt4all-chat/src/utils.inl | 9 + gpt4all-chat/tests/python/test_server_api.py | 23 +- 54 files changed, 3412 insertions(+), 2213 deletions(-) create mode 100644 gpt4all-bindings/python/docs/gpt4all_desktop/chat_templates.md create mode 160000 gpt4all-chat/deps/Jinja2Cpp create mode 100644 gpt4all-chat/qml/ChatMessageButton.qml create mode 100644 gpt4all-chat/qml/ConfirmationDialog.qml delete mode 100644 gpt4all-chat/qml/SwitchModelDialog.qml create mode 100644 gpt4all-chat/src/jinja_helpers.cpp create mode 100644 gpt4all-chat/src/jinja_helpers.h create mode 100644 gpt4all-chat/src/jinja_helpers.inl create mode 100644 gpt4all-chat/src/utils.inl diff --git a/.gitmodules b/.gitmodules index 752b837abd51..c177a0e12f19 100644 --- a/.gitmodules +++ b/.gitmodules @@ -17,3 +17,6 @@ [submodule "gpt4all-chat/deps/QXlsx"] path = gpt4all-chat/deps/QXlsx url = https://github.com/nomic-ai/QXlsx.git +[submodule "gpt4all-chat/deps/Jinja2Cpp"] + path = gpt4all-chat/deps/Jinja2Cpp + url = https://github.com/nomic-ai/jinja2cpp.git diff --git a/gpt4all-backend/include/gpt4all-backend/llmodel.h b/gpt4all-backend/include/gpt4all-backend/llmodel.h index ed5c4878fc6c..8695a5b5d832 100644 --- a/gpt4all-backend/include/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/include/gpt4all-backend/llmodel.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +25,10 @@ using namespace std::string_literals; class LLModel { public: using Token = int32_t; + using PromptCallback = std::function batch, bool cached)>; + using ResponseCallback = std::function; + using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend); + using ProgressCallback = std::function; class BadArchError: public std::runtime_error { public: @@ -101,6 +106,7 @@ class LLModel { static int32_t maxContextLength(const std::string &modelPath); static int32_t layerCount(const std::string &modelPath); static bool isEmbeddingModel(const std::string &modelPath); + static auto chatTemplate(const char *modelPath) -> std::expected; static void setImplementationsSearchPath(const std::string &path); static const std::string &implementationsSearchPath(); static bool hasSupportedCPU(); @@ -124,7 +130,6 @@ class LLModel { }; struct PromptContext { - int32_t n_past = 0; // number of tokens in past conversation int32_t n_predict = 200; int32_t top_k = 40; float top_p = 0.9f; @@ -136,8 +141,6 @@ class LLModel { float contextErase = 0.5f; // percent of context to erase if we exceed the context window }; - using ProgressCallback = std::function; - explicit LLModel() {} virtual ~LLModel() {} @@ -154,16 +157,12 @@ class LLModel { // This method requires the model to return true from supportsCompletion otherwise it will throw // an error - virtual void prompt(const std::string &prompt, - const std::string &promptTemplate, - std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &ctx, - bool special = false, - std::optional fakeReply = {}); + virtual void prompt(std::string_view prompt, + const PromptCallback &promptCallback, + const ResponseCallback &responseCallback, + const PromptContext &ctx); - using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend); + virtual int32_t countPromptTokens(std::string_view prompt) const; virtual size_t embeddingSize() const { throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); @@ -209,23 +208,22 @@ class LLModel { void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; } virtual int32_t contextLength() const = 0; + virtual auto specialTokens() -> std::unordered_map const = 0; protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions - virtual std::vector tokenize(std::string_view str, bool special = false) = 0; + virtual std::vector tokenize(std::string_view str) const = 0; virtual bool isSpecialToken(Token id) const = 0; virtual std::string tokenToString(Token id) const = 0; - virtual void initSampler(PromptContext &ctx) = 0; + virtual void initSampler(const PromptContext &ctx) = 0; virtual Token sampleToken() const = 0; - virtual bool evalTokens(PromptContext &ctx, std::span tokens) const = 0; - virtual void shiftContext(PromptContext &promptCtx) = 0; + virtual bool evalTokens(int32_t nPast, std::span tokens) const = 0; + virtual void shiftContext(const PromptContext &promptCtx, int32_t *nPast) = 0; virtual int32_t inputLength() const = 0; - virtual void setTokenizeInputPosition(int32_t pos) = 0; - virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector &input) - -> std::vector::const_iterator = 0; - virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0; - virtual void appendInputToken(PromptContext &ctx, Token tok) = 0; + virtual int32_t computeModelInputPosition(std::span input) const = 0; + virtual void setModelInputPosition(int32_t pos) = 0; + virtual void appendInputToken(Token tok) = 0; virtual std::span inputTokens() const = 0; virtual const std::vector &endTokens() const = 0; virtual bool shouldAddBOS() const = 0; @@ -242,6 +240,12 @@ class LLModel { return -1; } + virtual auto chatTemplate(const char *modelPath) const -> std::expected + { + (void)modelPath; + return std::unexpected("not implemented"); + } + const Implementation *m_implementation = nullptr; ProgressCallback m_progressCallback; @@ -253,19 +257,15 @@ class LLModel { return true; } - bool decodePrompt(std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx, - std::vector embd_inp, - bool isResponse = false, - bool alwaysDecode = false); - void generateResponse(std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx); - -protected: - Token m_tokenize_last_token = -1; // not serialized + // prefill context with prompt + auto decodePrompt(const PromptCallback &promptCallback, + const PromptContext &promptCtx, + std::vector embd_inp) + -> std::optional; + // generate a response + void generateResponse(const ResponseCallback &responseCallback, + const PromptContext &promptCtx, + int32_t nPast); friend class LLMImplementation; }; diff --git a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h index e9497d0fafd9..271475bae480 100644 --- a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h @@ -35,16 +35,15 @@ typedef int32_t token_t; * behavior. */ struct llmodel_prompt_context { - int32_t n_past; // number of tokens in past conversation int32_t n_predict; // number of tokens to predict int32_t top_k; // top k logits to sample from - float top_p; // nucleus sampling probability threshold - float min_p; // Min P sampling - float temp; // temperature to adjust model's output distribution + float top_p; // nucleus sampling probability threshold + float min_p; // Min P sampling + float temp; // temperature to adjust model's output distribution int32_t n_batch; // number of predictions to generate in parallel - float repeat_penalty; // penalty factor for repeated tokens + float repeat_penalty; // penalty factor for repeated tokens int32_t repeat_last_n; // last n tokens to penalize - float context_erase; // percent of context to erase if we exceed the context window + float context_erase; // percent of context to erase if we exceed the context window }; struct llmodel_gpu_device { @@ -63,10 +62,12 @@ typedef struct llmodel_gpu_device llmodel_gpu_device; /** * Callback type for prompt processing. - * @param token_id The token id of the prompt. + * @param token_ids An array of token ids of the prompt. + * @param n_token_ids The number of tokens in the array. + * @param cached Whether the tokens were already in cache. * @return a bool indicating whether the model should keep processing. */ -typedef bool (*llmodel_prompt_callback)(int32_t token_id); +typedef bool (*llmodel_prompt_callback)(const token_t *token_ids, size_t n_token_ids, bool cached); /** * Callback type for response. @@ -74,7 +75,7 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id); * @param response The response string. NOTE: a token_id of -1 indicates the string is an error string. * @return a bool indicating whether the model should keep generating. */ -typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response); +typedef bool (*llmodel_response_callback)(token_t token_id, const char *response); /** * Embedding cancellation callback for use with llmodel_embed. @@ -85,6 +86,8 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response */ typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend); +typedef void (*llmodel_special_token_callback)(const char *name, const char *token); + /** * Create a llmodel instance. * Recognises correct model type from file at model_path @@ -183,22 +186,17 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6 * Generate a response using the model. * @param model A pointer to the llmodel_model instance. * @param prompt A string representing the input prompt. - * @param prompt_template A string representing the input prompt template. * @param prompt_callback A callback function for handling the processing of prompt. * @param response_callback A callback function for handling the generated response. - * @param allow_context_shift Whether to allow shifting of context to make room for more input. - * @param special True if special tokens in the prompt should be processed, false otherwise. - * @param fake_reply A string to insert into context as the model's reply, or NULL to generate one. * @param ctx A pointer to the llmodel_prompt_context structure. + * @param error A pointer to a string; will only be set on error. */ -void llmodel_prompt(llmodel_model model, const char *prompt, - const char *prompt_template, - llmodel_prompt_callback prompt_callback, - llmodel_response_callback response_callback, - bool allow_context_shift, - llmodel_prompt_context *ctx, - bool special, - const char *fake_reply); +bool llmodel_prompt(llmodel_model model, + const char *prompt, + llmodel_prompt_callback prompt_callback, + llmodel_response_callback response_callback, + llmodel_prompt_context *ctx, + const char **error); /** * Generate an embedding using the model. @@ -310,6 +308,10 @@ const char *llmodel_model_backend_name(llmodel_model model); */ const char *llmodel_model_gpu_device_name(llmodel_model model); +int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error); + +void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback); + #ifdef __cplusplus } #endif diff --git a/gpt4all-backend/src/llamamodel.cpp b/gpt4all-backend/src/llamamodel.cpp index 453dbd972bd8..af03af81e26c 100644 --- a/gpt4all-backend/src/llamamodel.cpp +++ b/gpt4all-backend/src/llamamodel.cpp @@ -202,7 +202,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const if (keyidx != -1) { value = gguf_get_val_u32(ctx, keyidx); } else { - std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n"; + std::cerr << __func__ << ": " << key << " not found in " << modelPath << "\n"; } } @@ -518,18 +518,13 @@ size_t LLamaModel::restoreState(std::span state, std::span LLamaModel::tokenize(std::string_view str, bool special) +std::vector LLamaModel::tokenize(std::string_view str) const { - bool atStart = m_tokenize_last_token == -1; - bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token); std::vector fres(str.length() + 4); - int32_t fres_len = llama_tokenize_gpt4all( - d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart, - /*parse_special*/ special, /*insert_space*/ insertSpace + int32_t fres_len = llama_tokenize( + d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ true, /*parse_special*/ true ); fres.resize(fres_len); - if (fres_len) - m_tokenize_last_token = fres.back(); return fres; } @@ -555,7 +550,7 @@ std::string LLamaModel::tokenToString(Token id) const return std::string(result.data(), result.size()); } -void LLamaModel::initSampler(PromptContext &promptCtx) +void LLamaModel::initSampler(const PromptContext &promptCtx) { auto *model = d_ptr->model; auto *chain = d_ptr->sampler_chain; @@ -601,9 +596,11 @@ LLModel::Token LLamaModel::sampleToken() const return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1); } -bool LLamaModel::evalTokens(PromptContext &ctx, std::span tokens) const +bool LLamaModel::evalTokens(int32_t nPast, std::span tokens) const { - llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1); + assert(!tokens.empty()); + + llama_kv_cache_seq_rm(d_ptr->ctx, 0, nPast, -1); llama_batch batch = llama_batch_init(tokens.size(), 0, 1); @@ -611,7 +608,7 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span tokens) c for (int32_t i = 0; i < batch.n_tokens; i++) { batch.token [i] = tokens[i]; - batch.pos [i] = ctx.n_past + i; + batch.pos [i] = nPast + i; batch.n_seq_id[i] = 1; batch.seq_id [i][0] = 0; batch.logits [i] = false; @@ -625,13 +622,13 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span tokens) c return res == 0; } -void LLamaModel::shiftContext(PromptContext &promptCtx) +void LLamaModel::shiftContext(const PromptContext &promptCtx, int32_t *nPast) { // infinite text generation via context shifting // erase up to n_ctx*contextErase tokens int n_keep = shouldAddBOS(); - int n_past = promptCtx.n_past; + int n_past = *nPast; int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase)); assert(n_discard > 0); @@ -647,7 +644,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx) auto &inp = d_ptr->inputTokens; inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard); - promptCtx.n_past = inp.size(); + *nPast = inp.size(); } int32_t LLamaModel::contextLength() const @@ -655,39 +652,37 @@ int32_t LLamaModel::contextLength() const return llama_n_ctx(d_ptr->ctx); } -int32_t LLamaModel::inputLength() const +auto LLamaModel::specialTokens() -> std::unordered_map const { - return d_ptr->inputTokens.size(); + if (!d_ptr->model) + throw std::logic_error("model not loaded"); + + std::unordered_map tokens; + if (auto id = llama_token_bos(d_ptr->model); id != LLAMA_TOKEN_NULL) + tokens.emplace("bos_token", tokenToString(id)); + if (auto id = llama_token_eos(d_ptr->model); id != LLAMA_TOKEN_NULL) + tokens.emplace("eos_token", tokenToString(id)); + return tokens; } -void LLamaModel::setTokenizeInputPosition(int32_t pos) +int32_t LLamaModel::inputLength() const { - assert(pos >= 0); - m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized + return d_ptr->inputTokens.size(); } -auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector &input) - -> std::vector::const_iterator +int32_t LLamaModel::computeModelInputPosition(std::span input) const { - assert(ctx.n_past >= 0); - auto pos = size_t(ctx.n_past); - if (pos > d_ptr->inputTokens.size()) { - std::ostringstream ss; - ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size(); - throw std::out_of_range(ss.str()); - } - // find common prefix auto cacheIt = d_ptr->inputTokens.begin(); auto inputIt = input.begin(); while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) { - ++cacheIt; ++inputIt; ++pos; + ++cacheIt; ++inputIt; } // tell the caller to ignore the tokens between [begin, inputIt) - return inputIt; + return inputIt - input.begin(); } -void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos) +void LLamaModel::setModelInputPosition(int32_t pos) { auto &inp = d_ptr->inputTokens; assert(pos >= 0); @@ -695,13 +690,11 @@ void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos) // truncate token cache to end at the new n_past if (pos < inp.size()) inp.resize(pos); - ctx.n_past = pos; } -void LLamaModel::appendInputToken(PromptContext &ctx, Token tok) +void LLamaModel::appendInputToken(Token tok) { d_ptr->inputTokens.push_back(tok); - ctx.n_past += 1; } auto LLamaModel::inputTokens() const -> std::span @@ -729,6 +722,37 @@ int32_t LLamaModel::layerCount(std::string const &modelPath) const return get_arch_key_u32(modelPath, "block_count"); } +// TODO(jared): reduce redundant code and operations by combining all metadata getters for unloaded +// models into a class that keeps the model file open +auto LLamaModel::chatTemplate(const char *modelPath) const -> std::expected +{ + auto *ctx = load_gguf(modelPath); + if (!ctx) + return std::unexpected("failed to open model file"); + + std::expected result; + enum gguf_type ktype; + const int kid = gguf_find_key(ctx, "tokenizer.chat_template"); + if (kid == -1) { + result = std::unexpected("key not found"); + goto cleanup; + } + + ktype = gguf_get_kv_type(ctx, kid); + if (ktype != GGUF_TYPE_STRING) { + result = std::unexpected( + "expected key type STRING (" + std::to_string(GGUF_TYPE_STRING) + "), got " + std::to_string(ktype) + ); + goto cleanup; + } + + result = gguf_get_val_str(ctx, kid); + +cleanup: + gguf_free(ctx); + return result; +} + #ifdef GGML_USE_VULKAN static const char *getVulkanVendorName(uint32_t vendorID) { diff --git a/gpt4all-backend/src/llamamodel_impl.h b/gpt4all-backend/src/llamamodel_impl.h index d6290a061316..7d018ddb1083 100644 --- a/gpt4all-backend/src/llamamodel_impl.h +++ b/gpt4all-backend/src/llamamodel_impl.h @@ -11,6 +11,7 @@ #include #include #include +#include struct LLamaPrivate; struct EmbModelSpec; @@ -49,26 +50,26 @@ class LLamaModel : public LLModel { size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; int32_t contextLength() const override; + auto specialTokens() -> std::unordered_map const override; protected: - std::vector tokenize(std::string_view str, bool special) override; + std::vector tokenize(std::string_view str) const override; bool isSpecialToken(Token id) const override; std::string tokenToString(Token id) const override; - void initSampler(PromptContext &ctx) override; + void initSampler(const PromptContext &ctx) override; Token sampleToken() const override; - bool evalTokens(PromptContext &ctx, std::span tokens) const override; - void shiftContext(PromptContext &promptCtx) override; + bool evalTokens(int32_t nPast, std::span tokens) const override; + void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override; int32_t inputLength() const override; - void setTokenizeInputPosition(int32_t pos) override; - auto computeModelInputPosition(PromptContext &ctx, const std::vector &input) - -> std::vector::const_iterator override; - void setModelInputPosition(PromptContext &ctx, int32_t pos) override; - void appendInputToken(PromptContext &ctx, Token tok) override; + int32_t computeModelInputPosition(std::span input) const override; + void setModelInputPosition(int32_t pos) override; + void appendInputToken(Token tok) override; std::span inputTokens() const override; const std::vector &endTokens() const override; bool shouldAddBOS() const override; int32_t maxContextLength(std::string const &modelPath) const override; int32_t layerCount(std::string const &modelPath) const override; + auto chatTemplate(const char *modelPath) const -> std::expected override; void embedInternal(const std::vector &texts, float *embeddings, std::string prefix, int dimensionality, size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb, diff --git a/gpt4all-backend/src/llmodel.cpp b/gpt4all-backend/src/llmodel.cpp index 1acf0642ef2a..ee247f35c9d2 100644 --- a/gpt4all-backend/src/llmodel.cpp +++ b/gpt4all-backend/src/llmodel.cpp @@ -326,6 +326,12 @@ bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath) return llama && llama->isEmbeddingModel(modelPath); } +auto LLModel::Implementation::chatTemplate(const char *modelPath) -> std::expected +{ + auto *llama = constructGlobalLlama(); + return llama ? llama->chatTemplate(modelPath) : std::unexpected("backend not available"); +} + void LLModel::Implementation::setImplementationsSearchPath(const std::string& path) { s_implementations_search_path = path; diff --git a/gpt4all-backend/src/llmodel_c.cpp b/gpt4all-backend/src/llmodel_c.cpp index 068052665f39..a8c5554da0d3 100644 --- a/gpt4all-backend/src/llmodel_c.cpp +++ b/gpt4all-backend/src/llmodel_c.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -22,7 +21,6 @@ static_assert(sizeof(token_t) == sizeof(LLModel::Token)); struct LLModelWrapper { LLModel *llModel = nullptr; - LLModel::PromptContext promptContext; ~LLModelWrapper() { delete llModel; } }; @@ -126,49 +124,44 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6 return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)}); } -void llmodel_prompt(llmodel_model model, const char *prompt, - const char *prompt_template, - llmodel_prompt_callback prompt_callback, - llmodel_response_callback response_callback, - bool allow_context_shift, - llmodel_prompt_context *ctx, - bool special, - const char *fake_reply) +bool llmodel_prompt(llmodel_model model, + const char *prompt, + llmodel_prompt_callback prompt_callback, + llmodel_response_callback response_callback, + llmodel_prompt_context *ctx, + const char **error) { auto *wrapper = static_cast(model); - auto response_func = [response_callback](int32_t token_id, const std::string &response) { - return response_callback(token_id, response.c_str()); + // Copy the C prompt context + LLModel::PromptContext promptContext { + .n_predict = ctx->n_predict, + .top_k = ctx->top_k, + .top_p = ctx->top_p, + .min_p = ctx->min_p, + .temp = ctx->temp, + .n_batch = ctx->n_batch, + .repeat_penalty = ctx->repeat_penalty, + .repeat_last_n = ctx->repeat_last_n, + .contextErase = ctx->context_erase, }; - // Copy the C prompt context - wrapper->promptContext.n_past = ctx->n_past; - wrapper->promptContext.n_predict = ctx->n_predict; - wrapper->promptContext.top_k = ctx->top_k; - wrapper->promptContext.top_p = ctx->top_p; - wrapper->promptContext.min_p = ctx->min_p; - wrapper->promptContext.temp = ctx->temp; - wrapper->promptContext.n_batch = ctx->n_batch; - wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; - wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; - wrapper->promptContext.contextErase = ctx->context_erase; + auto prompt_func = [prompt_callback](std::span token_ids, bool cached) { + return prompt_callback(token_ids.data(), token_ids.size(), cached); + }; + auto response_func = [response_callback](LLModel::Token token_id, std::string_view piece) { + return response_callback(token_id, piece.data()); + }; // Call the C++ prompt method - wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift, - wrapper->promptContext, special, - fake_reply ? std::make_optional(fake_reply) : std::nullopt); - - // Update the rest of the C prompt context - ctx->n_past = wrapper->promptContext.n_past; - ctx->n_predict = wrapper->promptContext.n_predict; - ctx->top_k = wrapper->promptContext.top_k; - ctx->top_p = wrapper->promptContext.top_p; - ctx->min_p = wrapper->promptContext.min_p; - ctx->temp = wrapper->promptContext.temp; - ctx->n_batch = wrapper->promptContext.n_batch; - ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; - ctx->repeat_last_n = wrapper->promptContext.repeat_last_n; - ctx->context_erase = wrapper->promptContext.contextErase; + try { + wrapper->llModel->prompt(prompt, prompt_func, response_func, promptContext); + } catch (std::exception const &e) { + llmodel_set_error(error, e.what()); + return false; + } + + return true; } float *llmodel_embed( @@ -307,3 +300,21 @@ const char *llmodel_model_gpu_device_name(llmodel_model model) const auto *wrapper = static_cast(model); return wrapper->llModel->gpuDeviceName(); } + +int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error) +{ + auto *wrapper = static_cast(model); + try { + return wrapper->llModel->countPromptTokens(prompt); + } catch (const std::exception& e) { + llmodel_set_error(error, e.what()); + return -1; + } +} + +void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback) +{ + auto *wrapper = static_cast(model); + for (auto &[name, token] : wrapper->llModel->specialTokens()) + callback(name.c_str(), token.c_str()); +} diff --git a/gpt4all-backend/src/llmodel_shared.cpp b/gpt4all-backend/src/llmodel_shared.cpp index ef046433a217..99782f44ca55 100644 --- a/gpt4all-backend/src/llmodel_shared.cpp +++ b/gpt4all-backend/src/llmodel_shared.cpp @@ -4,232 +4,120 @@ #include #include #include -#include #include #include #include -#include -#include +#include #include #include #include #include namespace ranges = std::ranges; +namespace views = std::ranges::views; -static bool parsePromptTemplate(const std::string &tmpl, std::vector &placeholders, std::string &err) -{ - static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))"); +void LLModel::prompt( + std::string_view prompt, + const PromptCallback &promptCallback, + const ResponseCallback &responseCallback, + const PromptContext &promptCtx +) { + if (!isModelLoaded()) + throw std::invalid_argument("Attempted to prompt an unloaded model."); + if (!supportsCompletion()) + throw std::invalid_argument("Not a text completion model."); + if (!promptCtx.n_batch) + throw std::invalid_argument("Batch size cannot be zero."); + if (!promptCtx.n_predict) + return; // nothing requested - auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex); - placeholders.clear(); - placeholders.insert(placeholders.end(), it, std::sregex_iterator()); + auto embd_inp = tokenize(prompt); + if (embd_inp.empty()) + throw std::invalid_argument("Prompt tokenized to zero tokens."); - if (placeholders.size() > 2) { - err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size()); - return false; - } - if (placeholders.size() >= 1 && placeholders[0].str() != "%1") { - err = "ERROR: first placeholder must be %1, got " + placeholders[0].str(); - return false; - } - if (placeholders.size() >= 2 && placeholders[1].str() != "%2") { - err = "ERROR: second placeholder must be %2, got " + placeholders[1].str(); - return false; - } - return true; + if (auto res = decodePrompt(promptCallback, promptCtx, std::move(embd_inp))) + generateResponse(responseCallback, promptCtx, /*n_past*/ *res); } -void LLModel::prompt(const std::string &prompt, - const std::string &promptTemplate, - std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx, - bool special, - std::optional fakeReply) +int32_t LLModel::countPromptTokens(std::string_view prompt) const { - if (!isModelLoaded()) { - std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n"; - return; - } - - if (!supportsCompletion()) { - std::string errorMessage = "ERROR: this model does not support text completion or chat!"; - responseCallback(-1, errorMessage); - std::cerr << implementation().modelType() << " " << errorMessage << "\n"; - return; - } + if (!isModelLoaded()) + throw std::invalid_argument("Attempted to tokenize with an unloaded model."); + return int32_t(tokenize(prompt).size()); +} - // sanity checks - if (promptCtx.n_past > contextLength()) { - std::ostringstream ss; - ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength(); - throw std::out_of_range(ss.str()); - } - if (promptCtx.n_past > inputLength()) { - std::ostringstream ss; - ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength(); - throw std::out_of_range(ss.str()); - } +auto LLModel::decodePrompt( + const PromptCallback &promptCallback, + const PromptContext &promptCtx, + std::vector embd_inp +) -> std::optional +{ + assert(!embd_inp.empty()); - promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH); + int32_t nCtx = contextLength(); + int32_t n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH); - // parse the prompt template - std::vector placeholders; - { - std::string err; - if (!parsePromptTemplate(promptTemplate, placeholders, err)) { - responseCallback(-1, err); - std::cerr << err << "\n"; - return; - } - } + // Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the + // requested n_past. + // This is used to skip unnecessary work when the prompt shares a common prefix with the previous result. + int32_t nPast = computeModelInputPosition(embd_inp); - setTokenizeInputPosition(promptCtx.n_past); - - // tokenize the user prompt - std::vector embd_inp; - if (placeholders.empty()) { - // this is unusual, but well-defined - std::cerr << __func__ << ": prompt template has no placeholder\n"; - embd_inp = tokenize(promptTemplate, true); - } else { - // template: beginning of user prompt - const auto &phUser = placeholders[0]; - std::string userPrefix(phUser.prefix()); - if (!userPrefix.empty()) - embd_inp = tokenize(userPrefix, true); - - // user input (shouldn't have special token processing) - auto tokens = tokenize(prompt, special); - embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end()); - - // template: end of user prompt + start of assistant prompt - size_t start = phUser.position() + phUser.length(); - size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length(); - auto userToAsst = promptTemplate.substr(start, end - start); - if (!userToAsst.empty()) { - tokens = tokenize(userToAsst, true); - embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end()); - } - } + // always decode up to a full batch before generating, even if cached + nPast -= std::min(n_batch, nPast); - // decode the user prompt - if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false, - /*alwaysDecode*/ true)) - return; // error - - // decode the assistant's reply, either generated or spoofed - if (!fakeReply) { - generateResponse(responseCallback, allowContextShift, promptCtx); - } else { - embd_inp = tokenize(*fakeReply, false); - if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true)) - return; // error - } + // TODO(jared): generalize this to find the smallest new_embd_inp.size() - nPast given the cache + if (!nPast && int32_t(embd_inp.size()) > nCtx) { + // no cache hit -> shift the input before even processing - // decode the rest of the prompt template - // template: end of assistant prompt - std::string asstSuffix; - if (placeholders.size() >= 2) { - size_t start = placeholders[1].position() + placeholders[1].length(); - asstSuffix = promptTemplate.substr(start); - } else { - asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca - } - if (!asstSuffix.empty()) { - embd_inp = tokenize(asstSuffix, true); - decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp); - } -} + int32_t nKeep = shouldAddBOS(); + auto newLength = int32_t(nCtx * (1.f - promptCtx.contextErase)); + int32_t nDiscard = int32_t(embd_inp.size()) - std::max(1, std::min(nCtx, newLength)); -// returns false on error -bool LLModel::decodePrompt(std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx, - std::vector embd_inp, - bool isResponse, - bool alwaysDecode) { - if ((int) embd_inp.size() > contextLength() - 4) { - // FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for - // translation - responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter."); - std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() << - " tokens and the context window is " << contextLength() << "!\n"; - return false; - } + // execute the callback even for skipped tokens. this misrepresents the position of BOS but we don't care + auto discardedTokens = embd_inp | views::drop(nKeep) | views::take(nDiscard); + if (!promptCallback(discardedTokens, true)) + return std::nullopt; - // FIXME(jared): There are mitigations for this situation, such as making room before - // copying the prompt context, or restoring the KV cache when we restore the prompt - // context. - if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) { - std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size() - << ", n_ctx=" << contextLength() << "\n"; - return false; - } + // erase nDiscard tokens + embd_inp.erase(discardedTokens.begin(), discardedTokens.end()); + assert(int32_t(embd_inp.size()) <= nCtx); - // always decode something before generating, even if cached - if (alwaysDecode && embd_inp.empty()) { - auto cache = inputTokens(); - if (!promptCtx.n_past) - throw std::runtime_error("zero token prompt is not supported"); - assert(!cache.empty()); - embd_inp.push_back(cache.back()); - promptCtx.n_past--; + // check the cache again, just in case + nPast = computeModelInputPosition(embd_inp); + nPast -= std::min(n_batch, nPast); } - // Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the - // requested n_past. - // This is used to skip unnecessary work when the prompt shares a common prefix with the previous result. - auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp); - size_t start_offset = embd_inp_start - embd_inp.begin(); - - // always decode up to a full batch before generating, even if cached - if (alwaysDecode) - start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset)); - - setModelInputPosition(promptCtx, promptCtx.n_past + start_offset); + setModelInputPosition(nPast); // execute the callback even for skipped tokens - size_t i = 0; - for (; i < start_offset; i++) { - Token tok = embd_inp[i]; - bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok); - if (!res) - return false; - } + if (!promptCallback(embd_inp | views::take(nPast), true)) + return std::nullopt; // process the prompt in batches - while (i < embd_inp.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); - std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + for (int32_t i = nPast; i < embd_inp.size();) { + auto batch_end = std::min(i + n_batch, int32_t(embd_inp.size())); + std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); // Check if the context has run out... - if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) { - assert(allowContextShift); - shiftContext(promptCtx); - assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength()); + if (nPast + int32_t(batch.size()) > nCtx) { + shiftContext(promptCtx, &nPast); + assert(nPast + int32_t(batch.size()) <= nCtx); } - if (!evalTokens(promptCtx, batch)) { - std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n"; - return false; - } + // FIXME(Adam): We should find a way to bubble these strings to the UI level to allow for translation + if (!evalTokens(nPast, batch)) + throw std::runtime_error("An internal error was encountered during prompt processing."); - size_t tokens = batch_end - i; - for (size_t t = 0; t < tokens; ++t) { - Token tok = batch[t]; - appendInputToken(promptCtx, tok); - bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok); - if (!res) - return false; + for (auto &tok : batch) { + appendInputToken(tok); + nPast++; + if (!promptCallback({ &tok, 1 }, false)) + return std::nullopt; } i = batch_end; } - return true; + return nPast; } /* @@ -251,22 +139,16 @@ static std::string::size_type stringsOverlap(const std::string &s, const std::st return std::string::npos; } -void LLModel::generateResponse(std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx) { +void LLModel::generateResponse( + const ResponseCallback &responseCallback, + const PromptContext &promptCtx, + int32_t nPast +) { static const char *stopSequences[] { - "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context", + "### System", "### Instruction", "### Human", "### User", "### Response", "### Assistant", "### Context", + "<|im_start|>", "<|im_end|>", "<|endoftext|>", }; - // Don't even start if there is no room - if (!promptCtx.n_predict) - return; - if (!allowContextShift && promptCtx.n_past >= contextLength()) { - std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength() - << "\n"; - return; - } - initSampler(promptCtx); std::string cachedResponse; @@ -281,25 +163,20 @@ void LLModel::generateResponse(std::function cachedTokens.push_back(new_tok.value()); cachedResponse += new_piece; - auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool { + auto accept = [this, &promptCtx, &new_tok, &nPast] { // Shift context if out of space - if (promptCtx.n_past >= contextLength()) { - (void)allowContextShift; - assert(allowContextShift); - shiftContext(promptCtx); - assert(promptCtx.n_past < contextLength()); + if (nPast >= contextLength()) { + shiftContext(promptCtx, &nPast); + assert(nPast < contextLength()); } // Accept the token Token tok = std::exchange(new_tok, std::nullopt).value(); - if (!evalTokens(promptCtx, { &tok, 1 })) { - // TODO(jared): raise an exception - std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n"; - return false; - } + if (!evalTokens(nPast, { &tok, 1 })) + throw std::runtime_error("An internal error was encountered during response generation."); - appendInputToken(promptCtx, tok); - return true; + appendInputToken(tok); + nPast++; }; // Check for EOS @@ -336,13 +213,6 @@ void LLModel::generateResponse(std::function lengthLimit = cachedResponse.size() - new_piece.size(); } - // Optionally stop if the context will run out - if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) { - std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" - << contextLength() << "\n"; - stop = true; - } - // Empty the cache, up to the length limit std::string::size_type responseLength = 0; while (!cachedTokens.empty()) { @@ -359,8 +229,8 @@ void LLModel::generateResponse(std::function cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size()); // Accept the token, if needed (not cached) - if (cachedTokens.empty() && new_tok && !accept()) - return; + if (cachedTokens.empty() && new_tok) + accept(); // Send the token if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) { @@ -379,8 +249,8 @@ void LLModel::generateResponse(std::function assert(!cachedTokens.empty() && cachedTokens.back() == new_tok); if (stop) { cachedTokens.pop_back(); - } else if (!accept()) { - return; + } else { + accept(); } } } @@ -396,8 +266,6 @@ void LLModel::generateResponse(std::function auto discard_start = inp.end() - cachedTokens.size(); assert(std::equal(discard_start, inp.end(), cachedTokens.begin())); #endif - - promptCtx.n_past -= cachedTokens.size(); } void LLModel::embed( diff --git a/gpt4all-bindings/python/docs/gpt4all_desktop/chat_templates.md b/gpt4all-bindings/python/docs/gpt4all_desktop/chat_templates.md new file mode 100644 index 000000000000..5c15cf620c85 --- /dev/null +++ b/gpt4all-bindings/python/docs/gpt4all_desktop/chat_templates.md @@ -0,0 +1,206 @@ +## What are chat templates? +Natively, large language models only know how to complete plain text and do not know the difference between their input and their output. In order to support a chat with a person, LLMs are designed to use a template to convert the conversation to plain text using a specific format. + +For a given model, it is important to use an appropriate chat template, as each model is designed to work best with a specific format. The chat templates included with the built-in models should be sufficient for most purposes. + +There are two reasons you would want to alter the chat template: + +- You are sideloading a model and there is no chat template available, +- You would like to have greater control over the input to the LLM than a system message provides. + + +## What is a system message? +A system message is a message that controls the responses from the LLM in a way that affects the entire conversation. System messages can be short, such as "Speak like a pirate.", or they can be long and contain a lot of context for the LLM to keep in mind. + +Not all models are designed to use a system message, so they work with some models better than others. + + +## How do I customize the chat template or system message? +To customize the chat template or system message, go to Settings > Model. Make sure to select the correct model at the top. If you clone a model, you can use a different chat template or system message from the base model, enabling you to use different settings for each conversation. + +These settings take effect immediately. After changing them, you can click "Redo last response" in the chat view, and the response will take the new settings into account. + + +## Do I need to write a chat template? +You typically do not need to write your own chat template. The exception is models that are not in the official model list and do not come with a chat template built-in. These will show a "Clear" option above the chat template field in the Model Settings page instead of a "Reset" option. See the section on [finding] or [creating] a chat template. + +[finding]: #how-do-i-find-a-chat-template +[creating]: #advanced-how-do-chat-templates-work + + +## What changed in GPT4All v3.5? +GPT4All v3.5 overhauled the chat template system. There are three crucial differences: + +- The chat template now formats an entire conversation instead of a single pair of messages, +- The chat template now uses Jinja syntax instead of `%1` and `%2` placeholders, +- And the system message should no longer contain control tokens or trailing whitespace. + +If you are using any chat templates or system messages that had been added or altered from the default before upgrading to GPT4All v3.5 or newer, these will no longer work. See below for how to solve common errors you may see after upgrading. + + +## Error/Warning: System message is not plain text. +This is easy to fix. Go to the model's settings and look at the system prompt. There are three things to look for: + +- Control tokens such as `<|im_start|>`, `<|start_header_id|>`, or `<|system|>` +- A prefix such as `### System` or `SYSTEM:` +- Trailing whitespace, such as a space character or blank line. + +If you see any of these things, remove them. For example, this legacy system prompt: +``` +<|start_header_id|>system<|end_header_id|> +You are a helpful assistant.<|eot_id|> +``` + +Should become this: +``` +You are a helpful assistant. +``` + +If you do not see anything that needs to be changed, you can dismiss the error by making a minor modification to the message and then changing it back. + +If you see a warning, your system message does not appear to be plain text. If you believe this warning is incorrect, it can be safely ignored. If in doubt, ask on the [Discord]. + +[Discord]: https://discord.gg/mGZE39AS3e + + +## Error: Legacy system prompt needs to be updated in Settings. +This is the same as [above][above-1], but appears on the chat page. + +[above-1]: #errorwarning-system-message-is-not-plain-text + + +## Error/Warning: Chat template is not in Jinja format. +This is the result of attempting to use an old-style template (possibly from a previous version) in GPT4All 3.5+. + +Go to the Model Settings page and select the affected model. If you see a "Reset" button, and you have not intentionally modified the prompt template, you can click "Reset". Otherwise, this is what you can do: + +1. Back up your chat template by copying it safely to a text file and saving it. In the next step, it will be removed from GPT4All. +2. Click "Reset" or "Clear". +3. If you clicked "Clear", the chat template is now gone. Follow the steps to [find][finding] or [create][creating] a basic chat template for your model. +4. Customize the chat template to suit your needs. For help, read the section about [creating] a chat template. + + +## Error: Legacy prompt template needs to be updated in Settings. +This is the same as [above][above-2], but appears on the chat page. + +[above-2]: #errorwarning-chat-template-is-not-in-jinja-format + + +## The chat template has a syntax error. +If there is a syntax error while editing the chat template, the details will be displayed in an error message above the input box. This could be because the chat template is not actually in Jinja format (see [above][above-2]). + +Otherwise, you have either typed something correctly, or the model comes with a template that is incompatible with GPT4All. See [the below section][creating] on creating chat templates and make sure that everything is correct. When in doubt, ask on the [Discord]. + + +## Error: No chat template configured. +This may appear for models that are not from the official model list and do not include a chat template. Older versions of GPT4All picked a poor default in this case. You will get much better results if you follow the steps to [find][finding] or [create][creating] a chat template for your model. + + +## Error: The chat template cannot be blank. +If the button above the chat template on the Model Settings page says "Clear", see [above][above-3]. If you see "Reset", click that button to restore a reasonable default. Also see the section on [syntax errors][chat-syntax-error]. + +[above-3]: #error-no-chat-template-configured +[chat-syntax-error]: #the-chat-template-has-a-syntax-error + + +## How do I find a chat template? +When in doubt, you can always ask the [Discord] community for help. Below are the instructions to find one on your own. + +The authoritative source for a model's chat template is the HuggingFace repo that the original (non-GGUF) model came from. First, you should find this page. If you just have a model file, you can try a google search for the model's name. If you know the page you downloaded the GGUF model from, its README usually links to the original non-GGUF model. + +Once you have located the original model, there are two methods you can use to extract its chat template. Pick whichever one you are most comfortable with. + +### Using the CLI (all models) +1. Install `jq` using your preferred package manager - e.g. Chocolatey (Windows), Homebrew (macOS), or apt (Ubuntu). +2. Download `tokenizer_config.json` from the model's "Files and versions" tab. +3. Open a command prompt in the directory which you have downloaded the model file. +4. Run `jq -r ".chat_template" tokenizer_config.json`. This shows the chat template in a human-readable form. You can copy this and paste it into the settings page. +5. (Optional) You can save the output to a text file like this: `jq -r ".chat_template" tokenizer_config.json >chat_template.txt` + +If the output is "null", the model does not provide a chat template. See the [below instructions][creating] on creating a chat template. + +### Python (open models) +1. Install `transformers` using your preferred python package manager, e.g. `pip install transformers`. Make sure it is at least version v4.43.0. +2. Copy the ID of the HuggingFace model, using the clipboard icon next to the name. For example, if the URL is `https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B`, the ID is `NousResearch/Hermes-2-Pro-Llama-3-8B`. +3. Open a python interpreter (`python`) and run the following commands. Change the model ID in the example to the one you copied. +``` +>>> from transformers import AutoTokenizer +>>> tokenizer = AutoTokenizer.from_pretrained('NousResearch/Hermes-2-Pro-Llama-3-8B') +>>> print(tokenizer.get_chat_template()) +``` +You can copy the output and paste it into the settings page. +4. (Optional) You can save the output to a text file like this: +``` +>>> open('chat_template.txt', 'w').write(tokenizer.get_chat_template()) +``` + +If you get a ValueError exception, this model does not provide a chat template. See the [below instructions][creating] on creating a chat template. + + +### Python (gated models) +Some models, such as Llama and Mistral, do not allow public access to their chat template. You must either use the CLI method above, or follow the following instructions to use Python: + +1. For these steps, you must have git and git-lfs installed. +2. You must have a HuggingFace account and be logged in. +3. You must already have access to the gated model. Otherwise, request access. +4. You must have an SSH key configured for git access to HuggingFace. +5. `git clone` the model's HuggingFace repo using the SSH clone URL. There is no need to download the entire model, which is very large. A good way to do this on Linux is: +```console +$ GIT_LFS_SKIP_SMUDGE=1 git clone hf.co:meta-llama/Llama-3.1-8B-Instruct.git +$ cd Llama-3.1-8B-Instruct +$ git lfs pull -I "tokenizer.*" +``` +6. Follow the above instructions for open models, but replace the model ID with the path to the directory containing `tokenizer\_config.json`: +``` +>>> tokenizer = AutoTokenizer.from_pretrained('.') +``` + + +## Advanced: How do chat templates work? +The chat template is applied to the entire conversation you see in the chat window. The template loops over the list of messages, each containing `role` and `content` fields. `role` is either `user`, `assistant`, or `system`. + +GPT4All also supports the special variables `bos_token`, `eos_token`, and `add_generation_prompt`. See the [HuggingFace docs] for what those do. + +[HuggingFace docs]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#special-variables + + +## Advanced: How do I make a chat template? +The best way to create a chat template is to start by using an existing one as a reference. Then, modify it to use the format documented for the given model. Its README page may explicitly give an example of its template. Or, it may mention the name of a well-known standard template, such as ChatML, Alpaca, Vicuna. GPT4All does not yet include presets for these templates, so they will have to be found in other models or taken from the community. + +For more information, see the very helpful [HuggingFace guide]. Some of this is not applicable, such as the information about tool calling and RAG - GPT4All implements those features differently. + +Some models use a prompt template that does not intuitively map to a multi-turn chat, because it is more intended for single instructions. The [FastChat] implementation of these templates is a useful reference for the correct way to extend them to multiple messages. + +[HuggingFace guide]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#advanced-template-writing-tips +[FastChat]: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py + + +# Advanced: What are GPT4All v1 templates? +GPT4All supports its own template syntax, which is nonstandard but provides complete control over the way LocalDocs sources and file attachments are inserted into the conversation. These templates begin with `{# gpt4all v1 #}` and look similar to the example below. + +For standard templates, GPT4All combines the user message, sources, and attachments into the `content` field. For GPT4All v1 templates, this is not done, so they must be used directly in the template for those features to work correctly. + +```jinja +{# gpt4all v1 #} +{%- for message in messages %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['role'] == 'user' %} + {%- for source in message['sources'] %} + {%- if loop.first %} + {{- '### Context:\n' }} + {%- endif %} + {{- 'Collection: ' + source['collection'] + '\n' + + 'Path: ' + source['path'] + '\n' + + 'Excerpt: ' + source['text'] + '\n\n' }} + {%- endfor %} + {%- endif %} + {%- for attachment in message['prompt_attachments'] %} + {{- attachment['processed_content'] + '\n\n' }} + {%- endfor %} + {{- message['content'] | trim }} + {{- '<|eot_id|>' }} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} +``` diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 136cf685aa2b..616ce80a3533 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -9,7 +9,7 @@ import threading from enum import Enum from queue import Queue -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, NoReturn, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload if sys.version_info >= (3, 9): import importlib.resources as importlib_resources @@ -23,7 +23,9 @@ from typing import TypedDict if TYPE_CHECKING: - from typing_extensions import TypeAlias + from typing_extensions import ParamSpec, TypeAlias + T = TypeVar("T") + P = ParamSpec("P") EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]') @@ -31,7 +33,7 @@ # TODO(jared): use operator.call after we drop python 3.10 support -def _operator_call(obj, /, *args, **kwargs): +def _operator_call(obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: return obj(*args, **kwargs) @@ -116,16 +118,15 @@ def load_llmodel_library(): class LLModelPromptContext(ctypes.Structure): _fields_ = [ - ("n_past", ctypes.c_int32), - ("n_predict", ctypes.c_int32), - ("top_k", ctypes.c_int32), - ("top_p", ctypes.c_float), - ("min_p", ctypes.c_float), - ("temp", ctypes.c_float), - ("n_batch", ctypes.c_int32), + ("n_predict", ctypes.c_int32), + ("top_k", ctypes.c_int32), + ("top_p", ctypes.c_float), + ("min_p", ctypes.c_float), + ("temp", ctypes.c_float), + ("n_batch", ctypes.c_int32), ("repeat_penalty", ctypes.c_float), - ("repeat_last_n", ctypes.c_int32), - ("context_erase", ctypes.c_float), + ("repeat_last_n", ctypes.c_int32), + ("context_erase", ctypes.c_float), ] @@ -157,23 +158,21 @@ class LLModelGPUDevice(ctypes.Structure): llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p] llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool -PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) -ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) -EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p) +PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t, ctypes.c_bool) +ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) +EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p) +SpecialTokenCallback = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p) llmodel.llmodel_prompt.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, - ctypes.c_char_p, PromptCallback, ResponseCallback, - ctypes.c_bool, ctypes.POINTER(LLModelPromptContext), - ctypes.c_bool, - ctypes.c_char_p, + ctypes.POINTER(ctypes.c_char_p), ] -llmodel.llmodel_prompt.restype = None +llmodel.llmodel_prompt.restype = ctypes.c_bool llmodel.llmodel_embed.argtypes = [ ctypes.c_void_p, @@ -222,6 +221,12 @@ class LLModelGPUDevice(ctypes.Structure): llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p] llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p +llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char_p)] +llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32 + +llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback] +llmodel.llmodel_model_foreach_special_token.restype = None + ResponseCallbackType = Callable[[int, str], bool] RawResponseCallbackType = Callable[[int, bytes], bool] EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]' @@ -266,7 +271,6 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int, backend: str): self.model_path = model_path.encode() self.n_ctx = n_ctx self.ngl = ngl - self.context: LLModelPromptContext | None = None self.buffer = bytearray() self.buff_expecting_cont_bytes: int = 0 @@ -286,6 +290,10 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int, backend: str): raise RuntimeError(f"Unable to instantiate model: {errmsg}") self.model: ctypes.c_void_p | None = model + self.special_tokens_map: dict[str, str] = {} + llmodel.llmodel_model_foreach_special_token( + self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()), + ) def __del__(self, llmodel=llmodel): if hasattr(self, 'model'): @@ -312,6 +320,19 @@ def device(self) -> str | None: dev = llmodel.llmodel_model_gpu_device_name(self.model) return None if dev is None else dev.decode() + def count_prompt_tokens(self, prompt: str) -> int: + if self.model is None: + self._raise_closed() + err = ctypes.c_char_p() + n_tok = llmodel.llmodel_count_prompt_tokens(self.model, prompt, ctypes.byref(err)) + if n_tok < 0: + s = err.value + errmsg = 'null' if s is None else s.decode() + raise RuntimeError(f'Unable to count prompt tokens: {errmsg}') + return n_tok + + llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + @staticmethod def list_gpus(mem_required: int = 0) -> list[str]: """ @@ -375,48 +396,6 @@ def thread_count(self): raise Exception("Model not loaded") return llmodel.llmodel_threadCount(self.model) - def _set_context( - self, - n_predict: int = 4096, - top_k: int = 40, - top_p: float = 0.9, - min_p: float = 0.0, - temp: float = 0.1, - n_batch: int = 8, - repeat_penalty: float = 1.2, - repeat_last_n: int = 10, - context_erase: float = 0.75, - reset_context: bool = False, - ): - if self.context is None: - context = LLModelPromptContext( - n_past=0, - n_predict=n_predict, - top_k=top_k, - top_p=top_p, - min_p=min_p, - temp=temp, - n_batch=n_batch, - repeat_penalty=repeat_penalty, - repeat_last_n=repeat_last_n, - context_erase=context_erase, - ) - self.context = context - else: - context = self.context - if reset_context: - self.context.n_past = 0 - - self.context.n_predict = n_predict - self.context.top_k = top_k - self.context.top_p = top_p - self.context.min_p = min_p - self.context.temp = temp - self.context.n_batch = n_batch - self.context.repeat_penalty = repeat_penalty - self.context.repeat_last_n = repeat_last_n - self.context.context_erase = context_erase - @overload def generate_embeddings( self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, @@ -486,20 +465,18 @@ def wrap_cancel_cb(batch_sizes: Any, n_batch: int, backend: bytes) -> bool: def prompt_model( self, - prompt: str, - prompt_template: str, - callback: ResponseCallbackType, - n_predict: int = 4096, - top_k: int = 40, - top_p: float = 0.9, - min_p: float = 0.0, - temp: float = 0.1, - n_batch: int = 8, - repeat_penalty: float = 1.2, - repeat_last_n: int = 10, - context_erase: float = 0.75, - reset_context: bool = False, - special: bool = False, + prompt : str, + callback : ResponseCallbackType, + n_predict : int = 4096, + top_k : int = 40, + top_p : float = 0.9, + min_p : float = 0.0, + temp : float = 0.1, + n_batch : int = 8, + repeat_penalty : float = 1.2, + repeat_last_n : int = 10, + context_erase : float = 0.75, + reset_context : bool = False, ): """ Generate response from model from a prompt. @@ -522,34 +499,38 @@ def prompt_model( self.buffer.clear() self.buff_expecting_cont_bytes = 0 - self._set_context( - n_predict=n_predict, - top_k=top_k, - top_p=top_p, - min_p=min_p, - temp=temp, - n_batch=n_batch, - repeat_penalty=repeat_penalty, - repeat_last_n=repeat_last_n, - context_erase=context_erase, - reset_context=reset_context, + context = LLModelPromptContext( + n_predict = n_predict, + top_k = top_k, + top_p = top_p, + min_p = min_p, + temp = temp, + n_batch = n_batch, + repeat_penalty = repeat_penalty, + repeat_last_n = repeat_last_n, + context_erase = context_erase, ) - llmodel.llmodel_prompt( + error_msg: bytes | None = None + def error_callback(msg: bytes) -> None: + nonlocal error_msg + error_msg = msg + + err = ctypes.c_char_p() + if not llmodel.llmodel_prompt( self.model, ctypes.c_char_p(prompt.encode()), - ctypes.c_char_p(prompt_template.encode()), PromptCallback(self._prompt_callback), ResponseCallback(self._callback_decoder(callback)), - True, - self.context, - special, - ctypes.c_char_p(), - ) + context, + ctypes.byref(err), + ): + s = err.value + raise RuntimeError(f"prompt error: {'null' if s is None else s.decode()}") def prompt_model_streaming( - self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs - ) -> Iterable[str]: + self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs: Any, + ) -> Iterator[str]: if self.model is None: self._raise_closed() @@ -568,15 +549,15 @@ def _generator_callback(token_id: int, response: str): return _generator_callback - def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs): - self.prompt_model(prompt, prompt_template, callback, **kwargs) + def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs): + self.prompt_model(prompt, callback, **kwargs) output_queue.put(Sentinel.TERMINATING_SYMBOL) # Kick off llmodel_prompt in separate thread so we can return generator # immediately thread = threading.Thread( target=run_llmodel_prompt, - args=(prompt, prompt_template, _generator_callback_wrapper(callback)), + args=(prompt, _generator_callback_wrapper(callback)), kwargs=kwargs, ) thread.start() @@ -631,5 +612,5 @@ def _raw_callback(token_id: int, response: bytes) -> bool: # Empty prompt callback @staticmethod - def _prompt_callback(token_id: int) -> bool: + def _prompt_callback(token_ids: ctypes._Pointer[ctypes.c_int32], n_token_ids: int, cached: bool) -> bool: return True diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index c863817dfb07..84b236c996dc 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -4,37 +4,66 @@ from __future__ import annotations import hashlib +import json import os import platform import re import sys import warnings from contextlib import contextmanager +from datetime import datetime from pathlib import Path from types import TracebackType -from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, NoReturn, Protocol, TypedDict, overload +import jinja2 import requests +from jinja2.sandbox import ImmutableSandboxedEnvironment from requests.exceptions import ChunkedEncodingError from tqdm import tqdm from urllib3.exceptions import IncompleteRead, ProtocolError from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult, - LLModel, ResponseCallbackType, empty_response_callback) + LLModel, ResponseCallbackType, _operator_call, empty_response_callback) if TYPE_CHECKING: from typing_extensions import Self, TypeAlias -if sys.platform == 'darwin': +if sys.platform == "darwin": import fcntl # TODO: move to config DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all" -DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n" +ConfigType: TypeAlias = "dict[str, Any]" -ConfigType: TypeAlias = 'dict[str, Any]' -MessageType: TypeAlias = 'dict[str, str]' +# Environment setup adapted from HF transformers +@_operator_call +def _jinja_env() -> ImmutableSandboxedEnvironment: + def raise_exception(message: str) -> NoReturn: + raise jinja2.exceptions.TemplateError(message) + + def tojson(obj: Any, indent: int | None = None) -> str: + return json.dumps(obj, ensure_ascii=False, indent=indent) + + def strftime_now(fmt: str) -> str: + return datetime.now().strftime(fmt) + + env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + env.filters["tojson" ] = tojson + env.globals["raise_exception"] = raise_exception + env.globals["strftime_now" ] = strftime_now + return env + + +class MessageType(TypedDict): + role: str + content: str + + +class ChatSession(NamedTuple): + template: jinja2.Template + history: list[MessageType] class Embed4All: @@ -54,7 +83,7 @@ def __init__(self, model_name: str | None = None, *, n_threads: int | None = Non kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor. """ if model_name is None: - model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf' + model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf" self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs) def __enter__(self) -> Self: @@ -145,18 +174,18 @@ def embed( dimensionality = -1 else: if dimensionality <= 0: - raise ValueError(f'Dimensionality must be None or a positive integer, got {dimensionality}') + raise ValueError(f"Dimensionality must be None or a positive integer, got {dimensionality}") if dimensionality < self.MIN_DIMENSIONALITY: warnings.warn( - f'Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}.' - ' Performance may be degraded.' + f"Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}." + " Performance may be degraded." ) try: do_mean = {"mean": True, "truncate": False}[long_text_mode] except KeyError: raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}") result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb) - return result if return_dict else result['embeddings'] + return result if return_dict else result["embeddings"] class GPT4All: @@ -204,8 +233,7 @@ def __init__( """ self.model_type = model_type - self._history: list[MessageType] | None = None - self._current_prompt_template: str = "{0}" + self._chat_session: ChatSession | None = None device_init = None if sys.platform == "darwin": @@ -264,7 +292,13 @@ def device(self) -> str | None: @property def current_chat_session(self) -> list[MessageType] | None: - return None if self._history is None else list(self._history) + return None if self._chat_session is None else self._chat_session.history + + @current_chat_session.setter + def current_chat_session(self, history: list[MessageType]) -> None: + if self._chat_session is None: + raise ValueError("current_chat_session may only be set when there is an active chat session") + self._chat_session.history[:] = history @staticmethod def list_models() -> list[ConfigType]: @@ -276,7 +310,7 @@ def list_models() -> list[ConfigType]: """ resp = requests.get("https://gpt4all.io/models/models3.json") if resp.status_code != 200: - raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}') + raise ValueError(f"Request failed: HTTP {resp.status_code} {resp.reason}") return resp.json() @classmethod @@ -306,15 +340,9 @@ def retrieve_model( # get the config for the model config: ConfigType = {} if allow_download: - available_models = cls.list_models() - - for m in available_models: - if model_filename == m["filename"]: - tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE) - # change to Python-style formatting - m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1) - config.update(m) - break + models = cls.list_models() + if (model := next((m for m in models if m["filename"] == model_filename), None)) is not None: + config.update(model) # Validate download directory if model_path is None: @@ -378,13 +406,13 @@ def make_request(offset=None): headers = {} if offset: print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr) - headers['Range'] = f'bytes={offset}-' # resume incomplete response + headers["Range"] = f"bytes={offset}-" # resume incomplete response headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges response = requests.get(url, stream=True, headers=headers) if response.status_code not in (200, 206): - raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') - if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')): - raise ValueError('Connection was interrupted and server does not support range requests') + raise ValueError(f"Request failed: HTTP {response.status_code} {response.reason}") + if offset and (response.status_code != 206 or str(offset) not in response.headers.get("Content-Range", "")): + raise ValueError("Connection was interrupted and server does not support range requests") if (enc := response.headers.get("Content-Encoding")) is not None: raise ValueError(f"Expected identity Content-Encoding, got {enc}") return response @@ -483,19 +511,19 @@ def generate( def generate( self, - prompt: str, + prompt : str, *, - max_tokens: int = 200, - temp: float = 0.7, - top_k: int = 40, - top_p: float = 0.4, - min_p: float = 0.0, - repeat_penalty: float = 1.18, - repeat_last_n: int = 64, - n_batch: int = 8, - n_predict: int | None = None, - streaming: bool = False, - callback: ResponseCallbackType = empty_response_callback, + max_tokens : int = 200, + temp : float = 0.7, + top_k : int = 40, + top_p : float = 0.4, + min_p : float = 0.0, + repeat_penalty : float = 1.18, + repeat_last_n : int = 64, + n_batch : int = 8, + n_predict : int | None = None, + streaming : bool = False, + callback : ResponseCallbackType = empty_response_callback, ) -> Any: """ Generate outputs from any GPT4All model. @@ -520,122 +548,94 @@ def generate( # Preparing the model request generate_kwargs: dict[str, Any] = dict( - temp=temp, - top_k=top_k, - top_p=top_p, - min_p=min_p, - repeat_penalty=repeat_penalty, - repeat_last_n=repeat_last_n, - n_batch=n_batch, - n_predict=n_predict if n_predict is not None else max_tokens, + temp = temp, + top_k = top_k, + top_p = top_p, + min_p = min_p, + repeat_penalty = repeat_penalty, + repeat_last_n = repeat_last_n, + n_batch = n_batch, + n_predict = n_predict if n_predict is not None else max_tokens, ) - if self._history is not None: - # check if there is only one message, i.e. system prompt: - reset = len(self._history) == 1 - self._history.append({"role": "user", "content": prompt}) - - fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined] - if fct_func is GPT4All._format_chat_prompt_template: - if reset: - # ingest system prompt - # use "%1%2" and not "%1" to avoid implicit whitespace - self.model.prompt_model(self._history[0]["content"], "%1%2", - empty_response_callback, - n_batch=n_batch, n_predict=0, reset_context=True, special=True) - prompt_template = self._current_prompt_template.format("%1", "%2") - else: - warnings.warn( - "_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.", - DeprecationWarning, - ) - # special tokens won't be processed - prompt = self._format_chat_prompt_template( - self._history[-1:], - self._history[0]["content"] if reset else "", - ) - prompt_template = "%1" - generate_kwargs["reset_context"] = reset - else: - prompt_template = "%1" - generate_kwargs["reset_context"] = True - # Prepare the callback, process the model response - output_collector: list[MessageType] - output_collector = [ - {"content": ""} - ] # placeholder for the self._history if chat session is not activated - - if self._history is not None: - self._history.append({"role": "assistant", "content": ""}) - output_collector = self._history - - def _callback_wrapper( - callback: ResponseCallbackType, - output_collector: list[MessageType], - ) -> ResponseCallbackType: - def _callback(token_id: int, response: str) -> bool: - nonlocal callback, output_collector - - output_collector[-1]["content"] += response - - return callback(token_id, response) + full_response = "" + + def _callback_wrapper(token_id: int, response: str) -> bool: + nonlocal full_response + full_response += response + return callback(token_id, response) + + last_msg_rendered = prompt + if self._chat_session is not None: + session = self._chat_session + def render(messages: list[MessageType]) -> str: + return session.template.render( + messages=messages, + add_generation_prompt=True, + **self.model.special_tokens_map, + ) + session.history.append(MessageType(role="user", content=prompt)) + prompt = render(session.history) + if len(session.history) > 1: + last_msg_rendered = render(session.history[-1:]) - return _callback + # Check request length + last_msg_len = self.model.count_prompt_tokens(last_msg_rendered) + if last_msg_len > (limit := self.model.n_ctx - 4): + raise ValueError(f"Your message was too long and could not be processed ({last_msg_len} > {limit}).") # Send the request to the model if streaming: - return self.model.prompt_model_streaming( - prompt, - prompt_template, - _callback_wrapper(callback, output_collector), - **generate_kwargs, - ) - - self.model.prompt_model( - prompt, - prompt_template, - _callback_wrapper(callback, output_collector), - **generate_kwargs, - ) + def stream() -> Iterator[str]: + yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs) + if self._chat_session is not None: + self._chat_session.history.append(MessageType(role="assistant", content=full_response)) + return stream() - return output_collector[-1]["content"] + self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs) + if self._chat_session is not None: + self._chat_session.history.append(MessageType(role="assistant", content=full_response)) + return full_response @contextmanager def chat_session( self, - system_prompt: str | None = None, - prompt_template: str | None = None, + system_message: str | Literal[False] | None = None, + chat_template: str | None = None, ): """ Context manager to hold an inference optimized chat session with a GPT4All model. Args: - system_prompt: An initial instruction for the model. - prompt_template: Template for the prompts with {0} being replaced by the user message. + system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None. + chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None. """ - if system_prompt is None: - system_prompt = self.config.get("systemPrompt", "") - - if prompt_template is None: - if (tmpl := self.config.get("promptTemplate")) is None: - warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template " - "is deprecated. Defaulting to Alpaca.", DeprecationWarning) - tmpl = DEFAULT_PROMPT_TEMPLATE - prompt_template = tmpl - - if re.search(r"%1(?![0-9])", prompt_template): - raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt " - "placeholder, please use '{0}' instead.") - - self._history = [{"role": "system", "content": system_prompt}] - self._current_prompt_template = prompt_template + if system_message is None: + system_message = self.config.get("systemMessage", False) + + if chat_template is None: + if "name" not in self.config: + raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.") + if "chatTemplate" not in self.config: + raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not " + "currently implemented. Please pass a template to chat_session() directly.") + if (tmpl := self.config["chatTemplate"]) is None: + raise ValueError(f"The model {self.config['name']!r} does not support chat.") + chat_template = tmpl + + history = [] + if system_message is not False: + history.append(MessageType(role="system", content=system_message)) + self._chat_session = ChatSession( + template=_jinja_env.from_string(chat_template), + history=history, + ) try: yield self finally: - self._history = None - self._current_prompt_template = "{0}" + self._chat_session = None @staticmethod def list_gpus() -> list[str]: @@ -647,43 +647,6 @@ def list_gpus() -> list[str]: """ return LLModel.list_gpus() - def _format_chat_prompt_template( - self, - messages: list[MessageType], - default_prompt_header: str = "", - default_prompt_footer: str = "", - ) -> str: - """ - Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message. - - Warning: - This function was deprecated in version 2.3.0, and will be removed in a future release. - - Args: - messages: List of dictionaries. Each dictionary should have a "role" key - with value of "system", "assistant", or "user" and a "content" key with a - string value. Messages are organized such that "system" messages are at top of prompt, - and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as - "Response: {content}". - - Returns: - Formatted prompt. - """ - - full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else "" - - for message in messages: - if message["role"] == "user": - user_message = self._current_prompt_template.format(message["content"]) - full_prompt += user_message - if message["role"] == "assistant": - assistant_message = message["content"] + "\n" - full_prompt += assistant_message - - full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else "" - - return full_prompt - def append_extension_if_missing(model_name): if not model_name.endswith((".bin", ".gguf")): @@ -696,7 +659,7 @@ def fileno(self) -> int: ... def _fsync(fd: int | _HasFileno) -> None: - if sys.platform == 'darwin': + if sys.platform == "darwin": # Apple's fsync does not flush the drive write cache try: fcntl.fcntl(fd, fcntl.F_FULLFSYNC) diff --git a/gpt4all-bindings/python/mkdocs.yml b/gpt4all-bindings/python/mkdocs.yml index a80ec9b82129..651366a32ff3 100644 --- a/gpt4all-bindings/python/mkdocs.yml +++ b/gpt4all-bindings/python/mkdocs.yml @@ -14,6 +14,7 @@ nav: - 'Models' : 'gpt4all_desktop/models.md' - 'LocalDocs' : 'gpt4all_desktop/localdocs.md' - 'Settings' : 'gpt4all_desktop/settings.md' + - 'Chat Templates' : 'gpt4all_desktop/chat_templates.md' - 'Cookbook': - 'Local AI Chat with Microsoft Excel': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-microsoft-excel.md' - 'Local AI Chat with your Google Drive': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-google-drive.md' diff --git a/gpt4all-bindings/python/setup.py b/gpt4all-bindings/python/setup.py index e96f58405fe3..b316adc0ec4b 100644 --- a/gpt4all-bindings/python/setup.py +++ b/gpt4all-bindings/python/setup.py @@ -88,9 +88,10 @@ def get_long_description(): python_requires='>=3.8', packages=find_packages(), install_requires=[ + 'importlib_resources; python_version < "3.9"', + 'jinja2~=3.1', 'requests', 'tqdm', - 'importlib_resources; python_version < "3.9"', 'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"', ], extras_require={ diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 9ae28ce9aca6..257338b9510c 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -190,6 +190,7 @@ qt_add_executable(chat src/database.cpp src/database.h src/download.cpp src/download.h src/embllm.cpp src/embllm.h + src/jinja_helpers.cpp src/jinja_helpers.h src/llm.cpp src/llm.h src/localdocs.cpp src/localdocs.h src/localdocsmodel.cpp src/localdocsmodel.h @@ -215,6 +216,7 @@ qt_add_qml_module(chat qml/ApplicationSettings.qml qml/ChatDrawer.qml qml/ChatItemView.qml + qml/ChatMessageButton.qml qml/ChatView.qml qml/CollectionsDrawer.qml qml/HomeView.qml @@ -227,7 +229,7 @@ qt_add_qml_module(chat qml/PopupDialog.qml qml/SettingsView.qml qml/StartupDialog.qml - qml/SwitchModelDialog.qml + qml/ConfirmationDialog.qml qml/Theme.qml qml/ThumbsDownDialog.qml qml/Toast.qml @@ -386,7 +388,7 @@ target_include_directories(chat PRIVATE deps/usearch/include target_link_libraries(chat PRIVATE Qt6::Core Qt6::HttpServer Qt6::Pdf Qt6::Quick Qt6::Sql Qt6::Svg) target_link_libraries(chat - PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx) + PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx jinja2cpp) if (APPLE) target_link_libraries(chat PRIVATE ${COCOA_LIBRARY}) diff --git a/gpt4all-chat/deps/CMakeLists.txt b/gpt4all-chat/deps/CMakeLists.txt index a87a9a203a4b..495ff313ce6d 100644 --- a/gpt4all-chat/deps/CMakeLists.txt +++ b/gpt4all-chat/deps/CMakeLists.txt @@ -11,3 +11,5 @@ add_subdirectory(DuckX) set(QT_VERSION_MAJOR 6) add_subdirectory(QXlsx/QXlsx) + +add_subdirectory(Jinja2Cpp) diff --git a/gpt4all-chat/deps/Jinja2Cpp b/gpt4all-chat/deps/Jinja2Cpp new file mode 160000 index 000000000000..b2a716798bfa --- /dev/null +++ b/gpt4all-chat/deps/Jinja2Cpp @@ -0,0 +1 @@ +Subproject commit b2a716798bfa63c7dae303fc1e272964c4e1f9ee diff --git a/gpt4all-chat/icons/edit.svg b/gpt4all-chat/icons/edit.svg index 5a79a50e9cf4..ceb292bbf0de 100644 --- a/gpt4all-chat/icons/edit.svg +++ b/gpt4all-chat/icons/edit.svg @@ -1,3 +1 @@ - - - + \ No newline at end of file diff --git a/gpt4all-chat/metadata/models3.json b/gpt4all-chat/metadata/models3.json index e4a2162f3317..8c4adacbfdaf 100644 --- a/gpt4all-chat/metadata/models3.json +++ b/gpt4all-chat/metadata/models3.json @@ -29,7 +29,8 @@ "description": "", "url": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf", "promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2", - "systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>" + "systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>", + "chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" }, { "order": "c", @@ -45,7 +46,8 @@ "description": "", "url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf", "promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2", - "systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>" + "systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>", + "chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" }, { "order": "d", @@ -77,7 +79,8 @@ "systemPrompt": "", "description": "Strong overall fast instruction following model
  • Fast responses
  • Trained by Mistral AI
  • Uncensored
  • Licensed for commercial use
", "url": "https://gpt4all.io/models/gguf/mistral-7b-instruct-v0.1.Q4_0.gguf", - "promptTemplate": "[INST] %1 [/INST]" + "promptTemplate": "[INST] %1 [/INST]", + "chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_start = 1 %}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (message['role'] == 'user') != ((loop.index0 - loop_start) % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}" }, { "order": "f", @@ -125,7 +128,8 @@ "systemPrompt": "", "description": "Very fast model with good quality
  • Fastest responses
  • Instruction based
  • Trained by TII
  • Finetuned by Nomic AI
  • Licensed for commercial use
", "url": "https://gpt4all.io/models/gguf/gpt4all-falcon-newbpe-q4_0.gguf", - "promptTemplate": "### Instruction:\n%1\n\n### Response:\n" + "promptTemplate": "### Instruction:\n%1\n\n### Response:\n", + "chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User: ' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Assistant: ' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Assistant:' }}\n{%- endif %}" }, { "order": "i", @@ -140,7 +144,8 @@ "type": "LLaMA2", "systemPrompt": "", "description": "
  • Instruction based
  • Trained by Microsoft
  • Cannot be used commercially
", - "url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf" + "url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf", + "chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}" }, { "order": "j", @@ -155,7 +160,8 @@ "type": "LLaMA2", "systemPrompt": "", "description": "
  • Instruction based
  • Trained by Microsoft
  • Cannot be used commercially
", - "url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf" + "url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf", + "chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}" }, { "order": "k", @@ -170,7 +176,9 @@ "type": "LLaMA2", "systemPrompt": "", "description": "Strong overall larger model
  • Instruction based
  • Gives very long responses
  • Finetuned with only 1k of high-quality data
  • Trained by Microsoft and Peking University
  • Cannot be used commercially
", - "url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf" + "url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf", + "chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + ' ' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in loop_messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- if (loop.index0 - loop_start) % 2 == 0 %}\n {{- ' ' }}\n {%- else %}\n {{- eos_token }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- 'ASSISTANT:' }}\n{%- endif %}", + "systemMessage": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." }, { "order": "l", @@ -186,7 +194,8 @@ "description": "Ghost 7B v0.9.1 fast, powerful and smooth for Vietnamese and English languages.", "url": "https://huggingface.co/lamhieu/ghost-7b-v0.9.1-gguf/resolve/main/ghost-7b-v0.9.1-Q4_0.gguf", "promptTemplate": "<|user|>\n%1\n<|assistant|>\n%2\n", - "systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n" + "systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n", + "systemMessage": "You are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior." }, { "order": "m", @@ -202,7 +211,8 @@ "systemPrompt": "", "description": "Extremely good model
  • Instruction based
  • Gives long responses
  • Curated with 300,000 uncensored instructions
  • Trained by Nous Research
  • Cannot be used commercially
", "url": "https://gpt4all.io/models/gguf/nous-hermes-llama2-13b.Q4_0.gguf", - "promptTemplate": "### Instruction:\n%1\n\n### Response:\n" + "promptTemplate": "### Instruction:\n%1\n\n### Response:\n", + "chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Instruction:\\n' }}\n{%- endif %}" }, { "order": "n", @@ -217,7 +227,9 @@ "type": "LLaMA", "systemPrompt": "", "description": "Very good overall model
  • Instruction based
  • Based on the same dataset as Groovy
  • Slower than Groovy, with higher quality responses
  • Trained by Nomic AI
  • Cannot be used commercially
", - "url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf" + "url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf", + "chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}", + "systemMessage": "Below is an instruction that describes a task. Write a response that appropriately completes the request." }, { "order": "o", @@ -234,7 +246,8 @@ "description": "Good model with novel architecture
  • Fast responses
  • Chat based
  • Trained by Mosaic ML
  • Cannot be used commercially
", "url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf", "promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n", - "systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n" + "systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n", + "chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}" }, { "order": "p", @@ -250,7 +263,8 @@ "description": "Good model with novel architecture
  • Fast responses
  • Chat based
  • Trained by Mosaic ML
  • Cannot be used commercially
", "url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf", "promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n", - "systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n" + "systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n", + "chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}" }, { "order": "q", @@ -282,7 +296,8 @@ "description": "Small version of new model with novel dataset
  • Very fast responses
  • Instruction based
  • Explain tuned datasets
  • Orca Research Paper dataset construction approaches
  • Cannot be used commercially
", "url": "https://gpt4all.io/models/gguf/orca-mini-3b-gguf2-q4_0.gguf", "promptTemplate": "### User:\n%1\n\n### Response:\n", - "systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n" + "systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n", + "chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- '### System:\\n' + messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}" }, { "order": "s", @@ -299,7 +314,8 @@ "systemPrompt": "", "promptTemplate": "%1", "description": "Trained on subset of the Stack
  • Code completion based
  • Licensed for commercial use
  • WARNING: Not available for chat GUI
", - "url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf" + "url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf", + "chatTemplate": null }, { "order": "t", @@ -316,7 +332,8 @@ "systemPrompt": "", "promptTemplate": "%1", "description": "Trained on subset of the Stack
  • Code completion based
  • WARNING: Not available for chat GUI
", - "url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf" + "url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf", + "chatTemplate": null }, { "order": "u", @@ -333,7 +350,8 @@ "systemPrompt": "", "promptTemplate": "%1", "description": "Trained on collection of Python and TypeScript
  • Code completion based
  • WARNING: Not available for chat GUI
  • ", - "url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf" + "url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf", + "chatTemplate": null }, { "order": "v", @@ -351,7 +369,8 @@ "embeddingModel": true, "systemPrompt": "", "description": "LocalDocs text embeddings model
    • For use with LocalDocs feature
    • Used for retrieval augmented generation (RAG)", - "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf" + "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf", + "chatTemplate": null }, { "order": "w", @@ -367,7 +386,8 @@ "type": "Bert", "embeddingModel": true, "description": "LocalDocs text embeddings model
      • For use with LocalDocs feature
      • Used for retrieval augmented generation (RAG)", - "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf" + "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf", + "chatTemplate": null }, { "order": "x", @@ -383,7 +403,9 @@ "description": "Mistral-based model for German-language applications
        • Fast responses
        • Chat based model
        • Trained by ellamind
        • Finetuned on German instruction and chat data
        • Licensed for commercial use
        ", "url": "https://huggingface.co/TheBloke/em_german_mistral_v01-GGUF/resolve/main/em_german_mistral_v01.Q4_0.gguf", "promptTemplate": "USER: %1 ASSISTANT: ", - "systemPrompt": "Du bist ein hilfreicher Assistent. " + "systemPrompt": "Du bist ein hilfreicher Assistent. ", + "chatTemplate": "{%- set system_message = false %}\n{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {%- set system_message = true %}\n {{- messages[0]['content'] }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (not loop.first) or (system_message is not none) %}\n {{- ' ' }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('After the optional system message, conversation roles must be either user or assistant.') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {%- if messages %}\n {{- ' ' }}\n {%- endif %}\n {{- 'ASSISTANT:' }}\n{%- endif %}", + "systemMessage": "Du bist ein hilfreicher Assistent." }, { "order": "y", @@ -400,7 +422,8 @@ "embeddingModel": true, "systemPrompt": "", "description": "nomic-embed-text-v1", - "url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf" + "url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf", + "chatTemplate": null }, { "order": "z", @@ -417,7 +440,8 @@ "embeddingModel": true, "systemPrompt": "", "description": "nomic-embed-text-v1.5", - "url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf" + "url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf", + "chatTemplate": null }, { "order": "zzz", diff --git a/gpt4all-chat/qml/ApplicationSettings.qml b/gpt4all-chat/qml/ApplicationSettings.qml index e61fc274dad4..c002561563c9 100644 --- a/gpt4all-chat/qml/ApplicationSettings.qml +++ b/gpt4all-chat/qml/ApplicationSettings.qml @@ -10,7 +10,7 @@ import network import llm MySettingsTab { - onRestoreDefaultsClicked: { + onRestoreDefaults: { MySettings.restoreApplicationDefaults(); } title: qsTr("Application") @@ -486,23 +486,6 @@ MySettingsTab { Accessible.name: nThreadsLabel.text Accessible.description: ToolTip.text } - MySettingsLabel { - id: saveChatsContextLabel - text: qsTr("Save Chat Context") - helpText: qsTr("Save the chat model's state to disk for faster loading. WARNING: Uses ~2GB per chat.") - Layout.row: 12 - Layout.column: 0 - } - MyCheckBox { - id: saveChatsContextBox - Layout.row: 12 - Layout.column: 2 - Layout.alignment: Qt.AlignRight - checked: MySettings.saveChatsContext - onClicked: { - MySettings.saveChatsContext = !MySettings.saveChatsContext - } - } MySettingsLabel { id: trayLabel text: qsTr("Enable System Tray") diff --git a/gpt4all-chat/qml/ChatItemView.qml b/gpt4all-chat/qml/ChatItemView.qml index 8a0c04f8f256..ed7476149ddc 100644 --- a/gpt4all-chat/qml/ChatItemView.qml +++ b/gpt4all-chat/qml/ChatItemView.qml @@ -8,8 +8,23 @@ import QtQuick.Layouts import gpt4all import mysettings +ColumnLayout { + +property var inputBoxText: null +signal setInputBoxText(text: string) + +Item { + +Layout.fillWidth: true +Layout.maximumWidth: parent.width +Layout.preferredHeight: gridLayout.height + +HoverHandler { id: hoverArea } + GridLayout { - rows: 5 + id: gridLayout + anchors.left: parent.left + anchors.right: parent.right columns: 2 Item { @@ -40,7 +55,7 @@ GridLayout { to: 360 duration: 1000 loops: Animation.Infinite - running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText) + running: isCurrentResponse && currentChat.responseInProgress } } } @@ -73,13 +88,11 @@ GridLayout { color: theme.mutedTextColor } RowLayout { - visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText) + visible: isCurrentResponse && (value === "" && currentChat.responseInProgress) Text { color: theme.mutedTextColor font.pixelSize: theme.fontSizeLarger text: { - if (currentChat.restoringFromText) - return qsTr("restoring from text ..."); switch (currentChat.responseState) { case Chat.ResponseStopped: return qsTr("response stopped ..."); case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", ")); @@ -99,10 +112,11 @@ GridLayout { Layout.row: 1 Layout.column: 1 Layout.fillWidth: true - spacing: 20 + spacing: 10 Flow { id: attachedUrlsFlow Layout.fillWidth: true + Layout.bottomMargin: 10 spacing: 10 visible: promptAttachments.length !== 0 Repeater { @@ -156,7 +170,7 @@ GridLayout { focus: false readOnly: true font.pixelSize: theme.fontSizeLarge - cursorVisible: currentResponse ? currentChat.responseInProgress : false + cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false cursorPosition: text.length TapHandler { id: tapHandler @@ -183,12 +197,12 @@ GridLayout { } onLinkActivated: function(link) { - if (!currentResponse || !currentChat.responseInProgress) + if (!isCurrentResponse || !currentChat.responseInProgress) Qt.openUrlExternally(link) } onLinkHovered: function (link) { - if (!currentResponse || !currentChat.responseInProgress) + if (!isCurrentResponse || !currentChat.responseInProgress) statusBar.externalHoveredLink = link } @@ -239,13 +253,19 @@ GridLayout { textProcessor.setValue(value); } + property bool textProcessorReady: false + Component.onCompleted: { resetChatViewTextProcessor(); - chatModel.valueChanged.connect(function(i, value) { - if (index === i) + textProcessorReady = true; + } + + Connections { + target: chatModel + function onValueChanged(i, value) { + if (myTextArea.textProcessorReady && index === i) textProcessor.setValue(value); } - ); } Connections { @@ -282,67 +302,6 @@ GridLayout { Network.sendConversation(currentChat.id, getConversationJson()); } } - - Column { - Layout.alignment: Qt.AlignRight - Layout.rightMargin: 15 - visible: name === "Response: " && - (!currentResponse || !currentChat.responseInProgress) && MySettings.networkIsActive - spacing: 10 - - Item { - width: childrenRect.width - height: childrenRect.height - MyToolButton { - id: thumbsUp - width: 24 - height: 24 - imageWidth: width - imageHeight: height - opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2 - source: "qrc:/gpt4all/icons/thumbs_up.svg" - Accessible.name: qsTr("Thumbs up") - Accessible.description: qsTr("Gives a thumbs up to the response") - onClicked: { - if (thumbsUpState && !thumbsDownState) - return - - chatModel.updateNewResponse(index, "") - chatModel.updateThumbsUpState(index, true) - chatModel.updateThumbsDownState(index, false) - Network.sendConversation(currentChat.id, getConversationJson()); - } - } - - MyToolButton { - id: thumbsDown - anchors.top: thumbsUp.top - anchors.topMargin: 3 - anchors.left: thumbsUp.right - anchors.leftMargin: 3 - width: 24 - height: 24 - imageWidth: width - imageHeight: height - checked: thumbsDownState - opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2 - transform: [ - Matrix4x4 { - matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1) - }, - Translate { - x: thumbsDown.width - } - ] - source: "qrc:/gpt4all/icons/thumbs_down.svg" - Accessible.name: qsTr("Thumbs down") - Accessible.description: qsTr("Opens thumbs down dialog") - onClicked: { - thumbsDownDialog.open() - } - } - } - } } Item { @@ -353,11 +312,13 @@ GridLayout { Layout.preferredWidth: childrenRect.width Layout.preferredHeight: childrenRect.height visible: { + if (name !== "Response: ") + return false if (consolidatedSources.length === 0) return false if (!MySettings.localDocsShowReferences) return false - if (currentResponse && currentChat.responseInProgress + if (isCurrentResponse && currentChat.responseInProgress && currentChat.responseState !== Chat.GeneratingQuestions ) return false return true @@ -443,7 +404,7 @@ GridLayout { return false if (!MySettings.localDocsShowReferences) return false - if (currentResponse && currentChat.responseInProgress + if (isCurrentResponse && currentChat.responseInProgress && currentChat.responseState !== Chat.GeneratingQuestions ) return false return true @@ -566,8 +527,139 @@ GridLayout { } } + ConfirmationDialog { + id: editPromptDialog + dialogTitle: qsTr("Edit this prompt?") + description: qsTr("The existing response and all later messages will be permanently erased.") + onAccepted: { + const msg = currentChat.popPrompt(index); + if (msg !== null) + setInputBoxText(msg); + } + } + + ConfirmationDialog { + id: redoResponseDialog + dialogTitle: qsTr("Redo this response?") + description: qsTr("The existing response and all later messages will be permanently erased.") + onAccepted: currentChat.regenerateResponse(index) + } + + RowLayout { + id: buttonRow + Layout.row: 4 + Layout.column: 1 + Layout.maximumWidth: parent.width + Layout.fillWidth: false + Layout.alignment: Qt.AlignLeft | Qt.AlignTop + spacing: 3 + visible: !isCurrentResponse || !currentChat.responseInProgress + enabled: opacity > 0 + opacity: hoverArea.hovered + readonly property var canModify: !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress + + Behavior on opacity { + OpacityAnimator { duration: 30 } + } + + ChatMessageButton { + visible: parent.canModify && model.name === "Prompt: " + Layout.maximumWidth: 24 + Layout.maximumHeight: 24 + Layout.alignment: Qt.AlignVCenter + Layout.fillWidth: false + source: "qrc:/gpt4all/icons/edit.svg" + onClicked: { + if (inputBoxText === "") + editPromptDialog.open(); + } + name: qsTr("Edit") + } + + ChatMessageButton { + visible: parent.canModify && model.name === "Response: " + Layout.maximumWidth: 24 + Layout.maximumHeight: 24 + Layout.alignment: Qt.AlignVCenter + Layout.fillWidth: false + name: qsTr("Redo") + source: "qrc:/gpt4all/icons/regenerate.svg" + onClicked: redoResponseDialog.open() + } + + ChatMessageButton { + Layout.maximumWidth: 24 + Layout.maximumHeight: 24 + Layout.alignment: Qt.AlignVCenter + Layout.fillWidth: false + name: qsTr("Copy") + source: "qrc:/gpt4all/icons/copy.svg" + onClicked: { + myTextArea.selectAll(); + myTextArea.copy(); + myTextArea.deselect(); + } + } + + Item { + visible: name === "Response: " && MySettings.networkIsActive + Layout.alignment: Qt.AlignVCenter + Layout.preferredWidth: childrenRect.width + Layout.preferredHeight: childrenRect.height + Layout.fillWidth: false + + ChatMessageButton { + id: thumbsUp + anchors.left: parent.left + anchors.verticalCenter: parent.verticalCenter + opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2 + source: "qrc:/gpt4all/icons/thumbs_up.svg" + name: qsTr("Like response") + onClicked: { + if (thumbsUpState && !thumbsDownState) + return + + chatModel.updateNewResponse(index, "") + chatModel.updateThumbsUpState(index, true) + chatModel.updateThumbsDownState(index, false) + Network.sendConversation(currentChat.id, getConversationJson()); + } + } + + ChatMessageButton { + id: thumbsDown + anchors.top: thumbsUp.top + anchors.topMargin: buttonRow.spacing + anchors.left: thumbsUp.right + anchors.leftMargin: buttonRow.spacing + checked: thumbsDownState + opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2 + bgTransform: [ + Matrix4x4 { + matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1) + }, + Translate { + x: thumbsDown.width + } + ] + source: "qrc:/gpt4all/icons/thumbs_down.svg" + name: qsTr("Dislike response") + onClicked: { + thumbsDownDialog.open() + } + } + } + } +} // GridLayout + +} // Item + +GridLayout { + Layout.fillWidth: true + Layout.maximumWidth: parent.width + function shouldShowSuggestions() { - if (!currentResponse) + if (!isCurrentResponse) return false; if (MySettings.suggestionMode === 2) // Off return false; @@ -577,8 +669,8 @@ GridLayout { } Item { - visible: shouldShowSuggestions() - Layout.row: 4 + visible: parent.shouldShowSuggestions() + Layout.row: 5 Layout.column: 0 Layout.topMargin: 20 Layout.alignment: Qt.AlignVCenter | Qt.AlignRight @@ -601,8 +693,8 @@ GridLayout { } Item { - visible: shouldShowSuggestions() - Layout.row: 4 + visible: parent.shouldShowSuggestions() + Layout.row: 5 Layout.column: 1 Layout.topMargin: 20 Layout.fillWidth: true @@ -627,8 +719,8 @@ GridLayout { } ColumnLayout { - visible: shouldShowSuggestions() - Layout.row: 5 + visible: parent.shouldShowSuggestions() + Layout.row: 6 Layout.column: 1 Layout.fillWidth: true Layout.minimumHeight: 1 @@ -786,4 +878,7 @@ GridLayout { } } } -} + +} // GridLayout + +} // ColumnLayout diff --git a/gpt4all-chat/qml/ChatMessageButton.qml b/gpt4all-chat/qml/ChatMessageButton.qml new file mode 100644 index 000000000000..b3c0a31a7f1e --- /dev/null +++ b/gpt4all-chat/qml/ChatMessageButton.qml @@ -0,0 +1,20 @@ +import QtQuick +import QtQuick.Controls + +import gpt4all + +MyToolButton { + property string name + + width: 24 + height: 24 + imageWidth: width + imageHeight: height + ToolTip { + visible: parent.hovered + y: parent.height * 1.5 + text: name + delay: Qt.styleHints.mousePressAndHoldInterval + } + Accessible.name: name +} diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 219a5139801e..b8a5b27b2d6c 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -24,6 +24,12 @@ Rectangle { property var currentChat: ChatListModel.currentChat property var chatModel: currentChat.chatModel + property var currentModelInfo: currentChat && currentChat.modelInfo + property var currentModelId: null + onCurrentModelInfoChanged: { + const newId = currentModelInfo && currentModelInfo.id; + if (currentModelId !== newId) { currentModelId = newId; } + } signal addCollectionViewRequested() signal addModelViewRequested() @@ -79,14 +85,11 @@ Rectangle { function open_(msg) { message = msg; open(); } } - SwitchModelDialog { + ConfirmationDialog { id: switchModelDialog - anchors.centerIn: parent - Item { - Accessible.role: Accessible.Dialog - Accessible.name: qsTr("Switch model dialog") - Accessible.description: qsTr("Warn the user if they switch models, then context will be erased") - } + property int index: -1 + dialogTitle: qsTr("Erase conversation?") + description: qsTr("Changing the model will erase the current conversation.") } PopupDialog { @@ -103,6 +106,16 @@ Rectangle { font.pixelSize: theme.fontSizeLarge } + ConfirmationDialog { + id: resetContextDialog + dialogTitle: qsTr("Erase conversation?") + description: qsTr("The entire chat will be erased.") + onAccepted: { + Network.trackChatEvent("reset_context", { "length": chatModel.count }); + currentChat.reset(); + } + } + function getConversation() { var conversation = ""; for (var i = 0; i < chatModel.count; i++) { @@ -703,7 +716,7 @@ Rectangle { if (i !== -1) { defaultModel = comboBox.valueAt(i); } else { - defaultModel = comboBox.valueAt(0); + defaultModel = comboBox.count ? comboBox.valueAt(0) : ""; } if (defaultModel !== "") { defaultModelName = ModelList.modelInfo(defaultModel).name; @@ -790,9 +803,9 @@ Rectangle { Layout.leftMargin: 50 Layout.rightMargin: 50 Layout.alignment: Qt.AlignHCenter - spacing: 25 + spacing: 10 model: chatModel - cacheBuffer: Math.max(0, listView.contentHeight) + cacheBuffer: 2147483647 ScrollBar.vertical: ScrollBar { policy: ScrollBar.AsNeeded @@ -804,6 +817,12 @@ Rectangle { delegate: ChatItemView { width: listView.contentItem.width - 15 + inputBoxText: textInput.text + onSetInputBoxText: text => { + textInput.text = text; + textInput.forceActiveFocus(); + textInput.cursorPosition = text.length; + } } function scrollToEnd() { @@ -832,11 +851,9 @@ Rectangle { clip: true z: 400 - property bool isHovered: { - return conversationTrayButton.isHovered || - resetContextButton.hovered || copyChatButton.hovered || - regenerateButton.hovered - } + property bool isHovered: ( + conversationTrayButton.isHovered || resetContextButton.hovered || copyChatButton.hovered + ) state: conversationTrayContent.isHovered ? "expanded" : "collapsed" states: [ @@ -892,11 +909,7 @@ Rectangle { source: "qrc:/gpt4all/icons/recycle.svg" imageWidth: 20 imageHeight: 20 - onClicked: { - Network.trackChatEvent("reset_context", { "length": chatModel.count }) - currentChat.reset(); - currentChat.processSystemPrompt(); - } + onClicked: resetContextDialog.open() ToolTip.visible: resetContextButton.hovered ToolTip.text: qsTr("Erase and reset chat session") } @@ -921,34 +934,6 @@ Rectangle { ToolTip.visible: copyChatButton.hovered ToolTip.text: qsTr("Copy chat session to clipboard") } - MyToolButton { - id: regenerateButton - Layout.preferredWidth: 40 - Layout.preferredHeight: 40 - source: "qrc:/gpt4all/icons/regenerate.svg" - imageWidth: 20 - imageHeight: 20 - visible: chatModel.count && !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress - onClicked: { - if (chatModel.count < 2) - return - var promptIndex = chatModel.count - 2 - var promptElement = chatModel.get(promptIndex) - var responseIndex = chatModel.count - 1 - var responseElement = chatModel.get(responseIndex) - if (promptElement.name !== "Prompt: " || responseElement.name !== "Response: ") - return - currentChat.regenerateResponse() - chatModel.updateCurrentResponse(responseIndex, true) - chatModel.updateStopped(responseIndex, false) - chatModel.updateThumbsUpState(responseIndex, false) - chatModel.updateThumbsDownState(responseIndex, false) - chatModel.updateNewResponse(responseIndex, "") - currentChat.prompt(promptElement.promptPlusAttachments) - } - ToolTip.visible: regenerateButton.hovered - ToolTip.text: qsTr("Redo last chat response") - } } } @@ -1026,13 +1011,15 @@ Rectangle { anchors.leftMargin: 30 horizontalAlignment: Qt.AlignRight verticalAlignment: Qt.AlignVCenter - color: theme.mutedTextColor - visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== "" + color: textInputView.error !== null ? theme.textErrorColor : theme.mutedTextColor + visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== "" || textInputView.error !== null elide: Text.ElideRight wrapMode: Text.WordWrap text: { if (externalHoveredLink !== "") return externalHoveredLink + if (textInputView.error !== null) + return textInputView.error; const segments = [currentChat.tokenSpeed]; const device = currentChat.device; @@ -1050,6 +1037,7 @@ Rectangle { } font.pixelSize: theme.fontSizeSmaller font.bold: true + onLinkActivated: function(link) { Qt.openUrlExternally(link) } } RectangularGlow { @@ -1079,8 +1067,8 @@ Rectangle { Rectangle { id: textInputView color: theme.controlBackground - border.width: 1 - border.color: theme.controlBorder + border.width: error === null ? 1 : 2 + border.color: error === null ? theme.controlBorder : theme.textErrorColor radius: 10 anchors.left: parent.left anchors.right: parent.right @@ -1091,6 +1079,41 @@ Rectangle { height: textInputViewLayout.implicitHeight visible: !currentChat.isServer && ModelList.selectableModels.count !== 0 + property var error: null + function checkError() { + const info = currentModelInfo; + if (info === null || !info.id) { + error = null; + } else if (info.chatTemplate.isLegacy) { + error = qsTr("Legacy prompt template needs to be " + + "updated" + + " in Settings."); + } else if (!info.chatTemplate.isSet) { + error = qsTr("No " + + "chat template configured."); + } else if (/^\s*$/.test(info.chatTemplate.value)) { + error = qsTr("The " + + "chat template cannot be blank."); + } else if (info.systemMessage.isLegacy) { + error = qsTr("Legacy system prompt needs to be " + + "updated" + + " in Settings."); + } else + error = null; + } + Component.onCompleted: checkError() + Connections { + target: window + function onCurrentModelIdChanged() { textInputView.checkError(); } + } + Connections { + target: MySettings + function onChatTemplateChanged(info) + { if (info.id === window.currentModelId) textInputView.checkError(); } + function onSystemMessageChanged(info) + { if (info.id === window.currentModelId) textInputView.checkError(); } + } + MouseArea { id: textInputViewMouseArea anchors.fill: parent @@ -1214,16 +1237,16 @@ Rectangle { Accessible.role: Accessible.EditableText Accessible.name: placeholderText Accessible.description: qsTr("Send messages/prompts to the model") - Keys.onReturnPressed: (event)=> { - if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) - event.accepted = false; - else { - editingFinished(); - sendMessage() - } - } + Keys.onReturnPressed: event => { + if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) { + event.accepted = false; + } else if (!chatModel.hasError && textInputView.error === null) { + editingFinished(); + sendMessage(); + } + } function sendMessage() { - if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress || currentChat.restoringFromText) + if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress) return currentChat.stopGenerating() @@ -1338,6 +1361,7 @@ Rectangle { imageWidth: theme.fontSizeLargest imageHeight: theme.fontSizeLargest visible: !currentChat.responseInProgress && !currentChat.isServer && ModelList.selectableModels.count !== 0 + enabled: !chatModel.hasError && textInputView.error === null source: "qrc:/gpt4all/icons/send_message.svg" Accessible.name: qsTr("Send message") Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model") diff --git a/gpt4all-chat/qml/ConfirmationDialog.qml b/gpt4all-chat/qml/ConfirmationDialog.qml new file mode 100644 index 000000000000..4220245320c8 --- /dev/null +++ b/gpt4all-chat/qml/ConfirmationDialog.qml @@ -0,0 +1,59 @@ +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts + +MyDialog { + id: confirmationDialog + anchors.centerIn: parent + modal: true + padding: 20 + property alias dialogTitle: titleText.text + property alias description: descriptionText.text + + Theme { id: theme } + + contentItem: ColumnLayout { + Text { + id: titleText + Layout.alignment: Qt.AlignHCenter + textFormat: Text.StyledText + color: theme.textColor + font.pixelSize: theme.fontSizeLarger + font.bold: true + } + + Text { + id: descriptionText + Layout.alignment: Qt.AlignHCenter + textFormat: Text.StyledText + color: theme.textColor + font.pixelSize: theme.fontSizeMedium + } + } + + footer: DialogButtonBox { + id: dialogBox + padding: 20 + alignment: Qt.AlignRight + spacing: 10 + MySettingsButton { + text: qsTr("OK") + textColor: theme.mediumButtonText + backgroundColor: theme.mediumButtonBackground + backgroundColorHovered: theme.mediumButtonBackgroundHovered + DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole + } + MySettingsButton { + text: qsTr("Cancel") + DialogButtonBox.buttonRole: DialogButtonBox.RejectRole + } + background: Rectangle { + color: "transparent" + } + Keys.onEnterPressed: confirmationDialog.accept() + Keys.onReturnPressed: confirmationDialog.accept() + } + Component.onCompleted: dialogBox.forceActiveFocus() +} diff --git a/gpt4all-chat/qml/LocalDocsSettings.qml b/gpt4all-chat/qml/LocalDocsSettings.qml index a7ea5b75eb41..95124c9c822d 100644 --- a/gpt4all-chat/qml/LocalDocsSettings.qml +++ b/gpt4all-chat/qml/LocalDocsSettings.qml @@ -10,7 +10,7 @@ import mysettings import network MySettingsTab { - onRestoreDefaultsClicked: { + onRestoreDefaults: { MySettings.restoreLocalDocsDefaults(); } diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index 2435e08f8b5d..62906440c936 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -8,10 +8,34 @@ import mysettings import chatlistmodel MySettingsTab { - onRestoreDefaultsClicked: { + onRestoreDefaults: { MySettings.restoreModelDefaults(root.currentModelInfo); } title: qsTr("Model") + + ConfirmationDialog { + id: resetSystemMessageDialog + property var index: null + property bool resetClears: false + dialogTitle: qsTr("%1 system message?").arg(resetClears ? qsTr("Clear") : qsTr("Reset")) + description: qsTr("The system message will be %1.").arg(resetClears ? qsTr("removed") : qsTr("reset to the default")) + onAccepted: MySettings.resetModelSystemMessage(ModelList.modelInfo(index)) + function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); } + } + + ConfirmationDialog { + id: resetChatTemplateDialog + property bool resetClears: false + property var index: null + dialogTitle: qsTr("%1 chat template?").arg(resetClears ? qsTr("Clear") : qsTr("Reset")) + description: qsTr("The chat template will be %1.").arg(resetClears ? qsTr("erased") : qsTr("reset to the default")) + onAccepted: { + MySettings.resetModelChatTemplate(ModelList.modelInfo(index)); + templateTextArea.resetText(); + } + function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); } + } + contentItem: GridLayout { id: root columns: 3 @@ -35,6 +59,7 @@ MySettingsTab { RowLayout { Layout.fillWidth: true + Layout.maximumWidth: parent.width Layout.row: 2 Layout.column: 0 Layout.columnSpan: 2 @@ -153,69 +178,154 @@ MySettingsTab { Layout.fillWidth: true } - MySettingsLabel { - visible: !root.currentModelInfo.isOnline - text: qsTr("System Prompt") - helpText: qsTr("Prefixed at the beginning of every conversation. Must contain the appropriate framing tokens.") + RowLayout { Layout.row: 7 - Layout.column: 0 + Layout.columnSpan: 2 Layout.topMargin: 15 + Layout.fillWidth: true + Layout.maximumWidth: parent.width + spacing: 10 + MySettingsLabel { + id: systemMessageLabel + text: qsTr("System Message") + helpText: qsTr("A message to set the context or guide the behavior of the model. Leave blank for " + + "none. NOTE: Since GPT4All 3.5, this should not contain control tokens.") + onReset: () => resetSystemMessageDialog.show(root.currentModelId, resetClears) + function updateResetButton() { + const info = root.currentModelInfo; + // NOTE: checks if the *override* is set, regardless of whether there is a default + canReset = !!info.id && MySettings.isModelSystemMessageSet(info); + resetClears = !info.defaultSystemMessage; + } + Component.onCompleted: updateResetButton() + Connections { + target: root + function onCurrentModelIdChanged() { systemMessageLabel.updateResetButton(); } + } + Connections { + target: MySettings + function onSystemMessageChanged(info) + { if (info.id === root.currentModelId) systemMessageLabel.updateResetButton(); } + } + } + Label { + id: systemMessageLabelHelp + visible: systemMessageArea.errState !== "ok" + Layout.alignment: Qt.AlignBottom + Layout.fillWidth: true + Layout.rightMargin: 5 + Layout.maximumHeight: systemMessageLabel.height + text: qsTr("System message is not " + + "plain text.") + color: systemMessageArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor + font.pixelSize: theme.fontSizeLarger + font.bold: true + wrapMode: Text.Wrap + elide: Text.ElideRight + onLinkActivated: function(link) { Qt.openUrlExternally(link) } + } } Rectangle { - id: systemPrompt - visible: !root.currentModelInfo.isOnline + id: systemMessage Layout.row: 8 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true color: "transparent" - Layout.minimumHeight: Math.max(100, systemPromptArea.contentHeight + 20) + Layout.minimumHeight: Math.max(100, systemMessageArea.contentHeight + 20) MyTextArea { - id: systemPromptArea + id: systemMessageArea anchors.fill: parent - text: root.currentModelInfo.systemPrompt + property bool isBeingReset: false + function resetText() { + const info = root.currentModelInfo; + isBeingReset = true; + text = (info.id ? info.systemMessage.value : null) ?? ""; + isBeingReset = false; + } + Component.onCompleted: resetText() Connections { target: MySettings - function onSystemPromptChanged() { - systemPromptArea.text = root.currentModelInfo.systemPrompt; - } + function onSystemMessageChanged(info) + { if (info.id === root.currentModelId) systemMessageArea.resetText(); } } Connections { target: root - function onCurrentModelInfoChanged() { - systemPromptArea.text = root.currentModelInfo.systemPrompt; - } + function onCurrentModelIdChanged() { systemMessageArea.resetText(); } } + // strict validation, because setModelSystemMessage clears isLegacy + readonly property var reLegacyCheck: ( + /(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<>/m + ) onTextChanged: { - MySettings.setModelSystemPrompt(root.currentModelInfo, text) + const info = root.currentModelInfo; + if (!info.id) { + errState = "ok"; + } else if (info.systemMessage.isLegacy && (isBeingReset || reLegacyCheck.test(text))) { + errState = "error"; + } else + errState = reLegacyCheck.test(text) ? "warning" : "ok"; + if (info.id && errState !== "error" && !isBeingReset) + MySettings.setModelSystemMessage(info, text); + systemMessageLabel.updateResetButton(); } Accessible.role: Accessible.EditableText + Accessible.name: systemMessageLabel.text + Accessible.description: systemMessageLabelHelp.text } } RowLayout { Layout.row: 9 - Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 + Layout.fillWidth: true + Layout.maximumWidth: parent.width spacing: 10 MySettingsLabel { - id: promptTemplateLabel - text: qsTr("Prompt Template") - helpText: qsTr("The template that wraps every prompt.") + id: chatTemplateLabel + text: qsTr("Chat Template") + helpText: qsTr("This Jinja template turns the chat into input for the model.") + onReset: () => resetChatTemplateDialog.show(root.currentModelId, resetClears) + function updateResetButton() { + const info = root.currentModelInfo; + canReset = !!info.id && ( + MySettings.isModelChatTemplateSet(info) + || templateTextArea.text !== (info.chatTemplate.value ?? "") + ); + resetClears = !info.defaultChatTemplate; + } + Component.onCompleted: updateResetButton() + Connections { + target: root + function onCurrentModelIdChanged() { chatTemplateLabel.updateResetButton(); } + } + Connections { + target: MySettings + function onChatTemplateChanged(info) + { if (info.id === root.currentModelId) chatTemplateLabel.updateResetButton(); } + } } - MySettingsLabel { - id: promptTemplateLabelHelp - text: qsTr("Must contain the string \"%1\" to be replaced with the user's input.") - color: theme.textErrorColor - visible: templateTextArea.text.indexOf("%1") === -1 - wrapMode: TextArea.Wrap + Label { + id: chatTemplateLabelHelp + visible: templateTextArea.errState !== "ok" + Layout.alignment: Qt.AlignBottom + Layout.fillWidth: true + Layout.rightMargin: 5 + Layout.maximumHeight: chatTemplateLabel.height + text: templateTextArea.errMsg + color: templateTextArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor + font.pixelSize: theme.fontSizeLarger + font.bold: true + wrapMode: Text.Wrap + elide: Text.ElideRight + onLinkActivated: function(link) { Qt.openUrlExternally(link) } } } Rectangle { - id: promptTemplate + id: chatTemplate Layout.row: 10 Layout.column: 0 Layout.columnSpan: 2 @@ -226,27 +336,71 @@ MySettingsTab { MyTextArea { id: templateTextArea anchors.fill: parent - text: root.currentModelInfo.promptTemplate + font: fixedFont + property bool isBeingReset: false + property var errMsg: null + function resetText() { + const info = root.currentModelInfo; + isBeingReset = true; + text = (info.id ? info.chatTemplate.value : null) ?? ""; + isBeingReset = false; + } + Component.onCompleted: resetText() Connections { target: MySettings - function onPromptTemplateChanged() { - templateTextArea.text = root.currentModelInfo.promptTemplate; - } + function onChatTemplateChanged() { templateTextArea.resetText(); } } Connections { target: root - function onCurrentModelInfoChanged() { - templateTextArea.text = root.currentModelInfo.promptTemplate; - } + function onCurrentModelIdChanged() { templateTextArea.resetText(); } + } + function legacyCheck() { + return /%[12]\b/.test(text) || !/\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}/.test(text.replace(/\n/g, '')) + || !/\bcontent\b/.test(text); } onTextChanged: { - if (templateTextArea.text.indexOf("%1") !== -1) { - MySettings.setModelPromptTemplate(root.currentModelInfo, text) + const info = root.currentModelInfo; + let jinjaError; + if (!info.id) { + errMsg = null; + errState = "ok"; + } else if (info.chatTemplate.isLegacy && (isBeingReset || legacyCheck())) { + errMsg = null; + errState = "error"; + } else if (text === "" && !info.chatTemplate.isSet) { + errMsg = qsTr("No " + + "chat template configured."); + errState = "error"; + } else if (/^\s*$/.test(text)) { + errMsg = qsTr("The " + + "chat template cannot be blank."); + errState = "error"; + } else if ((jinjaError = MySettings.checkJinjaTemplateError(text)) !== null) { + errMsg = qsTr("Syntax" + + " error: %1").arg(jinjaError); + errState = "error"; + } else if (legacyCheck()) { + errMsg = qsTr("Chat template is not in " + + "" + + "Jinja format.") + errState = "warning"; + } else { + errState = "ok"; + } + if (info.id && errState !== "error" && !isBeingReset) + MySettings.setModelChatTemplate(info, text); + chatTemplateLabel.updateResetButton(); + } + Keys.onPressed: event => { + if (event.key === Qt.Key_Tab) { + const a = templateTextArea; + event.accepted = true; // suppress tab + a.insert(a.cursorPosition, ' '); // four spaces } } Accessible.role: Accessible.EditableText - Accessible.name: promptTemplateLabel.text - Accessible.description: promptTemplateLabelHelp.text + Accessible.name: chatTemplateLabel.text + Accessible.description: chatTemplateLabelHelp.text } } diff --git a/gpt4all-chat/qml/MySettingsButton.qml b/gpt4all-chat/qml/MySettingsButton.qml index 18de21afbbad..218a329c2f57 100644 --- a/gpt4all-chat/qml/MySettingsButton.qml +++ b/gpt4all-chat/qml/MySettingsButton.qml @@ -17,6 +17,7 @@ Button { property color borderColor: "transparent" property real fontPixelSize: theme.fontSizeLarge property string toolTip + property alias backgroundRadius: background.radius contentItem: Text { text: myButton.text @@ -28,6 +29,7 @@ Button { Accessible.name: text } background: Rectangle { + id: background radius: 10 border.width: borderWidth border.color: borderColor diff --git a/gpt4all-chat/qml/MySettingsLabel.qml b/gpt4all-chat/qml/MySettingsLabel.qml index 282bdc7332d3..2f0ba3c606b4 100644 --- a/gpt4all-chat/qml/MySettingsLabel.qml +++ b/gpt4all-chat/qml/MySettingsLabel.qml @@ -17,13 +17,42 @@ ColumnLayout { property alias color: mainTextLabel.color property alias linkColor: mainTextLabel.linkColor - Label { - id: mainTextLabel - color: theme.settingsTitleTextColor - font.pixelSize: theme.fontSizeLarger - font.bold: true - onLinkActivated: function(link) { - root.linkActivated(link); + property var onReset: null + property alias canReset: resetButton.enabled + property bool resetClears: false + + Item { + anchors.margins: 5 + width: childrenRect.width + height: mainTextLabel.contentHeight + + Label { + id: mainTextLabel + anchors.left: parent.left + anchors.top: parent.top + anchors.bottom: parent.bottom + color: theme.settingsTitleTextColor + font.pixelSize: theme.fontSizeLarger + font.bold: true + verticalAlignment: Text.AlignVCenter + onLinkActivated: function(link) { + root.linkActivated(link); + } + } + + MySettingsButton { + id: resetButton + anchors.baseline: mainTextLabel.baseline + anchors.left: mainTextLabel.right + height: mainTextLabel.contentHeight + anchors.leftMargin: 10 + padding: 2 + leftPadding: 10 + rightPadding: 10 + backgroundRadius: 5 + text: resetClears ? qsTr("Clear") : qsTr("Reset") + visible: root.onReset !== null + onClicked: root.onReset() } } Label { diff --git a/gpt4all-chat/qml/MySettingsTab.qml b/gpt4all-chat/qml/MySettingsTab.qml index 98ed402ec666..41657f0b7c87 100644 --- a/gpt4all-chat/qml/MySettingsTab.qml +++ b/gpt4all-chat/qml/MySettingsTab.qml @@ -9,7 +9,7 @@ Item { property string title: "" property Item contentItem: null property bool showRestoreDefaultsButton: true - signal restoreDefaultsClicked + signal restoreDefaults onContentItemChanged: function() { if (contentItem) { @@ -19,6 +19,13 @@ Item { } } + ConfirmationDialog { + id: restoreDefaultsDialog + dialogTitle: qsTr("Restore defaults?") + description: qsTr("This page of settings will be reset to the defaults.") + onAccepted: root.restoreDefaults() + } + ScrollView { id: scrollView width: parent.width @@ -47,6 +54,7 @@ Item { Column { id: contentInner Layout.fillWidth: true + Layout.maximumWidth: parent.width } Item { @@ -63,9 +71,7 @@ Item { Accessible.role: Accessible.Button Accessible.name: text Accessible.description: qsTr("Restores settings dialog to a default state") - onClicked: { - root.restoreDefaultsClicked(); - } + onClicked: restoreDefaultsDialog.open() } } } diff --git a/gpt4all-chat/qml/MyTextArea.qml b/gpt4all-chat/qml/MyTextArea.qml index e0894e9fbb0f..bace1f26dc04 100644 --- a/gpt4all-chat/qml/MyTextArea.qml +++ b/gpt4all-chat/qml/MyTextArea.qml @@ -5,18 +5,27 @@ import QtQuick.Controls.Basic TextArea { id: myTextArea + + property string errState: "ok" // one of "ok", "error", "warning" + color: enabled ? theme.textColor : theme.mutedTextColor placeholderTextColor: theme.mutedTextColor font.pixelSize: theme.fontSizeLarge background: Rectangle { implicitWidth: 150 color: theme.controlBackground - border.width: 1 - border.color: theme.controlBorder + border.width: errState === "ok" ? 1 : 2 + border.color: { + switch (errState) { + case "ok": return theme.controlBorder; + case "warning": return theme.textWarningColor; + case "error": return theme.textErrorColor; + } + } radius: 10 } padding: 10 wrapMode: TextArea.Wrap ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval -} \ No newline at end of file +} diff --git a/gpt4all-chat/qml/MyToolButton.qml b/gpt4all-chat/qml/MyToolButton.qml index b9d8c544aa1e..f62af5a0516d 100644 --- a/gpt4all-chat/qml/MyToolButton.qml +++ b/gpt4all-chat/qml/MyToolButton.qml @@ -16,6 +16,7 @@ Button { property alias fillMode: image.fillMode property alias imageWidth: image.sourceSize.width property alias imageHeight: image.sourceSize.height + property alias bgTransform: background.transform contentItem: Text { text: myButton.text horizontalAlignment: Text.AlignHCenter @@ -26,6 +27,7 @@ Button { } background: Item { + id: background anchors.fill: parent Rectangle { anchors.fill: parent diff --git a/gpt4all-chat/qml/SwitchModelDialog.qml b/gpt4all-chat/qml/SwitchModelDialog.qml deleted file mode 100644 index f0ca43abbc24..000000000000 --- a/gpt4all-chat/qml/SwitchModelDialog.qml +++ /dev/null @@ -1,46 +0,0 @@ -import QtCore -import QtQuick -import QtQuick.Controls -import QtQuick.Controls.Basic -import QtQuick.Layouts -import llm -import mysettings - -MyDialog { - id: switchModelDialog - anchors.centerIn: parent - modal: true - padding: 20 - property int index: -1 - - Theme { - id: theme - } - - contentItem: Text { - textFormat: Text.StyledText - text: qsTr("Warning: changing the model will erase the current conversation. Do you wish to continue?") - color: theme.textColor - font.pixelSize: theme.fontSizeLarge - } - - footer: DialogButtonBox { - id: dialogBox - padding: 20 - alignment: Qt.AlignRight - spacing: 10 - MySettingsButton { - text: qsTr("Continue") - Accessible.description: qsTr("Continue with model loading") - DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole - } - MySettingsButton { - text: qsTr("Cancel") - Accessible.description: qsTr("Cancel") - DialogButtonBox.buttonRole: DialogButtonBox.RejectRole - } - background: Rectangle { - color: "transparent" - } - } -} diff --git a/gpt4all-chat/qml/Theme.qml b/gpt4all-chat/qml/Theme.qml index 245a4473b0e8..e2675820449a 100644 --- a/gpt4all-chat/qml/Theme.qml +++ b/gpt4all-chat/qml/Theme.qml @@ -64,6 +64,9 @@ QtObject { property color green800: Qt.hsla(123/360, 0.17, 0.24) property color green900: Qt.hsla(124/360, 0.17, 0.20) property color green950: Qt.hsla(125/360, 0.22, 0.10) + property color green300_sat: Qt.hsla(122/360, 0.24, 0.73) + property color green400_sat: Qt.hsla(122/360, 0.23, 0.58) + property color green450_sat: Qt.hsla(122/360, 0.23, 0.52) // yellow property color yellow0: Qt.hsla(47/360, 0.90, 0.99) @@ -99,6 +102,7 @@ QtObject { property color purple200: Qt.hsla(279/360, 1.0, 0.91) property color purple300: Qt.hsla(279/360, 1.0, 0.84) property color purple400: Qt.hsla(279/360, 1.0, 0.73) + property color purple450: Qt.hsla(279/360, 1.0, 0.68) property color purple500: Qt.hsla(279/360, 1.0, 0.63) property color purple600: Qt.hsla(279/360, 1.0, 0.53) property color purple700: Qt.hsla(279/360, 1.0, 0.47) @@ -408,6 +412,39 @@ QtObject { } } + property color mediumButtonBackground: { + switch (MySettings.chatTheme) { + case MySettingsEnums.ChatTheme.LegacyDark: + return purple400 + case MySettingsEnums.ChatTheme.Dark: + return green400_sat + default: + return green400_sat + } + } + + property color mediumButtonBackgroundHovered: { + switch (MySettings.chatTheme) { + case MySettingsEnums.ChatTheme.LegacyDark: + return purple450 + case MySettingsEnums.ChatTheme.Dark: + return green450_sat + default: + return green300_sat + } + } + + property color mediumButtonText: { + switch (MySettings.chatTheme) { + case MySettingsEnums.ChatTheme.LegacyDark: + return textColor + case MySettingsEnums.ChatTheme.Dark: + return textColor + default: + return white + } + } + property color darkButtonText: { switch (MySettings.chatTheme) { case MySettingsEnums.ChatTheme.LegacyDark: @@ -922,16 +959,8 @@ QtObject { } } - property color textErrorColor: { - switch (MySettings.chatTheme) { - case MySettingsEnums.ChatTheme.LegacyDark: - return red400 - case MySettingsEnums.ChatTheme.Dark: - return red400 - default: - return red400 - } - } + readonly property color textErrorColor: red400 + readonly property color textWarningColor: yellow400 property color settingsTitleTextColor: { switch (MySettings.chatTheme) { diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index dc6abd0621d9..c40bb96ed34b 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -1,7 +1,6 @@ #include "chat.h" #include "chatlistmodel.h" -#include "mysettings.h" #include "network.h" #include "server.h" @@ -11,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -56,18 +54,18 @@ void Chat::connectLLM() // Should be in different threads connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::responseFailed, this, &Chat::handleResponseFailed, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); @@ -75,11 +73,10 @@ void Chat::connectLLM() connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection); - connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection); - connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection); - connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection); connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); + + connect(ModelList::globalInstance(), &ModelList::modelInfoChanged, this, &Chat::handleModelInfoChanged); } void Chat::reset() @@ -87,28 +84,17 @@ void Chat::reset() stopGenerating(); // Erase our current on disk representation as we're completely resetting the chat along with id ChatListModel::globalInstance()->removeChatFile(this); - emit resetContextRequested(); m_id = Network::globalInstance()->generateUniqueId(); emit idChanged(m_id); // NOTE: We deliberately do no reset the name or creation date to indicate that this was originally // an older chat that was reset for another purpose. Resetting this data will lead to the chat // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' // further down in the list. This might surprise the user. In the future, we might get rid of - // the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown - // we effectively do a reset context. We *have* to do this right now when switching between different - // types of models. The only way to get rid of that would be a very long recalculate where we rebuild - // the context if we switch between different types of models. Probably the right way to fix this - // is to allow switching models but throwing up a dialog warning users if we switch between types - // of models that a long recalculation will ensue. + // the "reset context" button in the UI. m_chatModel->clear(); m_needsSave = true; } -void Chat::processSystemPrompt() -{ - emit processSystemPromptRequested(); -} - void Chat::resetResponseState() { if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) @@ -160,25 +146,30 @@ void Chat::newPromptResponsePair(const QString &prompt, const QList &attac if (!attachedContexts.isEmpty()) promptPlusAttached = attachedContexts.join("\n\n") + "\n\n" + prompt; - newPromptResponsePairInternal(prompt, attachments); - emit resetResponseRequested(); + resetResponseState(); + qsizetype prevMsgIndex = m_chatModel->count() - 1; + if (prevMsgIndex >= 0) + m_chatModel->updateCurrentResponse(prevMsgIndex, false); + m_chatModel->appendPrompt(prompt, attachments); + m_chatModel->appendResponse(prevMsgIndex + 1); - this->prompt(promptPlusAttached); + emit promptRequested(m_collections); + m_needsSave = true; } -void Chat::prompt(const QString &prompt) +void Chat::regenerateResponse(int index) { resetResponseState(); - emit promptRequested(m_collections, prompt); + emit regenerateResponseRequested(index); m_needsSave = true; } -void Chat::regenerateResponse() +QVariant Chat::popPrompt(int index) { - const int index = m_chatModel->count() - 1; - m_chatModel->updateSources(index, QList()); - emit regenerateResponseRequested(); + auto content = m_llmodel->popPrompt(index); m_needsSave = true; + if (content) return *content; + return QVariant::fromValue(nullptr); } void Chat::stopGenerating() @@ -202,6 +193,14 @@ void Chat::handleResponseChanged(const QString &response) m_chatModel->updateValue(index, response); } +void Chat::handleResponseFailed(const QString &error) +{ + const int index = m_chatModel->count() - 1; + m_chatModel->updateValue(index, error); + m_chatModel->setError(); + responseStopped(0); +} + void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) { if (m_shouldDeleteLater) @@ -272,25 +271,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo) emit modelChangeRequested(modelInfo); } -// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread -void Chat::serverNewPromptResponsePair(const QString &prompt, const QList &attachments) -{ - newPromptResponsePairInternal(prompt, attachments); -} - -void Chat::newPromptResponsePairInternal(const QString &prompt, const QList &attachments) -{ - resetResponseState(); - m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); - m_chatModel->appendPrompt("Prompt: ", prompt, attachments); - m_chatModel->appendResponse("Response: "); -} - -bool Chat::restoringFromText() const -{ - return m_llmodel->restoringFromText(); -} - void Chat::unloadAndDeleteLater() { if (!isModelLoaded()) { @@ -356,12 +336,6 @@ void Chat::generatedQuestionFinished(const QString &question) m_needsSave = true; } -void Chat::handleRestoringFromText() -{ - Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} }); - emit restoringFromTextChanged(); -} - void Chat::handleModelLoadingError(const QString &error) { if (!error.isEmpty()) { @@ -396,12 +370,19 @@ QString Chat::fallbackReason() const void Chat::handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; - const int index = m_chatModel->count() - 1; - m_chatModel->updateSources(index, m_databaseResults); m_needsSave = true; } +// we need to notify listeners of the modelInfo property when its properties are updated, +// since it's a gadget and can't do that on its own void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) +{ + if (!m_modelInfo.id().isNull() && modelInfo.id() == m_modelInfo.id()) + emit modelInfoChanged(); +} + +// react if a new model is loaded +void Chat::handleModelChanged(const ModelInfo &modelInfo) { if (m_modelInfo == modelInfo) return; @@ -430,10 +411,7 @@ bool Chat::serialize(QDataStream &stream, int version) const if (version >= 3) stream << m_collections; - const bool serializeKV = MySettings::globalInstance()->saveChatsContext(); - if (version >= 6) - stream << serializeKV; - if (!m_llmodel->serialize(stream, version, serializeKV)) + if (!m_llmodel->serialize(stream, version)) return false; if (!m_chatModel->serialize(stream, version)) return false; @@ -462,19 +440,13 @@ bool Chat::deserialize(QDataStream &stream, int version) if (!m_modelInfo.id().isEmpty()) emit modelInfoChanged(); - bool discardKV = m_modelInfo.id().isEmpty(); - if (version >= 3) { stream >> m_collections; emit collectionListChanged(m_collections); } - bool deserializeKV = true; - if (version >= 6) - stream >> deserializeKV; - m_llmodel->setModelInfo(m_modelInfo); - if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV)) + if (!m_llmodel->deserialize(stream, version)) return false; if (!m_chatModel->deserialize(stream, version)) return false; diff --git a/gpt4all-chat/src/chat.h b/gpt4all-chat/src/chat.h index 245bc018d40c..57e413e5873d 100644 --- a/gpt4all-chat/src/chat.h +++ b/gpt4all-chat/src/chat.h @@ -12,6 +12,8 @@ #include #include #include +#include // IWYU pragma: keep +#include #include class QDataStream; @@ -27,7 +29,6 @@ class Chat : public QObject Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged) Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) - Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged) Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged) Q_PROPERTY(QList collectionList READ collectionList NOTIFY collectionListChanged) @@ -77,13 +78,12 @@ class Chat : public QObject bool isNewChat() const { return m_name == tr("New Chat") && !m_chatModel->count(); } Q_INVOKABLE void reset(); - Q_INVOKABLE void processSystemPrompt(); bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; } bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; } float modelLoadingPercentage() const { return m_modelLoadingPercentage; } Q_INVOKABLE void newPromptResponsePair(const QString &prompt, const QList &attachedUrls = {}); - Q_INVOKABLE void prompt(const QString &prompt); - Q_INVOKABLE void regenerateResponse(); + Q_INVOKABLE void regenerateResponse(int index); + Q_INVOKABLE QVariant popPrompt(int index); Q_INVOKABLE void stopGenerating(); QList databaseResults() const { return m_databaseResults; } @@ -92,7 +92,6 @@ class Chat : public QObject ResponseState responseState() const; ModelInfo modelInfo() const; void setModelInfo(const ModelInfo &modelInfo); - bool restoringFromText() const; Q_INVOKABLE void unloadModel(); Q_INVOKABLE void reloadModel(); @@ -113,7 +112,6 @@ class Chat : public QObject Q_INVOKABLE bool hasCollection(const QString &collection) const; Q_INVOKABLE void addCollection(const QString &collection); Q_INVOKABLE void removeCollection(const QString &collection); - void resetResponseState(); QString modelLoadingError() const { return m_modelLoadingError; } @@ -131,7 +129,7 @@ class Chat : public QObject void setNeedsSave(bool n) { m_needsSave = n; } public Q_SLOTS: - void serverNewPromptResponsePair(const QString &prompt, const QList &attachments = {}); + void resetResponseState(); Q_SIGNALS: void idChanged(const QString &id); @@ -143,14 +141,12 @@ public Q_SLOTS: void modelLoadingWarning(const QString &warning); void responseInProgressChanged(); void responseStateChanged(); - void promptRequested(const QList &collectionList, const QString &prompt); - void regenerateResponseRequested(); + void promptRequested(const QStringList &enabledCollections); + void regenerateResponseRequested(int index); void resetResponseRequested(); void resetContextRequested(); - void processSystemPromptRequested(); void modelChangeRequested(const ModelInfo &modelInfo); void modelInfoChanged(); - void restoringFromTextChanged(); void loadDefaultModelRequested(); void generateNameRequested(); void modelLoadingErrorChanged(); @@ -166,22 +162,20 @@ public Q_SLOTS: private Q_SLOTS: void handleResponseChanged(const QString &response); + void handleResponseFailed(const QString &error); void handleModelLoadingPercentageChanged(float); void promptProcessing(); void generatingQuestions(); void responseStopped(qint64 promptResponseMs); void generatedNameChanged(const QString &name); void generatedQuestionFinished(const QString &question); - void handleRestoringFromText(); void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); void handleDatabaseResultsChanged(const QList &results); void handleModelInfoChanged(const ModelInfo &modelInfo); + void handleModelChanged(const ModelInfo &modelInfo); void handleTrySwitchContextOfLoadedModelCompleted(int value); -private: - void newPromptResponsePairInternal(const QString &prompt, const QList &attachments); - private: QString m_id; QString m_name; diff --git a/gpt4all-chat/src/chatapi.cpp b/gpt4all-chat/src/chatapi.cpp index 27f64f0d6730..5164cac32169 100644 --- a/gpt4all-chat/src/chatapi.cpp +++ b/gpt4all-chat/src/chatapi.cpp @@ -1,10 +1,10 @@ #include "chatapi.h" -#include +#include "utils.h" #include -#include #include +#include #include #include #include @@ -13,12 +13,17 @@ #include #include #include +#include #include +#include #include #include #include +#include +#include #include +#include using namespace Qt::Literals::StringLiterals; @@ -67,71 +72,119 @@ bool ChatAPI::isModelLoaded() const return true; } -void ChatAPI::prompt(const std::string &prompt, - const std::string &promptTemplate, - std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx, - bool special, - std::optional fakeReply) { - - Q_UNUSED(promptCallback); - Q_UNUSED(allowContextShift); - Q_UNUSED(special); - - if (!isModelLoaded()) { - std::cerr << "ChatAPI ERROR: prompt won't work with an unloaded model!\n"; - return; - } - - if (!promptCtx.n_past) { m_queuedPrompts.clear(); } - Q_ASSERT(promptCtx.n_past <= m_context.size()); - m_context.resize(promptCtx.n_past); - - // FIXME(cebtenzzre): We're assuming people don't try to use %2 with ChatGPT. What would that even mean? - m_queuedPrompts << QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)); +static auto parsePrompt(QXmlStreamReader &xml) -> std::expected +{ + QJsonArray messages; - if (!promptCtx.n_predict && !fakeReply) { - return; // response explicitly suppressed, queue prompt for later + auto xmlError = [&xml] { + return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString())); + }; + + if (xml.hasError()) + return xmlError(); + if (xml.atEnd()) + return messages; + + // skip header + bool foundElement = false; + do { + switch (xml.readNext()) { + using enum QXmlStreamReader::TokenType; + case Invalid: + return xmlError(); + case EndDocument: + return messages; + default: + foundElement = true; + case StartDocument: + case Comment: + case DTD: + case ProcessingInstruction: + ; + } + } while (!foundElement); + + // document body loop + bool foundRoot = false; + for (;;) { + switch (xml.tokenType()) { + using enum QXmlStreamReader::TokenType; + case StartElement: + { + auto name = xml.name(); + if (!foundRoot) { + if (name != "chat"_L1) + return std::unexpected(u"unexpected tag: %1"_s.arg(name)); + foundRoot = true; + } else { + if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1) + return std::unexpected(u"unknown role: %1"_s.arg(name)); + auto content = xml.readElementText(); + if (xml.tokenType() != EndElement) + return xmlError(); + messages << makeJsonObject({ + { "role"_L1, name.toString().trimmed() }, + { "content"_L1, content }, + }); + } + break; + } + case Characters: + if (!xml.isWhitespace()) + return std::unexpected(u"unexpected text: %1"_s.arg(xml.text())); + case Comment: + case ProcessingInstruction: + case EndElement: + break; + case EndDocument: + return messages; + case Invalid: + return xmlError(); + default: + return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString())); + } + xml.readNext(); } +} - QString formattedPrompt = m_queuedPrompts.join(""); - m_queuedPrompts.clear(); +void ChatAPI::prompt( + std::string_view prompt, + const PromptCallback &promptCallback, + const ResponseCallback &responseCallback, + const PromptContext &promptCtx +) { + Q_UNUSED(promptCallback) - if (fakeReply) { - promptCtx.n_past += 1; - m_context.append(formattedPrompt); - m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size())); - return; - } + if (!isModelLoaded()) + throw std::invalid_argument("Attempted to prompt an unloaded model."); + if (!promptCtx.n_predict) + return; // nothing requested // FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering // an error we need to be able to count the tokens in our prompt. The only way to do this is to use - // the OpenAI tiktokken library or to implement our own tokenization function that matches precisely + // the OpenAI tiktoken library or to implement our own tokenization function that matches precisely // the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of // using the REST API to count tokens in a prompt. - QJsonObject root; - root.insert("model", m_modelName); - root.insert("stream", true); - root.insert("temperature", promptCtx.temp); - root.insert("top_p", promptCtx.top_p); + auto root = makeJsonObject({ + { "model"_L1, m_modelName }, + { "stream"_L1, true }, + { "temperature"_L1, promptCtx.temp }, + { "top_p"_L1, promptCtx.top_p }, + }); // conversation history - QJsonArray messages; - for (int i = 0; i < m_context.count(); ++i) { - QJsonObject message; - message.insert("role", i % 2 == 0 ? "user" : "assistant"); - message.insert("content", m_context.at(i)); - messages.append(message); + { + QUtf8StringView promptUtf8(prompt); + QXmlStreamReader xml(promptUtf8); + auto messages = parsePrompt(xml); + if (!messages) { + auto error = fmt::format("Failed to parse API model prompt: {}", messages.error()); + qDebug().noquote() << "ChatAPI ERROR:" << error << "Prompt:\n\n" << promptUtf8 << '\n'; + throw std::invalid_argument(error); + } + root.insert("messages"_L1, *messages); } - QJsonObject promptObject; - promptObject.insert("role", "user"); - promptObject.insert("content", formattedPrompt); - messages.append(promptObject); - root.insert("messages", messages); - QJsonDocument doc(root); #if defined(DEBUG) @@ -148,12 +201,9 @@ void ChatAPI::prompt(const std::string &prompt, connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection); workerThread.start(); - emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact)); + emit request(m_apiKey, doc.toJson(QJsonDocument::Compact)); workerThread.wait(); - promptCtx.n_past += 1; - m_context.append(formattedPrompt); - m_context.append(worker.currentResponse()); m_responseCallback = nullptr; #if defined(DEBUG) @@ -171,12 +221,8 @@ bool ChatAPI::callResponse(int32_t token, const std::string& string) return m_responseCallback(token, string); } -void ChatAPIWorker::request(const QString &apiKey, - LLModel::PromptContext *promptCtx, - const QByteArray &array) +void ChatAPIWorker::request(const QString &apiKey, const QByteArray &array) { - m_ctx = promptCtx; - QUrl apiUrl(m_chat->url()); const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed(); QNetworkRequest request(apiUrl); @@ -283,7 +329,6 @@ void ChatAPIWorker::handleReadyRead() const QJsonObject choice = choices.first().toObject(); const QJsonObject delta = choice.value("delta").toObject(); const QString content = delta.value("content").toString(); - Q_ASSERT(m_ctx); m_currentResponse += content; if (!m_chat->callResponse(0, content.toStdString())) { reply->abort(); diff --git a/gpt4all-chat/src/chatapi.h b/gpt4all-chat/src/chatapi.h index f37a105d29f1..b763c32524b2 100644 --- a/gpt4all-chat/src/chatapi.h +++ b/gpt4all-chat/src/chatapi.h @@ -7,17 +7,14 @@ #include #include #include -#include -#include #include #include -#include -#include #include #include #include #include +#include #include class QNetworkAccessManager; @@ -28,16 +25,13 @@ class ChatAPIWorker : public QObject { public: ChatAPIWorker(ChatAPI *chatAPI) : QObject(nullptr) - , m_ctx(nullptr) , m_networkManager(nullptr) , m_chat(chatAPI) {} virtual ~ChatAPIWorker() {} QString currentResponse() const { return m_currentResponse; } - void request(const QString &apiKey, - LLModel::PromptContext *promptCtx, - const QByteArray &array); + void request(const QString &apiKey, const QByteArray &array); Q_SIGNALS: void finished(); @@ -49,7 +43,6 @@ private Q_SLOTS: private: ChatAPI *m_chat; - LLModel::PromptContext *m_ctx; QNetworkAccessManager *m_networkManager; QString m_currentResponse; }; @@ -74,14 +67,14 @@ class ChatAPI : public QObject, public LLModel { size_t restoreState(std::span state, std::span inputTokens) override { Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); } - void prompt(const std::string &prompt, - const std::string &promptTemplate, - std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &ctx, - bool special, - std::optional fakeReply) override; + void prompt(std::string_view prompt, + const PromptCallback &promptCallback, + const ResponseCallback &responseCallback, + const PromptContext &ctx) override; + + [[noreturn]] + int32_t countPromptTokens(std::string_view prompt) const override + { Q_UNUSED(prompt); throwNotImplemented(); } void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; @@ -91,19 +84,17 @@ class ChatAPI : public QObject, public LLModel { void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; } QString url() const { return m_requestURL; } - QList context() const { return m_context; } - void setContext(const QList &context) { m_context = context; } - bool callResponse(int32_t token, const std::string &string); [[noreturn]] int32_t contextLength() const override { throwNotImplemented(); } + auto specialTokens() -> std::unordered_map const override + { return {}; } + Q_SIGNALS: - void request(const QString &apiKey, - LLModel::PromptContext *ctx, - const QByteArray &array); + void request(const QString &apiKey, const QByteArray &array); protected: // We have to implement these as they are pure virtual in base class, but we don't actually use @@ -114,8 +105,8 @@ class ChatAPI : public QObject, public LLModel { static void throwNotImplemented() { throw std::logic_error("not implemented"); } [[noreturn]] - std::vector tokenize(std::string_view str, bool special) override - { Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); } + std::vector tokenize(std::string_view str) const override + { Q_UNUSED(str); throwNotImplemented(); } [[noreturn]] bool isSpecialToken(Token id) const override @@ -126,7 +117,7 @@ class ChatAPI : public QObject, public LLModel { { Q_UNUSED(id); throwNotImplemented(); } [[noreturn]] - void initSampler(PromptContext &ctx) override + void initSampler(const PromptContext &ctx) override { Q_UNUSED(ctx); throwNotImplemented(); } [[noreturn]] @@ -134,33 +125,28 @@ class ChatAPI : public QObject, public LLModel { { throwNotImplemented(); } [[noreturn]] - bool evalTokens(PromptContext &ctx, std::span tokens) const override - { Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); } + bool evalTokens(int32_t nPast, std::span tokens) const override + { Q_UNUSED(nPast); Q_UNUSED(tokens); throwNotImplemented(); } [[noreturn]] - void shiftContext(PromptContext &promptCtx) override - { Q_UNUSED(promptCtx); throwNotImplemented(); } + void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override + { Q_UNUSED(promptCtx); Q_UNUSED(nPast); throwNotImplemented(); } [[noreturn]] int32_t inputLength() const override { throwNotImplemented(); } [[noreturn]] - void setTokenizeInputPosition(int32_t pos) override - { Q_UNUSED(pos); throwNotImplemented(); } - - [[noreturn]] - auto computeModelInputPosition(PromptContext &ctx, const std::vector &input) - -> std::vector::const_iterator override - { Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); } + int32_t computeModelInputPosition(std::span input) const override + { Q_UNUSED(input); throwNotImplemented(); } [[noreturn]] - void setModelInputPosition(PromptContext &ctx, int32_t pos) override - { Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); } + void setModelInputPosition(int32_t pos) override + { Q_UNUSED(pos); throwNotImplemented(); } [[noreturn]] - void appendInputToken(PromptContext &ctx, Token tok) override - { Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); } + void appendInputToken(Token tok) override + { Q_UNUSED(tok); throwNotImplemented(); } [[noreturn]] const std::vector &endTokens() const override @@ -175,12 +161,10 @@ class ChatAPI : public QObject, public LLModel { { throwNotImplemented(); } private: - std::function m_responseCallback; - QString m_modelName; - QString m_apiKey; - QString m_requestURL; - QList m_context; - QStringList m_queuedPrompts; + ResponseCallback m_responseCallback; + QString m_modelName; + QString m_apiKey; + QString m_requestURL; }; #endif // CHATAPI_H diff --git a/gpt4all-chat/src/chatlistmodel.cpp b/gpt4all-chat/src/chatlistmodel.cpp index 207a2b3b7e8b..bf76ce4449ae 100644 --- a/gpt4all-chat/src/chatlistmodel.cpp +++ b/gpt4all-chat/src/chatlistmodel.cpp @@ -17,9 +17,10 @@ #include #include +#include -#define CHAT_FORMAT_MAGIC 0xF5D553CC -#define CHAT_FORMAT_VERSION 10 +static constexpr quint32 CHAT_FORMAT_MAGIC = 0xF5D553CC; +static constexpr qint32 CHAT_FORMAT_VERSION = 11; class MyChatListModel: public ChatListModel { }; Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) @@ -118,8 +119,8 @@ void ChatSaver::saveChats(const QVector &chats) } QDataStream out(&tempFile); - out << (quint32)CHAT_FORMAT_MAGIC; - out << (qint32)CHAT_FORMAT_VERSION; + out << CHAT_FORMAT_MAGIC; + out << CHAT_FORMAT_VERSION; out.setVersion(QDataStream::Qt_6_2); qDebug() << "serializing chat" << fileName; @@ -257,12 +258,15 @@ void ChatsRestoreThread::run() qDebug() << "deserializing chat" << f.file; - Chat *chat = new Chat; + auto chat = std::make_unique(); chat->moveToThread(qGuiApp->thread()); - if (!chat->deserialize(in, version)) { + bool ok = chat->deserialize(in, version); + if (!ok) { qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName(); + } else if (!in.atEnd()) { + qWarning().nospace() << "error loading chat from " << file.fileName() << ": extra data at end of file"; } else { - emit chatRestored(chat); + emit chatRestored(chat.release()); } if (f.oldFile) file.remove(); // No longer storing in this directory diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index e1cae8c8c1c9..7841b9460e99 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -3,10 +3,18 @@ #include "chat.h" #include "chatapi.h" #include "chatmodel.h" +#include "jinja_helpers.h" #include "localdocs.h" #include "mysettings.h" #include "network.h" +#include + +#include +#include +#include +#include + #include #include #include @@ -17,34 +25,71 @@ #include #include #include -#include +#include // IWYU pragma: keep #include +#include #include -#include #include #include #include #include #include -#include +#include #include #include -#include +#include +#include +#include #include #include +#include +#include #include +#include +#include #include +#include #include #include using namespace Qt::Literals::StringLiterals; +namespace ranges = std::ranges; //#define DEBUG //#define DEBUG_MODEL_LOADING -static constexpr int LLAMA_INTERNAL_STATE_VERSION = 0; -static constexpr int API_INTERNAL_STATE_VERSION = 0; +// NOTE: not threadsafe +static jinja2::TemplateEnv *jinjaEnv() +{ + static std::optional environment; + if (!environment) { + auto &env = environment.emplace(); + auto &settings = env.GetSettings(); + settings.trimBlocks = true; + settings.lstripBlocks = true; + env.AddGlobal("raise_exception", jinja2::UserCallable( + /*callable*/ [](auto ¶ms) -> jinja2::Value { + auto &message = params.args.at("message").asString(); + throw std::runtime_error(fmt::format("Jinja template error: {}", message)); + }, + /*argsInfo*/ { jinja2::ArgInfo("message", /*isMandatory*/ true) } + )); + env.AddGlobal("strftime_now", jinja2::UserCallable( + /*callable*/ [](auto ¶ms) -> jinja2::Value { + using Clock = std::chrono::system_clock; + auto &format = params.args.at("format").asString(); + time_t nowUnix = Clock::to_time_t(Clock::now()); + auto localDate = *std::localtime(&nowUnix); + std::ostringstream ss; + ss << std::put_time(&localDate, format.c_str()); + return ss.str(); + }, + /*argsInfo*/ { jinja2::ArgInfo("format", /*isMandatory*/ true) } + )); + } + return &*environment; +} class LLModelStore { public: @@ -107,9 +152,6 @@ void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) { ChatLLM::ChatLLM(Chat *parent, bool isServer) : QObject{nullptr} , m_chat(parent) - , m_promptResponseTokens(0) - , m_promptTokens(0) - , m_restoringFromText(false) , m_shouldBeLoaded(false) , m_forceUnloadModel(false) , m_markedForDeletion(false) @@ -118,8 +160,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_isServer(isServer) , m_forceMetal(MySettings::globalInstance()->forceMetal()) , m_reloadingToChangeVariant(false) - , m_processedSystemPrompt(false) - , m_restoreStateFromText(false) , m_chatModel(parent->chatModel()) { moveToThread(&m_llmThread); @@ -241,12 +281,8 @@ void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) #endif emit trySwitchContextOfLoadedModelCompleted(2); - - // Restore, signal and process - restoreState(); emit modelLoadingPercentageChanged(1.0f); emit trySwitchContextOfLoadedModelCompleted(0); - processSystemPrompt(); } bool ChatLLM::loadModel(const ModelInfo &modelInfo) @@ -260,15 +296,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) // to provide an overview of what we're doing here. if (isModelLoaded() && this->modelInfo() == modelInfo) { - // already acquired -> keep it and reset - resetContext(); + // already acquired -> keep it return true; // already loaded } // reset status emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value emit modelLoadingError(""); - m_pristineLoadedState = false; QString filePath = modelInfo.dirpath + modelInfo.filename(); QFileInfo fileInfo(filePath); @@ -276,7 +310,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) // We have a live model, but it isn't the one we want bool alreadyAcquired = isModelLoaded(); if (alreadyAcquired) { - resetContext(); #if defined(DEBUG_MODEL_LOADING) qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif @@ -306,14 +339,11 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) #if defined(DEBUG_MODEL_LOADING) qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - restoreState(); emit modelLoadingPercentageChanged(1.0f); setModelInfo(modelInfo); Q_ASSERT(!m_modelInfo.filename().isEmpty()); if (m_modelInfo.filename().isEmpty()) emit modelLoadingError(u"Modelinfo is left null for %1"_s.arg(modelInfo.filename())); - else - processSystemPrompt(); return true; } else { // Release the memory since we have to switch to a different model. @@ -371,7 +401,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) #if defined(DEBUG_MODEL_LOADING) qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - restoreState(); #if defined(DEBUG) qDebug() << "modelLoadedChanged" << m_llmThread.objectName(); fflush(stdout); @@ -389,10 +418,8 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) emit modelLoadingError(u"Could not find file for model %1"_s.arg(modelInfo.filename())); } - if (m_llModelInfo.model) { + if (m_llModelInfo.model) setModelInfo(modelInfo); - processSystemPrompt(); - } return bool(m_llModelInfo.model); } @@ -594,71 +621,57 @@ bool ChatLLM::isModelLoaded() const return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded(); } -std::string remove_leading_whitespace(const std::string& input) +static QString &removeLeadingWhitespace(QString &s) { - auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { - return !std::isspace(c); - }); - - if (first_non_whitespace == input.end()) - return std::string(); - - return std::string(first_non_whitespace, input.end()); + auto firstNonSpace = ranges::find_if_not(s, [](auto c) { return c.isSpace(); }); + s.remove(0, firstNonSpace - s.begin()); + return s; } -std::string trim_whitespace(const std::string& input) +template + requires std::convertible_to, QChar> +bool isAllSpace(R &&r) { - auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { - return !std::isspace(c); - }); - - if (first_non_whitespace == input.end()) - return std::string(); - - auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) { - return !std::isspace(c); - }).base(); - - return std::string(first_non_whitespace, last_non_whitespace); -} - -// FIXME(jared): we don't actually have to re-decode the prompt to generate a new response -void ChatLLM::regenerateResponse() -{ - // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning - // of n_past is of the number of prompt/response pairs, rather than for total tokens. - if (m_llModelType == LLModelTypeV1::API) - m_ctx.n_past -= 1; - else - m_ctx.n_past -= m_promptResponseTokens; - m_ctx.n_past = std::max(0, m_ctx.n_past); - m_promptResponseTokens = 0; - m_promptTokens = 0; - m_response = m_trimmedResponse = std::string(); - emit responseChanged(QString::fromStdString(m_trimmedResponse)); + return ranges::all_of(std::forward(r), [](QChar c) { return c.isSpace(); }); } -void ChatLLM::resetResponse() +void ChatLLM::regenerateResponse(int index) { - m_promptTokens = 0; - m_promptResponseTokens = 0; - m_response = m_trimmedResponse = std::string(); - emit responseChanged(QString::fromStdString(m_trimmedResponse)); -} + Q_ASSERT(m_chatModel); + int promptIdx; + { + auto items = m_chatModel->chatItems(); // holds lock + if (index < 1 || index >= items.size() || items[index].type() != ChatItem::Type::Response) + return; + promptIdx = m_chatModel->getPeerUnlocked(index).value_or(-1); + } -void ChatLLM::resetContext() -{ - resetResponse(); - m_processedSystemPrompt = false; - m_ctx = LLModel::PromptContext(); + emit responseChanged({}); + m_chatModel->truncate(index + 1); + m_chatModel->updateCurrentResponse(index, true ); + m_chatModel->updateNewResponse (index, {} ); + m_chatModel->updateStopped (index, false); + m_chatModel->updateThumbsUpState (index, false); + m_chatModel->updateThumbsDownState(index, false); + m_chatModel->setError(false); + if (promptIdx >= 0) + m_chatModel->updateSources(promptIdx, {}); + + prompt(m_chat->collectionList()); } -QString ChatLLM::response(bool trim) const +std::optional ChatLLM::popPrompt(int index) { - std::string resp = m_response; - if (trim) - resp = remove_leading_whitespace(resp); - return QString::fromStdString(resp); + Q_ASSERT(m_chatModel); + QString content; + { + auto items = m_chatModel->chatItems(); // holds lock + if (index < 0 || index >= items.size() || items[index].type() != ChatItem::Type::Prompt) + return std::nullopt; + content = items[index].value; + } + m_chatModel->truncate(index); + return content; } ModelInfo ChatLLM::modelInfo() const @@ -693,148 +706,283 @@ void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo) } } -bool ChatLLM::handlePrompt(int32_t token) +static LLModel::PromptContext promptContextFromSettings(const ModelInfo &modelInfo) { - // m_promptResponseTokens is related to last prompt/response not - // the entire context window which we can reset on regenerate prompt -#if defined(DEBUG) - qDebug() << "prompt process" << m_llmThread.objectName() << token; -#endif - ++m_promptTokens; - ++m_promptResponseTokens; - m_timer->start(); - return !m_stopGenerating; + auto *mySettings = MySettings::globalInstance(); + return { + .n_predict = mySettings->modelMaxLength (modelInfo), + .top_k = mySettings->modelTopK (modelInfo), + .top_p = float(mySettings->modelTopP (modelInfo)), + .min_p = float(mySettings->modelMinP (modelInfo)), + .temp = float(mySettings->modelTemperature (modelInfo)), + .n_batch = mySettings->modelPromptBatchSize (modelInfo), + .repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)), + .repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo), + }; } -bool ChatLLM::handleResponse(int32_t token, const std::string &response) +void ChatLLM::prompt(const QStringList &enabledCollections) { -#if defined(DEBUG) - printf("%s", response.c_str()); - fflush(stdout); -#endif + if (!isModelLoaded()) { + emit responseStopped(0); + return; + } - // check for error - // FIXME (Adam) The error messages should not be treated as a model response or part of the - // normal conversation. They should be serialized along with the conversation, but the strings - // are separate and we should preserve info that these are error messages and not actual model responses. - if (token < 0) { - m_response.append(response); - m_trimmedResponse = remove_leading_whitespace(m_response); - emit responseChanged(QString::fromStdString(m_trimmedResponse)); - return false; + try { + promptInternalChat(enabledCollections, promptContextFromSettings(m_modelInfo)); + } catch (const std::exception &e) { + // FIXME(jared): this is neither translated nor serialized + emit responseFailed(u"Error: %1"_s.arg(QString::fromUtf8(e.what()))); + emit responseStopped(0); } +} + +// FIXME(jared): We can avoid this potentially expensive copy if we use ChatItem pointers, but this is only safe if we +// hold the lock while generating. We can't do that now because Chat is actually in charge of updating the response, not +// ChatLLM. +std::vector ChatLLM::forkConversation(const QString &prompt) const +{ + Q_ASSERT(m_chatModel); + if (m_chatModel->hasError()) + throw std::logic_error("cannot continue conversation with an error"); - // m_promptResponseTokens is related to last prompt/response not - // the entire context window which we can reset on regenerate prompt - ++m_promptResponseTokens; - m_timer->inc(); - Q_ASSERT(!response.empty()); - m_response.append(response); - m_trimmedResponse = remove_leading_whitespace(m_response); - emit responseChanged(QString::fromStdString(m_trimmedResponse)); - return !m_stopGenerating; + std::vector conversation; + { + auto items = m_chatModel->chatItems(); // holds lock + Q_ASSERT(items.size() >= 2); // should be prompt/response pairs + conversation.reserve(items.size() + 1); + conversation.assign(items.begin(), items.end()); + } + conversation.emplace_back(ChatItem::prompt_tag, prompt); + return conversation; } -bool ChatLLM::prompt(const QList &collectionList, const QString &prompt) +// version 0 (default): HF compatible +// version 1: explicit LocalDocs formatting +static uint parseJinjaTemplateVersion(QStringView tmpl) { - if (m_restoreStateFromText) { - Q_ASSERT(m_state.isEmpty()); - processRestoreStateFromText(); + static uint MAX_VERSION = 1; + static QRegularExpression reVersion(uR"(\A{#-?\s+gpt4all v(\d+)-?#}\s*$)"_s, QRegularExpression::MultilineOption); + if (auto match = reVersion.matchView(tmpl); match.hasMatch()) { + uint ver = match.captured(1).toUInt(); + if (ver > MAX_VERSION) + throw std::out_of_range(fmt::format("Unknown template version: {}", ver)); + return ver; } + return 0; +} - const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); - const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); - const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); - const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); - const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); - const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); - const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); - const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); - const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); - return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, - repeat_penalty, repeat_penalty_tokens); +static auto loadJinjaTemplate( + std::optional &tmpl /*out*/, const std::string &source +) -> jinja2::Result +{ + tmpl.emplace(jinjaEnv()); + return tmpl->Load(source); } -bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, - int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens, std::optional fakeReply) +std::optional ChatLLM::checkJinjaTemplateError(const std::string &source) { - if (!isModelLoaded()) - return false; + std::optional tmpl; + if (auto res = loadJinjaTemplate(tmpl, source); !res) + return res.error().ToString(); + return std::nullopt; +} + +std::string ChatLLM::applyJinjaTemplate(std::span items) const +{ + Q_ASSERT(items.size() >= 1); - if (!m_processedSystemPrompt) - processSystemPrompt(); + auto *mySettings = MySettings::globalInstance(); + auto &model = m_llModelInfo.model; + + QString chatTemplate, systemMessage; + auto chatTemplateSetting = mySettings->modelChatTemplate(m_modelInfo); + if (auto tmpl = chatTemplateSetting.asModern()) { + chatTemplate = *tmpl; + } else if (chatTemplateSetting.isLegacy()) { + throw std::logic_error("cannot apply Jinja to a legacy prompt template"); + } else { + throw std::logic_error("cannot apply Jinja without setting a chat template first"); + } + if (isAllSpace(chatTemplate)) { + throw std::logic_error("cannot apply Jinja with a blank chat template"); + } + if (auto tmpl = mySettings->modelSystemMessage(m_modelInfo).asModern()) { + systemMessage = *tmpl; + } else { + throw std::logic_error("cannot apply Jinja with a legacy system message"); + } + + uint version = parseJinjaTemplateVersion(chatTemplate); + + auto makeMap = [version](const ChatItem &item) { + return jinja2::GenericMap([msg = std::make_shared(version, item)] { return msg.get(); }); + }; + + std::unique_ptr systemItem; + bool useSystem = !isAllSpace(systemMessage); + + jinja2::ValuesList messages; + messages.reserve(useSystem + items.size()); + if (useSystem) { + systemItem = std::make_unique(ChatItem::system_tag, systemMessage); + messages.emplace_back(makeMap(*systemItem)); + } + for (auto &item : items) + messages.emplace_back(makeMap(item)); + + jinja2::ValuesMap params { + { "messages", std::move(messages) }, + { "add_generation_prompt", true }, + }; + for (auto &[name, token] : model->specialTokens()) + params.emplace(std::move(name), std::move(token)); + + std::optional tmpl; + auto maybeRendered = loadJinjaTemplate(tmpl, chatTemplate.toStdString()) + .and_then([&] { return tmpl->RenderAsString(params); }); + if (!maybeRendered) + throw std::runtime_error(fmt::format("Failed to parse chat template: {}", maybeRendered.error().ToString())); + return *maybeRendered; +} + +auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx) + -> ChatPromptResult +{ + Q_ASSERT(isModelLoaded()); + Q_ASSERT(m_chatModel); QList databaseResults; const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); - if (!fakeReply && !collectionList.isEmpty()) { - emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks - emit databaseResultsChanged(databaseResults); + if (!enabledCollections.isEmpty()) { + std::optional> query; + { + // Find the prompt that represents the query. Server chats are flexible and may not have one. + auto items = m_chatModel->chatItems(); // holds lock + Q_ASSERT(items); + auto response = items.end() - 1; + if (auto peer = m_chatModel->getPeerUnlocked(response)) + query = {*peer - items.begin(), (*peer)->value}; + } + if (query) { + auto &[promptIndex, queryStr] = *query; + emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks + m_chatModel->updateSources(promptIndex, databaseResults); + emit databaseResultsChanged(databaseResults); + } } - // Augment the prompt template with the results if any - QString docsContext; - if (!databaseResults.isEmpty()) { - QStringList results; - for (const ResultInfo &info : databaseResults) - results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text); + // copy messages for safety (since we can't hold the lock the whole time) + std::vector chatItems; + { + auto items = m_chatModel->chatItems(); // holds lock + Q_ASSERT(items.size() >= 2); // should be prompt/response pairs + chatItems.assign(items.begin(), items.end() - 1); // exclude last + } + auto result = promptInternal(chatItems, ctx, !databaseResults.isEmpty()); + return { + /*PromptResult*/ { + .response = std::move(result.response), + .promptTokens = result.promptTokens, + .responseTokens = result.responseTokens, + }, + /*databaseResults*/ std::move(databaseResults), + }; +} + +auto ChatLLM::promptInternal( + const std::variant, std::string_view> &prompt, + const LLModel::PromptContext &ctx, + bool usedLocalDocs +) -> PromptResult +{ + Q_ASSERT(isModelLoaded()); - // FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template - docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n")); + auto *mySettings = MySettings::globalInstance(); + + // unpack prompt argument + const std::span *chatItems = nullptr; + std::string jinjaBuffer; + std::string_view conversation; + if (auto *nonChat = std::get_if(&prompt)) { + conversation = *nonChat; // complete the string without a template + } else { + chatItems = &std::get>(prompt); + jinjaBuffer = applyJinjaTemplate(*chatItems); + conversation = jinjaBuffer; } - int n_threads = MySettings::globalInstance()->threadCount(); - - m_stopGenerating = false; - auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); - auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, - std::placeholders::_2); - emit promptProcessing(); - m_ctx.n_predict = n_predict; - m_ctx.top_k = top_k; - m_ctx.top_p = top_p; - m_ctx.min_p = min_p; - m_ctx.temp = temp; - m_ctx.n_batch = n_batch; - m_ctx.repeat_penalty = repeat_penalty; - m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llModelInfo.model->setThreadCount(n_threads); -#if defined(DEBUG) - printf("%s", qPrintable(prompt)); - fflush(stdout); -#endif + // check for overlength last message + if (!dynamic_cast(m_llModelInfo.model.get())) { + auto nCtx = m_llModelInfo.model->contextLength(); + std::string jinjaBuffer2; + auto lastMessageRendered = (chatItems && chatItems->size() > 1) + ? std::string_view(jinjaBuffer2 = applyJinjaTemplate({ &chatItems->back(), 1 })) + : conversation; + int32_t lastMessageLength = m_llModelInfo.model->countPromptTokens(lastMessageRendered); + if (auto limit = nCtx - 4; lastMessageLength > limit) { + throw std::invalid_argument( + tr("Your message was too long and could not be processed (%1 > %2). " + "Please try again with something shorter.").arg(lastMessageLength, limit).toUtf8().constData() + ); + } + } + + PromptResult result; + + auto handlePrompt = [this, &result](std::span batch, bool cached) -> bool { + Q_UNUSED(cached) + result.promptTokens += batch.size(); + m_timer->start(); + return !m_stopGenerating; + }; + + auto handleResponse = [this, &result](LLModel::Token token, std::string_view piece) -> bool { + Q_UNUSED(token) + result.responseTokens++; + m_timer->inc(); + result.response.append(piece.data(), piece.size()); + auto respStr = QString::fromUtf8(result.response); + emit responseChanged(removeLeadingWhitespace(respStr)); + return !m_stopGenerating; + }; + QElapsedTimer totalTime; totalTime.start(); m_timer->start(); - if (!docsContext.isEmpty()) { - auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response - m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, - /*allowContextShift*/ true, m_ctx); - m_ctx.n_predict = old_n_predict; // now we are ready for a response + + try { + emit promptProcessing(); + m_llModelInfo.model->setThreadCount(mySettings->threadCount()); + m_stopGenerating = false; + m_llModelInfo.model->prompt(conversation, handlePrompt, handleResponse, ctx); + } catch (...) { + m_timer->stop(); + throw; } - m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, - /*allowContextShift*/ true, m_ctx, false, - fakeReply.transform(std::mem_fn(&QString::toStdString))); -#if defined(DEBUG) - printf("\n"); - fflush(stdout); -#endif + m_timer->stop(); qint64 elapsed = totalTime.elapsed(); - std::string trimmed = trim_whitespace(m_response); - if (trimmed != m_trimmedResponse) { - m_trimmedResponse = trimmed; - emit responseChanged(QString::fromStdString(m_trimmedResponse)); - } - SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (!databaseResults.isEmpty() && mode == SuggestionMode::LocalDocsOnly)) + // trim trailing whitespace + auto respStr = QString::fromUtf8(result.response); + if (!respStr.isEmpty() && std::as_const(respStr).back().isSpace()) + emit responseChanged(respStr.trimmed()); + + bool doQuestions = false; + if (!m_isServer && chatItems) { + switch (mySettings->suggestionMode()) { + case SuggestionMode::On: doQuestions = true; break; + case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break; + case SuggestionMode::Off: ; + } + } + if (doQuestions) generateQuestions(elapsed); else emit responseStopped(elapsed); - m_pristineLoadedState = false; - return true; + return result; } void ChatLLM::setShouldBeLoaded(bool b) @@ -870,9 +1018,6 @@ void ChatLLM::unloadModel() else emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value - if (!m_markedForDeletion) - saveState(); - #if defined(DEBUG_MODEL_LOADING) qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif @@ -883,7 +1028,6 @@ void ChatLLM::unloadModel() } LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); - m_pristineLoadedState = false; } void ChatLLM::reloadModel() @@ -907,478 +1051,201 @@ void ChatLLM::reloadModel() void ChatLLM::generateName() { Q_ASSERT(isModelLoaded()); - if (!isModelLoaded()) + if (!isModelLoaded() || m_isServer) return; - const QString chatNamePrompt = MySettings::globalInstance()->modelChatNamePrompt(m_modelInfo); - if (chatNamePrompt.trimmed().isEmpty()) { + Q_ASSERT(m_chatModel); + + auto *mySettings = MySettings::globalInstance(); + + const QString chatNamePrompt = mySettings->modelChatNamePrompt(m_modelInfo); + if (isAllSpace(chatNamePrompt)) { qWarning() << "ChatLLM: not generating chat name because prompt is empty"; return; } - auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); - auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1); - auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2); - LLModel::PromptContext ctx = m_ctx; - m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(), - promptFunc, responseFunc, /*allowContextShift*/ false, ctx); - std::string trimmed = trim_whitespace(m_nameResponse); - if (trimmed != m_nameResponse) { - m_nameResponse = trimmed; - emit generatedNameChanged(QString::fromStdString(m_nameResponse)); - } - m_pristineLoadedState = false; -} + QByteArray response; // raw UTF-8 -void ChatLLM::handleChatIdChanged(const QString &id) -{ - m_llmThread.setObjectName(id); -} + auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool { + Q_UNUSED(token) -bool ChatLLM::handleNamePrompt(int32_t token) -{ -#if defined(DEBUG) - qDebug() << "name prompt" << m_llmThread.objectName() << token; -#endif - Q_UNUSED(token); - return !m_stopGenerating; -} - -bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) -{ -#if defined(DEBUG) - qDebug() << "name response" << m_llmThread.objectName() << token << response; -#endif - Q_UNUSED(token); + response.append(piece.data(), piece.size()); + QStringList words = QString::fromUtf8(response).simplified().split(u' ', Qt::SkipEmptyParts); + emit generatedNameChanged(words.join(u' ')); + return words.size() <= 3; + }; - m_nameResponse.append(response); - emit generatedNameChanged(QString::fromStdString(m_nameResponse)); - QString gen = QString::fromStdString(m_nameResponse).simplified(); - QStringList words = gen.split(' ', Qt::SkipEmptyParts); - return words.size() <= 3; + try { + m_llModelInfo.model->prompt( + applyJinjaTemplate(forkConversation(chatNamePrompt)), + [this](auto &&...) { return !m_stopGenerating; }, + handleResponse, + promptContextFromSettings(m_modelInfo) + ); + } catch (const std::exception &e) { + qWarning() << "ChatLLM failed to generate name:" << e.what(); + } } -bool ChatLLM::handleQuestionPrompt(int32_t token) +void ChatLLM::handleChatIdChanged(const QString &id) { -#if defined(DEBUG) - qDebug() << "question prompt" << m_llmThread.objectName() << token; -#endif - Q_UNUSED(token); - return !m_stopGenerating; + m_llmThread.setObjectName(id); } -bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response) +void ChatLLM::generateQuestions(qint64 elapsed) { -#if defined(DEBUG) - qDebug() << "question response" << m_llmThread.objectName() << token << response; -#endif - Q_UNUSED(token); - - // add token to buffer - m_questionResponse.append(response); - - // match whole question sentences // FIXME: This only works with response by the model in english which is not ideal for a multi-language // model. - static const QRegularExpression reQuestion(R"(\b(What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)"); - - // extract all questions from response - int lastMatchEnd = -1; - for (const auto &match : reQuestion.globalMatch(m_questionResponse)) { - lastMatchEnd = match.capturedEnd(); - emit generatedQuestionFinished(match.captured()); - } - - // remove processed input from buffer - if (lastMatchEnd != -1) - m_questionResponse.erase(m_questionResponse.cbegin(), m_questionResponse.cbegin() + lastMatchEnd); - - return true; -} + // match whole question sentences + static const std::regex reQuestion(R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)"); -void ChatLLM::generateQuestions(qint64 elapsed) -{ Q_ASSERT(isModelLoaded()); if (!isModelLoaded()) { emit responseStopped(elapsed); return; } - const std::string suggestedFollowUpPrompt = MySettings::globalInstance()->modelSuggestedFollowUpPrompt(m_modelInfo).toStdString(); - if (QString::fromStdString(suggestedFollowUpPrompt).trimmed().isEmpty()) { + auto *mySettings = MySettings::globalInstance(); + + QString suggestedFollowUpPrompt = mySettings->modelSuggestedFollowUpPrompt(m_modelInfo); + if (isAllSpace(suggestedFollowUpPrompt)) { + qWarning() << "ChatLLM: not generating follow-up questions because prompt is empty"; emit responseStopped(elapsed); return; } emit generatingQuestions(); - m_questionResponse.clear(); - auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); - auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1); - auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2); - LLModel::PromptContext ctx = m_ctx; - QElapsedTimer totalTime; - totalTime.start(); - m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc, - /*allowContextShift*/ false, ctx); - elapsed += totalTime.elapsed(); - emit responseStopped(elapsed); -} + std::string response; // raw UTF-8 -bool ChatLLM::handleSystemPrompt(int32_t token) -{ -#if defined(DEBUG) - qDebug() << "system prompt" << m_llmThread.objectName() << token << m_stopGenerating; -#endif - Q_UNUSED(token); - return !m_stopGenerating; -} + auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool { + Q_UNUSED(token) -bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) -{ -#if defined(DEBUG) - qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating; -#endif - Q_UNUSED(token); - return !m_stopGenerating; -} - -// this function serialized the cached model state to disk. -// we want to also serialize n_ctx, and read it at load time. -bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) -{ - if (version >= 2) { - if (m_llModelType == LLModelTypeV1::NONE) { - qWarning() << "ChatLLM ERROR: attempted to serialize a null model for chat id" << m_chat->id() - << "name" << m_chat->name(); - return false; - } + // add token to buffer + response.append(piece); - stream << m_llModelType; - switch (m_llModelType) { - case LLModelTypeV1::LLAMA: stream << LLAMA_INTERNAL_STATE_VERSION; break; - case LLModelTypeV1::API: stream << API_INTERNAL_STATE_VERSION; break; - default: stream << 0; // models removed in v2.5.0 + // extract all questions from response + ptrdiff_t lastMatchEnd = -1; + auto it = std::sregex_iterator(response.begin(), response.end(), reQuestion); + auto end = std::sregex_iterator(); + for (; it != end; ++it) { + auto pos = it->position(); + auto len = it->length(); + lastMatchEnd = pos + len; + emit generatedQuestionFinished(QString::fromUtf8(&response[pos], len)); } - } - stream << response(); - stream << generatedName(); - stream << m_promptResponseTokens; - if (!serializeKV) { -#if defined(DEBUG) - qDebug() << "serialize" << m_llmThread.objectName() << m_state.size(); -#endif - return stream.status() == QDataStream::Ok; - } + // remove processed input from buffer + if (lastMatchEnd != -1) + response.erase(0, lastMatchEnd); + return true; + }; - if (version < 4) { - int responseLogits = 0; - stream << responseLogits; - } - stream << m_ctx.n_past; - saveState(); - if (version >= 7) { - stream << m_stateContextLength; + QElapsedTimer totalTime; + totalTime.start(); + try { + m_llModelInfo.model->prompt( + applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)), + [this](auto &&...) { return !m_stopGenerating; }, + handleResponse, + promptContextFromSettings(m_modelInfo) + ); + } catch (const std::exception &e) { + qWarning() << "ChatLLM failed to generate follow-up questions:" << e.what(); } - stream << quint64(m_stateInputTokens.size()); - stream.writeRawData(reinterpret_cast(m_stateInputTokens.data()), - m_stateInputTokens.size() * sizeof(m_stateInputTokens[0])); - QByteArray compressed = qCompress(m_state); - stream << compressed; -#if defined(DEBUG) - qDebug() << "serialize" << m_llmThread.objectName() << m_state.size(); -#endif - return stream.status() == QDataStream::Ok; + elapsed += totalTime.elapsed(); + emit responseStopped(elapsed); } -bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) +// this function serialized the cached model state to disk. +// we want to also serialize n_ctx, and read it at load time. +bool ChatLLM::serialize(QDataStream &stream, int version) { - if (version >= 2) { - int llModelType; - stream >> llModelType; - m_llModelType = (version >= 6 ? parseLLModelTypeV1 : parseLLModelTypeV0)(llModelType); - if (m_llModelType == LLModelTypeV1::NONE) { - qWarning().nospace() << "error loading chat id " << m_chat->id() << ": unrecognized model type: " - << llModelType; - return false; + if (version < 11) { + if (version >= 6) { + stream << false; // serializeKV } + if (version >= 2) { + if (m_llModelType == LLModelTypeV1::NONE) { + qWarning() << "ChatLLM ERROR: attempted to serialize a null model for chat id" << m_chat->id() + << "name" << m_chat->name(); + return false; + } + stream << m_llModelType; + stream << 0; // state version + } + { + QString dummy; + stream << dummy; // response + stream << dummy; // generated name + } + stream << quint32(0); // prompt + response tokens - /* note: prior to chat version 10, API models and chats with models removed in v2.5.0 only wrote this because of - * undefined behavior in Release builds */ - int internalStateVersion; // for future use - stream >> internalStateVersion; - } - QString response; - stream >> response; - m_response = response.toStdString(); - m_trimmedResponse = trim_whitespace(m_response); - QString nameResponse; - stream >> nameResponse; - m_nameResponse = nameResponse.toStdString(); - stream >> m_promptResponseTokens; - - // If we do not deserialize the KV or it is discarded, then we need to restore the state from the - // text only. This will be a costly operation, but the chat has to be restored from the text archive - // alone. - if (!deserializeKV || discardKV) { - m_restoreStateFromText = true; - m_pristineLoadedState = true; - } - - if (!deserializeKV) { -#if defined(DEBUG) - qDebug() << "deserialize" << m_llmThread.objectName(); -#endif - return stream.status() == QDataStream::Ok; - } - - if (version < 4) { - int responseLogits; - stream >> responseLogits; - } - - int32_t n_past; - stream >> n_past; - if (!discardKV) m_ctx.n_past = n_past; - - if (version >= 7) { - uint32_t n_ctx; - stream >> n_ctx; - if (!discardKV) m_stateContextLength = n_ctx; - } - - if (version < 9) { - quint64 logitsSize; - stream >> logitsSize; - stream.skipRawData(logitsSize * sizeof(float)); - } - - quint64 tokensSize; - stream >> tokensSize; - if (!discardKV) { - m_stateInputTokens.resize(tokensSize); - stream.readRawData(reinterpret_cast(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0])); - } else { - stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0])); - } - - if (version >= 1) { - QByteArray compressed; - stream >> compressed; - if (!discardKV) - m_state = qUncompress(compressed); - } else { - if (!discardKV) { - stream >> m_state; - } else { - QByteArray state; - stream >> state; + if (version < 6) { // serialize binary state + if (version < 4) { + stream << 0; // responseLogits + } + stream << int32_t(0); // n_past + stream << quint64(0); // input token count + stream << QByteArray(); // KV cache state } } - -#if defined(DEBUG) - qDebug() << "deserialize" << m_llmThread.objectName(); -#endif return stream.status() == QDataStream::Ok; } -void ChatLLM::saveState() -{ - if (!isModelLoaded() || m_pristineLoadedState) - return; - - if (m_llModelType == LLModelTypeV1::API) { - m_state.clear(); - QDataStream stream(&m_state, QIODeviceBase::WriteOnly); - stream.setVersion(QDataStream::Qt_6_4); - ChatAPI *chatAPI = static_cast(m_llModelInfo.model.get()); - stream << chatAPI->context(); - return; - } - - const size_t stateSize = m_llModelInfo.model->stateSize(); - m_state.resize(stateSize); -#if defined(DEBUG) - qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size(); -#endif - bool ok = m_llModelInfo.model->saveState({reinterpret_cast(m_state.data()), size_t(m_state.size())}, - m_stateInputTokens); - if (!ok) { - // FIXME(jared): how badly does this situation break GPT4All? - qWarning() << "ChatLLM failed to save LLModel state"; - m_state.clear(); - m_state.squeeze(); - m_stateContextLength = -1; - } - m_stateContextLength = m_llModelInfo.model->contextLength(); -} - -void ChatLLM::restoreState() +bool ChatLLM::deserialize(QDataStream &stream, int version) { - if (!isModelLoaded()) - return; - - if (m_llModelType == LLModelTypeV1::API) { - QDataStream stream(m_state); - stream.setVersion(QDataStream::Qt_6_4); - ChatAPI *chatAPI = static_cast(m_llModelInfo.model.get()); - QList context; - stream >> context; - chatAPI->setContext(context); - m_state.clear(); - m_state.squeeze(); - return; - } - -#if defined(DEBUG) - qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size(); -#endif - - if (m_state.isEmpty()) - return; + // discard all state since we are initialized from the ChatModel as of v11 + if (version < 11) { + union { int intval; quint32 u32; quint64 u64; }; + + bool deserializeKV = true; + if (version >= 6) + stream >> deserializeKV; + + if (version >= 2) { + stream >> intval; // model type + auto llModelType = (version >= 6 ? parseLLModelTypeV1 : parseLLModelTypeV0)(intval); + if (llModelType == LLModelTypeV1::NONE) { + qWarning().nospace() << "error loading chat id " << m_chat->id() << ": unrecognized model type: " + << intval; + return false; + } - if (m_llModelInfo.model->contextLength() != m_stateContextLength) { - qWarning() << "restoring state from text because of n_ctx mismatch (state" - << m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")"; - m_restoreStateFromText = true; - } else { - size_t bytesRead = m_llModelInfo.model->restoreState( - {reinterpret_cast(m_state.data()), size_t(m_state.size())}, - m_stateInputTokens - ); - if (!bytesRead) { - qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)"; - m_restoreStateFromText = true; - } else { - m_processedSystemPrompt = true; - m_pristineLoadedState = true; + /* note: prior to chat version 10, API models and chats with models removed in v2.5.0 only wrote this because of + * undefined behavior in Release builds */ + stream >> intval; // state version + if (intval) { + qWarning().nospace() << "error loading chat id " << m_chat->id() << ": unrecognized internal state version"; + return false; + } } - } - - // free local state copy unless unload is pending - if (m_shouldBeLoaded) { - m_state.clear(); - m_state.squeeze(); - m_pristineLoadedState = false; - } -} -void ChatLLM::processSystemPrompt() -{ - Q_ASSERT(isModelLoaded()); - if (!isModelLoaded() || m_processedSystemPrompt) - return; - - const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString(); - - // Start with a whole new context - m_stopGenerating = false; - m_ctx = LLModel::PromptContext(); - - if (!QString::fromStdString(systemPrompt).trimmed().isEmpty()) { - auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1); - - const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); - const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); - const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); - const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); - const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); - const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); - const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); - const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); - int n_threads = MySettings::globalInstance()->threadCount(); - m_ctx.n_predict = n_predict; - m_ctx.top_k = top_k; - m_ctx.top_p = top_p; - m_ctx.min_p = min_p; - m_ctx.temp = temp; - m_ctx.n_batch = n_batch; - m_ctx.repeat_penalty = repeat_penalty; - m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llModelInfo.model->setThreadCount(n_threads); -#if defined(DEBUG) - printf("%s", qPrintable(QString::fromStdString(systemPrompt))); - fflush(stdout); -#endif - auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response - // use "%1%2" and not "%1" to avoid implicit whitespace - m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true); - m_ctx.n_predict = old_n_predict; -#if defined(DEBUG) - printf("\n"); - fflush(stdout); -#endif - } - - m_processedSystemPrompt = m_stopGenerating == false; - m_pristineLoadedState = false; -} - -void ChatLLM::processRestoreStateFromText() -{ - Q_ASSERT(isModelLoaded()); - if (!isModelLoaded() || !m_restoreStateFromText || m_isServer) - return; - - processSystemPrompt(); - - m_restoringFromText = true; - emit restoringFromTextChanged(); - - m_stopGenerating = false; - - auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1); - - const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); - const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); - const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); - const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); - const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); - const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); - const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); - const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); - const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); - int n_threads = MySettings::globalInstance()->threadCount(); - m_ctx.n_predict = n_predict; - m_ctx.top_k = top_k; - m_ctx.top_p = top_p; - m_ctx.min_p = min_p; - m_ctx.temp = temp; - m_ctx.n_batch = n_batch; - m_ctx.repeat_penalty = repeat_penalty; - m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llModelInfo.model->setThreadCount(n_threads); + { + QString dummy; + stream >> dummy; // response + stream >> dummy; // name response + } + stream >> u32; // prompt + response token count - Q_ASSERT(m_chatModel); - m_chatModel->lock(); - auto it = m_chatModel->begin(); - while (it < m_chatModel->end()) { - auto &prompt = *it++; - Q_ASSERT(prompt.name == "Prompt: "); - Q_ASSERT(it < m_chatModel->end()); - - auto &response = *it++; - Q_ASSERT(response.name == "Response: "); - - // FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing - // m_promptTokens or m_promptResponseTokens - m_llModelInfo.model->prompt( - prompt.promptPlusAttachments().toStdString(), promptTemplate.toStdString(), - promptFunc, /*responseFunc*/ [](auto &&...) { return true; }, - /*allowContextShift*/ true, - m_ctx, - /*special*/ false, - response.value.toUtf8().constData() - ); + // We don't use the raw model state anymore. + if (deserializeKV) { + if (version < 4) { + stream >> u32; // response logits + } + stream >> u32; // n_past + if (version >= 7) { + stream >> u32; // n_ctx + } + if (version < 9) { + stream >> u64; // logits size + stream.skipRawData(u64 * sizeof(float)); // logits + } + stream >> u64; // token cache size + stream.skipRawData(u64 * sizeof(int)); // token cache + QByteArray dummy; + stream >> dummy; // state + } } - m_chatModel->unlock(); - - if (!m_stopGenerating) - m_restoreStateFromText = false; - - m_restoringFromText = false; - emit restoringFromTextChanged(); - - m_pristineLoadedState = false; + return stream.status() == QDataStream::Ok; } diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index 4b9936cb038c..c79ca0bd517e 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -13,20 +13,24 @@ #include #include #include +#include // IWYU pragma: keep +#include #include -#include +#include // IWYU pragma: keep #include #include #include #include #include +#include #include -#include +#include using namespace Qt::Literals::StringLiterals; class QDataStream; +struct ChatItem; // NOTE: values serialized to disk, do not change or reuse enum class LLModelTypeV0 { // chat versions 2-5 @@ -142,7 +146,6 @@ class Chat; class ChatLLM : public QObject { Q_OBJECT - Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) @@ -150,12 +153,14 @@ class ChatLLM : public QObject ChatLLM(Chat *parent, bool isServer = false); virtual ~ChatLLM(); - void destroy(); static void destroyStore(); + static std::optional checkJinjaTemplateError(const std::string &source); + + void destroy(); bool isModelLoaded() const; - void regenerateResponse(); - void resetResponse(); - void resetContext(); + void regenerateResponse(int index); + // used to implement edit functionality + std::optional popPrompt(int index); void stopGenerating() { m_stopGenerating = true; } @@ -165,13 +170,9 @@ class ChatLLM : public QObject void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } - QString response(bool trim = true) const; - ModelInfo modelInfo() const; void setModelInfo(const ModelInfo &info); - bool restoringFromText() const { return m_restoringFromText; } - void acquireModel(); void resetModel(); @@ -196,13 +197,11 @@ class ChatLLM : public QObject return m_llModelInfo.fallbackReason.value_or(u""_s); } - QString generatedName() const { return QString::fromStdString(m_nameResponse); } - - bool serialize(QDataStream &stream, int version, bool serializeKV); - bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); + bool serialize(QDataStream &stream, int version); + bool deserialize(QDataStream &stream, int version); public Q_SLOTS: - bool prompt(const QList &collectionList, const QString &prompt); + void prompt(const QStringList &enabledCollections); bool loadDefaultModel(); void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo); @@ -210,22 +209,19 @@ public Q_SLOTS: void unloadModel(); void reloadModel(); void generateName(); - void generateQuestions(qint64 elapsed); void handleChatIdChanged(const QString &id); void handleShouldBeLoadedChanged(); void handleThreadStarted(); void handleForceMetalChanged(bool forceMetal); void handleDeviceChanged(); - void processSystemPrompt(); - void processRestoreStateFromText(); Q_SIGNALS: - void restoringFromTextChanged(); void loadedModelInfoChanged(); void modelLoadingPercentageChanged(float); void modelLoadingError(const QString &error); void modelLoadingWarning(const QString &warning); void responseChanged(const QString &response); + void responseFailed(const QString &error); void promptProcessing(); void generatingQuestions(); void responseStopped(qint64 promptResponseMs); @@ -244,58 +240,50 @@ public Q_SLOTS: void modelInfoChanged(const ModelInfo &modelInfo); protected: - bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, - int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens, std::optional fakeReply = {}); - bool handlePrompt(int32_t token); - bool handleResponse(int32_t token, const std::string &response); - bool handleNamePrompt(int32_t token); - bool handleNameResponse(int32_t token, const std::string &response); - bool handleSystemPrompt(int32_t token); - bool handleSystemResponse(int32_t token, const std::string &response); - bool handleRestoreStateFromTextPrompt(int32_t token); - bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response); - bool handleQuestionPrompt(int32_t token); - bool handleQuestionResponse(int32_t token, const std::string &response); - void saveState(); - void restoreState(); - -protected: - LLModel::PromptContext m_ctx; - quint32 m_promptTokens; - quint32 m_promptResponseTokens; + struct PromptResult { + QByteArray response; // raw UTF-8 + int promptTokens; // note: counts *entire* history, even if cached + int responseTokens; + }; + + struct ChatPromptResult : PromptResult { + QList databaseResults; + }; + + ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx); + // passing a string_view directly skips templating and uses the raw string + PromptResult promptInternal(const std::variant, std::string_view> &prompt, + const LLModel::PromptContext &ctx, + bool usedLocalDocs); private: bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); + std::vector forkConversation(const QString &prompt) const; + + // Applies the Jinja template. Query mode returns only the last message without special tokens. + // Returns a (# of messages, rendered prompt) pair. + std::string applyJinjaTemplate(std::span items) const; + + void generateQuestions(qint64 elapsed); + +protected: + QPointer m_chatModel; + +private: const Chat *m_chat; - std::string m_response; - std::string m_trimmedResponse; - std::string m_nameResponse; - QString m_questionResponse; LLModelInfo m_llModelInfo; LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE; ModelInfo m_modelInfo; TokenTimer *m_timer; - QByteArray m_state; - std::vector m_stateInputTokens; - int32_t m_stateContextLength = -1; QThread m_llmThread; std::atomic m_stopGenerating; std::atomic m_shouldBeLoaded; - std::atomic m_restoringFromText; // status indication std::atomic m_forceUnloadModel; std::atomic m_markedForDeletion; bool m_isServer; bool m_forceMetal; bool m_reloadingToChangeVariant; - bool m_processedSystemPrompt; - bool m_restoreStateFromText; - // m_pristineLoadedState is set if saveSate is unnecessary, either because: - // - an unload was queued during LLModel::restoreState() - // - the chat will be restored from text and hasn't been interacted with yet - bool m_pristineLoadedState = false; - QPointer m_chatModel; }; #endif // CHATLLM_H diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 5a5c63b2d6ee..7ce6b0e884ad 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -2,8 +2,11 @@ #define CHATMODEL_H #include "database.h" +#include "utils.h" #include "xlsxtomd.h" +#include + #include #include #include @@ -18,6 +21,15 @@ #include #include +#include +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; +namespace ranges = std::ranges; + + struct PromptAttachment { Q_GADGET Q_PROPERTY(QUrl url MEMBER url) @@ -60,66 +72,145 @@ Q_DECLARE_METATYPE(PromptAttachment) struct ChatItem { Q_GADGET - Q_PROPERTY(QString name MEMBER name) + Q_PROPERTY(QString name MEMBER name ) Q_PROPERTY(QString value MEMBER value) - Q_PROPERTY(QString newResponse MEMBER newResponse) - Q_PROPERTY(bool currentResponse MEMBER currentResponse) - Q_PROPERTY(bool stopped MEMBER stopped) - Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState) - Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) - Q_PROPERTY(QList sources MEMBER sources) - Q_PROPERTY(QList consolidatedSources MEMBER consolidatedSources) + + // prompts Q_PROPERTY(QList promptAttachments MEMBER promptAttachments) - Q_PROPERTY(QString promptPlusAttachments READ promptPlusAttachments) + Q_PROPERTY(QString bakedPrompt READ bakedPrompt ) + + // responses + Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse) + Q_PROPERTY(bool isError MEMBER isError ) + + // responses (DataLake) + Q_PROPERTY(QString newResponse MEMBER newResponse ) + Q_PROPERTY(bool stopped MEMBER stopped ) + Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState ) + Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) public: - QString promptPlusAttachments() const + enum class Type { System, Prompt, Response }; + + // tags for constructing ChatItems + struct prompt_tag_t { explicit prompt_tag_t() = default; }; + static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t(); + struct response_tag_t { explicit response_tag_t() = default; }; + static inline constexpr response_tag_t response_tag = response_tag_t(); + struct system_tag_t { explicit system_tag_t() = default; }; + static inline constexpr system_tag_t system_tag = system_tag_t(); + + // FIXME(jared): This should not be necessary. QML should see null or undefined if it + // tries to access something invalid. + ChatItem() = default; + + // NOTE: system messages are currently never stored in the model or serialized + ChatItem(system_tag_t, const QString &value) + : name(u"System: "_s), value(value) {} + + ChatItem(prompt_tag_t, const QString &value, const QList &attachments = {}) + : name(u"Prompt: "_s), value(value), promptAttachments(attachments) {} + + ChatItem(response_tag_t, bool isCurrentResponse = true) + : name(u"Response: "_s), isCurrentResponse(isCurrentResponse) {} + + Type type() const { - QStringList attachedContexts; - for (auto attached : promptAttachments) - attachedContexts << attached.processedContent(); - - QString promptPlus = value; - if (!attachedContexts.isEmpty()) - promptPlus = attachedContexts.join("\n\n") + "\n\n" + value; - return promptPlus; + if (name == u"System: "_s) + return Type::System; + if (name == u"Prompt: "_s) + return Type::Prompt; + if (name == u"Response: "_s) + return Type::Response; + throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name)); + } + + // used with version 0 Jinja templates + QString bakedPrompt() const + { + if (type() != Type::Prompt) + throw std::logic_error("bakedPrompt() called on non-prompt item"); + QStringList parts; + if (!sources.isEmpty()) { + parts << u"### Context:\n"_s; + for (auto &source : std::as_const(sources)) + parts << u"Collection: "_s << source.collection + << u"\nPath: "_s << source.path + << u"\nExcerpt: "_s << source.text << u"\n\n"_s; + } + for (auto &attached : std::as_const(promptAttachments)) + parts << attached.processedContent() << u"\n\n"_s; + parts << value; + return parts.join(QString()); } // TODO: Maybe we should include the model name here as well as timestamp? QString name; QString value; - QString newResponse; - QList sources; - QList consolidatedSources; + + // prompts + QList sources; + QList consolidatedSources; QList promptAttachments; - bool currentResponse = false; - bool stopped = false; - bool thumbsUpState = false; - bool thumbsDownState = false; + + // responses + bool isCurrentResponse = false; + bool isError = false; + + // responses (DataLake) + QString newResponse; + bool stopped = false; + bool thumbsUpState = false; + bool thumbsDownState = false; }; Q_DECLARE_METATYPE(ChatItem) -using ChatModelIterator = QList::const_iterator; +class ChatModelAccessor : public ranges::subrange::const_iterator> { +private: + using Super = ranges::subrange::const_iterator>; + +public: + template + ChatModelAccessor(QMutex &mutex, T &&...args) + : Super(std::forward(args)...), m_lock(&mutex) {} + +private: + QMutexLocker m_lock; +}; class ChatModel : public QAbstractListModel { Q_OBJECT Q_PROPERTY(int count READ count NOTIFY countChanged) + Q_PROPERTY(bool hasError READ hasError NOTIFY hasErrorChanged) public: - explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {} + explicit ChatModel(QObject *parent = nullptr) + : QAbstractListModel(parent) {} + // FIXME(jared): can't this start at Qt::UserRole (no +1)? enum Roles { NameRole = Qt::UserRole + 1, ValueRole, + + // prompts and responses + PeerRole, + + // prompts + PromptAttachmentsRole, + + // responses + // NOTE: sources are stored on the *prompts*, but in the model, they are only on the *responses*! + SourcesRole, + ConsolidatedSourcesRole, + IsCurrentResponseRole, + IsErrorRole, + + // responses (DataLake) NewResponseRole, - CurrentResponseRole, StoppedRole, ThumbsUpStateRole, ThumbsDownStateRole, - SourcesRole, - ConsolidatedSourcesRole, - PromptAttachmentsRole }; int rowCount(const QModelIndex &parent = QModelIndex()) const override @@ -129,34 +220,96 @@ class ChatModel : public QAbstractListModel return m_chatItems.size(); } + /* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs + * sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */ + auto getPeerUnlocked(QList::const_iterator item) const + -> std::optional::const_iterator> + { + switch (item->type()) { + using enum ChatItem::Type; + case Prompt: + { + auto peer = std::next(item); + if (peer < m_chatItems.cend() && peer->type() == Response) + return peer; + break; + } + case Response: + { + if (item > m_chatItems.cbegin()) { + if (auto peer = std::prev(item); peer->type() == Prompt) + return peer; + } + break; + } + default: + throw std::invalid_argument("getPeer() called on item that is not a prompt or response"); + } + return std::nullopt; + } + + auto getPeerUnlocked(int index) const -> std::optional + { + return getPeerUnlocked(m_chatItems.cbegin() + index) + .transform([&](auto &&i) { return i - m_chatItems.cbegin(); } ); + } + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override { QMutexLocker locker(&m_mutex); if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) return QVariant(); - const ChatItem &item = m_chatItems.at(index.row()); + auto item = m_chatItems.cbegin() + index.row(); switch (role) { case NameRole: - return item.name; + return item->name; case ValueRole: - return item.value; + return item->value; + case PeerRole: + switch (item->type()) { + using enum ChatItem::Type; + case Prompt: + case Response: + { + auto peer = getPeerUnlocked(item); + return peer ? QVariant::fromValue(**peer) : QVariant::fromValue(nullptr); + } + default: + return QVariant(); + } + case PromptAttachmentsRole: + return QVariant::fromValue(item->promptAttachments); + case SourcesRole: + { + QList data; + if (item->type() == ChatItem::Type::Response) { + if (auto prompt = getPeerUnlocked(item)) + data = (*prompt)->consolidatedSources; + } + return QVariant::fromValue(data); + } + case ConsolidatedSourcesRole: + { + QList data; + if (item->type() == ChatItem::Type::Response) { + if (auto prompt = getPeerUnlocked(item)) + data = (*prompt)->sources; + } + return QVariant::fromValue(data); + } + case IsCurrentResponseRole: + return item->isCurrentResponse; case NewResponseRole: - return item.newResponse; - case CurrentResponseRole: - return item.currentResponse; + return item->newResponse; case StoppedRole: - return item.stopped; + return item->stopped; case ThumbsUpStateRole: - return item.thumbsUpState; + return item->thumbsUpState; case ThumbsDownStateRole: - return item.thumbsDownState; - case SourcesRole: - return QVariant::fromValue(item.sources); - case ConsolidatedSourcesRole: - return QVariant::fromValue(item.consolidatedSources); - case PromptAttachmentsRole: - return QVariant::fromValue(item.promptAttachments); + return item->thumbsDownState; + case IsErrorRole: + return item->type() == ChatItem::Type::Response && item->isError; } return QVariant(); @@ -164,54 +317,126 @@ class ChatModel : public QAbstractListModel QHash roleNames() const override { - QHash roles; - roles[NameRole] = "name"; - roles[ValueRole] = "value"; - roles[NewResponseRole] = "newResponse"; - roles[CurrentResponseRole] = "currentResponse"; - roles[StoppedRole] = "stopped"; - roles[ThumbsUpStateRole] = "thumbsUpState"; - roles[ThumbsDownStateRole] = "thumbsDownState"; - roles[SourcesRole] = "sources"; - roles[ConsolidatedSourcesRole] = "consolidatedSources"; - roles[PromptAttachmentsRole] = "promptAttachments"; - return roles; + return { + { NameRole, "name" }, + { ValueRole, "value" }, + { PeerRole, "peer" }, + { PromptAttachmentsRole, "promptAttachments" }, + { SourcesRole, "sources" }, + { ConsolidatedSourcesRole, "consolidatedSources" }, + { IsCurrentResponseRole, "isCurrentResponse" }, + { IsErrorRole, "isError" }, + { NewResponseRole, "newResponse" }, + { StoppedRole, "stopped" }, + { ThumbsUpStateRole, "thumbsUpState" }, + { ThumbsDownStateRole, "thumbsDownState" }, + }; + } + + void appendPrompt(const QString &value, const QList &attachments = {}) + { + qsizetype count; + { + QMutexLocker locker(&m_mutex); + if (hasErrorUnlocked()) + throw std::logic_error("cannot append to a failed chat"); + count = m_chatItems.count(); + } + + beginInsertRows(QModelIndex(), count, count); + { + QMutexLocker locker(&m_mutex); + m_chatItems.emplace_back(ChatItem::prompt_tag, value, attachments); + } + endInsertRows(); + emit countChanged(); } - void appendPrompt(const QString &name, const QString &value, const QList &attachments) + void appendResponse(int promptIndex) { - ChatItem item; - item.name = name; - item.value = value; - item.promptAttachments << attachments; + qsizetype count; + { + QMutexLocker locker(&m_mutex); + if (hasErrorUnlocked()) + throw std::logic_error("cannot append to a failed chat"); + count = m_chatItems.count(); + } - m_mutex.lock(); - const int count = m_chatItems.count(); - m_mutex.unlock(); beginInsertRows(QModelIndex(), count, count); { QMutexLocker locker(&m_mutex); - m_chatItems.append(item); + if (promptIndex >= 0) { + if (promptIndex >= m_chatItems.size()) + throw std::out_of_range(fmt::format("index {} is out of range", promptIndex)); + auto &promptItem = m_chatItems[promptIndex]; + if (promptItem.type() != ChatItem::Type::Prompt) + throw std::invalid_argument(fmt::format("item at index {} is not a prompt", promptIndex)); + } + m_chatItems.emplace_back(ChatItem::response_tag, promptIndex); } endInsertRows(); emit countChanged(); + if (promptIndex >= 0) + emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole}); } - void appendResponse(const QString &name) + // Used by Server to append a new conversation to the chat log. + void appendResponseWithHistory(std::span history) { + if (history.empty()) + throw std::invalid_argument("at least one message is required"); + m_mutex.lock(); - const int count = m_chatItems.count(); + qsizetype startIndex = m_chatItems.count(); m_mutex.unlock(); - ChatItem item; - item.name = name; - item.currentResponse = true; - beginInsertRows(QModelIndex(), count, count); + + qsizetype nNewItems = history.size() + 1; + qsizetype endIndex = startIndex + nNewItems; + beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/); + bool hadError; + int promptIndex; { QMutexLocker locker(&m_mutex); - m_chatItems.append(item); + hadError = hasErrorUnlocked(); + m_chatItems.reserve(m_chatItems.count() + nNewItems); + for (auto &item : history) + m_chatItems << item; + m_chatItems.emplace_back(ChatItem::response_tag); } endInsertRows(); emit countChanged(); + // Server can add messages when there is an error because each call is a new conversation + if (hadError) + emit hasErrorChanged(false); + if (promptIndex >= 0) + emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole}); + } + + void truncate(qsizetype size) + { + qsizetype oldSize; + { + QMutexLocker locker(&m_mutex); + if (size >= (oldSize = m_chatItems.size())) + return; + if (size && m_chatItems.at(size - 1).type() != ChatItem::Type::Response) + throw std::invalid_argument( + fmt::format("chat model truncated to {} items would not end in a response", size) + ); + } + + bool oldHasError; + beginRemoveRows(QModelIndex(), size, oldSize - 1 /*inclusive*/); + { + QMutexLocker locker(&m_mutex); + oldHasError = hasErrorUnlocked(); + Q_ASSERT(size < m_chatItems.size()); + m_chatItems.resize(size); + } + endRemoveRows(); + emit countChanged(); + if (oldHasError) + emit hasErrorChanged(false); } Q_INVOKABLE void clear() @@ -221,13 +446,17 @@ class ChatModel : public QAbstractListModel if (m_chatItems.isEmpty()) return; } + bool oldHasError; beginResetModel(); { QMutexLocker locker(&m_mutex); + oldHasError = hasErrorUnlocked(); m_chatItems.clear(); } endResetModel(); emit countChanged(); + if (oldHasError) + emit hasErrorChanged(false); } Q_INVOKABLE ChatItem get(int index) @@ -245,13 +474,13 @@ class ChatModel : public QAbstractListModel if (index < 0 || index >= m_chatItems.size()) return; ChatItem &item = m_chatItems[index]; - if (item.currentResponse != b) { - item.currentResponse = b; + if (item.isCurrentResponse != b) { + item.isCurrentResponse = b; changed = true; } } - if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole}); + if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole}); } Q_INVOKABLE void updateStopped(int index, bool b) @@ -304,16 +533,23 @@ class ChatModel : public QAbstractListModel Q_INVOKABLE void updateSources(int index, const QList &sources) { + int responseIndex = -1; { QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - item.sources = sources; - item.consolidatedSources = consolidateSources(sources); + auto promptItem = m_chatItems.begin() + index; + if (promptItem->type() != ChatItem::Type::Prompt) + throw std::invalid_argument(fmt::format("item at index {} is not a prompt", index)); + if (auto peer = getPeerUnlocked(promptItem)) + responseIndex = *peer - m_chatItems.cbegin(); + promptItem->sources = sources; + promptItem->consolidatedSources = consolidateSources(sources); + } + if (responseIndex >= 0) { + emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {SourcesRole}); + emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {ConsolidatedSourcesRole}); } - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); } Q_INVOKABLE void updateThumbsUpState(int index, bool b) @@ -364,18 +600,56 @@ class ChatModel : public QAbstractListModel if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); } - int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); } + Q_INVOKABLE void setError(bool value = true) + { + qsizetype index; + { + QMutexLocker locker(&m_mutex); + + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1].type() != ChatItem::Type::Response) + throw std::logic_error("can only set error on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + auto &last = m_chatItems.back(); + if (last.isError == value) + return; // already set + last.isError = value; + } + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsErrorRole}); + emit hasErrorChanged(value); + } + + qsizetype count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); } + + ChatModelAccessor chatItems() const { return {m_mutex, std::as_const(m_chatItems)}; } - ChatModelIterator begin() const { return m_chatItems.begin(); } - ChatModelIterator end() const { return m_chatItems.end(); } - void lock() { m_mutex.lock(); } - void unlock() { m_mutex.unlock(); } + bool hasError() const { QMutexLocker locker(&m_mutex); return hasErrorUnlocked(); } bool serialize(QDataStream &stream, int version) const { QMutexLocker locker(&m_mutex); stream << int(m_chatItems.size()); - for (const auto &c : m_chatItems) { + for (auto itemIt = m_chatItems.cbegin(); itemIt < m_chatItems.cend(); ++itemIt) { + auto c = *itemIt; // NB: copies + if (version < 11) { + // move sources from their prompt to the next response + switch (c.type()) { + using enum ChatItem::Type; + case Prompt: + c.sources.clear(); + c.consolidatedSources.clear(); + break; + case Response: + // note: we drop sources for responseless prompts + if (auto peer = getPeerUnlocked(itemIt)) { + c.sources = (*peer)->sources; + c.consolidatedSources = (*peer)->consolidatedSources; + } + default: + ; + } + } + // FIXME: This 'id' should be eliminated the next time we bump serialization version. // (Jared) This was apparently never used. int id = 0; @@ -383,10 +657,12 @@ class ChatModel : public QAbstractListModel stream << c.name; stream << c.value; stream << c.newResponse; - stream << c.currentResponse; + stream << c.isCurrentResponse; stream << c.stopped; stream << c.thumbsUpState; stream << c.thumbsDownState; + if (version >= 11 && c.type() == ChatItem::Type::Response) + stream << c.isError; if (version >= 8) { stream << c.sources.size(); for (const ResultInfo &info : c.sources) { @@ -452,14 +728,24 @@ class ChatModel : public QAbstractListModel bool deserialize(QDataStream &stream, int version) { + clear(); // reset to known state + int size; stream >> size; + int lastPromptIndex = -1; + QList chatItems; for (int i = 0; i < size; ++i) { ChatItem c; // FIXME: see comment in serialization about id int id; stream >> id; stream >> c.name; + try { + c.type(); // check name + } catch (const std::exception &e) { + qWarning() << "ChatModel ERROR:" << e.what(); + return false; + } stream >> c.value; if (version < 10) { // This is deprecated and no longer used @@ -467,10 +753,12 @@ class ChatModel : public QAbstractListModel stream >> prompt; } stream >> c.newResponse; - stream >> c.currentResponse; + stream >> c.isCurrentResponse; stream >> c.stopped; stream >> c.thumbsUpState; stream >> c.thumbsDownState; + if (version >= 11 && c.type() == ChatItem::Type::Response) + stream >> c.isError; if (version >= 8) { qsizetype count; stream >> count; @@ -587,23 +875,53 @@ class ChatModel : public QAbstractListModel } c.promptAttachments = attachments; } - m_mutex.lock(); - const int count = m_chatItems.size(); - m_mutex.unlock(); - beginInsertRows(QModelIndex(), count, count); - { - QMutexLocker locker(&m_mutex); - m_chatItems.append(c); + + if (version < 11 && c.type() == ChatItem::Type::Response) { + // move sources from the response to their last prompt + if (lastPromptIndex >= 0) { + auto &prompt = chatItems[lastPromptIndex]; + prompt.sources = std::move(c.sources ); + prompt.consolidatedSources = std::move(c.consolidatedSources); + lastPromptIndex = -1; + } else { + // drop sources for promptless responses + c.sources.clear(); + c.consolidatedSources.clear(); + } } - endInsertRows(); + + chatItems << c; + if (c.type() == ChatItem::Type::Prompt) + lastPromptIndex = chatItems.size() - 1; } + + bool hasError; + beginInsertRows(QModelIndex(), 0, chatItems.size() - 1 /*inclusive*/); + { + QMutexLocker locker(&m_mutex); + m_chatItems = chatItems; + hasError = hasErrorUnlocked(); + } + endInsertRows(); emit countChanged(); + if (hasError) + emit hasErrorChanged(true); return stream.status() == QDataStream::Ok; } Q_SIGNALS: void countChanged(); void valueChanged(int index, const QString &value); + void hasErrorChanged(bool value); + +private: + bool hasErrorUnlocked() const + { + if (m_chatItems.isEmpty()) + return false; + auto &last = m_chatItems.back(); + return last.type() == ChatItem::Type::Response && last.isError; + } private: mutable QMutex m_mutex; diff --git a/gpt4all-chat/src/jinja_helpers.cpp b/gpt4all-chat/src/jinja_helpers.cpp new file mode 100644 index 000000000000..826dfb01e812 --- /dev/null +++ b/gpt4all-chat/src/jinja_helpers.cpp @@ -0,0 +1,111 @@ +#include "jinja_helpers.h" + +#include "utils.h" + +#include + +#include +#include + +#include +#include + +using namespace std::literals::string_view_literals; + + +JinjaResultInfo::~JinjaResultInfo() = default; + +const JinjaFieldMap JinjaResultInfo::s_fields = { + { "collection", [](auto &s) { return s.collection.toStdString(); } }, + { "path", [](auto &s) { return s.path .toStdString(); } }, + { "file", [](auto &s) { return s.file .toStdString(); } }, + { "title", [](auto &s) { return s.title .toStdString(); } }, + { "author", [](auto &s) { return s.author .toStdString(); } }, + { "date", [](auto &s) { return s.date .toStdString(); } }, + { "text", [](auto &s) { return s.text .toStdString(); } }, + { "page", [](auto &s) { return s.page; } }, + { "file_uri", [](auto &s) { return s.fileUri() .toStdString(); } }, +}; + +JinjaPromptAttachment::~JinjaPromptAttachment() = default; + +const JinjaFieldMap JinjaPromptAttachment::s_fields = { + { "url", [](auto &s) { return s.url.toString() .toStdString(); } }, + { "file", [](auto &s) { return s.file() .toStdString(); } }, + { "processed_content", [](auto &s) { return s.processedContent().toStdString(); } }, +}; + +std::vector JinjaMessage::GetKeys() const +{ + std::vector result; + auto &keys = this->keys(); + result.reserve(keys.size()); + result.assign(keys.begin(), keys.end()); + return result; +} + +auto JinjaMessage::keys() const -> const std::unordered_set & +{ + static const std::unordered_set baseKeys + { "role", "content" }; + static const std::unordered_set userKeys + { "role", "content", "sources", "prompt_attachments" }; + switch (m_item->type()) { + using enum ChatItem::Type; + case System: + case Response: + return baseKeys; + case Prompt: + return userKeys; + } + Q_UNREACHABLE(); +} + +bool operator==(const JinjaMessage &a, const JinjaMessage &b) +{ + if (a.m_item == b.m_item) + return true; + const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item); + auto type = ia.type(); + if (type != ib.type() || ia.value != ib.value) + return false; + + switch (type) { + using enum ChatItem::Type; + case System: + case Response: + return true; + case Prompt: + return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments; + } + Q_UNREACHABLE(); +} + +const JinjaFieldMap JinjaMessage::s_fields = { + { "role", [](auto &m) { + switch (m.item().type()) { + using enum ChatItem::Type; + case System: return "system"sv; + case Prompt: return "user"sv; + case Response: return "assistant"sv; + } + Q_UNREACHABLE(); + } }, + { "content", [](auto &m) { + if (m.version() == 0 && m.item().type() == ChatItem::Type::Prompt) + return m.item().bakedPrompt().toStdString(); + return m.item().value.toStdString(); + } }, + { "sources", [](auto &m) { + auto sources = m.item().sources | views::transform([](auto &r) { + return jinja2::GenericMap([map = std::make_shared(r)] { return map.get(); }); + }); + return jinja2::ValuesList(sources.begin(), sources.end()); + } }, + { "prompt_attachments", [](auto &m) { + auto attachments = m.item().promptAttachments | views::transform([](auto &pa) { + return jinja2::GenericMap([map = std::make_shared(pa)] { return map.get(); }); + }); + return jinja2::ValuesList(attachments.begin(), attachments.end()); + } }, +}; diff --git a/gpt4all-chat/src/jinja_helpers.h b/gpt4all-chat/src/jinja_helpers.h new file mode 100644 index 000000000000..a196b47f8fdf --- /dev/null +++ b/gpt4all-chat/src/jinja_helpers.h @@ -0,0 +1,116 @@ +#pragma once + +#include "chatmodel.h" +#include "database.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace views = std::views; + + +template +using JinjaFieldMap = std::unordered_map>; + +template +class JinjaComparable : public jinja2::IMapItemAccessor { +public: + JinjaComparable() = default; + + bool IsEqual(const jinja2::IComparable &other) const override; + +private: + Q_DISABLE_COPY_MOVE(JinjaComparable) +}; + +template +class JinjaHelper : public JinjaComparable { +public: + size_t GetSize() const override + { return Derived::s_fields.size(); } + + bool HasValue(const std::string &name) const override + { return Derived::s_fields.contains(name); } + + jinja2::Value GetValueByName(const std::string &name) const override; + + std::vector GetKeys() const override + { auto keys = views::elements<0>(Derived::s_fields); return { keys.begin(), keys.end() }; } +}; + +class JinjaResultInfo : public JinjaHelper { +public: + explicit JinjaResultInfo(const ResultInfo &source) noexcept + : m_source(&source) {} + + ~JinjaResultInfo() override; + + const ResultInfo &value() const { return *m_source; } + + friend bool operator==(const JinjaResultInfo &a, const JinjaResultInfo &b) + { return a.m_source == b.m_source || *a.m_source == *b.m_source; } + +private: + static const JinjaFieldMap s_fields; + const ResultInfo *m_source; + + friend class JinjaHelper; +}; + +class JinjaPromptAttachment : public JinjaHelper { +public: + explicit JinjaPromptAttachment(const PromptAttachment &attachment) noexcept + : m_attachment(&attachment) {} + + ~JinjaPromptAttachment() override; + + const PromptAttachment &value() const { return *m_attachment; } + + friend bool operator==(const JinjaPromptAttachment &a, const JinjaPromptAttachment &b) + { return a.m_attachment == b.m_attachment || *a.m_attachment == *b.m_attachment; } + +private: + static const JinjaFieldMap s_fields; + const PromptAttachment *m_attachment; + + friend class JinjaHelper; +}; + +class JinjaMessage : public JinjaHelper { +public: + explicit JinjaMessage(uint version, const ChatItem &item) noexcept + : m_version(version), m_item(&item) {} + + const JinjaMessage &value () const { return *this; } + uint version() const { return m_version; } + const ChatItem &item () const { return *m_item; } + + size_t GetSize() const override { return keys().size(); } + bool HasValue(const std::string &name) const override { return keys().contains(name); } + + jinja2::Value GetValueByName(const std::string &name) const override + { return HasValue(name) ? JinjaHelper::GetValueByName(name) : jinja2::EmptyValue(); } + + std::vector GetKeys() const override; + +private: + auto keys() const -> const std::unordered_set &; + +private: + static const JinjaFieldMap s_fields; + uint m_version; + const ChatItem *m_item; + + friend class JinjaHelper; + friend bool operator==(const JinjaMessage &a, const JinjaMessage &b); +}; + +#include "jinja_helpers.inl" diff --git a/gpt4all-chat/src/jinja_helpers.inl b/gpt4all-chat/src/jinja_helpers.inl new file mode 100644 index 000000000000..a8ae938b2e76 --- /dev/null +++ b/gpt4all-chat/src/jinja_helpers.inl @@ -0,0 +1,17 @@ +template +bool JinjaComparable::IsEqual(const jinja2::IComparable &other) const +{ + if (auto *omsg = dynamic_cast(&other)) + return *static_cast(this) == *omsg; + return false; +} + +template +jinja2::Value JinjaHelper::GetValueByName(const std::string &name) const +{ + if (auto it = D::s_fields.find(name); it != D::s_fields.end()) { + auto [_, func] = *it; + return func(static_cast(this)->value()); + } + return jinja2::EmptyValue(); +} diff --git a/gpt4all-chat/src/main.cpp b/gpt4all-chat/src/main.cpp index b07867c6f093..0fc23be3c961 100644 --- a/gpt4all-chat/src/main.cpp +++ b/gpt4all-chat/src/main.cpp @@ -12,12 +12,16 @@ #include #include +#include +#include #include #include +#include #include #include #include #include +#include #include #ifdef Q_OS_LINUX @@ -91,18 +95,22 @@ int main(int argc, char *argv[]) // Set the local and language translation before the qml engine has even been started. This will // use the default system locale unless the user has explicitly set it to use a different one. - MySettings::globalInstance()->setLanguageAndLocale(); + auto *mySettings = MySettings::globalInstance(); + mySettings->setLanguageAndLocale(); QQmlApplicationEngine engine; // Add a connection here from MySettings::languageAndLocaleChanged signal to a lambda slot where I can call // engine.uiLanguage property - QObject::connect(MySettings::globalInstance(), &MySettings::languageAndLocaleChanged, [&engine]() { + QObject::connect(mySettings, &MySettings::languageAndLocaleChanged, [&engine]() { engine.setUiLanguage(MySettings::globalInstance()->languageAndLocale()); }); - qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", MySettings::globalInstance()); - qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", ModelList::globalInstance()); + auto *modelList = ModelList::globalInstance(); + QObject::connect(modelList, &ModelList::dataChanged, mySettings, &MySettings::onModelInfoChanged); + + qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", mySettings); + qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", modelList); qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance()); qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance()); qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); @@ -110,6 +118,11 @@ int main(int argc, char *argv[]) qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance()); qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums"); + { + auto fixedFont = QFontDatabase::systemFont(QFontDatabase::FixedFont); + engine.rootContext()->setContextProperty("fixedFont", fixedFont); + } + const QUrl url(u"qrc:/gpt4all/main.qml"_s); QObject::connect(&engine, &QQmlApplicationEngine::objectCreated, diff --git a/gpt4all-chat/src/modellist.cpp b/gpt4all-chat/src/modellist.cpp index 7cd36094a3e7..4798c69bdd8e 100644 --- a/gpt4all-chat/src/modellist.cpp +++ b/gpt4all-chat/src/modellist.cpp @@ -316,26 +316,44 @@ void ModelInfo::setRepeatPenaltyTokens(int t) m_repeatPenaltyTokens = t; } -QString ModelInfo::promptTemplate() const -{ - return MySettings::globalInstance()->modelPromptTemplate(*this); +QVariant ModelInfo::defaultChatTemplate() const +{ + auto res = m_chatTemplate.or_else([this] -> std::optional { + if (!installed || isOnline) + return std::nullopt; + if (!m_modelChatTemplate) { + auto path = (dirpath + filename()).toUtf8(); + auto res = LLModel::Implementation::chatTemplate(path.constData()); + if (res) { + m_modelChatTemplate = QString::fromStdString(*res); + } else { + qWarning().nospace() << "failed to get chat template for " << filename() << ": " << res.error().c_str(); + m_modelChatTemplate = QString(); // do not retry + } + } + if (m_modelChatTemplate->isNull()) + return std::nullopt; + return m_modelChatTemplate; + }); + + if (res) + return std::move(*res); + return QVariant::fromValue(nullptr); } -void ModelInfo::setPromptTemplate(const QString &t) +auto ModelInfo::chatTemplate() const -> UpgradeableSetting { - if (shouldSaveMetadata()) MySettings::globalInstance()->setModelPromptTemplate(*this, t, true /*force*/); - m_promptTemplate = t; + return MySettings::globalInstance()->modelChatTemplate(*this); } -QString ModelInfo::systemPrompt() const +QString ModelInfo::defaultSystemMessage() const { - return MySettings::globalInstance()->modelSystemPrompt(*this); + return m_systemMessage; } -void ModelInfo::setSystemPrompt(const QString &p) +auto ModelInfo::systemMessage() const -> UpgradeableSetting { - if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPrompt(*this, p, true /*force*/); - m_systemPrompt = p; + return MySettings::globalInstance()->modelSystemMessage(*this); } QString ModelInfo::chatNamePrompt() const @@ -360,39 +378,41 @@ void ModelInfo::setSuggestedFollowUpPrompt(const QString &p) m_suggestedFollowUpPrompt = p; } +// FIXME(jared): this should not be used for model settings that have meaningful defaults, such as temperature bool ModelInfo::shouldSaveMetadata() const { return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/); } -QVariantMap ModelInfo::getFields() const -{ - return { - { "filename", m_filename }, - { "description", m_description }, - { "url", m_url }, - { "quant", m_quant }, - { "type", m_type }, - { "isClone", m_isClone }, - { "isDiscovered", m_isDiscovered }, - { "likes", m_likes }, - { "downloads", m_downloads }, - { "recency", m_recency }, - { "temperature", m_temperature }, - { "topP", m_topP }, - { "minP", m_minP }, - { "topK", m_topK }, - { "maxLength", m_maxLength }, - { "promptBatchSize", m_promptBatchSize }, - { "contextLength", m_contextLength }, - { "gpuLayers", m_gpuLayers }, - { "repeatPenalty", m_repeatPenalty }, - { "repeatPenaltyTokens", m_repeatPenaltyTokens }, - { "promptTemplate", m_promptTemplate }, - { "systemPrompt", m_systemPrompt }, - { "chatNamePrompt", m_chatNamePrompt }, - { "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt }, +QVariant ModelInfo::getField(QLatin1StringView name) const +{ + static const std::unordered_map s_fields = { + { "filename"_L1, [](auto &i) -> QVariant { return i.m_filename; } }, + { "description"_L1, [](auto &i) -> QVariant { return i.m_description; } }, + { "url"_L1, [](auto &i) -> QVariant { return i.m_url; } }, + { "quant"_L1, [](auto &i) -> QVariant { return i.m_quant; } }, + { "type"_L1, [](auto &i) -> QVariant { return i.m_type; } }, + { "isClone"_L1, [](auto &i) -> QVariant { return i.m_isClone; } }, + { "isDiscovered"_L1, [](auto &i) -> QVariant { return i.m_isDiscovered; } }, + { "likes"_L1, [](auto &i) -> QVariant { return i.m_likes; } }, + { "downloads"_L1, [](auto &i) -> QVariant { return i.m_downloads; } }, + { "recency"_L1, [](auto &i) -> QVariant { return i.m_recency; } }, + { "temperature"_L1, [](auto &i) -> QVariant { return i.m_temperature; } }, + { "topP"_L1, [](auto &i) -> QVariant { return i.m_topP; } }, + { "minP"_L1, [](auto &i) -> QVariant { return i.m_minP; } }, + { "topK"_L1, [](auto &i) -> QVariant { return i.m_topK; } }, + { "maxLength"_L1, [](auto &i) -> QVariant { return i.m_maxLength; } }, + { "promptBatchSize"_L1, [](auto &i) -> QVariant { return i.m_promptBatchSize; } }, + { "contextLength"_L1, [](auto &i) -> QVariant { return i.m_contextLength; } }, + { "gpuLayers"_L1, [](auto &i) -> QVariant { return i.m_gpuLayers; } }, + { "repeatPenalty"_L1, [](auto &i) -> QVariant { return i.m_repeatPenalty; } }, + { "repeatPenaltyTokens"_L1, [](auto &i) -> QVariant { return i.m_repeatPenaltyTokens; } }, + { "chatTemplate"_L1, [](auto &i) -> QVariant { return i.defaultChatTemplate(); } }, + { "systemMessage"_L1, [](auto &i) -> QVariant { return i.m_systemMessage; } }, + { "chatNamePrompt"_L1, [](auto &i) -> QVariant { return i.m_chatNamePrompt; } }, + { "suggestedFollowUpPrompt"_L1, [](auto &i) -> QVariant { return i.m_suggestedFollowUpPrompt; } }, }; + return s_fields.at(name)(*this); } InstalledModels::InstalledModels(QObject *parent, bool selectable) @@ -491,31 +511,48 @@ ModelList::ModelList() m_selectableModels->setSourceModel(this); m_downloadableModels->setSourceModel(this); - connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory); - connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson); - connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings); - connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings); - connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings); + auto *mySettings = MySettings::globalInstance(); + connect(mySettings, &MySettings::nameChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::topPChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::minPChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::topKChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings ); + connect(mySettings, &MySettings::chatTemplateChanged, this, &ModelList::maybeUpdateDataForSettings); + connect(mySettings, &MySettings::systemMessageChanged, this, &ModelList::maybeUpdateDataForSettings); + + connect(this, &ModelList::dataChanged, this, &ModelList::onDataChanged); + connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors); updateModelsFromJson(); updateModelsFromSettings(); updateModelsFromDirectory(); + connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory); + connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson ); + connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings ); + QCoreApplication::instance()->installEventFilter(this); } +// an easier way to listen for model info and setting changes +void ModelList::onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList &roles) +{ + Q_UNUSED(roles) + for (int row = topLeft.row(); row <= bottomRight.row(); row++) { + auto index = topLeft.siblingAtRow(row); + auto id = index.data(ModelList::IdRole).toString(); + if (auto info = modelInfo(id); !info.id().isNull()) + emit modelInfoChanged(info); + } +} + QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) { QCryptographicHash sha256(QCryptographicHash::Sha256); sha256.addData((baseUrl.toString() + "_" + modelName).toUtf8()); @@ -776,10 +813,10 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->repeatPenalty(); case RepeatPenaltyTokensRole: return info->repeatPenaltyTokens(); - case PromptTemplateRole: - return info->promptTemplate(); - case SystemPromptRole: - return info->systemPrompt(); + case ChatTemplateRole: + return QVariant::fromValue(info->chatTemplate()); + case SystemMessageRole: + return QVariant::fromValue(info->systemMessage()); case ChatNamePromptRole: return info->chatNamePrompt(); case SuggestedFollowUpPromptRole: @@ -952,10 +989,10 @@ void ModelList::updateData(const QString &id, const QVector info->setRepeatPenalty(value.toDouble()); break; case RepeatPenaltyTokensRole: info->setRepeatPenaltyTokens(value.toInt()); break; - case PromptTemplateRole: - info->setPromptTemplate(value.toString()); break; - case SystemPromptRole: - info->setSystemPrompt(value.toString()); break; + case ChatTemplateRole: + info->m_chatTemplate = value.toString(); break; + case SystemMessageRole: + info->m_systemMessage = value.toString(); break; case ChatNamePromptRole: info->setChatNamePrompt(value.toString()); break; case SuggestedFollowUpPromptRole: @@ -1056,11 +1093,11 @@ ModelInfo ModelList::modelInfo(const QString &id) const return *m_modelMap.value(id); } -ModelInfo ModelList::modelInfoByFilename(const QString &filename) const +ModelInfo ModelList::modelInfoByFilename(const QString &filename, bool allowClone) const { QMutexLocker locker(&m_mutex); for (ModelInfo *info : m_models) - if (info->filename() == filename) + if (info->filename() == filename && (allowClone || !info->isClone())) return *info; return ModelInfo(); } @@ -1080,6 +1117,20 @@ QString ModelList::clone(const ModelInfo &model) const QString id = Network::globalInstance()->generateUniqueId(); addModel(id); + QString chatTemplate, systemMessage; + if (auto tmpl = model.chatTemplate().asModern()) { + chatTemplate = *tmpl; + } else { + qWarning("ModelList Warning: attempted to clone model with legacy chat template"); + return {}; + } + if (auto msg = model.systemMessage().asModern()) { + systemMessage = *msg; + } else { + qWarning("ModelList Warning: attempted to clone model with legacy system message"); + return {}; + } + QVector> data { { ModelList::InstalledRole, model.installed }, { ModelList::IsCloneRole, true }, @@ -1099,8 +1150,8 @@ QString ModelList::clone(const ModelInfo &model) { ModelList::GpuLayersRole, model.gpuLayers() }, { ModelList::RepeatPenaltyRole, model.repeatPenalty() }, { ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() }, - { ModelList::PromptTemplateRole, model.promptTemplate() }, - { ModelList::SystemPromptRole, model.systemPrompt() }, + { ModelList::ChatTemplateRole, chatTemplate }, + { ModelList::SystemMessageRole, systemMessage }, { ModelList::ChatNamePromptRole, model.chatNamePrompt() }, { ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() }, }; @@ -1125,21 +1176,23 @@ void ModelList::removeInstalled(const ModelInfo &model) removeInternal(model); } +int ModelList::indexByModelId(const QString &id) const +{ + QMutexLocker locker(&m_mutex); + if (auto it = m_modelMap.find(id); it != m_modelMap.cend()) + return m_models.indexOf(*it); + return -1; +} + void ModelList::removeInternal(const ModelInfo &model) { - const bool hasModel = contains(model.id()); - Q_ASSERT(hasModel); - if (!hasModel) { + int indexOfModel = indexByModelId(model.id()); + Q_ASSERT(indexOfModel != -1); + if (indexOfModel == -1) { qWarning() << "ERROR: model list does not contain" << model.id(); return; } - int indexOfModel = 0; - { - QMutexLocker locker(&m_mutex); - ModelInfo *info = m_modelMap.value(model.id()); - indexOfModel = m_models.indexOf(info); - } beginRemoveRows(QModelIndex(), indexOfModel, indexOfModel); { QMutexLocker locker(&m_mutex); @@ -1314,8 +1367,6 @@ void ModelList::processModelDirectory(const QString &path) // The description is hard-coded into "GPT4All.ini" due to performance issue. // If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList. data.append({ DescriptionRole, description }); - // Prompt template should be clear while using ChatML format which is using in most of OpenAI-Compatible API server. - data.append({ PromptTemplateRole, "%1" }); } updateData(id, data); } @@ -1451,9 +1502,20 @@ void ModelList::handleSslErrors(QNetworkReply *reply, const QList &er qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url; } +void ModelList::maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo) +{ + // ignore updates that were *because* of a dataChanged - would cause a circular dependency + int idx; + if (!fromInfo && (idx = indexByModelId(info.id())) != -1) { + emit dataChanged(index(idx, 0), index(idx, 0)); + emit selectableModelListChanged(); + } +} + void ModelList::updateDataForSettings() { emit dataChanged(index(0, 0), index(m_models.size() - 1, 0)); + emit selectableModelListChanged(); } void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) @@ -1560,10 +1622,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) data.append({ ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble() }); if (obj.contains("repeatPenaltyTokens")) data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() }); - if (obj.contains("promptTemplate")) - data.append({ ModelList::PromptTemplateRole, obj["promptTemplate"].toString() }); - if (obj.contains("systemPrompt")) - data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() }); + if (auto it = obj.find("chatTemplate"_L1); it != obj.end()) + data.append({ ModelList::ChatTemplateRole, it->toString() }); + if (auto it = obj.find("systemMessage"_L1); it != obj.end()) + data.append({ ModelList::SystemMessageRole, it->toString() }); updateData(id, data); } @@ -1755,6 +1817,9 @@ void ModelList::updateDiscoveredInstalled(const ModelInfo &info) updateData(info.id(), data); } +// FIXME(jared): This should only contain fields without reasonable defaults such as name, description, and URL. +// For other settings, there is no authoritative value and we should load the setting lazily like we do +// for any other override. void ModelList::updateModelsFromSettings() { QSettings settings; @@ -1769,12 +1834,27 @@ void ModelList::updateModelsFromSettings() // If we can't find the corresponding file, then ignore it as this reflects a stale model. // The file could have been deleted manually by the user for instance or temporarily renamed. - if (!settings.contains(g + "/filename") || !modelExists(settings.value(g + "/filename").toString())) - continue; + QString filename; + { + auto value = settings.value(u"%1/filename"_s.arg(g)); + if (!value.isValid() || !modelExists(filename = value.toString())) + continue; + } + + QVector> data; + + // load data from base model + // FIXME(jared): how does "Restore Defaults" work for other settings of clones which we don't do this for? + if (auto base = modelInfoByFilename(filename, /*allowClone*/ false); !base.id().isNull()) { + if (auto tmpl = base.m_chatTemplate) + data.append({ ModelList::ChatTemplateRole, *tmpl }); + if (auto msg = base.m_systemMessage; !msg.isNull()) + data.append({ ModelList::SystemMessageRole, msg }); + } addModel(id); - QVector> data; + // load data from settings if (settings.contains(g + "/name")) { const QString name = settings.value(g + "/name").toString(); data.append({ ModelList::NameRole, name }); @@ -1859,14 +1939,6 @@ void ModelList::updateModelsFromSettings() const int repeatPenaltyTokens = settings.value(g + "/repeatPenaltyTokens").toInt(); data.append({ ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens }); } - if (settings.contains(g + "/promptTemplate")) { - const QString promptTemplate = settings.value(g + "/promptTemplate").toString(); - data.append({ ModelList::PromptTemplateRole, promptTemplate }); - } - if (settings.contains(g + "/systemPrompt")) { - const QString systemPrompt = settings.value(g + "/systemPrompt").toString(); - data.append({ ModelList::SystemPromptRole, systemPrompt }); - } if (settings.contains(g + "/chatNamePrompt")) { const QString chatNamePrompt = settings.value(g + "/chatNamePrompt").toString(); data.append({ ModelList::ChatNamePromptRole, chatNamePrompt }); diff --git a/gpt4all-chat/src/modellist.h b/gpt4all-chat/src/modellist.h index 6123dde81b6c..121a8433cf62 100644 --- a/gpt4all-chat/src/modellist.h +++ b/gpt4all-chat/src/modellist.h @@ -5,12 +5,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -19,11 +21,53 @@ #include #include +#include #include using namespace Qt::Literals::StringLiterals; +class UpgradeableSetting { + Q_GADGET + QML_ANONYMOUS + + // NOTE: Unset implies there is neither a value nor a default + enum class State { Unset, Legacy, Modern }; + + Q_PROPERTY(bool isSet READ isSet ) + Q_PROPERTY(bool isLegacy READ isLegacy) + Q_PROPERTY(bool isModern READ isModern) + Q_PROPERTY(QVariant value READ value) // string or null + +public: + struct legacy_tag_t { explicit legacy_tag_t() = default; }; + static inline constexpr legacy_tag_t legacy_tag = legacy_tag_t(); + + UpgradeableSetting() : m_state(State::Unset ) {} + UpgradeableSetting(legacy_tag_t, QString value): m_state(State::Legacy), m_value(std::move(value)) {} + UpgradeableSetting( QString value): m_state(State::Modern), m_value(std::move(value)) {} + + bool isSet () const { return m_state != State::Unset; } + bool isLegacy() const { return m_state == State::Legacy; } + bool isModern() const { return m_state == State::Modern; } + QVariant value () const { return m_state == State::Unset ? QVariant::fromValue(nullptr) : m_value; } + + friend bool operator==(const UpgradeableSetting &a, const UpgradeableSetting &b) + { return a.m_state == b.m_state && (a.m_state == State::Unset || a.m_value == b.m_value); } + + // returns std::nullopt if there is a legacy template or it is not set + std::optional asModern() const + { + if (m_state == State::Modern) + return m_value; + return std::nullopt; + } + +private: + State m_state; + QString m_value; +}; + struct ModelInfo { Q_GADGET Q_PROPERTY(QString id READ id WRITE setId) @@ -69,8 +113,11 @@ struct ModelInfo { Q_PROPERTY(int maxGpuLayers READ maxGpuLayers) Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty) Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens) - Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate) - Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt) + // user-defined chat template and system message must be written through settings because of their legacy compat + Q_PROPERTY(QVariant defaultChatTemplate READ defaultChatTemplate ) + Q_PROPERTY(UpgradeableSetting chatTemplate READ chatTemplate ) + Q_PROPERTY(QString defaultSystemMessage READ defaultSystemMessage) + Q_PROPERTY(UpgradeableSetting systemMessage READ systemMessage ) Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt) Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt) Q_PROPERTY(int likes READ likes WRITE setLikes) @@ -178,19 +225,22 @@ struct ModelInfo { void setRepeatPenalty(double p); int repeatPenaltyTokens() const; void setRepeatPenaltyTokens(int t); - QString promptTemplate() const; - void setPromptTemplate(const QString &t); - QString systemPrompt() const; - void setSystemPrompt(const QString &p); + QVariant defaultChatTemplate() const; + UpgradeableSetting chatTemplate() const; + QString defaultSystemMessage() const; + UpgradeableSetting systemMessage() const; QString chatNamePrompt() const; void setChatNamePrompt(const QString &p); QString suggestedFollowUpPrompt() const; void setSuggestedFollowUpPrompt(const QString &p); + // Some metadata must be saved to settings because it does not have a meaningful default from some other source. + // This is useful for fields such as name, description, and URL. + // It is true for any models that have not been installed from models.json. bool shouldSaveMetadata() const; private: - QVariantMap getFields() const; + QVariant getField(QLatin1StringView name) const; QString m_id; QString m_name; @@ -216,11 +266,13 @@ struct ModelInfo { mutable int m_maxGpuLayers = -1; double m_repeatPenalty = 1.18; int m_repeatPenaltyTokens = 64; - QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n"; - QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n"; + std::optional m_chatTemplate; + mutable std::optional m_modelChatTemplate; + QString m_systemMessage; QString m_chatNamePrompt = "Describe the above conversation in seven words or less."; QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts."; friend class MySettings; + friend class ModelList; }; Q_DECLARE_METATYPE(ModelInfo) @@ -340,8 +392,8 @@ class ModelList : public QAbstractListModel GpuLayersRole, RepeatPenaltyRole, RepeatPenaltyTokensRole, - PromptTemplateRole, - SystemPromptRole, + ChatTemplateRole, + SystemMessageRole, ChatNamePromptRole, SuggestedFollowUpPromptRole, MinPRole, @@ -394,8 +446,8 @@ class ModelList : public QAbstractListModel roles[GpuLayersRole] = "gpuLayers"; roles[RepeatPenaltyRole] = "repeatPenalty"; roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens"; - roles[PromptTemplateRole] = "promptTemplate"; - roles[SystemPromptRole] = "systemPrompt"; + roles[ChatTemplateRole] = "chatTemplate"; + roles[SystemMessageRole] = "systemMessage"; roles[ChatNamePromptRole] = "chatNamePrompt"; roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt"; roles[LikesRole] = "likes"; @@ -416,7 +468,7 @@ class ModelList : public QAbstractListModel bool contains(const QString &id) const; bool containsByFilename(const QString &filename) const; Q_INVOKABLE ModelInfo modelInfo(const QString &id) const; - Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename) const; + Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename, bool allowClone = true) const; Q_INVOKABLE bool isUniqueName(const QString &name) const; Q_INVOKABLE QString clone(const ModelInfo &model); Q_INVOKABLE void removeClone(const ModelInfo &model); @@ -476,15 +528,18 @@ class ModelList : public QAbstractListModel void discoverSortChanged(); void discoverProgressChanged(); void discoverInProgressChanged(); + void modelInfoChanged(const ModelInfo &info); protected: bool eventFilter(QObject *obj, QEvent *ev) override; private Q_SLOTS: + void onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList &roles); void resortModel(); void updateModelsFromJson(); void updateModelsFromJsonAsync(); void updateModelsFromSettings(); + void maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo); void updateDataForSettings(); void handleModelsJsonDownloadFinished(); void handleModelsJsonDownloadErrorOccurred(QNetworkReply::NetworkError code); @@ -495,6 +550,9 @@ private Q_SLOTS: void handleSslErrors(QNetworkReply *reply, const QList &errors); private: + // Return the index of the model with the given id, or -1 if not found. + int indexByModelId(const QString &id) const; + void removeInternal(const ModelInfo &model); void clearDiscoveredModels(); bool modelExists(const QString &fileName) const; diff --git a/gpt4all-chat/src/mysettings.cpp b/gpt4all-chat/src/mysettings.cpp index 87abcc4601ba..ffccc912dade 100644 --- a/gpt4all-chat/src/mysettings.cpp +++ b/gpt4all-chat/src/mysettings.cpp @@ -1,5 +1,8 @@ #include "mysettings.h" +#include "chatllm.h" +#include "modellist.h" + #include #include @@ -29,8 +32,13 @@ static const QStringList suggestionModeNames { "LocalDocsOnly", "On", "Off" }; static const QStringList chatThemeNames { "Light", "Dark", "LegacyDark" }; static const QStringList fontSizeNames { "Small", "Medium", "Large" }; -// FIXME: All of these default strings that are shown in the UI for settings need to be marked as -// translatable +// psuedo-enum +namespace ModelSettingsKey { namespace { + auto ChatTemplate = "chatTemplate"_L1; + auto PromptTemplate = "promptTemplate"_L1; // legacy + auto SystemMessage = "systemMessage"_L1; + auto SystemPrompt = "systemPrompt"_L1; // legacy +} } // namespace ModelSettingsKey::(anonymous) namespace defaults { @@ -48,7 +56,6 @@ static const QVariantMap basicDefaults { { "fontSize", QVariant::fromValue(FontSize::Small) }, { "lastVersionStarted", "" }, { "networkPort", 4891, }, - { "saveChatsContext", false }, { "systemTray", false }, { "serverChat", false }, { "userDefaultModel", "Application default" }, @@ -147,6 +154,11 @@ static QStringList getUiLanguages(const QString &modelPath) return languageList; } +static QString modelSettingName(const ModelInfo &info, auto &&name) +{ + return u"model-%1/%2"_s.arg(info.id(), name); +} + class MyPrivateSettings: public MySettings { }; Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance) MySettings *MySettings::globalInstance() @@ -162,6 +174,34 @@ MySettings::MySettings() { } +QVariant MySettings::checkJinjaTemplateError(const QString &tmpl) +{ + if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString())) + return QString::fromStdString(*err); + return QVariant::fromValue(nullptr); +} + +// Unset settings come from ModelInfo. Listen for changes so we can emit our own setting-specific signals. +void MySettings::onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList &roles) +{ + auto settingChanged = [&](const auto &info, auto role, const auto &name) { + return (roles.isEmpty() || roles.contains(role)) && !m_settings.contains(modelSettingName(info, name)); + }; + + auto &modelList = dynamic_cast(*QObject::sender()); + for (int row = topLeft.row(); row <= bottomRight.row(); row++) { + using enum ModelList::Roles; + using namespace ModelSettingsKey; + auto index = topLeft.siblingAtRow(row); + if (auto info = modelList.modelInfo(index.data(IdRole).toString()); !info.id().isNull()) { + if (settingChanged(info, ChatTemplateRole, ChatTemplate)) + emit chatTemplateChanged(info, /*fromInfo*/ true); + if (settingChanged(info, SystemMessageRole, SystemMessage)) + emit systemMessageChanged(info, /*fromInfo*/ true); + } + } +} + QVariant MySettings::getBasicSetting(const QString &name) const { return m_settings.value(name, basicDefaults.value(name)); @@ -194,8 +234,8 @@ void MySettings::restoreModelDefaults(const ModelInfo &info) setModelGpuLayers(info, info.m_gpuLayers); setModelRepeatPenalty(info, info.m_repeatPenalty); setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens); - setModelPromptTemplate(info, info.m_promptTemplate); - setModelSystemPrompt(info, info.m_systemPrompt); + resetModelChatTemplate (info); + resetModelSystemMessage(info); setModelChatNamePrompt(info, info.m_chatNamePrompt); setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt); } @@ -206,7 +246,6 @@ void MySettings::restoreApplicationDefaults() setFontSize(basicDefaults.value("fontSize").value()); setDevice(defaults::device); setThreadCount(defaults::threadCount); - setSaveChatsContext(basicDefaults.value("saveChatsContext").toBool()); setSystemTray(basicDefaults.value("systemTray").toBool()); setServerChat(basicDefaults.value("serverChat").toBool()); setNetworkPort(basicDefaults.value("networkPort").toInt()); @@ -252,29 +291,37 @@ void MySettings::setModelName(const ModelInfo &info, const QString &value, bool emit nameChanged(info); } -static QString modelSettingName(const ModelInfo &info, const QString &name) +QVariant MySettings::getModelSetting(QLatin1StringView name, const ModelInfo &info) const { - return u"model-%1/%2"_s.arg(info.id(), name); + QLatin1StringView nameL1(name); + return m_settings.value(modelSettingName(info, nameL1), info.getField(nameL1)); } -QVariant MySettings::getModelSetting(const QString &name, const ModelInfo &info) const +QVariant MySettings::getModelSetting(const char *name, const ModelInfo &info) const { - return m_settings.value(modelSettingName(info, name), info.getFields().value(name)); + return getModelSetting(QLatin1StringView(name), info); } -void MySettings::setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force, +void MySettings::setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force, bool signal) { if (!force && (info.id().isEmpty() || getModelSetting(name, info) == value)) return; - QString settingName = modelSettingName(info, name); - if (info.getFields().value(name) == value && !info.shouldSaveMetadata()) + QLatin1StringView nameL1(name); + QString settingName = modelSettingName(info, nameL1); + if (info.getField(nameL1) == value && !info.shouldSaveMetadata()) m_settings.remove(settingName); else m_settings.setValue(settingName, value); if (signal && !force) - QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(name).toLatin1().constData(), Q_ARG(ModelInfo, info)); + QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(nameL1).toLatin1().constData(), Q_ARG(ModelInfo, info)); +} + +void MySettings::setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force, + bool signal) +{ + setModelSetting(QLatin1StringView(name), info, value, force, signal); } QString MySettings::modelFilename (const ModelInfo &info) const { return getModelSetting("filename", info).toString(); } @@ -297,11 +344,68 @@ int MySettings::modelContextLength (const ModelInfo &info) const int MySettings::modelGpuLayers (const ModelInfo &info) const { return getModelSetting("gpuLayers", info).toInt(); } double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); } int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); } -QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); } -QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); } QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); } QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); } +auto MySettings::getUpgradeableModelSetting( + const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey +) const -> UpgradeableSetting +{ + if (info.id().isEmpty()) { + qWarning("%s: got null model", Q_FUNC_INFO); + return {}; + } + + auto value = m_settings.value(modelSettingName(info, legacyKey)); + if (value.isValid()) + return { UpgradeableSetting::legacy_tag, value.toString() }; + + value = getModelSetting(newKey, info); + if (!value.isNull()) + return value.toString(); + return {}; // neither a default nor an override +} + +bool MySettings::isUpgradeableModelSettingSet( + const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey +) const +{ + if (info.id().isEmpty()) { + qWarning("%s: got null model", Q_FUNC_INFO); + return false; + } + + if (m_settings.contains(modelSettingName(info, legacyKey))) + return true; + + // NOTE: unlike getUpgradeableSetting(), this ignores the default + return m_settings.contains(modelSettingName(info, newKey)); +} + +auto MySettings::modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting +{ + using namespace ModelSettingsKey; + return getUpgradeableModelSetting(info, PromptTemplate, ChatTemplate); +} + +bool MySettings::isModelChatTemplateSet(const ModelInfo &info) const +{ + using namespace ModelSettingsKey; + return isUpgradeableModelSettingSet(info, PromptTemplate, ChatTemplate); +} + +auto MySettings::modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting +{ + using namespace ModelSettingsKey; + return getUpgradeableModelSetting(info, SystemPrompt, SystemMessage); +} + +bool MySettings::isModelSystemMessageSet(const ModelInfo &info) const +{ + using namespace ModelSettingsKey; + return isUpgradeableModelSettingSet(info, SystemPrompt, SystemMessage); +} + void MySettings::setModelFilename(const ModelInfo &info, const QString &value, bool force) { setModelSetting("filename", info, value, force, true); @@ -402,14 +506,77 @@ void MySettings::setModelRepeatPenaltyTokens(const ModelInfo &info, int value, b setModelSetting("repeatPenaltyTokens", info, value, force, true); } -void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force) +bool MySettings::setUpgradeableModelSetting( + const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey +) { + if (info.id().isEmpty()) { + qWarning("%s: got null model", Q_FUNC_INFO); + return false; + } + + auto legacyModelKey = modelSettingName(info, legacyKey); + auto newModelKey = modelSettingName(info, newKey ); + bool changed = false; + if (m_settings.contains(legacyModelKey)) { + m_settings.remove(legacyModelKey); + changed = true; + } + auto oldValue = m_settings.value(newModelKey); + if (!oldValue.isValid() || oldValue.toString() != value) { + m_settings.setValue(newModelKey, value); + changed = true; + } + return changed; +} + +bool MySettings::resetUpgradeableModelSetting( + const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey +) { + if (info.id().isEmpty()) { + qWarning("%s: got null model", Q_FUNC_INFO); + return false; + } + + auto legacyModelKey = modelSettingName(info, legacyKey); + auto newModelKey = modelSettingName(info, newKey ); + bool changed = false; + if (m_settings.contains(legacyModelKey)) { + m_settings.remove(legacyModelKey); + changed = true; + } + if (m_settings.contains(newModelKey)) { + m_settings.remove(newModelKey); + changed = true; + } + return changed; +} + +void MySettings::setModelChatTemplate(const ModelInfo &info, const QString &value) +{ + using namespace ModelSettingsKey; + if (setUpgradeableModelSetting(info, value, PromptTemplate, ChatTemplate)) + emit chatTemplateChanged(info); +} + +void MySettings::resetModelChatTemplate(const ModelInfo &info) +{ + using namespace ModelSettingsKey; + if (resetUpgradeableModelSetting(info, PromptTemplate, ChatTemplate)) + emit chatTemplateChanged(info); +} + +void MySettings::setModelSystemMessage(const ModelInfo &info, const QString &value) { - setModelSetting("promptTemplate", info, value, force, true); + using namespace ModelSettingsKey; + if (setUpgradeableModelSetting(info, value, SystemPrompt, SystemMessage)) + emit systemMessageChanged(info); } -void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force) +void MySettings::resetModelSystemMessage(const ModelInfo &info) { - setModelSetting("systemPrompt", info, value, force, true); + using namespace ModelSettingsKey; + if (resetUpgradeableModelSetting(info, SystemPrompt, SystemMessage)) + emit systemMessageChanged(info); } void MySettings::setModelChatNamePrompt(const ModelInfo &info, const QString &value, bool force) @@ -445,7 +612,6 @@ void MySettings::setThreadCount(int value) emit threadCountChanged(); } -bool MySettings::saveChatsContext() const { return getBasicSetting("saveChatsContext" ).toBool(); } bool MySettings::systemTray() const { return getBasicSetting("systemTray" ).toBool(); } bool MySettings::serverChat() const { return getBasicSetting("serverChat" ).toBool(); } int MySettings::networkPort() const { return getBasicSetting("networkPort" ).toInt(); } @@ -464,7 +630,6 @@ ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnu FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); } SuggestionMode MySettings::suggestionMode() const { return SuggestionMode(getEnumSetting("suggestionMode", suggestionModeNames)); } -void MySettings::setSaveChatsContext(bool value) { setBasicSetting("saveChatsContext", value); } void MySettings::setSystemTray(bool value) { setBasicSetting("systemTray", value); } void MySettings::setServerChat(bool value) { setBasicSetting("serverChat", value); } void MySettings::setNetworkPort(int value) { setBasicSetting("networkPort", value); } diff --git a/gpt4all-chat/src/mysettings.h b/gpt4all-chat/src/mysettings.h index f3d5e5b058b5..a1a61e0618e0 100644 --- a/gpt4all-chat/src/mysettings.h +++ b/gpt4all-chat/src/mysettings.h @@ -4,6 +4,9 @@ #include "modellist.h" // IWYU pragma: keep #include +#include +#include +#include #include #include #include @@ -48,7 +51,6 @@ class MySettings : public QObject { Q_OBJECT Q_PROPERTY(int threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) - Q_PROPERTY(bool saveChatsContext READ saveChatsContext WRITE setSaveChatsContext NOTIFY saveChatsContextChanged) Q_PROPERTY(bool systemTray READ systemTray WRITE setSystemTray NOTIFY systemTrayChanged) Q_PROPERTY(bool serverChat READ serverChat WRITE setServerChat NOTIFY serverChatChanged) Q_PROPERTY(QString modelPath READ modelPath WRITE setModelPath NOTIFY modelPathChanged) @@ -75,9 +77,18 @@ class MySettings : public QObject Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged) Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT) +private: + explicit MySettings(); + ~MySettings() override = default; + +public Q_SLOTS: + void onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList &roles = {}); + public: static MySettings *globalInstance(); + Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl); + // Restore methods Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info); Q_INVOKABLE void restoreApplicationDefaults(); @@ -125,10 +136,14 @@ class MySettings : public QObject Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &info, double value, bool force = false); int modelRepeatPenaltyTokens(const ModelInfo &info) const; Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false); - QString modelPromptTemplate(const ModelInfo &info) const; - Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); - QString modelSystemPrompt(const ModelInfo &info) const; - Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false); + auto modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting; + Q_INVOKABLE bool isModelChatTemplateSet(const ModelInfo &info) const; + Q_INVOKABLE void setModelChatTemplate(const ModelInfo &info, const QString &value); + Q_INVOKABLE void resetModelChatTemplate(const ModelInfo &info); + auto modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting; + Q_INVOKABLE bool isModelSystemMessageSet(const ModelInfo &info) const; + Q_INVOKABLE void setModelSystemMessage(const ModelInfo &info, const QString &value); + Q_INVOKABLE void resetModelSystemMessage(const ModelInfo &info); int modelContextLength(const ModelInfo &info) const; Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false); int modelGpuLayers(const ModelInfo &info) const; @@ -141,8 +156,6 @@ class MySettings : public QObject // Application settings int threadCount() const; void setThreadCount(int value); - bool saveChatsContext() const; - void setSaveChatsContext(bool value); bool systemTray() const; void setSystemTray(bool value); bool serverChat() const; @@ -215,12 +228,11 @@ class MySettings : public QObject void gpuLayersChanged(const ModelInfo &info); void repeatPenaltyChanged(const ModelInfo &info); void repeatPenaltyTokensChanged(const ModelInfo &info); - void promptTemplateChanged(const ModelInfo &info); - void systemPromptChanged(const ModelInfo &info); + void chatTemplateChanged(const ModelInfo &info, bool fromInfo = false); + void systemMessageChanged(const ModelInfo &info, bool fromInfo = false); void chatNamePromptChanged(const ModelInfo &info); void suggestedFollowUpPromptChanged(const ModelInfo &info); void threadCountChanged(); - void saveChatsContextChanged(); void systemTrayChanged(); void serverChatChanged(); void modelPathChanged(); @@ -245,6 +257,30 @@ class MySettings : public QObject void suggestionModeChanged(); void languageAndLocaleChanged(); +private: + QVariant getBasicSetting(const QString &name) const; + void setBasicSetting(const QString &name, const QVariant &value, std::optional signal = std::nullopt); + int getEnumSetting(const QString &setting, const QStringList &valueNames) const; + QVariant getModelSetting(QLatin1StringView name, const ModelInfo &info) const; + QVariant getModelSetting(const char *name, const ModelInfo &info) const; + void setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force, + bool signal = false); + void setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force, + bool signal = false); + auto getUpgradeableModelSetting( + const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey + ) const -> UpgradeableSetting; + bool isUpgradeableModelSettingSet( + const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey + ) const; + bool setUpgradeableModelSetting( + const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey + ); + bool resetUpgradeableModelSetting( + const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey + ); + QString filePathForLocale(const QLocale &locale); + private: QSettings m_settings; bool m_forceMetal; @@ -253,18 +289,7 @@ class MySettings : public QObject const QStringList m_uiLanguages; std::unique_ptr m_translator; -private: - explicit MySettings(); - ~MySettings() {} friend class MyPrivateSettings; - - QVariant getBasicSetting(const QString &name) const; - void setBasicSetting(const QString &name, const QVariant &value, std::optional signal = std::nullopt); - int getEnumSetting(const QString &setting, const QStringList &valueNames) const; - QVariant getModelSetting(const QString &name, const ModelInfo &info) const; - void setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force, - bool signal = false); - QString filePathForLocale(const QLocale &locale); }; #endif // MYSETTINGS_H diff --git a/gpt4all-chat/src/network.cpp b/gpt4all-chat/src/network.cpp index 2e794d1ab241..bff380102ba2 100644 --- a/gpt4all-chat/src/network.cpp +++ b/gpt4all-chat/src/network.cpp @@ -8,6 +8,7 @@ #include "localdocsmodel.h" #include "modellist.h" #include "mysettings.h" +#include "utils.h" #include @@ -192,11 +193,14 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json) return false; } + auto *currentChat = ChatListModel::globalInstance()->currentChat(); + Q_ASSERT(currentChat); + auto modelInfo = currentChat->modelInfo(); + Q_ASSERT(doc.isObject()); - Q_ASSERT(ChatListModel::globalInstance()->currentChat()); QJsonObject object = doc.object(); object.insert("source", "gpt4all-chat"); - object.insert("agent_id", ChatListModel::globalInstance()->currentChat()->modelInfo().filename()); + object.insert("agent_id", modelInfo.filename()); object.insert("submitter_id", m_uniqueId); object.insert("ingest_id", ingestId); @@ -204,8 +208,9 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json) if (!attribution.isEmpty()) object.insert("network/attribution", attribution); - QString promptTemplate = ChatListModel::globalInstance()->currentChat()->modelInfo().promptTemplate(); - object.insert("prompt_template", promptTemplate); + if (!modelInfo.id().isNull()) + if (auto tmpl = modelInfo.chatTemplate().asModern()) + object.insert("chat_template"_L1, *tmpl); QJsonDocument newDoc; newDoc.setObject(object); @@ -358,7 +363,8 @@ void Network::sendStartup() void Network::trackChatEvent(const QString &ev, QVariantMap props) { - const auto &curChat = ChatListModel::globalInstance()->currentChat(); + auto *curChat = ChatListModel::globalInstance()->currentChat(); + Q_ASSERT(curChat); if (!props.contains("model")) props.insert("model", curChat->modelInfo().filename()); props.insert("device_backend", curChat->deviceBackend()); @@ -366,7 +372,7 @@ void Network::trackChatEvent(const QString &ev, QVariantMap props) props.insert("doc_collections_enabled", curChat->collectionList().count()); props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount()); props.insert("datalake_active", MySettings::globalInstance()->networkIsActive()); - props.insert("using_server", ChatListModel::globalInstance()->currentChat()->isServer()); + props.insert("using_server", curChat->isServer()); trackEvent(ev, props); } diff --git a/gpt4all-chat/src/server.cpp b/gpt4all-chat/src/server.cpp index 577859a3266b..2435d43b3c9e 100644 --- a/gpt4all-chat/src/server.cpp +++ b/gpt4all-chat/src/server.cpp @@ -313,11 +313,8 @@ const std::unordered_map BaseCompleti class ChatRequest : public BaseCompletionRequest { public: struct Message { - enum class Role : uint8_t { - User, - Assistant, - }; - Role role; + enum class Role { System, User, Assistant }; + Role role; QString content; }; @@ -349,7 +346,6 @@ class ChatRequest : public BaseCompletionRequest { this->messages.clear(); { QCborArray arr = value.toArray(); - Message::Role nextRole = Message::Role::User; for (qsizetype i = 0; i < arr.size(); i++) { const auto &elem = arr[i]; if (!elem.isMap()) @@ -360,9 +356,9 @@ class ChatRequest : public BaseCompletionRequest { QCborMap msg = elem.toMap(); Message res; QString role = takeValue(msg, "role", String, /*required*/ true).toString(); - if (role == u"system"_s) - continue; // FIXME(jared): don't ignore these - if (role == u"user"_s) { + if (role == u"system"_s) { + res.role = Message::Role::System; + } else if (role == u"user"_s) { res.role = Message::Role::User; } else if (role == u"assistant"_s) { res.role = Message::Role::Assistant; @@ -374,13 +370,7 @@ class ChatRequest : public BaseCompletionRequest { )); } res.content = takeValue(msg, "content", String, /*required*/ true).toString(); - if (res.role != nextRole) - throw InvalidRequestError(fmt::format( - "Invalid 'messages[{}].role': did not expect '{}' here", i, role - )); this->messages.append(res); - nextRole = res.role == Message::Role::User ? Message::Role::Assistant - : Message::Role::User; if (!msg.isEmpty()) throw InvalidRequestError(fmt::format( @@ -630,8 +620,7 @@ void Server::start() }); #endif - connect(this, &Server::requestServerNewPromptResponsePair, m_chat, - &Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection); + connect(this, &Server::requestResetResponseState, m_chat, &Chat::resetResponseState, Qt::BlockingQueuedConnection); } static auto makeError(auto &&...args) -> std::pair> @@ -642,6 +631,10 @@ static auto makeError(auto &&...args) -> std::pair std::pair> { + Q_ASSERT(m_chatModel); + + auto *mySettings = MySettings::globalInstance(); + ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); const QList modelList = ModelList::globalInstance()->selectableModelList(); for (const ModelInfo &info : modelList) { @@ -654,10 +647,6 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) } } - // adds prompt/response items to GUI - emit requestServerNewPromptResponsePair(request.prompt); // blocks - resetResponse(); - // load the new model if necessary setShouldBeLoaded(true); @@ -666,47 +655,55 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) return makeError(QHttpServerResponder::StatusCode::InternalServerError); } + emit requestResetResponseState(); // blocks + qsizetype prevMsgIndex = m_chatModel->count() - 1; + if (prevMsgIndex >= 0) + m_chatModel->updateCurrentResponse(prevMsgIndex, false); + // NB: this resets the context, regardless of whether this model is already loaded if (!loadModel(modelInfo)) { std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; return makeError(QHttpServerResponder::StatusCode::InternalServerError); } + // add prompt/response items to GUI + m_chatModel->appendPrompt(request.prompt); + m_chatModel->appendResponse(prevMsgIndex + 1); + // FIXME(jared): taking parameters from the UI inhibits reproducibility of results - const int top_k = modelInfo.topK(); - const int n_batch = modelInfo.promptBatchSize(); - const auto repeat_penalty = float(modelInfo.repeatPenalty()); - const int repeat_last_n = modelInfo.repeatPenaltyTokens(); + LLModel::PromptContext promptCtx { + .n_predict = request.max_tokens, + .top_k = mySettings->modelTopK(modelInfo), + .top_p = request.top_p, + .min_p = request.min_p, + .temp = request.temperature, + .n_batch = mySettings->modelPromptBatchSize(modelInfo), + .repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)), + .repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo), + }; + auto promptUtf8 = request.prompt.toUtf8(); int promptTokens = 0; int responseTokens = 0; - QList>> responses; + QStringList responses; for (int i = 0; i < request.n; ++i) { - if (!promptInternal( - m_collections, - request.prompt, - /*promptTemplate*/ u"%1"_s, - request.max_tokens, - top_k, - request.top_p, - request.min_p, - request.temperature, - n_batch, - repeat_penalty, - repeat_last_n)) { - - std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl; + PromptResult result; + try { + result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()), + promptCtx, + /*usedLocalDocs*/ false); + } catch (const std::exception &e) { + emit responseChanged(e.what()); + emit responseStopped(0); return makeError(QHttpServerResponder::StatusCode::InternalServerError); } - QString resp = response(/*trim*/ false); + QString resp = QString::fromUtf8(result.response); if (request.echo) resp = request.prompt + resp; - responses.append({resp, m_databaseResults}); - if (!promptTokens) - promptTokens = m_promptTokens; - responseTokens += m_promptResponseTokens - m_promptTokens; - if (i < request.n - 1) - resetResponse(); + responses << resp; + if (i == 0) + promptTokens = result.promptTokens; + responseTokens += result.responseTokens; } QJsonObject responseObject { @@ -717,25 +714,13 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) }; QJsonArray choices; - { - int index = 0; - for (const auto &r : responses) { - QString result = r.first; - QList infos = r.second; - QJsonObject choice { - { "text", result }, - { "index", index++ }, - { "logprobs", QJsonValue::Null }, - { "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" }, - }; - if (MySettings::globalInstance()->localDocsShowReferences()) { - QJsonArray references; - for (const auto &ref : infos) - references.append(resultToJson(ref)); - choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references)); - } - choices.append(choice); - } + for (qsizetype i = 0; auto &resp : std::as_const(responses)) { + choices << QJsonObject { + { "text", resp }, + { "index", i++ }, + { "logprobs", QJsonValue::Null }, + { "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" }, + }; } responseObject.insert("choices", choices); @@ -751,6 +736,8 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) auto Server::handleChatRequest(const ChatRequest &request) -> std::pair> { + auto *mySettings = MySettings::globalInstance(); + ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); const QList modelList = ModelList::globalInstance()->selectableModelList(); for (const ModelInfo &info : modelList) { @@ -771,83 +758,58 @@ auto Server::handleChatRequest(const ChatRequest &request) return makeError(QHttpServerResponder::StatusCode::InternalServerError); } + emit requestResetResponseState(); // blocks + // NB: this resets the context, regardless of whether this model is already loaded if (!loadModel(modelInfo)) { std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; return makeError(QHttpServerResponder::StatusCode::InternalServerError); } - const QString promptTemplate = modelInfo.promptTemplate(); - const int top_k = modelInfo.topK(); - const int n_batch = modelInfo.promptBatchSize(); - const auto repeat_penalty = float(modelInfo.repeatPenalty()); - const int repeat_last_n = modelInfo.repeatPenaltyTokens(); + m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); - int promptTokens = 0; - int responseTokens = 0; - QList>> responses; Q_ASSERT(!request.messages.isEmpty()); - Q_ASSERT(request.messages.size() % 2 == 1); - for (int i = 0; i < request.messages.size() - 2; i += 2) { + + // adds prompt/response items to GUI + std::vector chatItems; + for (auto &message : request.messages) { using enum ChatRequest::Message::Role; - auto &user = request.messages[i]; - auto &assistant = request.messages[i + 1]; - Q_ASSERT(user.role == User); - Q_ASSERT(assistant.role == Assistant); - - // adds prompt/response items to GUI - emit requestServerNewPromptResponsePair(user.content); // blocks - resetResponse(); - - if (!promptInternal( - {}, - user.content, - promptTemplate, - request.max_tokens, - top_k, - request.top_p, - request.min_p, - request.temperature, - n_batch, - repeat_penalty, - repeat_last_n, - assistant.content) - ) { - std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl; - return makeError(QHttpServerResponder::StatusCode::InternalServerError); + switch (message.role) { + case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break; + case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break; + case Assistant: chatItems.emplace_back(ChatItem::response_tag, /*currentResponse*/ false); break; } - promptTokens += m_promptResponseTokens; // previous responses are part of current prompt } + m_chatModel->appendResponseWithHistory(chatItems); - QString lastMessage = request.messages.last().content; - // adds prompt/response items to GUI - emit requestServerNewPromptResponsePair(lastMessage); // blocks - resetResponse(); + // FIXME(jared): taking parameters from the UI inhibits reproducibility of results + LLModel::PromptContext promptCtx { + .n_predict = request.max_tokens, + .top_k = mySettings->modelTopK(modelInfo), + .top_p = request.top_p, + .min_p = request.min_p, + .temp = request.temperature, + .n_batch = mySettings->modelPromptBatchSize(modelInfo), + .repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)), + .repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo), + }; + int promptTokens = 0; + int responseTokens = 0; + QList>> responses; for (int i = 0; i < request.n; ++i) { - if (!promptInternal( - m_collections, - lastMessage, - promptTemplate, - request.max_tokens, - top_k, - request.top_p, - request.min_p, - request.temperature, - n_batch, - repeat_penalty, - repeat_last_n) - ) { - std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl; + ChatPromptResult result; + try { + result = promptInternalChat(m_collections, promptCtx); + } catch (const std::exception &e) { + emit responseChanged(e.what()); + emit responseStopped(0); return makeError(QHttpServerResponder::StatusCode::InternalServerError); } - responses.append({response(), m_databaseResults}); - // FIXME(jared): these are UI counts and do not include framing tokens, which they should + responses.emplace_back(result.response, result.databaseResults); if (i == 0) - promptTokens += m_promptTokens; - responseTokens += m_promptResponseTokens - m_promptTokens; - if (i != request.n - 1) - resetResponse(); + promptTokens = result.promptTokens; + responseTokens += result.responseTokens; } QJsonObject responseObject { diff --git a/gpt4all-chat/src/server.h b/gpt4all-chat/src/server.h index a5447d865921..fce61e5c8097 100644 --- a/gpt4all-chat/src/server.h +++ b/gpt4all-chat/src/server.h @@ -33,7 +33,7 @@ public Q_SLOTS: void start(); Q_SIGNALS: - void requestServerNewPromptResponsePair(const QString &prompt, const QList &attachments = {}); + void requestResetResponseState(); private: auto handleCompletionRequest(const CompletionRequest &request) -> std::pair>; diff --git a/gpt4all-chat/src/utils.h b/gpt4all-chat/src/utils.h index 0eacfe8bceb6..ac67e892cb40 100644 --- a/gpt4all-chat/src/utils.h +++ b/gpt4all-chat/src/utils.h @@ -3,23 +3,41 @@ #include #include +#include +#include +#include #include +#include +#include #include -#include +#include +#include +#include + +class QJsonObject; // fmtlib formatters for QString and QVariant -#define MAKE_FORMATTER(type, conversion) \ - template <> \ - struct fmt::formatter: fmt::formatter { \ - template \ - FmtContext::iterator format(const type &value, FmtContext &ctx) const \ - { \ - return formatter::format(conversion, ctx); \ - } \ +#define MAKE_FORMATTER(type, conversion) \ + template <> \ + struct fmt::formatter: fmt::formatter { \ + template \ + FmtContext::iterator format(const type &value, FmtContext &ctx) const \ + { \ + auto valueUtf8 = (conversion); \ + std::string_view view(valueUtf8.cbegin(), valueUtf8.cend()); \ + return formatter::format(view, ctx); \ + } \ } -MAKE_FORMATTER(QString, value.toStdString() ); -MAKE_FORMATTER(QVariant, value.toString().toStdString()); +MAKE_FORMATTER(QUtf8StringView, value ); +MAKE_FORMATTER(QStringView, value.toUtf8() ); +MAKE_FORMATTER(QString, value.toUtf8() ); +MAKE_FORMATTER(QVariant, value.toString().toUtf8()); + +// alternative to QJsonObject's initializer_list constructor that accepts Latin-1 strings +QJsonObject makeJsonObject(std::initializer_list> args); + +#include "utils.inl" diff --git a/gpt4all-chat/src/utils.inl b/gpt4all-chat/src/utils.inl new file mode 100644 index 000000000000..8aeb1f88c504 --- /dev/null +++ b/gpt4all-chat/src/utils.inl @@ -0,0 +1,9 @@ +#include + +inline QJsonObject makeJsonObject(std::initializer_list> args) +{ + QJsonObject obj; + for (auto &arg : args) + obj.insert(arg.first, arg.second); + return obj; +} diff --git a/gpt4all-chat/tests/python/test_server_api.py b/gpt4all-chat/tests/python/test_server_api.py index e1b0e476f4e9..26a6aff70dc0 100644 --- a/gpt4all-chat/tests/python/test_server_api.py +++ b/gpt4all-chat/tests/python/test_server_api.py @@ -203,11 +203,10 @@ def test_with_models_empty(chat_server: None) -> None: EXPECTED_COMPLETIONS_RESPONSE = { 'choices': [ { - 'finish_reason': 'stop', + 'finish_reason': 'length', 'index': 0, 'logprobs': None, - 'references': None, - 'text': ' jumps over the lazy dog.', + 'text': ' jumps over the lazy dog.\n', }, ], 'id': 'placeholder', @@ -242,18 +241,14 @@ def test_with_models(chat_server_with_model: None) -> None: 'type': 'invalid_request_error', }} - data = { - 'model': 'Llama 3.2 1B Instruct', - 'prompt': 'The quick brown fox', - 'temperature': 0, - } - + data = dict( + model = 'Llama 3.2 1B Instruct', + prompt = 'The quick brown fox', + temperature = 0, + max_tokens = 6, + ) response = request.post('completions', data=data) - assert len(response['choices']) == 1 - assert response['choices'][0].keys() == {'text', 'index', 'logprobs', 'references', 'finish_reason'} - assert response['choices'][0]['text'] == ' jumps over the lazy dog.' - assert 'created' in response - response.pop('created') # Remove the dynamic field for comparison + del response['created'] # Remove the dynamic field for comparison assert response == EXPECTED_COMPLETIONS_RESPONSE