import logging
import luigi
import sciluigi as sl
import os
import six.moves as s
import time
import unittest

TESTFILE_PATH = '/tmp/test.out'

log = logging.getLogger('sciluigi-interface')
log.setLevel(logging.WARNING)

class MultiInOutWf(sl.WorkflowTask):
    def workflow(self):
        mo = self.new_task('mout', MultiOutTask, an_id='x')
        mi = self.new_task('min', MultiInTask)
        mi.in_multi = mo.out_multi
        return mi

class MultiOutTask(sl.Task):
    an_id = luigi.Parameter()

    def out_multi(self):
        return [sl.TargetInfo(self, '/tmp/out_%s_%d.txt' % (self.an_id, i)) for i in s.range(10)]
    def run(self):
        for otgt in self.out_multi():
            with otgt.open('w') as ofile:
                ofile.write('hej')

class MultiInTask(sl.Task):
    in_multi = None
    def out_multi(self):
        return [sl.TargetInfo(self, itgt.path + '.daa.txt') for itgt in self.in_multi()]
    def run(self):
        for itgt, otgt in zip(self.in_multi(), self.out_multi()):
            with itgt.open() as ifile:
                with otgt.open('w') as ofile:
                    ofile.write(ifile.read() + ' daa')

class TestMultiInOutWorkflow(unittest.TestCase):
    def setUp(self):
        self.w = luigi.worker.Worker()

    def test_methods(self):
        wf = sl.WorkflowTask()
        touta = wf.new_task('tout', MultiOutTask,
            an_id='a')
        toutb = wf.new_task('tout', MultiOutTask,
            an_id='b')
        toutc = wf.new_task('tout', MultiOutTask,
            an_id='c')
        tin = wf.new_task('tout', MultiInTask)

        tin.in_multi = [touta.out_multi, {'a': toutb.out_multi, 'b': toutc.out_multi()}]

        # Assert outputs returns luigi targets, or list of luigi targets
        outs = touta.output()
        self.assertIsInstance(outs, list)
        for out in outs:
            self.assertIsInstance(out, luigi.Target)

        reqs = tin.requires()
        self.assertIsInstance(reqs, list)
        for req in reqs:
            self.assertIsInstance(req, luigi.Task)

    def test_workflow(self):
        wf = MultiInOutWf()
        self.w.add(wf)
        self.w.run()

        # Assert outputs exists
        for p in ['/tmp/out_%s_%d.txt' % (aid, i) for i in s.range(10) for aid in ['x']]:
            self.assertTrue(os.path.exists(p))
        for p in ['/tmp/out_%s_%d.txt.daa.txt' % (aid, i) for i in s.range(10) for aid in ['x']]:
            self.assertTrue(os.path.exists(p))

        # Remove
        for p in ['/tmp/out_%s_%d.txt' % (aid, i) for i in s.range(10) for aid in ['x']]:
            os.remove(p)
        for p in ['/tmp/out_%s_%d.txt.daa.txt' % (aid, i) for i in s.range(10) for aid in ['x']]:
            os.remove(p)

    def tearDown(self):
        pass