# Copyright (c) The University of Edinburgh 2014-2015 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() from dispel4py.new.processor \ import GenericWrapper, simpleLogger, STATUS_TERMINATED, STATUS_ACTIVE from dispel4py.new import processor import types def process(workflow, inputs, args): processes = {} inputmappings = {} outputmappings = {} success = True nodes = [node.getContainedObject() for node in workflow.graph.nodes()] if rank == 0 and not args.simple: try: processes, inputmappings, outputmappings =\ processor.assign_and_connect(workflow, size) except: success = False success = comm.bcast(success, root=0) if args.simple or not success: ubergraph = processor.create_partitioned(workflow) nodes = [node.getContainedObject() for node in ubergraph.graph.nodes()] if rank == 0: print('Partitions: %s' % ', '.join( ('[%s]' % ', '.join((pe.id for pe in part)) for part in workflow.partitions))) for node in ubergraph.graph.nodes(): wrapperPE = node.getContainedObject() ns = [n.getContainedObject().id for n in wrapperPE.workflow.graph.nodes()] print('%s contains %s' % (wrapperPE.id, ns)) try: processes, inputmappings, outputmappings = \ processor.assign_and_connect(ubergraph, size) inputs = processor.map_inputs_to_partitions(ubergraph, inputs) success = True except: print('dispel4py.mpi_process: ' 'Not enough processes for execution of graph') success = False success = comm.bcast(success, root=0) if not success: return inputs = {pe.id: v for pe, v in inputs.items()} processes = comm.bcast(processes, root=0) inputmappings = comm.bcast(inputmappings, root=0) outputmappings = comm.bcast(outputmappings, root=0) inputs = comm.bcast(inputs, root=0) if rank == 0: print('Processes: %s' % processes) # print('Inputs: %s' % inputs) for pe in nodes: if rank in processes[pe.id]: provided_inputs = processor.get_inputs(pe, inputs) wrapper = MPIWrapper(pe, provided_inputs) wrapper.targets = outputmappings[rank] wrapper.sources = inputmappings[rank] wrapper.process() import Queue from threading import Thread def receive(wrapper): while wrapper.terminated < wrapper._num_sources: status = MPI.Status() msg = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) tag = status.Get_tag() # print('Received %s, %s' % (msg, tag)) if tag == STATUS_TERMINATED: wrapper.terminated += 1 else: wrapper.input_data.put((msg, tag)) # self.wrapper.pe.log('Queue size: %s'%self.wrapper.input_data.qsize()) # put the final terminate block into the queue wrapper.input_data.put((None, STATUS_TERMINATED)) class MPIWrapper(GenericWrapper): def __init__(self, pe, provided_inputs=None): GenericWrapper.__init__(self, pe) self.pe.log = types.MethodType(simpleLogger, pe) self.pe.rank = rank self.provided_inputs = provided_inputs self.terminated = 0 self.input_data = Queue.Queue() def process(self): self.reader = Thread(target=receive, args=(self,)) self.reader.start() super(MPIWrapper, self).process() def _read(self): result = super(MPIWrapper, self)._read() if result is not None: return result return self.input_data.get() def _write(self, name, data): try: targets = self.targets[name] except KeyError: # no targets # self.pe.log('Produced output: %s' % {name: data}) return for (inputName, communication) in targets: output = {inputName: data} dest = communication.getDestination(output) for i in dest: # self.pe.log('Sending %s to %s' % (output, i)) request = comm.isend(output, tag=STATUS_ACTIVE, dest=i) status = MPI.Status() request.Wait(status) def _terminate(self): self.reader.join() for output, targets in self.targets.items(): for (inputName, communication) in targets: for i in communication.destinations: # self.pe.log('Terminating consumer %s' % i) comm.isend(None, tag=STATUS_TERMINATED, dest=i)