我的任务实际上很简单,但是我不知道如何实现它。我打算在我的ML算法中使用它,但让我们简化示例。假设有一个类似如下的生成器:
nums = ((i+1) for i in range(4))
以上内容将产生1
,2
,3
和4
。
假设上述生成器返回单个“样本”。我想写一个生成器方法来批量处理它们。假设批次大小为2
。因此,如果调用此新方法:
def batch_generator(batch_size):
do something on nums
yield batches of size batch_size
然后,此批处理生成器的输出将是:1
和2
,然后是3
和4
。元组/列表无关紧要。重要的是如何退回这些批次。我找到了Python 3.3中引入的yield from
关键字,但是对于我来说似乎没有用。
显然,如果我们用5
而不是4
来计算,并且batch_size
是2
,我们将忽略第一个生成器的最后一个产生的值。
答案 0 :(得分:3)
我自己的解决方案可能是
<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:jms="http://www.springframework.org/schema/jms"
xsi:schemaLocation="http://www.springframework.org/schema/beans
http://www.springframework.org/schema/beans/spring-beans-3.0.xsd
http://www.springframework.org/schema/jms
http://www.springframework.org/schema/jms/spring-jms-3.0.xsd
http://www.springframework.org/schema/beans/spring-beans-3.1.xsd
http://www.springframework.org/schema/aop
http://www.springframework.org/schema/aop/spring-aop-3.1.xsd
http://www.springframework.org/schema/context
http://www.springframework.org/schema/tx">
<context:component-scan base-package="demo" />
<bean id="connectionFactory" class="org.springframework.jndi.JndiObjectFactoryBean">
<property name="jndiTemplate" ref="jndiTemplate" />
<property name="jndiName" value="spring_jms_p2s_con" />
</bean>
<bean id="destination" class="org.springframework.jndi.JndiObjectFactoryBean">
<property name="jndiTemplate" ref="jndiTemplate" />
<property name="jndiName" value="spring_jms_p2p_des" />
</bean>
<bean id="chatMessageListener" class="demo.ChatMessageListener"></bean>
<jms:listener-container>
<listener>
<listener-class>org.springframework.web.context.ContextLoaderListener</listener-class>
</listener>
</jms:listener-container>
<jms:listener-container
connection-factory="spring_jms_p2s_con"
acknowledge="auto"
destination-type="topic">
<jms:listener destination="spring_jms_p2s_des"
ref="chatMessageListener" method="onMessage"/>
</jms:listener-container>
<bean id="jndiTemplate" class="org.springframework.jndi.JndiTemplate">
<property name="environment">
<props>
<prop key="java.naming.factory.initial">com.sun.enterprise.naming.SerialInitContextFactory
</prop>
<prop key="java.naming.factory.url.pkgs">com.sun.enterprise.naming</prop>
<prop key="java.naming.factory.state">com.sun.corba.ee.impl.presentation.rmi.JNDIStateFactoryImpl
</prop>
<prop key="org.omg.CORBA.ORBInitialHost">localhost</prop>
<prop key="org.omg.CORBA.ORBInitialPort">3700</prop>
</props>
</property>
</bean>
<bean id="jmsTemplate" class="org.springframework.jms.core.JmsTemplate">
<property name="connectionFactory" ref="connectionFactory" />
<property name="defaultDestination" ref="destination" />
<!-- Topic setting -->
<property name="pubSubDomain" value="true"/>
</bean>
</beans>
另一种解决方案是使用nums = (i+1 for i in range(4))
def giveBatch(gen, numOfItems):
try:
return [next(gen) for i in range(numOfItems)]
except StopIteration:
pass
giveBatch(nums, 2)
# [1, 2]
giveBatch(nums, 2)
# [3, 4]
,如@Bharel所述。我已经比较了运行这两种解决方案所需的时间。没有什么区别。我想可以忽略不计。
grouper
答案 1 :(得分:1)
在itertools下,您有一个代码段可以做到这一点:
from itertools import zip_longest
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
您拥有一个迭代器,而不是每次都调用方法,它可以更高效,更快速地返回批处理,并且可以处理一些极端情况,例如过早用完数据而不会丢失数据。
答案 2 :(得分:0)
这正是我所需要的:
def giveBatch(numOfItems):
nums = (i+1 for i in range(7))
while True:
yield [next(nums) for i in range(numOfItems)]