使用 Python 从头开始构建 AI 文本生成视频模型 OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经问世或未来将出现的文本转视频模型,都是继大语言模型之后,2024 年最流行的 AI 趋势之一(LLMs)。在本博客中,我们将从头开始构建一个小型文本生成视频模型。我们将输入一个文本提示,我们训练的模型将根据该提示生成一个视频。该博客将涵盖从理解理论概念到编码整个架构并生成最终结果的所有内容。
由于我没有高级的 GPU,所以我编写了小型架构。以下是在不同处理器上训练模型所需时间的比较:
Training Videos
Epochs
CPU
GPU A10
GPU T4
10K
30
more than 3 hr
1 hr
1 hr 42m
30K
30
more than 6 hr
1 hr 30
2 hr 30
100K
30
-
3-4 hr
5-6 hr
在 CPU 上运行显然需要更长的时间来训练模型。如果您需要快速测试代码中的更改并查看结果,CPU 并不是最佳选择。我建议使用 Colab 或 Kaggle 的 T4 GPU 来实现更高效、更快的训练。
为了避免复制和粘贴此博客中的代码,以下是包含笔记本文件以及所有代码和信息的 GitHub 存储库:AI-text-to-video-model-from-scratch
以下是博客链接,指导您如何从头开始创建Stable Diffusion:从头开始编码Stable Diffusion
目录
我们正在建设什么
先决条件
了解 GAN 架构
搭建舞台
对训练数据进行编码
预处理我们的训练数据
实现文本嵌入层
实现生成器层
实施鉴别器层
编码训练参数
训练循环编码
保存训练后的模型
生成人工智能视频
少了什么东西?
我们正在建设什么
我们将遵循与传统机器学习或深度学习模型类似的方法,在数据集上进行训练,然后在未见过的数据上进行测试。在文本到视频的背景下,假设我们有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集。我们将训练我们的模型来生成猫捡球或狗追老鼠的视频。
尽管此类训练数据集很容易在互联网上获得,但所需的计算能力非常高。因此,我们将使用由 Python 代码生成的移动对象的视频数据集。
我们将使用 GAN(生成对抗网络)架构来创建我们的模型,而不是 OpenAI Sora 使用的扩散模型。我尝试使用扩散模型,但由于内存要求而崩溃,这超出了我的能力。另一方面,GAN 的训练和测试更容易、更快捷。
先决条件
我们将使用 OOP(面向对象编程),因此您必须对它和神经网络有基本的了解。 GAN(生成对抗网络)的知识不是强制性的,因为我们将在这里介绍它们的架构。
了解 GAN 架构
理解 GAN 很重要,因为我们的大部分架构都依赖于它。让我们来探讨一下它是什么、它的组件等等。
什么是GAN?
生成对抗网络 (GAN) 是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集中创建新数据(例如图像或音乐),另一个尝试判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分为止。
实际应用
生成图像:GAN 根据文本提示创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。
数据增强:它们生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。
补全缺失的信息:GAN 可以填充缺失的数据,例如从地形图生成地下图像以用于能源应用。
生成 3D 模型:它们将 2D 图像转换为 3D 模型,这在医疗保健等领域非常有用,可以为手术规划创建逼真的器官图像。
GAN 是如何工作的?
它由两个深度神经网络组成:生成器和鉴别器。这些网络在对抗性设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真实的还是虚假的。
以下是 GAN 工作原理的简单概述:
训练集分析:生成器分析训练集以识别数据属性,而判别器独立分析相同的数据以学习其属性。
数据修改:生成器向数据的某些属性添加噪声(随机变化)。
数据传递:修改后的数据然后被传递到鉴别器。
概率计算:判别器计算生成的数据来自原始数据集的概率。
反馈循环:鉴别器向生成器提供反馈,指导生成器减少下一个周期的随机噪声。
对抗性训练:生成器试图最大化判别器的错误,而判别器则试图最小化自己的错误。通过多次训练迭代,两个网络都得到改进和发展。
平衡状态:训练继续,直到判别器无法再区分真实数据和合成数据,表明生成器已成功学会生成真实数据。至此,训练过程就完成了。
GAN 训练示例
让我们以图像到图像转换的示例来解释 GAN 模型,重点是修改人脸。
输入图像:输入是真实的人脸图像。
属性修改:生成器修改脸部的属性,例如为眼睛添加墨镜。
生成的图像:生成器创建一组添加了太阳镜的图像。
鉴别器的任务:鉴别器接收真实图像(戴太阳镜的人)和生成图像(添加太阳镜的人脸)的混合。
评估:鉴别器试图区分真实图像和生成图像。
反馈循环:如果鉴别器正确识别出假图像,则生成器会调整其参数以产生更令人信服的图像。如果生成器成功欺骗了鉴别器,鉴别器就会更新其参数以改进其检测。
通过这个对抗过程,两个网络都在不断改进。生成器在创建真实图像方面变得更好,鉴别器在识别赝品方面也变得更好,直到达到平衡,鉴别器不再能够区分真实图像和生成图像之间的差异。至此,GAN 已成功学会产生现实的修改。
搭建舞台
我们将使用一系列 Python 库,让我们导入它们。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 import osimport randomimport numpy as npimport cv2from PIL import Image, ImageDraw, ImageFontimport torchfrom torch.utils.data import Datasetimport torchvision.transforms as transformsimport torch.nn as nnimport torch.optim as optimfrom torch.nn.utils.rnn import pad_sequencefrom torchvision.utils import save_imageimport matplotlib.pyplot as pltfrom IPython.display import clear_output, display, HTMLimport base64
现在我们已经导入了所有库,下一步是定义我们将用来训练 GAN 架构的训练数据。
对训练数据进行编码
我们需要至少 10,000 个视频作为训练数据。为什么?嗯,因为我用较小的数字进行了测试,结果很差,几乎没有什么可看的。下一个大问题是:这些视频是关于什么的?我们的训练视频数据集由一个以不同运动向不同方向移动的圆圈组成。那么,让我们对其进行编码并生成 10,000 个视频来看看它是什么样子。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 os.makedirs('training_dataset' , exist_ok=True ) num_videos = 10000 frames_per_video = 10 img_size = (64 , 64 ) shape_size = 10
设置一些基本参数后,接下来我们需要定义训练数据集的文本提示,根据该文本提示将生成训练视频。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 prompts_and_movements = [ ("circle moving down" , "circle" , "down" ), ("circle moving left" , "circle" , "left" ), ("circle moving right" , "circle" , "right" ), ("circle moving diagonally up-right" , "circle" , "diagonal_up_right" ), ("circle moving diagonally down-left" , "circle" , "diagonal_down_left" ), ("circle moving diagonally up-left" , "circle" , "diagonal_up_left" ), ("circle moving diagonally down-right" , "circle" , "diagonal_down_right" ), ("circle rotating clockwise" , "circle" , "rotate_clockwise" ), ("circle rotating counter-clockwise" , "circle" , "rotate_counter_clockwise" ), ("circle shrinking" , "circle" , "shrink" ), ("circle expanding" , "circle" , "expand" ), ("circle bouncing vertically" , "circle" , "bounce_vertical" ), ("circle bouncing horizontally" , "circle" , "bounce_horizontal" ), ("circle zigzagging vertically" , "circle" , "zigzag_vertical" ), ("circle zigzagging horizontally" , "circle" , "zigzag_horizontal" ), ("circle moving up-left" , "circle" , "up_left" ), ("circle moving down-right" , "circle" , "down_right" ), ("circle moving down-left" , "circle" , "down_left" ), ]
我们使用这些提示定义了圆圈的几种运动。现在,我们需要编写一些数学方程来根据提示移动该圆圈。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 def create_image_with_moving_shape (size, frame_num, shape, direction ): img = Image.new('RGB' , size, color=(255 , 255 , 255 )) draw = ImageDraw.Draw(img) center_x, center_y = size[0 ] // 2 , size[1 ] // 2 position = (center_x, center_y) direction_map = { "down" : (0 , frame_num * 5 % size[1 ]), "left" : (-frame_num * 5 % size[0 ], 0 ), "right" : (frame_num * 5 % size[0 ], 0 ), "diagonal_up_right" : (frame_num * 5 % size[0 ], -frame_num * 5 % size[1 ]), "diagonal_down_left" : (-frame_num * 5 % size[0 ], frame_num * 5 % size[1 ]), "diagonal_up_left" : (-frame_num * 5 % size[0 ], -frame_num * 5 % size[1 ]), "diagonal_down_right" : (frame_num * 5 % size[0 ], frame_num * 5 % size[1 ]), "rotate_clockwise" : img.rotate(frame_num * 10 % 360 , center=(center_x, center_y), fillcolor=(255 , 255 , 255 )), "rotate_counter_clockwise" : img.rotate(-frame_num * 10 % 360 , center=(center_x, center_y), fillcolor=(255 , 255 , 255 )), "bounce_vertical" : (0 , center_y - abs (frame_num * 5 % size[1 ] - center_y)), "bounce_horizontal" : (center_x - abs (frame_num * 5 % size[0 ] - center_x), 0 ), "zigzag_vertical" : (0 , center_y - frame_num * 5 % size[1 ]) if frame_num % 2 == 0 else (0 , center_y + frame_num * 5 % size[1 ]), "zigzag_horizontal" : (center_x - frame_num * 5 % size[0 ], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0 ], center_y), "up_right" : (frame_num * 5 % size[0 ], -frame_num * 5 % size[1 ]), "up_left" : (-frame_num * 5 % size[0 ], -frame_num * 5 % size[1 ]), "down_right" : (frame_num * 5 % size[0 ], frame_num * 5 % size[1 ]), "down_left" : (-frame_num * 5 % size[0 ], frame_num * 5 % size[1 ]) } if direction in direction_map: if isinstance (direction_map[direction], tuple ): position = tuple (np.add(position, direction_map[direction])) else : img = direction_map[direction] return np.array(img)
上面的函数用于根据所选方向移动每一帧的圆。我们只需要在其上运行一个循环,直到达到视频数量的次数即可生成所有视频。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 for i in range (num_videos): prompt, shape, direction = random.choice(prompts_and_movements) video_dir = f'training_dataset/video_{i} ' os.makedirs(video_dir, exist_ok=True ) with open (f'{video_dir} /prompt.txt' , 'w' ) as f: f.write(prompt) for frame_num in range (frames_per_video): img = create_image_with_moving_shape(img_size, frame_num, shape, direction) cv2.imwrite(f'{video_dir} /frame_{frame_num} .png' , img)
运行上述代码后,它将生成我们的整个训练数据集。这是我们的训练数据集文件的结构。
每个训练视频文件夹都包含其帧及其文本提示。让我们看一下训练数据集的样本。
在我们的训练数据集中,我们没有包含圆圈向上移动然后向右移动的运动。我们将使用它作为我们的测试提示来评估我们在未见过的数据上训练的模型。
需要注意的更重要的一点是,我们的训练数据确实包含许多样本,其中物体远离场景或部分出现在相机前面,类似于我们在 OpenAI Sora 演示视频中观察到的情况。
在我们的训练数据中包含此类样本的原因是为了测试当圆圈从最角落进入场景而不破坏其形状时,我们的模型是否能够保持一致性。
现在我们的训练数据已经生成,我们需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,执行归一化等转换有助于通过将数据扩展到更小的范围来提高训练架构的收敛性和稳定性。
预处理我们的训练数据
我们必须为文本到视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本提示,使其可在 PyTorch 中使用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 class TextToVideoDataset (Dataset ): def __init__ (self, root_dir, transform=None ): self.root_dir = root_dir self.transform = transform self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))] self.frame_paths = [] self.prompts = [] for video_dir in self.video_dirs: frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png' )] self.frame_paths.extend(frames) with open (os.path.join(video_dir, 'prompt.txt' ), 'r' ) as f: prompt = f.read().strip() self.prompts.extend([prompt] * len (frames)) def __len__ (self ): return len (self.frame_paths) def __getitem__ (self, idx ): frame_path = self.frame_paths[idx] image = Image.open (frame_path) prompt = self.prompts[idx] if self.transform: image = self.transform(image) return image, prompt
在继续对架构进行编码之前,我们需要规范化我们的训练数据。我们将使用 16 的批量大小并对数据进行打乱以引入更多随机性。
1 2 3 4 5 6 7 8 9 10 11 12 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5 ,), (0.5 ,)) ]) dataset = TextToVideoDataset(root_dir='training_dataset' , transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=16 , shuffle=True )
实现文本嵌入层
您可能已经在 Transformer 架构中看到过,其中的起点是将我们的文本输入转换为嵌入,以便在多头注意力中进一步处理,与这里类似,我们必须编写一个文本嵌入层,基于该层,GAN 架构训练将在我们的嵌入数据上进行和图像张量。
1 2 3 4 5 6 7 8 9 10 11 12 13 class TextEmbedding (nn.Module): def __init__ (self, vocab_size, embed_size ): super (TextEmbedding, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_size) def forward (self, x ): return self.embedding(x)
词汇量大小将基于我们的训练数据,稍后我们将计算这些数据。嵌入大小将为 10。如果使用更大的数据集,您还可以使用自己选择的 Hugging Face 上提供的嵌入模型。
实现生成器层
现在我们已经知道生成器在 GAN 中的作用,让我们对这一层进行编码,然后了解其内容。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 class Generator (nn.Module): def __init__ (self, text_embed_size ): super (Generator, self).__init__() self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8 ) self.deconv1 = nn.ConvTranspose2d(256 , 128 , 4 , 2 , 1 ) self.deconv2 = nn.ConvTranspose2d(128 , 64 , 4 , 2 , 1 ) self.deconv3 = nn.ConvTranspose2d(64 , 3 , 4 , 2 , 1 ) self.relu = nn.ReLU(True ) self.tanh = nn.Tanh() def forward (self, noise, text_embed ): x = torch.cat((noise, text_embed), dim=1 ) x = self.fc1(x).view(-1 , 256 , 8 , 8 ) x = self.relu(self.deconv1(x)) x = self.relu(self.deconv2(x)) x = self.tanh(self.deconv3(x)) return x
这个 Generator 类负责根据随机噪声和文本嵌入的组合创建视频帧。它的目的是根据给定的文本描述生成逼真的视频帧。该网络从全连接层 ( nn.Linear ) 开始,它将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重塑并通过一系列转置卷积层 ( nn.ConvTranspose2d ),这些层逐渐将特征图上采样到所需的视频帧大小。
这些层使用 ReLU 激活 ( nn.ReLU ) 实现非线性,最后一层使用 Tanh 激活 ( nn.Tanh ) 将输出缩放到范围 [-1, 1]。因此,生成器将抽象的高维输入转换为连贯的视频帧,直观地表示输入文本。
实施鉴别器层
对生成器层进行编码后,我们需要实现另一半,即鉴别器部分。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class Discriminator (nn.Module): def __init__ (self ): super (Discriminator, self).__init__() self.conv1 = nn.Conv2d(3 , 64 , 4 , 2 , 1 ) self.conv2 = nn.Conv2d(64 , 128 , 4 , 2 , 1 ) self.conv3 = nn.Conv2d(128 , 256 , 4 , 2 , 1 ) self.fc1 = nn.Linear(256 * 8 * 8 , 1 ) self.leaky_relu = nn.LeakyReLU(0.2 , inplace=True ) self.sigmoid = nn.Sigmoid() def forward (self, input ): x = self.leaky_relu(self.conv1(input )) x = self.leaky_relu(self.conv2(x)) x = self.leaky_relu(self.conv3(x)) x = x.view(-1 , 256 * 8 * 8 ) x = self.sigmoid(self.fc1(x)) return x
Discriminator 类充当二元分类器,区分真实视频帧和生成的视频帧。其目的是评估视频帧的真实性,从而指导生成器产生更真实的输出。该网络由卷积层 ( nn.Conv2d ) 组成,这些卷积层从输入视频帧中提取分层特征,并使用 Leaky ReLU 激活 ( nn.LeakyReLU ) 添加非线性,同时允许负梯度较小价值观。然后特征图被展平并通过全连接层 ( nn.Linear ),最终形成 sigmoid 激活 ( nn.Sigmoid ),输出一个概率分数,指示帧是真实的还是假的。
通过训练鉴别器对帧进行准确分类,同时训练生成器以创建更有说服力的视频帧,因为它的目的是欺骗鉴别器。
编码训练参数
我们必须设置训练 GAN 的基本组件,例如损失函数、优化器等。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu" ) all_prompts = [prompt for prompt, _, _ in prompts_and_movements] vocab = {word: idx for idx, word in enumerate (set (" " .join(all_prompts).split()))} vocab_size = len (vocab) embed_size = 10 def encode_text (prompt ): return torch.tensor([vocab[word] for word in prompt.split()]) text_embedding = TextEmbedding(vocab_size, embed_size).to(device) netG = Generator(embed_size).to(device) netD = Discriminator().to(device) criterion = nn.BCELoss().to(device) optimizerD = optim.Adam(netD.parameters(), lr=0.0002 , betas=(0.5 , 0.999 )) optimizerG = optim.Adam(netG.parameters(), lr=0.0002 , betas=(0.5 , 0.999 ))
这是我们必须将代码转换为在 GPU 上运行(如果可用)的部分。我们已编写代码来查找 vocab_size,并且我们对生成器和判别器使用 ADAM 优化器。如果您愿意,可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这与其他可供公众使用的 Hugging Face 模型相比要小得多。
训练循环编码
就像所有其他神经网络一样,我们将以类似的方式编码 GAN 架构训练。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 num_epochs = 13 for epoch in range (num_epochs): for i, (data, prompts) in enumerate (dataloader): real_data = data.to(device) prompts = [prompt for prompt in prompts] netD.zero_grad() batch_size = real_data.size(0 ) labels = torch.ones(batch_size, 1 ).to(device) output = netD(real_data) lossD_real = criterion(output, labels) lossD_real.backward() noise = torch.randn(batch_size, 100 ).to(device) text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0 ) for prompt in prompts]) fake_data = netG(noise, text_embeds) labels = torch.zeros(batch_size, 1 ).to(device) output = netD(fake_data.detach()) lossD_fake = criterion(output, labels) lossD_fake.backward() optimizerD.step() netG.zero_grad() labels = torch.ones(batch_size, 1 ).to(device) output = netD(fake_data) lossG = criterion(output, labels) lossG.backward() optimizerG.step() print (f"Epoch [{epoch + 1 } /{num_epochs} ] Loss D: {lossD_real + lossD_fake} , Loss G: {lossG} " )
通过反向传播,我们的损失将针对生成器和鉴别器进行调整。我们使用了 13 个 epoch 进行训练循环。我测试了不同的值,但如果纪元高于此值,结果不会显示出太大差异。而且,遇到过拟合的风险很高。如果我们有一个更多样化的数据集,有更多的运动和形状,我们可以考虑使用更高的纪元,但在这种情况下不行。
当我们运行此代码时,它会开始训练并在每个时期后打印生成器和鉴别器的损失。
保存训练后的模型
训练完成后,我们需要保存训练好的 GAN 架构的判别器和生成器,这只需两行代码即可实现。
1 2 3 4 5 torch.save(netG.state_dict(), 'generator.pth' ) torch.save(netD.state_dict(), 'discriminator.pth' )
生成人工智能视频
正如我们所讨论的,我们在未见过的数据上测试模型的方法与我们的训练数据涉及狗捡球和猫追老鼠的示例相当。因此,我们的测试提示可能涉及诸如猫取球或狗追老鼠之类的场景。
在我们的具体情况下,圆圈向上然后向右移动的运动不存在于我们的训练数据中,因此模型不熟悉这种特定运动。然而,它已经接受了其他动作的训练。我们可以使用这个动作作为测试我们训练的模型并观察其性能的提示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def generate_video (text_prompt, num_frames=10 ): os.makedirs(f'generated_video_{text_prompt.replace(" " , "_" )} ' , exist_ok=True ) text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0 ).unsqueeze(0 ) for frame_num in range (num_frames): noise = torch.randn(1 , 100 ).to(device) with torch.no_grad(): fake_frame = netG(noise, text_embed) save_image(fake_frame, f'generated_video_{text_prompt.replace(" " , "_" )} /frame_{frame_num} .png' ) generate_video('circle moving up-right' )
当我们运行上面的代码时,它将生成一个目录,其中包含我们生成的视频的所有帧。我们需要使用一些代码将所有这些帧合并成一个短视频。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 folder_path = 'generated_video_circle_moving_up-right' image_files = [f for f in os.listdir(folder_path) if f.endswith('.png' )] image_files.sort() frames = [] for image_file in image_files: image_path = os.path.join(folder_path, image_file) frame = cv2.imread(image_path) frames.append(frame) frames = np.array(frames) fps = 10 fourcc = cv2.VideoWriter_fourcc(*'XVID' ) out = cv2.VideoWriter('generated_video.avi' , fourcc, fps, (frames[0 ].shape[1 ], frames[0 ].shape[0 ])) for frame in frames: out.write(frame) out.release()
确保文件夹路径指向新生成的视频所在的位置。运行此代码后,您的AI视频将已成功创建。让我们看看它是什么样子的。
我以相同的时期数进行了多次训练。在这两种情况下,圆圈都是从底部出现的一半开始。好的部分是我们的模型尝试在这两种情况下执行直立运动。例如,在尝试 1 中,圆圈沿对角线向上移动,然后执行向上运动,而在尝试 2 中,圆圈沿对角线移动,同时缩小尺寸。在这两种情况下,圆圈都没有向左移动或完全消失,这是一个好兆头。
少了什么东西?
我测试了该架构的各个方面,发现训练数据是关键。通过在数据集中包含更多运动和形状,您可以增加可变性并提高模型的性能。由于数据是通过代码生成的,因此生成更多样的数据不会花费太多时间;相反,您可以专注于完善逻辑。
此外,本博客中讨论的 GAN 架构相对简单。您可以通过集成先进技术或使用语言模型嵌入 (LLM) 而不是基本的神经网络嵌入来使其变得更加复杂。此外,调整嵌入大小等参数可以显着影响模型的有效性。
文章转载于:使用 Python 从头开始构建 AI 文本到视频模型