neo4j ogm,获取匹配特定查询的节点数

时间:2018-07-26 04:35:34

标签: java neo4j spring-data-neo4j neo4j-ogm

我的graphDb中有3种类型的节点。 SkillSkillSubClusterSkillCluster。一个Skill节点连接到一个或多个SkillSubCluster节点(一对多关系),一个SkillSubCluster节点连接到单个SkillCluster节点(一对一关系)

我想找到所有使用skillCluster名称的技能。我想出了这个密码查询-

match(n:SkillCluster {CleanedText: "arts"}) match(n)<-[parent]-(x)<-[belongsTo]-(y) return y

节点数可能很大,因此我正在考虑返回分页结果。使用skiplimit运算符可以轻松完成此操作。另外,我想返回给定skillCluster节点的技能总数。

对应的cypher查询将是

match(n:SkillCluster {CleanedText: "arts"}) match(n)<-[parent]-(x)<-[belongsTo]-(y) return count(y)

我正在尝试使用针对Java的neo4j-ogm做同样的事情。

我的技能课是

public class Skill {
    @Id @GeneratedValue
    private Long id;
    private String Name;
    private String CleanedText;

    @Relationship(type = "BelongsTo", direction = Relationship.OUTGOING)
    private Set<SkillSubCluster> belongsTo = new HashSet<>();
}

其对应的DAO类

public class SkillDAO extends GenericDAO<Skill>{
    public SkillDAO(Session session) {
        super(session);
    }

    protected Class<Skill> getEntityType() {
        return Skill.class;
    }   
}

和我的通用DAO类-

public abstract class GenericDAO<T> {
    private static final int DEPTH_LIST = 0;
    private static final int DEPTH_ENTITY = 1;  
    private Session session;

    public long filterCount(Iterable<Filter> filters){
        return session.count(getEntityType(), filters);
    }

    public T find(Long id) {
        return session.load(getEntityType(), id, DEPTH_ENTITY);
    }

    public T find(String name) {
        return session.load(getEntityType(), name, DEPTH_ENTITY);
    }

    public void delete(Long id) {
        session.delete(session.load(getEntityType(), id));
    }

    public void createOrUpdate(T entity) {
        session.save(entity, DEPTH_ENTITY);
        //return find(entity.id);
    }

    protected abstract Class<T> getEntityType();

    public GenericDAO(Session session) {
        this.session = session;
    }
}      

是否可以返回Skill类以外的对象,或者获取诸如cypher等复杂的group by查询的结果。

1 个答案:

答案 0 :(得分:1)

挖掘了一段时间后,我想到了正确的方法。因此,在我的GenericDAO抽象类中,我必须添加以下方法-

    import os
    import time
    import shutil
    import multiprocessing
    from threading import *
    from multiprocessing import Process

    dest = 'D:\\temp\\empty\\winx64_12201_database.zip'
    src = 'D:\\05 软件\\winx64_12201_database.zip'
    thread_size = 10
    process_size = multiprocessing.cpu_count()


    def copy():
        shutil.copy(src, dest)


    def copyStream():
        source = src
        destination = dest
        if os.path.isfile(destination):
            os.remove(destination)
        copied = 0
        source = open(source, "rb")
        target = open(destination, "wb")
        while True:
            chunk = source.read(1024)
            if not chunk:
                break
            target.write(chunk)
            copied += len(chunk)
        source.close()
        target.close()


    def multiCopy():
        source = src
        source_size = os.stat(source).st_size
        if source_size%thread_size == 0:
            block_size = source_size / thread_size
        else:
            block_size = round(source_size / thread_size)
        threads = []
        for thread_id in range(thread_size):
            thread = CopyWorker(source_size, block_size, thread_id)
            thread.start()
            threads.append(thread)
        for t in threads:
            if t.is_alive():
                t.join()


    class CopyWorker(Thread):
        def __init__(self, source_size, block_size, tid):
            Thread.__init__(self)
            self.source_size = source_size
            self.block_size = block_size
            self.tid = tid

        def run(self):
            source = open(src, "rb")
            target = open(dest, "wb")
            start_position = self.block_size * self.tid
            end_position = start_position + self.block_size
            if end_position > self.source_size:
                end_position = self.source_size
            source.seek(int(start_position))
            target.seek(int(start_position))
            while start_position < end_position:
                if (start_position + 1024) < end_position:
                    chunk = source.read(1024)
                else:
                    chunk = source.read(int(end_position - start_position))
                if not chunk:
                    break
                target.write(chunk)
                start_position += 1024
            source.close()
            target.close()


    def copyMulti():
        source = src
        source_size = os.stat(source).st_size
        if source_size % process_size == 0:
            block_size = source_size / process_size
        else:
            block_size = round(source_size / thread_size)
        processes = []
        for process_id in range(process_size):
            process = Process(target=copyWorker, args=(source_size, block_size, process_id))
            process.start()
            print(process.pid)
            processes.append(process)
        for p in processes:
            if p.is_alive():
                p.join()


    def copyWorker(source_size, block_size, pid):
        source = open(src, "rb")
        target = open(dest, "wb")
        start_position = block_size * pid
        end_position = start_position + block_size
        if end_position > source_size:
            end_position = source_size
        source.seek(int(start_position))
        target.seek(int(start_position))
        while start_position < end_position:
            if (start_position + 1024) < end_position:
                chunk = source.read(1024)
            else:
                chunk = source.read(int(end_position - start_position))
            if not chunk:
                break
            target.write(chunk)
            start_position += 1024
        source.close()
        target.close()



    if __name__ == '__main__':

        print("========== Single thread Copy ==========")
        time_0 = time.time()
        copy()
        print("End:", time.time() - time_0, "\n")

        print("========== Single thread stream Copy ==========")
        time_0 = time.time()
        copyStream()
        print("End:", time.time() - time_0, "\n")

        print("========== Multi threads stream Copy ==========")
        time_0 = time.time()
        multiCopy()
        print("End:", time.time() - time_0, "\n")

        print("========== Multi processes stream Copy ==========")
        time_0 = time.time()
        copyMulti()
        print("End:", time.time() - time_0, "\n")

,然后通过以下代码获取计数-

public abstract class GenericDAO<T> {
    // Rest of the implementation from above 

    public Result runComplexQuery(String query){
        return session.query(query, Collections.emptyMap());    
    }

    // ..................
}