AssertionError:在使用预测功能时,batch_size必须被使用中的TPU内核数(1对8)除以

时间:2019-04-24 18:07:45

标签: tensorflow keras google-colaboratory google-cloud-tpu tpu

有关上下文的一些详细信息:

  1. 使用TPU处理Google Colab。
  2. 模型拟合成功,没有任何问题
  3. 尝试使用预测功能时遇到问题

这是我用来训练的代码:

<script src="https://cdnjs.cloudflare.com/ajax/libs/vue/2.6.10/vue.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/axios/0.18.0/axios.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.11/lodash.min.js"></script>

<div id="app">
  <input v-model="query" @keyup.stop="handleSearch" type="text" class="form-control" placeholder="Search">
  <button @click.stop="clear">Clear</button>
  <div v-if="isLoading">Loading...</div>
  <ul v-if="results !== ''">
    <li v-for="(r, index) in results" :key="index">
      {{ r.name }}
    </li>
  </ul>
</div>

这是我用来预测的代码:

tpu_model.fit(x, y,
          batch_size=128,
          epochs=60)

这是错误(在上面添加了一个箭头,所以您知道错误的位置:

def generate_output():
    generated = ''
    #sentence = text[start_index: start_index + Tx]
    #sentence = '0'*Tx
    usr_input = input("Write the beginning of your poem, the Shakespeare machine will complete it. Your input is: ")
    # zero pad the sentence to Tx characters.
    sentence = ('{0:0>' + str(maxlen) + '}').format(usr_input).lower()
    generated += usr_input 

    sys.stdout.write("\n\nHere is your poem: \n\n") 
    sys.stdout.write(usr_input)
    for i in range(400):

        x_pred = np.zeros((1, maxlen, len(chars)))

        for t, char in enumerate(sentence):
            if char != '0':
                x_pred[0, t, char_indices[char]] = 1.

        --> preds = tpu_model.predict(x_pred, batch_size = 128 ,workers = 8,verbose=0)[0]
        next_index = sample(preds, temperature = 1.0)
        next_char = indices_char[next_index]

        generated += next_char
        sentence = sentence[1:] + next_char

        sys.stdout.write(next_char)
        sys.stdout.flush()

        if next_char == '\n':
            continue

这对我来说毫无意义,因为我在训练时使用的批处理大小可以除以8,而我在预测函数中传递的批处理大小可以除以8。

我不确定问题是什么以及如何解决。任何帮助将非常感激。

1 个答案:

答案 0 :(得分:0)

由于错误:

function recordChanges() {
  var url = 'LINK TO MY SHEET';   
  var ss = SpreadsheetApp.openByUrl(url);
  var sh = ss.getSheetByName('Sheet1'); 
  var range = sh.getRange("A1"); 
  var firstRow = range.getDataRegion(SpreadsheetApp.Dimension.COLUMNS);
  var values = firstRow.getValues();
  values[0].push(values[0][0]);
  sh.getRange(1, 1, 1, firstRow.getLastColumn() + 1).setValues(values);
}

您似乎使用的batch_size为1,这可以从输入数据的第一个维度推断出来:

AssertionError: batch_size must be divisible by the number of TPU cores in use (1 vs 8)

我认为您可能希望将其更改为:

x_pred = np.zeros((1, maxlen, len(chars)))

因此批次大小将变为8,与使用中的TPU内核数匹配。

或者您也可以保持当前的batch_size为1,但使用1个TPU内核。