Skip to content

Commit

Permalink
Add Janus description (#1044)
Browse files Browse the repository at this point in the history
Co-authored-by: nifeng <[email protected]>
  • Loading branch information
cheng221 and nemonameless authored Feb 13, 2025
1 parent aebbbac commit 9133822
Showing 1 changed file with 301 additions and 0 deletions.
301 changes: 301 additions & 0 deletions paddlemix/examples/janus/Janus_desciption.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Janus 介绍"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1 引言\n",
"\n",
"Janus是一个简单、统一且可扩展的多模态理解与生成模型,其将多模态理解与生成的视觉编码进行解耦,缓解了两个任务潜在存在的冲突。可在未来通过拓展,纳入更多的输入模态。Janus-Pro在此基础上,优化训练策略(包括增加训练步数、调整数据配比等)、增加数据(包括使用合成数据等)、扩大模型规模(扩大到70亿参数),使得模型多模态理解和文本到图像指令遵循能力方面取得了进步。\n",
"\n",
"Janus包含2个独立的视觉编码路径,分别用于多模态理解、生成,并带来两个收益:1)缓解了源自多模态理解和生成不同粒度需求的冲突,2)具有灵活性和可扩展性,解耦后,理解和生成任务都可以采用针对其领域最先进的编码技术,未来可输入点云、脑电信号或音频数据,使用统一的Transformer进行处理。\n",
"<div style=\"display: flex; justify-content: center; align-items: center; height: [desired-container-height]px;\">\n",
" <img src=\"https://ai-studio-static-online.cdn.bcebos.com/ea0703505b3b40ad923981dbddda20973c81da7a36194e3abc75ad1d9b870ab4\" alt=\"Description\" width=\"50%\" >\n",
"</div>\n",
"<center>图1: Janus 架构</center>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2 方法\n",
"\n",
"### 2.1 模型架构\n",
"Janus 的架构如图 1 所示。对于纯文本理解、多模态理解和视觉生成任务,采用独立的编码方法将原始输入转换为特征,然后通过统一的自回归 Transformer 进行处理。具体来说:\n",
"- 文本理解:我们使用大语言模型(LLM)内置的分词器将文本转换为离散的 ID,并获取每个 ID 对应的特征表示。\n",
"- 多模态理解:我们使用 SigLIP 编码器从图像中提取高维语义特征。这些特征从 2D 网格展平为 1D 序列,并通过一个理解适配器将这些图像特征映射到 LLM 的输入空间。\n",
"- 视觉生成:我们使用 VQ 分词器将图像转换为离散的 ID。将 ID 序列展平为 1D 后,使用一个生成适配器将每个 ID 对应的码本嵌入映射到 LLM 的输入空间。\n",
"然后,我们将这些特征序列连接起来,形成一个多模态特征序列,随后输入到 LLM 中进行处理。在纯文本理解和多模态理解任务中,使用 LLM 内置的预测头进行文本预测;而在视觉生成任务中,使用随机初始化的预测头进行图像预测。整个模型遵循自回归框架,无需特别设计的注意力掩码。\n",
"\n",
"### 2.2 Janus 训练\n",
"Janus的训练分为3个阶段:\n",
"- 第一阶段:训练Adaptor与Image Head,在嵌入空间创建语言元素与视觉元素之间的联系,使得LLM能够理解图像中的实体,并具备初步视觉生成能力;\n",
"对于多模态理解,使用来自SHareGPT4V125万个图像-文本配对字幕数据,格式:<图像><文本>;\n",
"对于视觉生成,使用来自ImageNet1k的120万个样本,格式:<类别名><图像>;\n",
"\n",
"- 第二阶段:统一预训练,使用多模态语料库进行统一预训练,学习多模态理解和生成。\n",
" - 在该阶段使用纯文本数据、多模态理解数据和视觉生成数据\n",
" - 使用ImageNet-1k进行简单的视觉生成训练,随后使用通用文本到图像数据提升模型开放领域的视觉生成能力\n",
" - 纯文本数据:DeepSeek-LLM预训练语料库\n",
" - 交错的图像-文本数据:WikiHow 和 WIT 数据集;\n",
" - 图像Caption数据:来自多个来源的图像,并采用开源多模态模型重新为部分图像添加字幕,数据格式为问答对,如$\\texttt{<caption>}$ Describe the image in detail.$\\texttt{<caption>}$;\n",
" - 表格和图表数据:来自 DeepSeek-VL的相应表格和图表数据,数据格式为<question><answer>;\n",
" - 视觉生成数据:来自多个数据集的image-caption对以及 200 万个内部数据;在训练过程中,以25%的概率随机仅使用caption的第一句话;ImageNet 样本仅在最初的 120K 训练步骤中出现,其他数据集的图像在后续 60K 步骤中出现;\n",
"- 第三阶段:监督微调,使用指令微调数据对预训练模型进行微调,以增强其遵循指令和对话的能力。微调除生成编码器之外的所有参数。在监督答案的同时,对系统和用户提示进行遮盖。为了确保Janus在多模态理解和生成方面都具备熟练度,不会针对特定任务分别微调模型。相反,我们使用纯文本对话数据、多模态理解数据和视觉生成数据的混合数据,以确保在各种场景下的多功能性;\n",
" - 文本理解:使用来自特定来源的数据;\n",
" - 多模态理解:使用来自多个来源的指令调整数据;\n",
" - 视觉生成:使用来自部分第二阶段数据集的图像-文本对子集以及 400 万个内部数据;\n",
" - 数据格式为:User:$\\texttt{<Input Message>}$ \\n Assistant: $\\texttt{<Response>}$;\n",
"<div style=\"display: flex; justify-content: center; align-items: center; height: [desired-container-height]px;\">\n",
" <img src=\"https://github.com/user-attachments/assets/0035318f-3348-4e5d-9256-a9a3410fa625\" alt=\"Description\" width=\"50%\" >\n",
"</div>\n",
"<center>图2: Janus 三阶段训练步骤</center>\n",
"\n",
"### 2.3 Janus 推理\n",
"在推理过程中,Janus 模型采用了一种 Next-token预测的方法。对于纯文本理解和多模态理解,我们遵循从预测分布中顺序采样token的标准做法。对于图像生成,我们利用了无分类器引导(CFG)在训练过程中,我们以10%的概率将文本到图像数据中的文本条件替换为填充token,使模型具备无条件视觉生成能力。对于生成下一个token的概率分布 $l_g$ 的计算公式为 $l_g = l_u + s(l_c - l_u)$ ,$l_c$是条件概率分布,$l_u$ 是条件概率分布,$s$ 是CFG系数,默认情况下 $s$ 为 5.\n",
"\n",
"### 2.4 Janus-Pro\n",
"- 训练策略\n",
" - Stage 1: 增加训练步数,在 ImageNet 上充分训练;\n",
" - Stage 2: 不再使用 ImageNet,直接使用常规文本到图像数据的训练数据;\n",
" - Stage 3: 修改微调过程中的数据集配比,将多模态数据、纯文本数据和文本到图像的比例从 7:3:10 改为 5:1:4;\n",
"- 数据规模\n",
" - 多模态理解\n",
" - Stage 2: 增加 9000 万个样本,包括图像字幕数据 YFCC、表格图表文档理解数据 Doc-matrix;\n",
" - Stage 3: 加入 DeepSeek-VL2 额外数据集,如 MEME 理解等;\n",
" - 视觉生成:真实世界数据可能包含质量不高,导致文本到图像的生成不稳定,产生美学效果不佳的输出,Janus-Pro 使用 7200 万份合成美学数据样本,统一预训练阶段(Stage 2)真实数据与合成数据比例 1:1;\n",
"- 模型规模\n",
" - 将模型参数扩展到 70 亿参数规模;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3 代码解读\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1 Janus 组网代码介绍"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 类名: JanusMultiModalityCausalLM\n",
"- 功能: 该类实现了一个多模态因果语言模型,它能够处理图像和文本数据\n",
"- 实现步骤: \n",
" - 初始化:\n",
" - 从配置对象中提取各个组件的配置。\n",
" - 使用model_name_to_cls函数和配置参数来实例化视觉模型、对齐器、生成视觉模型、生成对齐器、生成头和语言模型。\n",
" - 创建一个Embedding层用于图像标识符到Embedding向量的映射。\n",
" - 准备输入Embed:\n",
" - 重新排列图像数据pixel_values的形状以适应视觉模型的输入要求。\n",
" - 使用视觉模型处理图像数据,并通过对齐器生成图像Embed。\n",
" - 重新排列图像Embed和掩码的形状以匹配文本输入的形状。\n",
" - 处理文本输入input_ids,将其转换为语言模型可以处理的Embed形式。\n",
" - 根据掩码将图像嵌入插入到文本嵌入中,生成最终的输入Embed input_embeds。\n",
" - 准备生成图像嵌入:\n",
" - 使用嵌入层将图像标识符image_ids映射到 Embedding 向量。\n",
" - 通过对齐器处理这些 Embedding 向量,生成最终的图像Embedding。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class JanusMultiModalityCausalLM(JanusMultiModalityPreTrainedModel):\n",
" config_class = MultiModalityConfig\n",
"\n",
" def __init__(self, config: MultiModalityConfig):\n",
" super().__init__(config)\n",
" vision_config = config.vision_config\n",
" vision_cls = model_name_to_cls(vision_config.cls)\n",
" self.vision_model = vision_cls(**vision_config.params)\n",
" aligner_config = config.aligner_config\n",
" aligner_cls = model_name_to_cls(aligner_config.cls)\n",
" self.aligner = aligner_cls(aligner_config.params)\n",
" gen_vision_config = config.gen_vision_config\n",
" gen_vision_cls = model_name_to_cls(gen_vision_config.cls)\n",
" self.gen_vision_model = gen_vision_cls()\n",
" gen_aligner_config = config.gen_aligner_config\n",
" gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)\n",
" self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)\n",
" gen_head_config = config.gen_head_config\n",
" gen_head_cls = model_name_to_cls(gen_head_config.cls)\n",
" self.gen_head = gen_head_cls(gen_head_config.params)\n",
" self.gen_embed = paddle.nn.Embedding(\n",
" num_embeddings=gen_vision_config.params[\"image_token_size\"],\n",
" embedding_dim=gen_vision_config.params[\"n_embed\"],\n",
" )\n",
" language_config = config.language_config\n",
" self.language_model = LlamaForCausalLM(language_config)\n",
"\n",
" def prepare_inputs_embeds(\n",
" self,\n",
" input_ids: paddle.Tensor,\n",
" pixel_values: paddle.Tensor,\n",
" images_seq_mask: paddle.Tensor,\n",
" images_emb_mask: paddle.Tensor,\n",
" **kwargs\n",
" ):\n",
" \"\"\"\n",
"\n",
" Args:\n",
" input_ids (paddle.Tensor): [b, T]\n",
" pixel_values (paddle.Tensor): [b, n_images, 3, h, w]\n",
" images_seq_mask (paddle.Tensor): [b, T]\n",
" images_emb_mask (paddle.Tensor): [b, n_images, n_image_tokens]\n",
"\n",
" assert paddle.sum(images_seq_mask) == paddle.sum(images_emb_mask)\n",
"\n",
" Returns:\n",
" input_embeds (paddle.Tensor): [b, T, D]\n",
" \"\"\"\n",
" bs, n = tuple(pixel_values.shape)[0:2]\n",
" images = rearrange(pixel_values, \"b n c h w -> (b n) c h w\")\n",
" images_embeds = self.aligner(self.vision_model(images))\n",
" images_embeds = rearrange(images_embeds, \"(b n) t d -> b (n t) d\", b=bs, n=n)\n",
" images_emb_mask = rearrange(images_emb_mask, \"b n t -> b (n t)\")\n",
" input_ids[input_ids < 0] = 0\n",
" inputs_embeds = self.language_model.get_input_embeddings()(input_ids)\n",
" inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]\n",
"\n",
" return inputs_embeds\n",
"\n",
" def prepare_gen_img_embeds(self, image_ids: paddle.Tensor):\n",
" return self.gen_aligner(self.gen_embed(image_ids))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2 Janus 多模态生成代码介绍"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 方法: generate\n",
"- 参数:\n",
" - mmgpt:JanusMultiModalityCausalLM的对象,负责生成图像和文本。\n",
" - vl_chat_processor:一个处理器对象,用于处理视觉-语言(VL)聊天数据,包括分词和图像编码等。\n",
" - prompt:一个字符串,代表输入给模型的文本提示。\n",
" - temperature:一个浮点数,用于调整生成结果的随机性(或称为“温度”)。较低的值会使生成结果更加确定,而较高的值会增加多样性。\n",
" - parallel_size:一个整数,表示并行生成图像的数量。\n",
" - cfg_weight:一个浮点数,用于在生成过程中调整条件和无条件生成的概率分布(logits)之间的权重。\n",
" - image_token_num_per_image:一个整数,表示每张图像生成的token数量。\n",
" - img_size:一个整数,表示生成图像的尺寸(假设图像是正方形)。\n",
" - patch_size:一个整数,表示图像被分割成的小块(patch)的尺寸\n",
"- 步骤:\n",
" - 文本处理:使用vl_chat_processor的分词器将文本提示编码为输入ID,然后转换为Paddle张量。\n",
" - 初始化token:创建一个用于存储输入token和生成图像token的张量。对于并行生成的每个样本,都复制输入token,并在奇数索引的样本中插入填充token。\n",
" - 输入Embedding:将token转换为模型可以理解的Embedding形式。\n",
" - 生成图像token:通过一个循环,逐步生成图像的每个令牌。在每个步骤中:\n",
" - 更新position id 以反映当前token生成的位置序号。\n",
" - 使用模型的语言模型部分生成下一个token的概率分布。\n",
" - 根据条件和无条件生成的 logits 以及温度调整概率分布。\n",
" - 使用paddle.multinomial根据调整后的概率分布采样下一个令牌。\n",
" - 使用生成的token生成图像Embedding,并更新输入Embedding以用于下一次迭代。\n",
" - 解码图像:将生成的图像token解码为图像数据。\n",
" - 后处理和保存:将解码后的图像数据标准化为0-255之间的整数,并保存为JPEG文件。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate(\n",
" mmgpt,\n",
" vl_chat_processor,\n",
" prompt: str,\n",
" temperature: float = 1,\n",
" parallel_size: int = 2,\n",
" cfg_weight: float = 5,\n",
" image_token_num_per_image: int = 576,\n",
" img_size: int = 384,\n",
" patch_size: int = 16,\n",
"):\n",
" input_ids = vl_chat_processor.tokenizer.encode(prompt)\n",
" input_ids = paddle.to_tensor(data=input_ids.input_ids, dtype=\"int64\")\n",
" tokens = paddle.zeros(shape=(parallel_size * 2, len(input_ids)), dtype=\"int32\")\n",
" for i in range(parallel_size * 2):\n",
" tokens[i, :] = input_ids\n",
" if i % 2 != 0:\n",
" tokens[i, 1:-1] = vl_chat_processor.pad_id\n",
" inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) # [4, 50, 2048]\n",
" generated_tokens = paddle.zeros(shape=(parallel_size, image_token_num_per_image), dtype=\"int32\")\n",
" batch_size, seq_length = inputs_embeds.shape[:2]\n",
" for i in tqdm(range(image_token_num_per_image)):\n",
" batch_size, seq_length = inputs_embeds.shape[:2]\n",
"\n",
" past_key_values_length = outputs.past_key_values[0][0].shape[1] if i != 0 else 0\n",
" position_ids = paddle.arange(past_key_values_length, seq_length + past_key_values_length).expand(\n",
" (batch_size, seq_length)\n",
" )\n",
"\n",
" outputs = mmgpt.language_model.llama(\n",
" position_ids=position_ids,\n",
" inputs_embeds=inputs_embeds, # [4, 1, 2048]\n",
" use_cache=True,\n",
" past_key_values=outputs.past_key_values if i != 0 else None,\n",
" return_dict=True,\n",
" )\n",
"\n",
" hidden_states = outputs.last_hidden_state\n",
" logits = mmgpt.gen_head(hidden_states[:, -1, :])\n",
" logit_cond = logits[0::2, :]\n",
" logit_uncond = logits[1::2, :]\n",
"\n",
" logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)\n",
" probs = paddle.nn.functional.softmax(x=logits / temperature, axis=-1)\n",
" next_token = paddle.multinomial(x=probs, num_samples=1)\n",
"\n",
" generated_tokens[:, i] = next_token.squeeze(axis=-1)\n",
" next_token = paddle.concat(x=[next_token.unsqueeze(axis=1), next_token.unsqueeze(axis=1)], axis=1).reshape(\n",
" [-1]\n",
" )\n",
" img_embeds = mmgpt.prepare_gen_img_embeds(next_token)\n",
" inputs_embeds = img_embeds.unsqueeze(axis=1)\n",
"\n",
" dec = mmgpt.gen_vision_model.decode_code(\n",
" generated_tokens.to(dtype=\"int32\"), shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]\n",
" )\n",
" dec = dec.to(\"float32\").cpu().numpy().transpose(0, 2, 3, 1)\n",
" dec = np.clip((dec + 1) / 2 * 255, 0, 255)\n",
" visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)\n",
" visual_img[:, :, :] = dec\n",
" os.makedirs(\"janus_generated_samples\", exist_ok=True)\n",
" for i in range(parallel_size):\n",
" save_path = os.path.join(\"janus_generated_samples\", \"img_{}.jpg\".format(i))\n",
" PIL.Image.fromarray(visual_img[i]).save(save_path)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 9133822

Please sign in to comment.