Dropout
在图上实施dropout,以一定概率忽略一部分边
def __dropout_x(self, x, keep_prob):
# 获取self.Graph中的大小,下标和值,Graph采用稀疏矩阵的表示方法SparseTensor
size = x.size()
index = x.indices().t()
values = x.values()
# 通过rand得到len(values)数量的随机数,加上keep_prob
random_index = torch.rand(len(values)) + keep_prob
# 通过对这些数字取int使得小于1的为0,在通过bool()将0->false,大于等于1的取True
random_index = random_index.int().bool()
# 利用上面得到的True,False数组选取下标,从而dropout了为False的下标
index = index[random_index]
# 由于dropout在训练和测试过程中的不一致,所以需要除以p
values = values[random_index]/keep_prob
# 得到新的graph
g = torch.sparse.FloatTensor(index.t(), values, size)
return g
def __dropout(self, keep_prob):
if self.A_split:
graph = []
for g in self.Graph:
graph.append(self.__dropout_x(g, keep_prob))
else:
graph = self.__dropout_x(self.Graph, keep_prob)
return graph
消息传播
computer函数是LightGCN类中用于进行图信息传播的实现方法,整体上通过在整个图上进行矩阵计算得到所有用户和商品的embedding。
def computer(self):
"""
propagate methods for lightGCN
"""
# 得到所有用户和所有商品的embedding
users_emb = self.embedding_user.weight
items_emb = self.embedding_item.weight
all_emb = torch.cat([users_emb, items_emb])
# torch.split(all_emb , [self.num_users, self.num_items])
embs = [all_emb]
# 判断是否需要dropout
if self.config['dropout']:
if self.training:
print("droping")
g_droped = self.__dropout(self.keep_prob)
else:
g_droped = self.Graph
else:
g_droped = self.Graph
# 根据层数对图进行信息传播和聚合考虑n-hop
# 通过稀疏矩阵乘法对Graph进行n_layers次的计算
for layer in range(self.n_layers):
if self.A_split:
temp_emb = []
for f in range(len(g_droped)):
temp_emb.append(torch.sparse.mm(g_droped[f], all_emb))
side_emb = torch.cat(temp_emb, dim=0)
all_emb = side_emb
else:
all_emb = torch.sparse.mm(g_droped, all_emb)
embs.append(all_emb)
embs = torch.stack(embs, dim=1)
#print(embs.size())
# 对每一层得到的输出求均值,以此将不同层的信息进行融合
light_out = torch.mean(embs, dim=1)
users, items = torch.split(light_out, [self.num_users, self.num_items])
return users, items
损失构建
在computer函数计算得到所有用户和商品经过消息传播后的embedding之后,getEmbedding根据当前用户和商品查询出需要用到的embedding以及当前用户和商品的原始embedding,即未经GCN的embedding。
传播后的embedding用于计算bpr损失,原始embedding用于计算L2正则项。
def getEmbedding(self, users, pos_items, neg_items):
# 得到需要计算相似度的用户和商品的embedding
all_users, all_items = self.computer()
users_emb = all_users[users]
pos_emb = all_items[pos_items]
neg_emb = all_items[neg_items]
# 没经过传播的embedding,用于后续正则项计算
users_emb_ego = self.embedding_user(users)
pos_emb_ego = self.embedding_item(pos_items)
neg_emb_ego = self.embedding_item(neg_items)
return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego
def bpr_loss(self, users, pos, neg):
(users_emb, pos_emb, neg_emb,
userEmb0, posEmb0, negEmb0) = self.getEmbedding(users.long(), pos.long(), neg.long())
# 这个损失计算的是LightGCN论文中损失函数中的正则项,即做了一个L2正则
reg_loss = (1/2)*(userEmb0.norm(2).pow(2) +
posEmb0.norm(2).pow(2) +
negEmb0.norm(2).pow(2))/float(len(users))
# 通过乘法计算用户和商品的相似度
pos_scores = torch.mul(users_emb, pos_emb)
pos_scores = torch.sum(pos_scores, dim=1)
neg_scores = torch.mul(users_emb, neg_emb)
neg_scores = torch.sum(neg_scores, dim=1)
# pair-wise的排序损失
loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))
由于公众号试行乱序推送,您可能不再准时收到机器学习与推荐算法的推送。为了第一时间收到本号的干货内容, 请将本号设为星标,以及常点文末右下角的“在看”。