跨类

时间:2018-05-26 03:10:07

标签: python tensorflow graph namespaces

此处提供了问题的要点:https://gist.github.com/sankhaMukherjee/97c212072385ee36ebd8bfab9a801794

基本上,我正在尝试生成一个通用类,可以用作我的张量流项目的样板。提供的要点是我能够提出的最微不足道的例子。我知道这与名称空间有关,但我不确定如何解决这个问题。

我有一个类Test,其中包含以下成员函数:

  1. __init__(self, inpShape, outShape, layers, activations)
  2. saveModel(self, sess)
  3. restoreModel(self, sess, restorePoint)
  4. fit(self, X, y, Niter=101, restorePoint=None)
  5. predict(self, X, restorePoint=None)
  6. 这些功能本身很简单,可以在提供的要点中查找。现在,给定这个类,我们可以尝试测试它,看看它是如何做的:

    X = np.random.random((10000, 2))
    y = (2*X[:, 0] + 3*X[:, 1]).reshape(-1, 1)
    
    inpShape    = (None, 2)
    outShape    = (None, 1)
    layers      = [7, 1]
    activations = [tf.sigmoid, None]
    
    t = Test(inpShape, outShape, layers, activations)
    t.fit(X, y, 10000)
    yHat = t.predict(X, restorePoint=t.restorePoints[-1])
    
    plt.plot(yHat, y, '.', label='original')
    

    这一切都很好!

    现在我们要创建同一个类的另一个实例,并恢复从此处保存的模型。所有的地狱都在这里松动。让我们更新以上内容:

    X = np.random.random((10000, 2))
    y = (2*X[:, 0] + 3*X[:, 1]).reshape(-1, 1)
    
    inpShape    = (None, 2)
    outShape    = (None, 1)
    layers      = [7, 1]
    activations = [tf.sigmoid, None]
    
    t = Test(inpShape, outShape, layers, activations)
    t.fit(X, y, 10000)
    yHat = t.predict(X, restorePoint=t.restorePoints[-1])
    
    plt.plot(yHat, y, '.', label='original')
    
    if True: # In the gist, turn this to True for seeing the problem
        t1 = Test(inpShape, outShape, layers, activations)
        yHat1 = t1.predict(X, restorePoint=t.restorePoints[-1])
        plt.plot(yHat1, y, '.', label='copied')
    
    事实证明,我们不能再这样做了。它将完全搞乱所有东西,并带有一个全新的图形。现在是否可以创建复制旧图的类的新实例,而不是创建旧图的全新实例?

0 个答案:

没有答案