-
Notifications
You must be signed in to change notification settings - Fork 178
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: nifeng <[email protected]>
- Loading branch information
1 parent
aebbbac
commit 9133822
Showing
1 changed file
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Janus 介绍" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 1 引言\n", | ||
"\n", | ||
"Janus是一个简单、统一且可扩展的多模态理解与生成模型,其将多模态理解与生成的视觉编码进行解耦,缓解了两个任务潜在存在的冲突。可在未来通过拓展,纳入更多的输入模态。Janus-Pro在此基础上,优化训练策略(包括增加训练步数、调整数据配比等)、增加数据(包括使用合成数据等)、扩大模型规模(扩大到70亿参数),使得模型多模态理解和文本到图像指令遵循能力方面取得了进步。\n", | ||
"\n", | ||
"Janus包含2个独立的视觉编码路径,分别用于多模态理解、生成,并带来两个收益:1)缓解了源自多模态理解和生成不同粒度需求的冲突,2)具有灵活性和可扩展性,解耦后,理解和生成任务都可以采用针对其领域最先进的编码技术,未来可输入点云、脑电信号或音频数据,使用统一的Transformer进行处理。\n", | ||
"<div style=\"display: flex; justify-content: center; align-items: center; height: [desired-container-height]px;\">\n", | ||
" <img src=\"https://ai-studio-static-online.cdn.bcebos.com/ea0703505b3b40ad923981dbddda20973c81da7a36194e3abc75ad1d9b870ab4\" alt=\"Description\" width=\"50%\" >\n", | ||
"</div>\n", | ||
"<center>图1: Janus 架构</center>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 2 方法\n", | ||
"\n", | ||
"### 2.1 模型架构\n", | ||
"Janus 的架构如图 1 所示。对于纯文本理解、多模态理解和视觉生成任务,采用独立的编码方法将原始输入转换为特征,然后通过统一的自回归 Transformer 进行处理。具体来说:\n", | ||
"- 文本理解:我们使用大语言模型(LLM)内置的分词器将文本转换为离散的 ID,并获取每个 ID 对应的特征表示。\n", | ||
"- 多模态理解:我们使用 SigLIP 编码器从图像中提取高维语义特征。这些特征从 2D 网格展平为 1D 序列,并通过一个理解适配器将这些图像特征映射到 LLM 的输入空间。\n", | ||
"- 视觉生成:我们使用 VQ 分词器将图像转换为离散的 ID。将 ID 序列展平为 1D 后,使用一个生成适配器将每个 ID 对应的码本嵌入映射到 LLM 的输入空间。\n", | ||
"然后,我们将这些特征序列连接起来,形成一个多模态特征序列,随后输入到 LLM 中进行处理。在纯文本理解和多模态理解任务中,使用 LLM 内置的预测头进行文本预测;而在视觉生成任务中,使用随机初始化的预测头进行图像预测。整个模型遵循自回归框架,无需特别设计的注意力掩码。\n", | ||
"\n", | ||
"### 2.2 Janus 训练\n", | ||
"Janus的训练分为3个阶段:\n", | ||
"- 第一阶段:训练Adaptor与Image Head,在嵌入空间创建语言元素与视觉元素之间的联系,使得LLM能够理解图像中的实体,并具备初步视觉生成能力;\n", | ||
"对于多模态理解,使用来自SHareGPT4V125万个图像-文本配对字幕数据,格式:<图像><文本>;\n", | ||
"对于视觉生成,使用来自ImageNet1k的120万个样本,格式:<类别名><图像>;\n", | ||
"\n", | ||
"- 第二阶段:统一预训练,使用多模态语料库进行统一预训练,学习多模态理解和生成。\n", | ||
" - 在该阶段使用纯文本数据、多模态理解数据和视觉生成数据\n", | ||
" - 使用ImageNet-1k进行简单的视觉生成训练,随后使用通用文本到图像数据提升模型开放领域的视觉生成能力\n", | ||
" - 纯文本数据:DeepSeek-LLM预训练语料库\n", | ||
" - 交错的图像-文本数据:WikiHow 和 WIT 数据集;\n", | ||
" - 图像Caption数据:来自多个来源的图像,并采用开源多模态模型重新为部分图像添加字幕,数据格式为问答对,如$\\texttt{<caption>}$ Describe the image in detail.$\\texttt{<caption>}$;\n", | ||
" - 表格和图表数据:来自 DeepSeek-VL的相应表格和图表数据,数据格式为<question><answer>;\n", | ||
" - 视觉生成数据:来自多个数据集的image-caption对以及 200 万个内部数据;在训练过程中,以25%的概率随机仅使用caption的第一句话;ImageNet 样本仅在最初的 120K 训练步骤中出现,其他数据集的图像在后续 60K 步骤中出现;\n", | ||
"- 第三阶段:监督微调,使用指令微调数据对预训练模型进行微调,以增强其遵循指令和对话的能力。微调除生成编码器之外的所有参数。在监督答案的同时,对系统和用户提示进行遮盖。为了确保Janus在多模态理解和生成方面都具备熟练度,不会针对特定任务分别微调模型。相反,我们使用纯文本对话数据、多模态理解数据和视觉生成数据的混合数据,以确保在各种场景下的多功能性;\n", | ||
" - 文本理解:使用来自特定来源的数据;\n", | ||
" - 多模态理解:使用来自多个来源的指令调整数据;\n", | ||
" - 视觉生成:使用来自部分第二阶段数据集的图像-文本对子集以及 400 万个内部数据;\n", | ||
" - 数据格式为:User:$\\texttt{<Input Message>}$ \\n Assistant: $\\texttt{<Response>}$;\n", | ||
"<div style=\"display: flex; justify-content: center; align-items: center; height: [desired-container-height]px;\">\n", | ||
" <img src=\"https://github.com/user-attachments/assets/0035318f-3348-4e5d-9256-a9a3410fa625\" alt=\"Description\" width=\"50%\" >\n", | ||
"</div>\n", | ||
"<center>图2: Janus 三阶段训练步骤</center>\n", | ||
"\n", | ||
"### 2.3 Janus 推理\n", | ||
"在推理过程中,Janus 模型采用了一种 Next-token预测的方法。对于纯文本理解和多模态理解,我们遵循从预测分布中顺序采样token的标准做法。对于图像生成,我们利用了无分类器引导(CFG)在训练过程中,我们以10%的概率将文本到图像数据中的文本条件替换为填充token,使模型具备无条件视觉生成能力。对于生成下一个token的概率分布 $l_g$ 的计算公式为 $l_g = l_u + s(l_c - l_u)$ ,$l_c$是条件概率分布,$l_u$ 是条件概率分布,$s$ 是CFG系数,默认情况下 $s$ 为 5.\n", | ||
"\n", | ||
"### 2.4 Janus-Pro\n", | ||
"- 训练策略\n", | ||
" - Stage 1: 增加训练步数,在 ImageNet 上充分训练;\n", | ||
" - Stage 2: 不再使用 ImageNet,直接使用常规文本到图像数据的训练数据;\n", | ||
" - Stage 3: 修改微调过程中的数据集配比,将多模态数据、纯文本数据和文本到图像的比例从 7:3:10 改为 5:1:4;\n", | ||
"- 数据规模\n", | ||
" - 多模态理解\n", | ||
" - Stage 2: 增加 9000 万个样本,包括图像字幕数据 YFCC、表格图表文档理解数据 Doc-matrix;\n", | ||
" - Stage 3: 加入 DeepSeek-VL2 额外数据集,如 MEME 理解等;\n", | ||
" - 视觉生成:真实世界数据可能包含质量不高,导致文本到图像的生成不稳定,产生美学效果不佳的输出,Janus-Pro 使用 7200 万份合成美学数据样本,统一预训练阶段(Stage 2)真实数据与合成数据比例 1:1;\n", | ||
"- 模型规模\n", | ||
" - 将模型参数扩展到 70 亿参数规模;" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 3 代码解读\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 3.1 Janus 组网代码介绍" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"- 类名: JanusMultiModalityCausalLM\n", | ||
"- 功能: 该类实现了一个多模态因果语言模型,它能够处理图像和文本数据\n", | ||
"- 实现步骤: \n", | ||
" - 初始化:\n", | ||
" - 从配置对象中提取各个组件的配置。\n", | ||
" - 使用model_name_to_cls函数和配置参数来实例化视觉模型、对齐器、生成视觉模型、生成对齐器、生成头和语言模型。\n", | ||
" - 创建一个Embedding层用于图像标识符到Embedding向量的映射。\n", | ||
" - 准备输入Embed:\n", | ||
" - 重新排列图像数据pixel_values的形状以适应视觉模型的输入要求。\n", | ||
" - 使用视觉模型处理图像数据,并通过对齐器生成图像Embed。\n", | ||
" - 重新排列图像Embed和掩码的形状以匹配文本输入的形状。\n", | ||
" - 处理文本输入input_ids,将其转换为语言模型可以处理的Embed形式。\n", | ||
" - 根据掩码将图像嵌入插入到文本嵌入中,生成最终的输入Embed input_embeds。\n", | ||
" - 准备生成图像嵌入:\n", | ||
" - 使用嵌入层将图像标识符image_ids映射到 Embedding 向量。\n", | ||
" - 通过对齐器处理这些 Embedding 向量,生成最终的图像Embedding。\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class JanusMultiModalityCausalLM(JanusMultiModalityPreTrainedModel):\n", | ||
" config_class = MultiModalityConfig\n", | ||
"\n", | ||
" def __init__(self, config: MultiModalityConfig):\n", | ||
" super().__init__(config)\n", | ||
" vision_config = config.vision_config\n", | ||
" vision_cls = model_name_to_cls(vision_config.cls)\n", | ||
" self.vision_model = vision_cls(**vision_config.params)\n", | ||
" aligner_config = config.aligner_config\n", | ||
" aligner_cls = model_name_to_cls(aligner_config.cls)\n", | ||
" self.aligner = aligner_cls(aligner_config.params)\n", | ||
" gen_vision_config = config.gen_vision_config\n", | ||
" gen_vision_cls = model_name_to_cls(gen_vision_config.cls)\n", | ||
" self.gen_vision_model = gen_vision_cls()\n", | ||
" gen_aligner_config = config.gen_aligner_config\n", | ||
" gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)\n", | ||
" self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)\n", | ||
" gen_head_config = config.gen_head_config\n", | ||
" gen_head_cls = model_name_to_cls(gen_head_config.cls)\n", | ||
" self.gen_head = gen_head_cls(gen_head_config.params)\n", | ||
" self.gen_embed = paddle.nn.Embedding(\n", | ||
" num_embeddings=gen_vision_config.params[\"image_token_size\"],\n", | ||
" embedding_dim=gen_vision_config.params[\"n_embed\"],\n", | ||
" )\n", | ||
" language_config = config.language_config\n", | ||
" self.language_model = LlamaForCausalLM(language_config)\n", | ||
"\n", | ||
" def prepare_inputs_embeds(\n", | ||
" self,\n", | ||
" input_ids: paddle.Tensor,\n", | ||
" pixel_values: paddle.Tensor,\n", | ||
" images_seq_mask: paddle.Tensor,\n", | ||
" images_emb_mask: paddle.Tensor,\n", | ||
" **kwargs\n", | ||
" ):\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" Args:\n", | ||
" input_ids (paddle.Tensor): [b, T]\n", | ||
" pixel_values (paddle.Tensor): [b, n_images, 3, h, w]\n", | ||
" images_seq_mask (paddle.Tensor): [b, T]\n", | ||
" images_emb_mask (paddle.Tensor): [b, n_images, n_image_tokens]\n", | ||
"\n", | ||
" assert paddle.sum(images_seq_mask) == paddle.sum(images_emb_mask)\n", | ||
"\n", | ||
" Returns:\n", | ||
" input_embeds (paddle.Tensor): [b, T, D]\n", | ||
" \"\"\"\n", | ||
" bs, n = tuple(pixel_values.shape)[0:2]\n", | ||
" images = rearrange(pixel_values, \"b n c h w -> (b n) c h w\")\n", | ||
" images_embeds = self.aligner(self.vision_model(images))\n", | ||
" images_embeds = rearrange(images_embeds, \"(b n) t d -> b (n t) d\", b=bs, n=n)\n", | ||
" images_emb_mask = rearrange(images_emb_mask, \"b n t -> b (n t)\")\n", | ||
" input_ids[input_ids < 0] = 0\n", | ||
" inputs_embeds = self.language_model.get_input_embeddings()(input_ids)\n", | ||
" inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]\n", | ||
"\n", | ||
" return inputs_embeds\n", | ||
"\n", | ||
" def prepare_gen_img_embeds(self, image_ids: paddle.Tensor):\n", | ||
" return self.gen_aligner(self.gen_embed(image_ids))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 3.2 Janus 多模态生成代码介绍" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"- 方法: generate\n", | ||
"- 参数:\n", | ||
" - mmgpt:JanusMultiModalityCausalLM的对象,负责生成图像和文本。\n", | ||
" - vl_chat_processor:一个处理器对象,用于处理视觉-语言(VL)聊天数据,包括分词和图像编码等。\n", | ||
" - prompt:一个字符串,代表输入给模型的文本提示。\n", | ||
" - temperature:一个浮点数,用于调整生成结果的随机性(或称为“温度”)。较低的值会使生成结果更加确定,而较高的值会增加多样性。\n", | ||
" - parallel_size:一个整数,表示并行生成图像的数量。\n", | ||
" - cfg_weight:一个浮点数,用于在生成过程中调整条件和无条件生成的概率分布(logits)之间的权重。\n", | ||
" - image_token_num_per_image:一个整数,表示每张图像生成的token数量。\n", | ||
" - img_size:一个整数,表示生成图像的尺寸(假设图像是正方形)。\n", | ||
" - patch_size:一个整数,表示图像被分割成的小块(patch)的尺寸\n", | ||
"- 步骤:\n", | ||
" - 文本处理:使用vl_chat_processor的分词器将文本提示编码为输入ID,然后转换为Paddle张量。\n", | ||
" - 初始化token:创建一个用于存储输入token和生成图像token的张量。对于并行生成的每个样本,都复制输入token,并在奇数索引的样本中插入填充token。\n", | ||
" - 输入Embedding:将token转换为模型可以理解的Embedding形式。\n", | ||
" - 生成图像token:通过一个循环,逐步生成图像的每个令牌。在每个步骤中:\n", | ||
" - 更新position id 以反映当前token生成的位置序号。\n", | ||
" - 使用模型的语言模型部分生成下一个token的概率分布。\n", | ||
" - 根据条件和无条件生成的 logits 以及温度调整概率分布。\n", | ||
" - 使用paddle.multinomial根据调整后的概率分布采样下一个令牌。\n", | ||
" - 使用生成的token生成图像Embedding,并更新输入Embedding以用于下一次迭代。\n", | ||
" - 解码图像:将生成的图像token解码为图像数据。\n", | ||
" - 后处理和保存:将解码后的图像数据标准化为0-255之间的整数,并保存为JPEG文件。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def generate(\n", | ||
" mmgpt,\n", | ||
" vl_chat_processor,\n", | ||
" prompt: str,\n", | ||
" temperature: float = 1,\n", | ||
" parallel_size: int = 2,\n", | ||
" cfg_weight: float = 5,\n", | ||
" image_token_num_per_image: int = 576,\n", | ||
" img_size: int = 384,\n", | ||
" patch_size: int = 16,\n", | ||
"):\n", | ||
" input_ids = vl_chat_processor.tokenizer.encode(prompt)\n", | ||
" input_ids = paddle.to_tensor(data=input_ids.input_ids, dtype=\"int64\")\n", | ||
" tokens = paddle.zeros(shape=(parallel_size * 2, len(input_ids)), dtype=\"int32\")\n", | ||
" for i in range(parallel_size * 2):\n", | ||
" tokens[i, :] = input_ids\n", | ||
" if i % 2 != 0:\n", | ||
" tokens[i, 1:-1] = vl_chat_processor.pad_id\n", | ||
" inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) # [4, 50, 2048]\n", | ||
" generated_tokens = paddle.zeros(shape=(parallel_size, image_token_num_per_image), dtype=\"int32\")\n", | ||
" batch_size, seq_length = inputs_embeds.shape[:2]\n", | ||
" for i in tqdm(range(image_token_num_per_image)):\n", | ||
" batch_size, seq_length = inputs_embeds.shape[:2]\n", | ||
"\n", | ||
" past_key_values_length = outputs.past_key_values[0][0].shape[1] if i != 0 else 0\n", | ||
" position_ids = paddle.arange(past_key_values_length, seq_length + past_key_values_length).expand(\n", | ||
" (batch_size, seq_length)\n", | ||
" )\n", | ||
"\n", | ||
" outputs = mmgpt.language_model.llama(\n", | ||
" position_ids=position_ids,\n", | ||
" inputs_embeds=inputs_embeds, # [4, 1, 2048]\n", | ||
" use_cache=True,\n", | ||
" past_key_values=outputs.past_key_values if i != 0 else None,\n", | ||
" return_dict=True,\n", | ||
" )\n", | ||
"\n", | ||
" hidden_states = outputs.last_hidden_state\n", | ||
" logits = mmgpt.gen_head(hidden_states[:, -1, :])\n", | ||
" logit_cond = logits[0::2, :]\n", | ||
" logit_uncond = logits[1::2, :]\n", | ||
"\n", | ||
" logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)\n", | ||
" probs = paddle.nn.functional.softmax(x=logits / temperature, axis=-1)\n", | ||
" next_token = paddle.multinomial(x=probs, num_samples=1)\n", | ||
"\n", | ||
" generated_tokens[:, i] = next_token.squeeze(axis=-1)\n", | ||
" next_token = paddle.concat(x=[next_token.unsqueeze(axis=1), next_token.unsqueeze(axis=1)], axis=1).reshape(\n", | ||
" [-1]\n", | ||
" )\n", | ||
" img_embeds = mmgpt.prepare_gen_img_embeds(next_token)\n", | ||
" inputs_embeds = img_embeds.unsqueeze(axis=1)\n", | ||
"\n", | ||
" dec = mmgpt.gen_vision_model.decode_code(\n", | ||
" generated_tokens.to(dtype=\"int32\"), shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]\n", | ||
" )\n", | ||
" dec = dec.to(\"float32\").cpu().numpy().transpose(0, 2, 3, 1)\n", | ||
" dec = np.clip((dec + 1) / 2 * 255, 0, 255)\n", | ||
" visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)\n", | ||
" visual_img[:, :, :] = dec\n", | ||
" os.makedirs(\"janus_generated_samples\", exist_ok=True)\n", | ||
" for i in range(parallel_size):\n", | ||
" save_path = os.path.join(\"janus_generated_samples\", \"img_{}.jpg\".format(i))\n", | ||
" PIL.Image.fromarray(visual_img[i]).save(save_path)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |