pytorch吧 关注:1,253贴子:2,347
  • 2回复贴,共1

关于pytorch多模型训练的问题

只看楼主收藏回复

我定义了一个训练函数:def train(train_loader1, train_loader2, model1, model2, criterion, optimizer1, optimizer2, epoch, evaluation, logger):
这里两个模型用不同的数据进行训练,这两个模型的输出维度是不同的,一个是120,另一个是180,然后创建了不同的优化器:
optimizer1 = torch.optim.Adam(model1.parameters(), lr=args.lr1)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=args.lr2)
然后就进行了正常训练:
output1 = model1(g1, h1, e1)
output2 = model2(g2, h2, e2
output_all = torch.cat((output_vm, output_va), dim=1
P_tensor, Q_tensor, Pij_tensor, Qij_tensor = get_pf(output_all
train_loss1 = criterion(output_vm, target1) + criterion(P_tensor, P_in) + criterion(Q_tensor, Q_in) + \
criterion(Pij_tensor, edge_f_Pij) + criterion(Qij_tensor, edge_f_Qij
train_loss2= criterion(output_va, target2) + criterion(P_tensor, P_in) + criterion(Q_tensor, Q_in) + \
criterion(Pij_tensor, edge_f_Pij) + criterion(Qij_tensor, edge_f_Qi
total_loss = train_loss_va + train_loss_vm
total_loss.backward()
optimizer1.step()
model1.zero_grad()
model2.zero_grad()
optimizer2.step()
但是我在训练的时候只有模型1的损失正常下降了,但是模型2的损失一直保持不变,还有就是我把model1.zero_grad()
model2.zero_grad()
注释掉之后他的代码就会显示下面的错误:
File "D:\PhcharmProjectsPytorch\MPNN\main_3.py", line 458, in train
optimizer2.step()
File "D:\anaconda\envs\pytorch\lib\site-packages\torch\optim\optimizer.py", line 280, in wrapper
out = func(*args, **kwargs)
File "D:\anaconda\envs\pytorch\lib\site-packages\torch\optim\optimizer.py", line 33, in _use_grad
ret = func(self, *args, **kwargs)
File "D:\anaconda\envs\pytorch\lib\site-packages\torch\optim\adam.py", line 141, in step
adam(
File "D:\anaconda\envs\pytorch\lib\site-packages\torch\optim\adam.py", line 281, in adam
func(params,
File "D:\anaconda\envs\pytorch\lib\site-packages\torch\optim\adam.py", line 344, in _single_tensor_adam exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (120) must match the size of tensor b (180) at non-singleton dimension 0
他显示维度不匹配问题,我也都检查了数据和输出的维度,发现都是一样的,我想问问各位大佬,遇到这个问题该怎么解决啊,困扰我好长时间了


IP属地:江苏1楼2024-04-17 16:13回复
    损失不变的话主要看一下loss和backward那里有没有问题吧


    IP属地:北京2楼2024-04-21 22:56
    收起回复