我想用pytorch代替keras,但我自己做不到
<script src="https://cdn.jsdelivr.net/npm/vega@5"></script>
<script src="https://cdn.jsdelivr.net/npm/vega-lite@3"></script>
<script src="https://cdn.jsdelivr.net/npm/vega-embed@4"></script>
<script type="module">
const barChart = {
"$schema": "https://vega.github.io/schema/vega/v5.json",
"width": 400,
"height": 200,
"padding": 5,
"data": [
{
"name": "table",
"values": [
{"category": "A", "amount": 28},
{"category": "B", "amount": 55},
{"category": "C", "amount": 43},
{"category": "D", "amount": 91},
{"category": "E", "amount": 81},
{"category": "F", "amount": 53},
{"category": "G", "amount": 19},
{"category": "H", "amount": 87}
]
}
],
"signals": [
{
"name": "tooltip",
"value": {},
"on": [
{"events": "rect:mouseover", "update": "datum"},
{"events": "rect:mouseout", "update": "{}"}
]
}
],
"scales": [
{
"name": "xscale",
"type": "band",
"domain": {"data": "table", "field": "category"},
"range": "width",
"padding": 0.05,
"round": true
},
{
"name": "yscale",
"domain": {"data": "table", "field": "amount"},
"nice": true,
"range": "height"
}
],
"axes": [
{ "orient": "bottom", "scale": "xscale" },
{ "orient": "left", "scale": "yscale" }
],
"marks": [
{
"type": "rect",
"from": {"data":"table"},
"encode": {
"enter": {
"x": {"scale": "xscale", "field": "category"},
"width": {"scale": "xscale", "band": 1},
"y": {"scale": "yscale", "field": "amount"},
"y2": {"scale": "yscale", "value": 0}
},
"update": {
"fill": {"value": "steelblue"}
},
"hover": {
"fill": {"value": "red"}
}
}
},
{
"type": "text",
"encode": {
"enter": {
"align": {"value": "center"},
"baseline": {"value": "bottom"},
"fill": {"value": "#333"}
},
"update": {
"x": {"scale": "xscale", "signal": "tooltip.category", "band": 0.5},
"y": {"scale": "yscale", "signal": "tooltip.amount", "offset": -2},
"text": {"signal": "tooltip.amount"},
"fillOpacity": [
{"test": "datum === tooltip", "value": 0},
{"value": 1}
]
}
}
}
]
}
class VegaElement extends HTMLElement {
constructor() {
super()
this.innerHTML = "<div id='view'></div>"
}
connectedCallback() {
vegaEmbed( '#view', barChart ); }
}
customElements.define('vega-element', VegaElement);
</script>
<vega-element></vega-element>
def _model(self):
model = Sequential()
model.add(Dense(units=64, input_dim=self.state_size,
activation="relu"))
model.add(Dense(units=32, activation="relu"))
model.add(Dense(units=8, activation="relu"))
model.add(Dense(self.action_size, activation="linear"))
model.compile(loss="mse", optimizer=Adam(lr=0.001))
return model
答案 0 :(得分:1)
model = Model()
调用Model()
时需要提供一个参数,因为__init__(self, input_dims)
需要一个参数。
应该是这样:
model = Model(<integer dimensions for your network>)