Keras实例:PointNet点云分类

2020 年 5 月 30 日 专知

【导读】点云的分类,检测和分割是计算机视觉中的核心问题。本示例实现了点云深度学习论文PointNet。

原文链接:

https://keras.io/examples/vision/pointnet/


准备工作

首先使用下列命令安装trimesh库,这个包用于可视化数据:

pip install trimesh

然后安装引入相应的库

import osimport globimport trimeshimport numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersfrom matplotlib import pyplot as plt
tf.random.set_seed(1234)

加载数据集

我们使用ModelNet10 数据集,它是ModelNet40数据集中的一部分,首先下载数据:

DATA_DIR = tf.keras.utils.get_file("modelnet.zip","http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip",    extract=True,)DATA_DIR = os.path.join(os.path.dirname(DATA_DIR), "ModelNet10")

然后我们可以用trimesh工具可视化数据(3D模型):

mesh = trimesh.load(os.path.join(DATA_DIR, "chair/train/chair_0001.off"))mesh.show()

我们可以将从3D点云采样,并用matplotlib可视化:

points = mesh.sample(2048)
fig = plt.figure(figsize=(5, 5))ax = fig.add_subplot(111, projection="3d")ax.scatter(points[:, 0], points[:, 1], points[:, 2])ax.set_axis_off()plt.show()

然后我们解析数据,并将之转换成TensorFlow能使用的数据:

def parse_dataset(num_points=2048):
train_points = [] train_labels = [] test_points = [] test_labels = [] class_map = {} folders = glob.glob(os.path.join(DATA_DIR, "[!README]*"))
for i, folder in enumerate(folders): print("processing class: {}".format(os.path.basename(folder)))# store folder name with ID so we can retrieve later class_map[i] = folder.split("/")[-1]# gather all files train_files = glob.glob(os.path.join(folder, "train/*")) test_files = glob.glob(os.path.join(folder, "test/*"))
for f in train_files: train_points.append(trimesh.load(f).sample(num_points)) train_labels.append(i)
for f in test_files: test_points.append(trimesh.load(f).sample(num_points)) test_labels.append(i)
return ( np.array(train_points), np.array(test_points), np.array(train_labels), np.array(test_labels), class_map, )

然后设置采样的点与batch大小:

NUM_POINTS = 2048NUM_CLASSES = 10BATCH_SIZE = 32
train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset( NUM_POINTS)

使用 tf.data.Dataset() 构建数据集:

def augment(points, label):# jitter points    points += tf.random.uniform(points.shape, -0.005, 0.005, dtype=tf.float64)# shuffle points    points = tf.random.shuffle(points)    return points, label

train_dataset = tf.data.Dataset.from_tensor_slices((train_points, train_labels))test_dataset = tf.data.Dataset.from_tensor_slices((test_points, test_labels))
train_dataset = train_dataset.shuffle(len(train_points)).map(augment).batch(BATCH_SIZE)test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)

构建模型

定义卷积层与全连接层

def conv_bn(x, filters):    x = layers.Conv1D(filters, kernel_size=1, padding="valid")(x)    x = layers.BatchNormalization(momentum=0.0)(x)return layers.Activation("relu")(x)

def dense_bn(x, filters): x = layers.Dense(filters)(x) x = layers.BatchNormalization(momentum=0.0)(x)return layers.Activation("relu")(x)

PointNet有两个核心元素:MLP层和一个Transformer(T-net)。

class OrthogonalRegularizer(keras.regularizers.Regularizer):def __init__(self, num_features, l2reg=0.001):self.num_features = num_featuresself.l2reg = l2regself.eye = tf.eye(num_features)
def __call__(self, x): x = tf.reshape(x, (-1, self.num_features, self.num_features)) xxt = tf.tensordot(x, x, axes=(2, 2)) xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))
def tnet(inputs, num_features):
# Initalise bias as the indentity matrixbias = keras.initializers.Constant(np.eye(num_features).flatten())reg = OrthogonalRegularizer(num_features)
x = conv_bn(inputs, 32)x = conv_bn(x, 64)x = conv_bn(x, 512)x = layers.GlobalMaxPooling1D()(x)x = dense_bn(x, 256)x = dense_bn(x, 128)x = layers.Dense(num_features * num_features,kernel_initializer="zeros",bias_initializer=bias,activity_regularizer=reg,)(x)feat_T = layers.Reshape((num_features, num_features))(x) # Apply affine transformation to input featuresreturn layers.Dot(axes=(2, 1))([inputs, feat_T])
inputs = keras.Input(shape=(NUM_POINTS, 3))
x = tnet(inputs, 3)x = conv_bn(x, 32)x = conv_bn(x, 32)x = tnet(x, 32)x = conv_bn(x, 32)x = conv_bn(x, 64)x = conv_bn(x, 512)x = layers.GlobalMaxPooling1D()(x)x = dense_bn(x, 256)x = layers.Dropout(0.3)(x)x = dense_bn(x, 128)x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")model.summary()

训练模型

使用 .compile() 与.fit() 训练模型

model.compile(loss="sparse_categorical_crossentropy",optimizer=keras.optimizers.Adam(learning_rate=0.001),metrics=["sparse_categorical_accuracy"],)
model.fit(train_dataset, epochs=20, validation_data=test_dataset)

可视化结果

data = test_dataset.take(1)
points, labels = list(data)[0]points = points[:8, ...]labels = labels[:8, ...]
# run test data through modelpreds = model.predict(points)preds = tf.math.argmax(preds, -1)
points = points.numpy()
# plot points with predicted class and labelfig = plt.figure(figsize=(15, 10))for i in range(8): ax = fig.add_subplot(2, 4, i + 1, projection="3d") ax.scatter(points[i, :, 0], points[i, :, 1], points[i, :, 2]) ax.set_title("pred: {:}, label: {:}".format( CLASS_MAP[preds[i].numpy()], CLASS_MAP[labels.numpy()[i]] ) ) ax.set_axis_off()plt.show()

专知,专业可信的人工智能知识分发,让认知协作更快更好!欢迎注册登录专知www.zhuanzhi.ai,获取5000+AI主题干货知识资料!
欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程资料和与专家交流咨询
点击“ 阅读原文 ”,了解使用 专知 ,查看获取5000+AI主题知识资源
登录查看更多
6

相关内容

干净的数据:数据清洗入门与实践,204页pdf
专知会员服务
162+阅读 · 2020年5月14日
专知会员服务
55+阅读 · 2020年3月16日
【强化学习资源集合】Awesome Reinforcement Learning
专知会员服务
95+阅读 · 2019年12月23日
Keras作者François Chollet推荐的开源图像搜索引擎项目Sis
专知会员服务
30+阅读 · 2019年10月17日
Stabilizing Transformers for Reinforcement Learning
专知会员服务
60+阅读 · 2019年10月17日
Keras François Chollet 《Deep Learning with Python 》, 386页pdf
专知会员服务
154+阅读 · 2019年10月12日
PointNet系列论文解读
人工智能前沿讲习班
17+阅读 · 2019年5月3日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
用深度学习keras的cnn做图像识别分类,准确率达97%
数据挖掘入门与实战
4+阅读 · 2017年12月17日
用 Scikit-Learn 和 Pandas 学习线性回归
Python开发者
9+阅读 · 2017年9月26日
【推荐】一步一步带你用TensorFlow玩转LSTM
机器学习研究会
9+阅读 · 2017年9月12日
【推荐】用Tensorflow理解LSTM
机器学习研究会
36+阅读 · 2017年9月11日
3D-LaneNet: end-to-end 3D multiple lane detection
Arxiv
7+阅读 · 2018年11月26日
VIP会员
相关VIP内容
相关资讯
PointNet系列论文解读
人工智能前沿讲习班
17+阅读 · 2019年5月3日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
用深度学习keras的cnn做图像识别分类,准确率达97%
数据挖掘入门与实战
4+阅读 · 2017年12月17日
用 Scikit-Learn 和 Pandas 学习线性回归
Python开发者
9+阅读 · 2017年9月26日
【推荐】一步一步带你用TensorFlow玩转LSTM
机器学习研究会
9+阅读 · 2017年9月12日
【推荐】用Tensorflow理解LSTM
机器学习研究会
36+阅读 · 2017年9月11日
Top
微信扫码咨询专知VIP会员