发布人: 技术推广工程师 魏巍
简单回顾一下游戏规则: 我们基于强化学习的 agent 需要根据真人玩家的棋盘位置预测击打位置,以便能早于真人玩家完成游戏。如需进一步了解游戏规则,请参阅我们之前发布的文章。
△ "Plane Strike" 游戏演示
背景: JAX 和 TensorFlow
而 Flax 则是在 JAX 基础上构建的一款热门神经网络库。研究人员一直在使用 JAX/Flax 来训练包含数亿万个参数的超大模型 (如用于语言理解和生成的 PaLM,或者用于图像生成的 Imagen),以便充分利用现代硬件。
如果您不熟悉 JAX 和 Flax,可以先从 JAX 101 教程和 Flax 入门示例开始。
视频 "使用 TensorFlow Serving 为 JAX 模型提供服务",展示了如何使用 TensorFlow Serving 部署 JAX 模型:
https://youtu.be/I4dx7OI9FJQ?t=36
https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html
使用 Flax/JAX 实现游戏 agent
T: 每段的时步数,各段的时步数可能有所不同
st: 时步上的状态 t
at: 时步上的所选操作 t 指定状态 s
πθ: 参数为 θ 的策略
R(*): 在指定策略下,收集到的奖励
class PolicyGradient(nn.Module):
"""Neural network to predict the next strike position."""
@nn.compact
def __call__(self, x):
dtype = jnp.float32
x = x.reshape((x.shape[0], -1))
x = nn.Dense(
features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)(
x)
x = nn.relu(x)
x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x)
policy_probabilities = nn.softmax(x)
return policy_probabilities
for i in tqdm(range(iterations)):
predict_fn = functools.partial(run_inference, params)
board_log, action_log, result_log = common.play_game(predict_fn)
rewards = common.compute_rewards(result_log)
optimizer, params, opt_state = train_step(optimizer, params, opt_state,
board_log, action_log, rewards)
def compute_loss(logits, labels, rewards):
one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2)
loss = -jnp.mean(
jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards))
return loss
def train_step(model_optimizer, params, opt_state, game_board_log,
predicted_action_log, action_result_log):
"""Run one training step."""
def loss_fn(model_params):
logits = run_inference(model_params, game_board_log)
loss = compute_loss(logits, predicted_action_log, action_result_log)
return loss
def compute_grads(params):
return jax.grad(loss_fn)(params)
grads = compute_grads(params)
updates, opt_state = model_optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return model_optimizer, params, opt_state
@jax.jit
def run_inference(model_params, board):
logits = PolicyGradient().apply({'params': model_params}, board)
return logits
将 Flax/JAX 模型转换为 TensorFlow Lite 并与 Android 应用集成
完成模型训练后,我们使用 jax2tf (一款 TensorFlow-JAX 互操作工具),将 JAX 模型转换为 TensorFlow concrete function。最后一步是调用 TensorFlow Lite 转换器来将 concrete function 转换为 TFLite 模型。
# Convert to tflite model
model = PolicyGradient()
jax_predict_fn = lambda input: model.apply({'params': params}, input)
tf_predict = tf.function(
jax2tf.convert(jax_predict_fn, enable_xla=False),
input_signature=[
tf.TensorSpec(
shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],
dtype=tf.float32,
name='input')
],
autograph=False,
)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()], tf_predict)
tflite_model = converter.convert()
# Save the model
with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:
f.write(tflite_model)
convertBoardStateToByteBuffer(board);
tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i < probArray.length; i++) {
int x = i / Constants.BOARD_SIZE;
int y = i % Constants.BOARD_SIZE;
if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
agentStrikePosition = i;
maxProb = probArray[i];
}
}
总结