通过模型对象可以在运行时访问已移植到TensorFlow.js的预训练Tensorflow模型的权重吗?

时间:2019-03-19 15:50:54

标签: javascript tensorflow tensorflow.js

在调试过程中如何访问模型的权重?

当我在调试器中执行过程中检查model.model.weights['dense_3/bias'][0]时,实际的权重不存在。但是,当我console.log表达式被打印时,权重被打印出来。似乎正在执行某种延迟执行?

我在下面基于有毒分类器medium article创建了一个片段,该片段显示了如何访问特定图层的权重对象。

const threshold = 0.9;

// Which toxicity labels to return.
const labelsToInclude = ['identity_attack', 'insult', 'threat'];

toxicity.load(threshold, labelsToInclude).then(model => {
    // Now you can use the `model` object to label sentences. 
    model.classify(['you suck']).then(predictions => {
    console.log("Specific weights: "+ model.model.weights['dense_3/bias'][0])
      document.getElementById("predictions").innerHTML =  JSON.stringify(predictions, null, 2);
    });
});
<!DOCTYPE html>
<html lang="en-us">
<head>
  <meta charset="UTF-8">
  <title>Activity 1: Basic HTML Bio</title>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/toxicity"></script>
</head>

<body>
<div id="predictions">
Will be populated by prebuilt toxicity model
</div>
</body>

</html>

2 个答案:

答案 0 :(得分:1)

张量数组中的每一层。可以通过遍历数组来访问图层的权重。

const t = model.model.weights['dense_3/bias'][0] // t is a tensor
t.print() // will display the tensor in the console
// to add value to the weight
t.add(tf.scalar(0.5))

console.log(model.model.weights ['dense_3 / bias'] [0])将显示一个对象,而不显示张量的值。原因是张量是TypeScript中的类,该类在js中转换为Function类型的对象。这就是为什么console.log(model.model.weights['dense_3/bias'][0])将打印键为类张量的属性的对象的原因。需要调用print方法以查看张量的基础值

const threshold = 0.9;

// Which toxicity labels to return.
const labelsToInclude = ['identity_attack', 'insult', 'threat'];

toxicity.load(threshold, labelsToInclude).then(model => {
    // print weights
    model.model.weights['dense_3/bias'][0].print()
    // continue processing
});
<!DOCTYPE html>
<html lang="en-us">
<head>
  <meta charset="UTF-8">
  <title>Activity 1: Basic HTML Bio</title>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/toxicity"></script>
</head>

<body>
</body>

</html>

如果要获取cpu上的张量值并使用dom元素的innerHTML显示它,可以考虑使用datadataSync

答案 1 :(得分:0)

从另一个post中我可以得出如何访问权重。

对于每一层都有一个data承诺将允许访问权重。

const threshold = 0.9;

// Which toxicity labels to return.
const labelsToInclude = ['identity_attack', 'insult', 'threat'];

toxicity.load(threshold, labelsToInclude).then(model => {
    // Now you can use the `model` object to label sentences. 
    model.classify(['you suck']).then(predictions => {
      model.model.weights['dense_3/bias'][0].data().then(
        function(value) {
        document.getElementById("specific_weights").innerHTML = JSON.stringify(value);
        });
      
      document.getElementById("predictions").innerHTML =  JSON.stringify(predictions, null, 2);
    });
});
<!DOCTYPE html>
<html lang="en-us">
<head>
  <meta charset="UTF-8">
  <title>Activity 1: Basic HTML Bio</title>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/toxicity"></script>
</head>

<body>
<div id="predictions">
Will be populated by prebuilt toxicity model
</div>
<div id="specific_weights">
Will contain weights for specific layer
</div>

</body>

</html>