paper:Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks
code:https://github.com/agrimgupta92/sgan
np.around() 返回四舍五入后的值,可指定精度。
np.transpose() 转置
np.polyfit()数据拟合
np.cumsum()累加
datasets返回值:
def __getitem__(self, index):
start, end = self.seq_start_end[index]
out = [
self.obs_traj[start:end, :], self.pred_traj[start:end, :],
self.obs_traj_rel[start:end, :], self.pred_traj_rel[start:end, :],
self.non_linear_ped[start:end], self.loss_mask[start:end, :]
]
"""
obs_traj:过去序列,(peds_num,2,obs_len)
pred_traj:预测序列,(peds_num,2,pred_len)
obs_traj_rel:过去相对序列,(peds_num,2,obs_len)
pred_traj_rel:预测相对序列,(peds_num,2,pred_len)
non_linear_ped:非线性值
loss_mask:mask
"""
这里用到dataloader里的collate_fn参数,collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。感觉在目标检测的时候用的比较多,主要是一个batch里有好几张图片的box,那需要单独加个索引再concat起来,不然因为每张图片box数量不同,是无法load进去的。
def seq_collate(data):
(obs_seq_list, pred_seq_list, obs_seq_rel_list, pred_seq_rel_list,
non_linear_ped_list, loss_mask_list) = zip(*data)
_len = [len(seq) for seq in obs_seq_list]
cum_start_idx = [0] + np.cumsum(_len).tolist()
seq_start_end = [[start, end]
for start, end in zip(cum_start_idx, cum_start_idx[1:])]
# Data format: batch, input_size, seq_len
# LSTM input format: seq_len, batch, input_size
obs_traj = torch.cat(obs_seq_list, dim=0).permute(2, 0, 1)
pred_traj = torch.cat(pred_seq_list, dim=0).permute(2, 0, 1)
obs_traj_rel = torch.cat(obs_seq_rel_list, dim=0).permute(2, 0, 1)
pred_traj_rel = torch.cat(pred_seq_rel_list, dim=0).permute(2, 0, 1)
non_linear_ped = torch.cat(non_linear_ped_list)
loss_mask = torch.cat(loss_mask_list, dim=0)
seq_start_end = torch.LongTensor(seq_start_end)
out = [
obs_traj, pred_traj, obs_traj_rel, pred_traj_rel, non_linear_ped,
loss_mask, seq_start_end
]
return tuple(out)
通过调用collate_fn,dataloader的返回值由(batch,peds_num,2,seq_len)->(seq_len,batch*peds_num,2)