#!/usr/bin/env python

import os
import sys
import argparse
import re

class HelpOnErrorParser(argparse.ArgumentParser):
    def error(self, msg):
        sys.stderr.write("{}: error: {}\n\n".format(os.path.basename(sys.argv[0]), msg))
        self.print_help()
        sys.exit(-1)

def my_assert(bool, msg):
    if not bool:
        sys.stderr.write(msg + "\n")
        sys.exit(-1)

def findTag(tag, attributes):
    pos = attributes.find(tag + "=")
    if pos < 0:
        return None
    pos += len(tag) + 1
    rpos = attributes.find(";", pos)
    if rpos < 0:
        rpos = len(attributes)
    return attributes[pos:rpos]

parser = HelpOnErrorParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter, description = "Convert GFF3 files to GTF files.")
parser.add_argument("input_GFF3_file", help = "Input GFF3 file.")
parser.add_argument("output_GTF_file", help = "Output GTF file.")
parser.add_argument("--RNA-patterns", help = "Types of RNAs to be extracted.", default = "mRNA")

args = parser.parse_args()

trans2gene = {} # tid -> gid
gid2name = {} # gid -> gene name
tid2name = {} # tid -> transcript name

rgx = re.compile("[\t]+")
rgx2 = re.compile("^(" + "|".join(args.RNA_patterns.split(",")) + ")$")

fin = open(args.input_GFF3_file, "r")
fout = open(args.output_GTF_file, "w")

line_no = 0

for line in fin:
    line = line.rstrip("\r\n") # remove return characters
    line_no += 1
    
    if line.startswith("##FASTA"):
        break
    if line.startswith("#"):
        if line.startswith("###"):
            # clear all dictionaries
            trans2gene = {}
            gid2name = {}
            tid2name = {}
        continue

    arr = rgx.split(line)
    my_assert(len(arr) == 9, "Line {} does not have 9 fields:\n{}".format(line_no, line))

    if arr[2] == "exon":
        parent = findTag("Parent", arr[8])
        my_assert(parent != None, "Line {} does not have a Parent tag:\n{}".format(line_no, line))

        tids = parent.split(",")
        for tid in tids:
            if tid not in trans2gene:
                continue
            gid = trans2gene[tid]
            fout.write("{}\tgene_id \"{}\"; transcript_id \"{}\";".format("\t".join(arr[:-1]), gid, tid))
            name = gid2name.get(gid, None)
            if name != None:
                fout.write(" gene_name \"{}\";".format(name))
            name = tid2name.get(tid, None)
            if name != None:
                fout.write(" transcript_name \"{}\";".format(name))
            fout.write("\n")
        
    elif arr[2] == "gene":
        gid = findTag("ID", arr[8])
        name = findTag("Name", arr[8])

        my_assert(gid != None, "Line {} does not have a ID tag:\n{}".format(line_no, line))
        
        if name != None:
            gid2name[gid] = name
                
    elif rgx2.search(arr[2]) != None:
        tid = findTag("ID", arr[8])
        gid = findTag("Parent", arr[8])
        name = findTag("Name", arr[8])

        my_assert(tid != None, "Line {} does not have a ID tag:\n{}".format(line_no, line))
        my_assert(gid != None, "Line {} does not have a Parent tag:\n{}".format(line_no, line))
                    
        trans2gene[tid] = gid
        if name != None:
            tid2name[tid] = name
        
fin.close()
fout.close()
