一行TensorFlow/Keras代码解决真实场景中数据不平衡(imbalanced)问题

2019 年 5 月 31 日 专知

【导读】数据不平衡是最令数据挖掘工程师头疼的问题之一。例如在某些真实场景中,一个二分类问题的训练集可能仅包含1%的正样本,导致难以训练高性能分类器。本文介绍TensorFlow/Keras中的一种简单的应对数据不平衡的方法,仅需添加一行代码。


在真实场景和比赛中,数据不平衡的情况是普遍存在的。如果不使用一定的策略,而直接用不平衡的数据来训练分类器,可能会导致分类器性能大幅度降低。Keras作者François Chollet在Twitter上推荐了一个示例,展示了TensorFlow/Keras的一个解决数据不平衡问题的机制:


该示例在Kaggle信用卡欺诈数据集上训练分类模型,数据集极度不平衡(99.82%负样本 vs 0.18%正样本)。


依赖安装


代码依赖于TensorFlow 2.0 Preview,可以用下面命令安装:

!pip install tf-nightly-gpu-2.0-preview


核心代码


首先我们需要统计每种类别包含的样本数量,并基于此计算类别权重:

# 计算每种类别数据的数量
counts = np.bincount(train_targets[:, 0])
# 基于数量计算类别权重
weight_for_0 = 1. / counts[0]
weight_for_1 = 1. / counts[1]
class_weight = {0: weight_for_0, 1: weight_for_1}


在训练时,加上一行代码设置类别权重即可:

model.fit(train_features, train_targets,
batch_size=2048,
epochs=50,
verbose=2,
callbacks=callbacks,
validation_data=(val_features, val_targets),
# 设置类别权重
class_weight=class_weight)


完整代码


代码地址:

https://colab.research.google.com/drive/1xL2jSdY-MGlN60gGuSH_L30P7kxxwUfM

代码内容:

# -*- coding: utf-8 -*-
"""## First, vectorize the CSV data"""
import csv
import numpy as np

# Get the real data from https://www.kaggle.com/mlg-ulb/creditcardfraud/downloads/creditcardfraud.zip/
fname = '/Users/fchollet/Downloads/creditcard.csv'
all_features = []
all_targets = []
with open(fname) as f:
for i, line in enumerate(f):
if i == 0:
print('HEADER:', line.strip())
continue # Skip header
fields = line.strip().split(',')
all_features.append([float(v.replace('"', '')) for v in fields[:-1]])
all_targets.append([int(fields[-1].replace('"', ''))])
if i == 1:
print('EXAMPLE FEATURES:', all_features[-1])

features = np.array(all_features, dtype='float32')
targets = np.array(all_targets, dtype='uint8')
print('features.shape:', features.shape)
print('targets.shape:', targets.shape)

"""## Prepare a validation set"""
num_val_samples = int(len(features) * 0.2)
train_features = features[:-num_val_samples]
train_targets = targets[:-num_val_samples]
val_features = features[-num_val_samples:]
val_targets = targets[-num_val_samples:]

print('Number of training samples:', len(train_features))
print('Number of validation samples:', len(val_features))

"""## Analyze class imbalance in the targets"""
counts = np.bincount(train_targets[:, 0])
print('Number of positive samples in training data: {} ({:.2f}% of total)'.format(counts[1],
100 * float(counts[1]) / len(
train_targets)))

weight_for_0 = 1. / counts[0]
weight_for_1 = 1. / counts[1]

"""## Normalize the data using training set statistics"""
mean = np.mean(train_features, axis=0)
train_features -= mean
val_features -= mean
std = np.std(train_features, axis=0)
train_features /= std
val_features /= std

from tensorflow import keras

model = keras.Sequential([
keras.layers.Dense(256, activation='relu',
input_shape=(train_features.shape[-1],)),
keras.layers.Dense(256, activation='relu'),
keras.layers.Dropout(0.3),
keras.layers.Dense(256, activation='relu'),
keras.layers.Dropout(0.3),
keras.layers.Dense(1, activation='sigmoid'),
])
model.summary()

metrics = [keras.metrics.FalseNegatives(name='fn'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.TruePositives(name='tp'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall')]

model.compile(optimizer=keras.optimizers.Adam(1e-2),
loss='binary_crossentropy',
metrics=metrics)

callbacks = [keras.callbacks.ModelCheckpoint('fraud_model_at_epoch_{epoch}.h5')]
class_weight = {0: weight_for_0, 1: weight_for_1}

model.fit(train_features, train_targets,
batch_size=2048,
epochs=50,
verbose=2,
callbacks=callbacks,
validation_data=(val_features, val_targets),
class_weight=class_weight)


参考链接:

  • https://colab.research.google.com/drive/1xL2jSdY-MGlN60gGuSH_L30P7kxxwUfM


-END-

专 · 知

专知,专业可信的人工智能知识分发,让认知协作更快更好!欢迎登录www.zhuanzhi.ai,注册登录专知,获取更多AI知识资料!

欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程视频资料和与专家交流咨询

请加专知小助手微信(扫一扫如下二维码添加),加入专知人工智能主题群,咨询技术商务合作~

专知《深度学习:算法到实战》课程全部完成!550+位同学在学习,现在报名,限时优惠!网易云课堂人工智能畅销榜首位!

点击“阅读原文”,了解报名专知《深度学习:算法到实战》课程

登录查看更多
78

相关内容

Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
129+阅读 · 2020年3月15日
抢鲜看!13篇CVPR2020论文链接/开源代码/解读
专知会员服务
49+阅读 · 2020年2月26日
Keras作者François Chollet推荐的开源图像搜索引擎项目Sis
专知会员服务
29+阅读 · 2019年10月17日
Keras François Chollet 《Deep Learning with Python 》, 386页pdf
专知会员服务
151+阅读 · 2019年10月12日
一文教你如何处理不平衡数据集(附代码)
大数据文摘
11+阅读 · 2019年6月2日
Keras实现基于MSCNN的人群计数
AI科技评论
8+阅读 · 2019年2月11日
基于 Keras 用深度学习预测时间序列
R语言中文社区
23+阅读 · 2018年7月27日
深度学习训练数据不平衡问题,怎么解决?
AI研习社
7+阅读 · 2018年7月3日
基于Keras进行迁移学习
论智
12+阅读 · 2018年5月6日
tensorflow项目学习路径
数据挖掘入门与实战
22+阅读 · 2017年11月19日
用深度学习Keras判断出一句文本是积极还是消极
北京思腾合力科技有限公司
6+阅读 · 2017年11月16日
Arxiv
7+阅读 · 2020年3月1日
Simplifying Graph Convolutional Networks
Arxiv
12+阅读 · 2019年2月19日
Implicit Maximum Likelihood Estimation
Arxiv
7+阅读 · 2018年9月24日
Learning to Importance Sample in Primary Sample Space
Feature Selection Library (MATLAB Toolbox)
Arxiv
7+阅读 · 2018年8月6日
Arxiv
3+阅读 · 2018年3月2日
Arxiv
3+阅读 · 2018年1月31日
VIP会员
相关资讯
一文教你如何处理不平衡数据集(附代码)
大数据文摘
11+阅读 · 2019年6月2日
Keras实现基于MSCNN的人群计数
AI科技评论
8+阅读 · 2019年2月11日
基于 Keras 用深度学习预测时间序列
R语言中文社区
23+阅读 · 2018年7月27日
深度学习训练数据不平衡问题,怎么解决?
AI研习社
7+阅读 · 2018年7月3日
基于Keras进行迁移学习
论智
12+阅读 · 2018年5月6日
tensorflow项目学习路径
数据挖掘入门与实战
22+阅读 · 2017年11月19日
用深度学习Keras判断出一句文本是积极还是消极
北京思腾合力科技有限公司
6+阅读 · 2017年11月16日
相关论文
Arxiv
7+阅读 · 2020年3月1日
Simplifying Graph Convolutional Networks
Arxiv
12+阅读 · 2019年2月19日
Implicit Maximum Likelihood Estimation
Arxiv
7+阅读 · 2018年9月24日
Learning to Importance Sample in Primary Sample Space
Feature Selection Library (MATLAB Toolbox)
Arxiv
7+阅读 · 2018年8月6日
Arxiv
3+阅读 · 2018年3月2日
Arxiv
3+阅读 · 2018年1月31日
Top
微信扫码咨询专知VIP会员