我正在使用 Talos 和Google colab TPU 对 Keras 模型进行超参数调整。请注意,我使用的是 Tensorflow 1.15.0和 Keras 2.2.4-tf。
import os
import tensorflow as tf
import talos as ta
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
def iris_model(x_train, y_train, x_val, y_val, params):
# Specify a distributed strategy to use TPU
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)
# Use the strategy to create and compile a Keras model
with strategy.scope():
model = Sequential()
model.add(Dense(32, input_shape=(4,), activation=tf.nn.relu, name = "relu"))
model.add(Dense(3, activation=tf.nn.softmax, name = "softmax"))
model.compile(optimizer=Adam(learning_rate=0.1), loss=params['losses'])
# Fit the Keras model on the dataset
out = model.fit(x_train, y_train,
batch_size=params['batch_size'],
epochs=params['epochs'],
validation_data=[x_val, y_val],
verbose=0,
steps_per_epoch=2)
return out, model
x, y = ta.templates.datasets.iris()
# Create a hyperparameter distributions
p = {'losses': ['logcosh'],
'batch_size': (20, 50, 5),
'epochs': [10, 20]}
# Use Talos to scan the best hyperparameters of the Keras model
scan_object = ta.Scan(x, y, model=iris_model, params=p, fraction_limit=0.1, experiment_name='first_test')
使用 out = model.fit 拟合模型时,出现以下错误:
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _do_call(self, fn, *args)
1382 '\nsession_config.graph_options.rewrite_options.'
1383 'disable_meta_optimizer = True')
-> 1384 raise type(e)(node_def, op, message)
1385
1386 def _extend_graph(self):
InvalidArgumentError: Unsupported data type for TPU: double, caused by output cond_8/Merge:0
答案 0 :(得分:1)
最近在TPU中增加了对double的支持。您现在可以参考https://github.com/tensorflow/tensorflow/blob/d0a48afee650b12dde805fadca868d6b113c3c5d/tensorflow/core/tpu/tpu_defs.h#L52了解所有受支持的类型。