Skip to content

Commit

Permalink
Remove binary state from high-level API and use Jinja templates (nomi…
Browse files Browse the repository at this point in the history
…c-ai#3147)

Signed-off-by: Jared Van Bortel <[email protected]>
Signed-off-by: Adam Treat <[email protected]>
Co-authored-by: Adam Treat <[email protected]>
  • Loading branch information
cebtenzzre and manyoso authored Nov 25, 2024
1 parent 3320094 commit 225bf6b
Show file tree
Hide file tree
Showing 54 changed files with 3,412 additions and 2,213 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
[submodule "gpt4all-chat/deps/QXlsx"]
path = gpt4all-chat/deps/QXlsx
url = https://github.com/nomic-ai/QXlsx.git
[submodule "gpt4all-chat/deps/Jinja2Cpp"]
path = gpt4all-chat/deps/Jinja2Cpp
url = https://github.com/nomic-ai/jinja2cpp.git
68 changes: 34 additions & 34 deletions gpt4all-backend/include/gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <expected>
#include <functional>
#include <optional>
#include <span>
Expand All @@ -24,6 +25,10 @@ using namespace std::string_literals;
class LLModel {
public:
using Token = int32_t;
using PromptCallback = std::function<bool(std::span<const Token> batch, bool cached)>;
using ResponseCallback = std::function<bool(Token token, std::string_view piece)>;
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
using ProgressCallback = std::function<bool(float progress)>;

class BadArchError: public std::runtime_error {
public:
Expand Down Expand Up @@ -101,6 +106,7 @@ class LLModel {
static int32_t maxContextLength(const std::string &modelPath);
static int32_t layerCount(const std::string &modelPath);
static bool isEmbeddingModel(const std::string &modelPath);
static auto chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>;
static void setImplementationsSearchPath(const std::string &path);
static const std::string &implementationsSearchPath();
static bool hasSupportedCPU();
Expand All @@ -124,7 +130,6 @@ class LLModel {
};

struct PromptContext {
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
Expand All @@ -136,8 +141,6 @@ class LLModel {
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};

using ProgressCallback = std::function<bool(float progress)>;

explicit LLModel() {}
virtual ~LLModel() {}

Expand All @@ -154,16 +157,12 @@ class LLModel {

// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
virtual void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::optional<std::string_view> fakeReply = {});
virtual void prompt(std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &ctx);

using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
virtual int32_t countPromptTokens(std::string_view prompt) const;

virtual size_t embeddingSize() const {
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
Expand Down Expand Up @@ -209,23 +208,22 @@ class LLModel {
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }

virtual int32_t contextLength() const = 0;
virtual auto specialTokens() -> std::unordered_map<std::string, std::string> const = 0;

protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0;
virtual std::vector<Token> tokenize(std::string_view str) const = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual void initSampler(const PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual bool evalTokens(int32_t nPast, std::span<const Token> tokens) const = 0;
virtual void shiftContext(const PromptContext &promptCtx, int32_t *nPast) = 0;
virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator = 0;
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual int32_t computeModelInputPosition(std::span<const Token> input) const = 0;
virtual void setModelInputPosition(int32_t pos) = 0;
virtual void appendInputToken(Token tok) = 0;
virtual std::span<const Token> inputTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
Expand All @@ -242,6 +240,12 @@ class LLModel {
return -1;
}

virtual auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
{
(void)modelPath;
return std::unexpected("not implemented");
}

const Implementation *m_implementation = nullptr;

ProgressCallback m_progressCallback;
Expand All @@ -253,19 +257,15 @@ class LLModel {
return true;
}

bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse = false,
bool alwaysDecode = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);

protected:
Token m_tokenize_last_token = -1; // not serialized
// prefill context with prompt
auto decodePrompt(const PromptCallback &promptCallback,
const PromptContext &promptCtx,
std::vector<Token> embd_inp)
-> std::optional<int32_t>;
// generate a response
void generateResponse(const ResponseCallback &responseCallback,
const PromptContext &promptCtx,
int32_t nPast);

friend class LLMImplementation;
};
Expand Down
44 changes: 23 additions & 21 deletions gpt4all-backend/include/gpt4all-backend/llmodel_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ typedef int32_t token_t;
* behavior.
*/
struct llmodel_prompt_context {
int32_t n_past; // number of tokens in past conversation
int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution
float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution
int32_t n_batch; // number of predictions to generate in parallel
float repeat_penalty; // penalty factor for repeated tokens
float repeat_penalty; // penalty factor for repeated tokens
int32_t repeat_last_n; // last n tokens to penalize
float context_erase; // percent of context to erase if we exceed the context window
float context_erase; // percent of context to erase if we exceed the context window
};

struct llmodel_gpu_device {
Expand All @@ -63,18 +62,20 @@ typedef struct llmodel_gpu_device llmodel_gpu_device;

/**
* Callback type for prompt processing.
* @param token_id The token id of the prompt.
* @param token_ids An array of token ids of the prompt.
* @param n_token_ids The number of tokens in the array.
* @param cached Whether the tokens were already in cache.
* @return a bool indicating whether the model should keep processing.
*/
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
typedef bool (*llmodel_prompt_callback)(const token_t *token_ids, size_t n_token_ids, bool cached);

/**
* Callback type for response.
* @param token_id The token id of the response.
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
* @return a bool indicating whether the model should keep generating.
*/
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
typedef bool (*llmodel_response_callback)(token_t token_id, const char *response);

/**
* Embedding cancellation callback for use with llmodel_embed.
Expand All @@ -85,6 +86,8 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
*/
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);

typedef void (*llmodel_special_token_callback)(const char *name, const char *token);

/**
* Create a llmodel instance.
* Recognises correct model type from file at model_path
Expand Down Expand Up @@ -183,22 +186,17 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
* @param error A pointer to a string; will only be set on error.
*/
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply);
bool llmodel_prompt(llmodel_model model,
const char *prompt,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_prompt_context *ctx,
const char **error);

/**
* Generate an embedding using the model.
Expand Down Expand Up @@ -310,6 +308,10 @@ const char *llmodel_model_backend_name(llmodel_model model);
*/
const char *llmodel_model_gpu_device_name(llmodel_model model);

int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error);

void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);

#ifdef __cplusplus
}
#endif
Expand Down
Loading

0 comments on commit 225bf6b

Please sign in to comment.