LightGCN推荐模型代码解读

2021 年 12 月 23 日 机器学习与推荐算法

本文对LightGCN模型部分的代码进行了解读,对相应部分进行了简单的注释帮助大家理解。笔者第一次尝试代码阅读分享,有什么不足之处或者建议可以给我留言哦,感谢。




Dropout


在图上实施dropout,以一定概率忽略一部分边


def __dropout_x(self, x, keep_prob):
        # 获取self.Graph中的大小,下标和值,Graph采用稀疏矩阵的表示方法SparseTensor
        size = x.size()
        index = x.indices().t()
        values = x.values()
        # 通过rand得到len(values)数量的随机数,加上keep_prob
        random_index = torch.rand(len(values)) + keep_prob
        # 通过对这些数字取int使得小于1的为0,在通过bool()将0->false,大于等于1的取True
        random_index = random_index.int().bool()
        # 利用上面得到的True,False数组选取下标,从而dropout了为False的下标
        index = index[random_index]
        # 由于dropout在训练和测试过程中的不一致,所以需要除以p
        values = values[random_index]/keep_prob
        # 得到新的graph
        g = torch.sparse.FloatTensor(index.t(), values, size)
        return g
    
    def __dropout(self, keep_prob):
        if self.A_split:
            graph = []
            for g in self.Graph:
                graph.append(self.__dropout_x(g, keep_prob))
        else:
            graph = self.__dropout_x(self.Graph, keep_prob)
        return graph






消息传播


computer函数是LightGCN类中用于进行图信息传播的实现方法,整体上通过在整个图上进行矩阵计算得到所有用户和商品的embedding。


def computer(self):
        """
        propagate methods for lightGCN
        """
       
        # 得到所有用户和所有商品的embedding
        users_emb = self.embedding_user.weight
        items_emb = self.embedding_item.weight
        all_emb = torch.cat([users_emb, items_emb])
        # torch.split(all_emb , [self.num_users, self.num_items])
        embs = [all_emb]
        # 判断是否需要dropout
        if self.config['dropout']:
            if self.training:
                print("droping")
                g_droped = self.__dropout(self.keep_prob)
            else:
                g_droped = self.Graph
        else:
            g_droped = self.Graph
        # 根据层数对图进行信息传播和聚合考虑n-hop
        # 通过稀疏矩阵乘法对Graph进行n_layers次的计算
        for layer in range(self.n_layers):
            if self.A_split:
                temp_emb = []
                for f in range(len(g_droped)):
                    temp_emb.append(torch.sparse.mm(g_droped[f], all_emb))
                side_emb = torch.cat(temp_emb, dim=0)
                all_emb = side_emb
            else:
                all_emb = torch.sparse.mm(g_droped, all_emb)
            embs.append(all_emb)
        embs = torch.stack(embs, dim=1)
        #print(embs.size())
        # 对每一层得到的输出求均值,以此将不同层的信息进行融合
        light_out = torch.mean(embs, dim=1)
        users, items = torch.split(light_out, [self.num_users, self.num_items])
        return users, items




损失构建


在computer函数计算得到所有用户和商品经过消息传播后的embedding之后,getEmbedding根据当前用户和商品查询出需要用到的embedding以及当前用户和商品的原始embedding,即未经GCN的embedding。

传播后的embedding用于计算bpr损失,原始embedding用于计算L2正则项。


def getEmbedding(self, users, pos_items, neg_items):
        # 得到需要计算相似度的用户和商品的embedding
        all_users, all_items = self.computer()
        users_emb = all_users[users]
        pos_emb = all_items[pos_items]
        neg_emb = all_items[neg_items]
        # 没经过传播的embedding,用于后续正则项计算
        users_emb_ego = self.embedding_user(users)
        pos_emb_ego = self.embedding_item(pos_items)
        neg_emb_ego = self.embedding_item(neg_items)
        return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego
    
    def bpr_loss(self, users, pos, neg):
        (users_emb, pos_emb, neg_emb,
        userEmb0, posEmb0, negEmb0) = self.getEmbedding(users.long(), pos.long(), neg.long())
        # 这个损失计算的是LightGCN论文中损失函数中的正则项,即做了一个L2正则
        reg_loss = (1/2)*(userEmb0.norm(2).pow(2) +
                         posEmb0.norm(2).pow(2) +
                         negEmb0.norm(2).pow(2))/float(len(users))
        # 通过乘法计算用户和商品的相似度
        pos_scores = torch.mul(users_emb, pos_emb)
        pos_scores = torch.sum(pos_scores, dim=1)
        neg_scores = torch.mul(users_emb, neg_emb)
        neg_scores = torch.sum(neg_scores, dim=1)
        # pair-wise的排序损失
        loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))


欢迎干货投稿 \ 论文宣传 \ 合作交流

推荐阅读

强化学习推荐系统的模型结构与特点总结

一文理解PyTorch:附代码实例
推荐系统之FM算法原理及实现(附代码)

由于公众号试行乱序推送,您可能不再准时收到机器学习与推荐算法的推送。为了第一时间收到本号的干货内容, 请将本号设为星标,以及常点文末右下角的“在看”。

喜欢的话点个在看吧👇
登录查看更多
23

相关内容

WSDM'22「京东」个性化会话推荐:异构全局图神经网络
专知会员服务
22+阅读 · 2022年1月7日
专知会员服务
32+阅读 · 2021年10月4日
专知会员服务
55+阅读 · 2021年6月30日
专知会员服务
32+阅读 · 2021年2月12日
【WWW2021】基于双侧深度上下文调制的社会化推荐系统
专知会员服务
27+阅读 · 2021年1月28日
【SIGIR2020】LightGCN: 简化和增强图卷积网络推荐
专知会员服务
72+阅读 · 2020年6月1日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
专知会员服务
87+阅读 · 2020年1月20日
一文梳理推荐系统中的特征交互排序模型
RUC AI Box
1+阅读 · 2022年4月8日
WSDM2022 | 考虑行为多样性与对比元学习的推荐系统
机器学习与推荐算法
2+阅读 · 2022年2月24日
CIKM21 | 图+推荐系统: 比LightGCN更有效的UltraGCN
机器学习与推荐算法
2+阅读 · 2021年11月30日
唯快不破! 比LightGCN还要快10倍的UltraGCN
图与推荐
1+阅读 · 2021年11月22日
【论文解读】“推荐系统”加上“图神经网络”
深度学习自然语言处理
16+阅读 · 2020年3月31日
国家自然科学基金
2+阅读 · 2016年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
3+阅读 · 2013年12月31日
国家自然科学基金
6+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Interest-aware Message-Passing GCN for Recommendation
Arxiv
12+阅读 · 2021年2月19日
Arxiv
20+阅读 · 2019年11月23日
VIP会员
相关VIP内容
WSDM'22「京东」个性化会话推荐:异构全局图神经网络
专知会员服务
22+阅读 · 2022年1月7日
专知会员服务
32+阅读 · 2021年10月4日
专知会员服务
55+阅读 · 2021年6月30日
专知会员服务
32+阅读 · 2021年2月12日
【WWW2021】基于双侧深度上下文调制的社会化推荐系统
专知会员服务
27+阅读 · 2021年1月28日
【SIGIR2020】LightGCN: 简化和增强图卷积网络推荐
专知会员服务
72+阅读 · 2020年6月1日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
专知会员服务
87+阅读 · 2020年1月20日
相关资讯
一文梳理推荐系统中的特征交互排序模型
RUC AI Box
1+阅读 · 2022年4月8日
WSDM2022 | 考虑行为多样性与对比元学习的推荐系统
机器学习与推荐算法
2+阅读 · 2022年2月24日
CIKM21 | 图+推荐系统: 比LightGCN更有效的UltraGCN
机器学习与推荐算法
2+阅读 · 2021年11月30日
唯快不破! 比LightGCN还要快10倍的UltraGCN
图与推荐
1+阅读 · 2021年11月22日
【论文解读】“推荐系统”加上“图神经网络”
深度学习自然语言处理
16+阅读 · 2020年3月31日
相关基金
国家自然科学基金
2+阅读 · 2016年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
3+阅读 · 2013年12月31日
国家自然科学基金
6+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Top
微信扫码咨询专知VIP会员