#!/usr/bin/python
# CPSSP -- Compare Protein Secondary Structure Predictions
# v1.0                                            20090606
# Copyright (C) 2009 by Wolfgang Skala
#
# This work may be distributed and/or modified under the
# conditions of the LaTeX Project Public License, either version 1.3
# of this license or (at your option) any later version.
# The latest version of this license is in
#   http://www.latex-project.org/lppl.txt
# and version 1.3 or later is part of all distributions of LaTeX
# version 2005/12/01 or later.

import getopt, sys



### 1. FUNCTIONS

def readFasta(filename, alphabet):
	# reads a FASTA file containing one or several sequences/alignments/predictions
	# input: * (filename) the name of the FASTA file
	#        * (alphabet) the allowed characters in the file, e.g. amino acids
	# output: * (names) a list of strings corresponding to the sequence names
	#         * (allseqs) a list containing the sequences;
	#                     each sequence is a list of characters, where each character represents a
	#                     residue, gap or secondary structure element
	allseqs = []
	names = []
	seq = []
	f = file(filename, "r")
	for line in f.readlines():
		if line[0] == ">":
			if seq != []:
				allseqs.append(seq)
				seq = []
			names.append(line[1:-1])
		else:
			for c in line:
				c = c.upper()
				if c in alphabet:
					seq.append(c)
	if seq != []:
		allseqs.append(seq)
	f.close()
	return names, allseqs


def removeGaps(allSeqs):
	# removes gaps from a list of sequences
	# if a position in each sequence is occupied exclusively by gaps
	# input: (allSeqs) a list of sequences as returned by readFasta()
	# output: same format as input with gaps removes
	gapPositions = []
	for i in range(len(allSeqs[0])-1, -1, -1):
		onlyGaps = True
		for j in range(len(allSeqs)):
			if allSeqs[j][i] not in GAP_SYMBOLS:
				onlyGaps = False
		if onlyGaps:
			for j in range(len(allSeqs)):
				allSeqs[j].pop(i)
	return allSeqs


def commonSyntax(allStrucs):
	# make all structures a common syntax, i.e. a coil is represented by "C" and not by "-" or a space
	# input: (allStrucs) list of structures as returned by readFasta
	# output: same format as input
	for i in range(len(allStrucs)):
		for j in range(len(allStrucs[i])):
			if allStrucs[i][j] == "-" or allStrucs[i][j] == " ":
				allStrucs[i][j] = "C"
	return allStrucs


def addGaps(allSeqs, allStrucs):
	# add gaps to the secondary structures so that they correspond to the gapped sequences
	# input: * (allSeqs) list of (degapped) sequences as returned by removeGaps() OR None,
	#                    indicating that no gaps should be added
	#        * (allStrucs) list of structures as returned by readFasta/commonSyntax
	# output: list of gapped structures (same format as allStrucs)
	resultStrucs = []
	if allSeqs == None:
		for struc in allStrucs:
			curStruc = ""
			for res in struc:
				curStruc += res
			resultStrucs.append(curStruc)
	else:
		i = 0
		for seq in allSeqs:
			k = 0
			curStruc = ""
			for j in range(len(seq)):
				if seq[j] not in GAP_SYMBOLS:
					curStruc += allStrucs[i][k]
					k += 1
				else:
					curStruc += "-"
			resultStrucs.append(curStruc)
			i += 1
	return resultStrucs


def breakLines(allStrucs, n):
	# break the structures into lines according to the number of residues per line specified by
	# the user if a line ends with a sheet (E) and the next line starts with a sheet, change the end
	# letter to "e" which indicates that no arrowhead should be drawn in the graphical representation
	# input: (allStrucs) list of gapped structures as returned by addGaps()
	#        (n) residues per line
	# output: list of structures; each structure is list of 3-tuplets containing (1) a string which
	#         specifies the residues on the line, (2) the number of the first and (3) the number of
	#         the last residue on the line
	resultStrucs = []
	for struc in allStrucs:
		curStruc = []
		for i in range(len(struc) / n + 1):
			if i*n != len(struc):
				s = struc[i*n:(i+1)*n]
				if s[-1] == "E" and (i+1)*n < len(struc):
					if struc[(i+1)*n] == "E":
						curStruc.append(s[:-1] + "e")
					else:
						curStruc.append(s)
				else:
					curStruc.append(s)
		resultStrucs.append(curStruc)
	for struc in resultStrucs:
		startRes = 0
		endRes = 0
		for i in range(len(struc)):
			for j in range(len(struc[i])):
				if struc[i][j] != "-":
					endRes += 1
			struc[i] = (struc[i], startRes, endRes)
			startRes = endRes
	return resultStrucs


def makeTikzDraw(ssType, block, line, start, end):
	# compose a TikZ command which draws a secondary structure element
	# input: * (ssType) secondary structure type (C, coil; H, helix; E, sheet; e, sheet at end of
	#                   the line; -, gap)
	#        * (block) the current sequence block
	#        * (line) the current line
	#        * (start) the start position
	#        * (end) the end position 
	# output: a string containing the TikZ command
	result = "\t\t"
	if ssType == "B":
		result += "\\cpsspBridge{-"
	elif ssType == "C":
		result += "\\cpsspCoil{-"
	elif ssType == "E":
		result += "\\cpsspSheet{-"
	elif ssType == "e":
		result += "\\cpsspSheetT{-"
	elif ssType == "G":
		result += "\\cpsspThreeTenHelix{-"
	elif ssType == "H":
		result += "\\cpsspAlphaHelix{-"
	elif ssType == "I":
		result += "\\cpsspPiHelix{-"
	elif ssType == "S":
		result += "\\cpsspBend{-"
	elif ssType == "T":
		result += "\\cpsspTurn{-"
	else:
		result += "\\cpsspGap{-"
	
	result += str(block * (blockDistance + nStruc * lineDistance) + line * lineDistance) + "}{"
	result += str(lineIndent + resWidth * start) + "}{"
	result += str(lineIndent + resWidth * end) + "}\n"
	return result


def makeTikzLabel(text, block, line):
	# compose a TikZ command which draws a label
	# input: * (text) the label text
	#        * (block) the current sequence block
	#        * (line) the current line
	# output: a string containing the TikZ command
	result = "\t\\cpsspLabel{-"
	result += str(block * (blockDistance + nStruc * lineDistance) + line * lineDistance) + "}{"
	result += text + "}\n"
	return result


def makeTikzRes(number, block, line, isStart, pos=None):
	# compose a TikZ command which draws the number of the first residue in the line
	# input: * (number) the residue number
	#        * (block) the current sequence block
	#        * (line) the current line
	#        * (isStart) True if the start residue number is to be drawn, False otherwise
	#        * (pos) unused for the start residue; for the end residue, it indicates the x position
	result = "\t\t"
	if isStart:
		result +="\\cpsspStartRes{-"
		result += str(block * (blockDistance + nStruc * lineDistance) + line * lineDistance) + "}{"
		result += str(lineIndent) + "}{"
	else:
		result +="\\cpsspEndRes{-"
		result += str(block * (blockDistance + nStruc * lineDistance) + line * lineDistance) + "}{"
		result += str(lineIndent + resWidth * pos) + "}{"
	result += str(number) + "}\n"
	return result


def usage():
	# print usage of the program
	print """CPSSP -- Compare Protein Secondary Structure Prediction v1.0
Usage: cpssp
-h or --help prints this message
-s or --sequence-file (FASTA file containing the sequences)
-u or --structure-file (FASTA file containing the structures; mandatory)
-o or --output-file (outout filename without extension and numbering)
-w or --image-width (width of the image in cm)
-t or --image-height (maximal height of an image in cm)
-i or --line-indent (indentation at the beginning of the line in cm)
-r or --residues-per-line (number of residues per line)
-l or --line-distance (distance between lines in cm)
-b or --block-distance (distance between blocks in cm)"""


def version():
	# print the program version
	print """CPSSP 1.0
Copyright (C) 2009 Wolfgang Skala
License LPPL v1.3c: The LaTeX project public license version 1.3c <http://www.latex-project.org/lppl.txt>
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law."""


### 2. CONSTANTS AND VARIABLES

AMINO_ACIDS = 'ARNDCQEGHILKMFPSTWYV-.' # the characters allowed in the FASTA file (AAs, gaps, structures)
SS_ELEMENTS = "BCEGHIST- "
GAP_SYMBOLS = "-."                   # possible gap symbols

sequenceFile = None
structureFile = None
imageWidth = 15 # total line width (in cm)
imageHeight = 20  # maximal height of the image (in cm); it will be split into separate files
                  # if its natural height exceeds this value; 0 indicates an arbitrary height
lineIndent = 2.5    # indentation at the left pof each line (in cm)
resPerLine = 50   # number of residues per line
lineDistance = .5 # distance between sequences within one line (in cm)
blockDistance = 1 # distance between sequence blocks (in cm)
outputFile = "cpsspresult"

	


### 3. MAIN PART

# process command line options
try:
	opts, args = getopt.getopt(sys.argv[1:],
	 "vhs:u:w:t:r:i:l:b:o:",
	 ["version", "help", "sequence-file=", "structure-file=", "image-width=", "image-height=",
	 "residues-per-line=", "line-indent=", "line-distance=", "block-distance=", "output-file="])
except getopt.GetoptError:
	usage()
	sys.exit(2)
	
for opt, arg in opts:
	if opt in ("-h", "--help"):
		usage()
		sys.exit()
	elif opt in ("-v", "--version"):
		version()
		sys.exit()
	elif opt in ("-s", "--sequence-file"):
		sequenceFile = arg
	elif opt in ("-u", "--structure-file"):
		structureFile = arg
	elif opt in ("-o", "--output-file"):
		outputFile = arg
	elif opt in ("-w", "--image-width"):
		try:
			imageWidth = float(arg)
		except ValueError:
			print "Invalid image width."
			sys.exit(1)
	elif opt in ("-t", "--image-height"):
		try:
			imageHeight = float(arg)
		except ValueError:
			print "Invalid image height."
			sys.exit(1)
	elif opt in ("r", "--residues-per-line"):
		try:
			resPerLine = int(arg)
		except ValueError:
			print "Invalid number of residues per line."
			sys.exit(1)
	elif opt in ("-i", "--line-indent"):
		try:
			lineIndent = float(arg)
		except ValueError:
			print "Invalid line indentation."
			sys.exit(1)
	elif opt in ("-l", "--line-distance"):
		try:
			lineDistance = float(arg)
		except ValueError:
			print "Invalid line distance."
			sys.exit(1)
	elif opt in ("-b", "--block-distance"):
		try:
			blockDistance = float(arg)
		except ValueError:
			print "Invalid block distance."
			sys.exit(1)

if structureFile == None:
	usage()
	sys.exit(2)
elif sequenceFile == None:
	# compare the predictions from multiple programs for a single protein
	# open FASTA file
	try:
		seqNames, structures = readFasta(structureFile, SS_ELEMENTS)
	except IOError as error:
		print "Could not open '" + error.filename + "'."
		sys.exit(1)
	
	# process structures
	try:
		structures = commonSyntax(structures)
		structures = addGaps(None, structures)
		brokenStructures = breakLines(structures, resPerLine)
	except IndexError:
		print "The structures seem to differ in length."
		sys.exit(1)	
else:
	# compare the predictions from a single program for multiple proteins
	# open FASTA files
	try:
		seqNames, sequences = readFasta(sequenceFile, AMINO_ACIDS)
		strucNames, structures = readFasta(structureFile, SS_ELEMENTS)
	except IOError as error:
		print "Could not open '" + error.filename + "'."
		sys.exit(1)

	# process sequences and structures read from the files
	try:
		sequences = removeGaps(sequences)
		structures = commonSyntax(structures)
		structures = addGaps(sequences, structures)
		brokenStructures = breakLines(structures, resPerLine)
	except IndexError:
		print "The sequences and structures seem to differ in length."
		sys.exit(1)


# now for the common part
# calculate or initiate some variables
resWidth = float(imageWidth - lineIndent) / resPerLine # width of a single residue (in cm)
nStruc = len(structures)                               # number of structures
nBlocks = len(brokenStructures[0])                     # number of blocks
if imageHeight == 0:
	blocksPerImage = nBlocks                           # blocks per image (output file)
else:
	blocksPerImage = int((imageHeight + blockDistance) / (nStruc * lineDistance + blockDistance))

tikzCommands = [] # list of strings where each string contains all TikZ commands for an image
for i in range(nBlocks / blocksPerImage + (1 if nBlocks % blocksPerImage != 0 else 0)):
	tikzCommands.append("")

# determine the appropriate commands
curLine = 0
for struc in brokenStructures:
	curBlock = 0
	curImage = 0
	for line, startRes, endRes in struc:
		tikzCommands[curImage] += makeTikzLabel(seqNames[curLine], curBlock, curLine)
		curType = ""
		for i in range(len(line)):
			if i == 0:
				curType = line[i]
				startPos = i
				if line[i] != "-":
					tikzCommands[curImage] += makeTikzRes(startRes + 1, curBlock, curLine, True)
				else:
					tikzCommands[curImage] += makeTikzRes(startRes, curBlock, curLine, True)
			if i < len(line) - 1:
				if line[i+1].upper() != line[i]:
					tikzCommands[curImage] += makeTikzDraw(curType, curBlock, curLine, startPos, i + 1)
					curType = line[i+1]
					startPos = i + 1
			else:
				if line[i] == "e":
					tikzCommands[curImage] += makeTikzDraw("e", curBlock, curLine, startPos, i + 1)
				else:
					tikzCommands[curImage] += makeTikzDraw(curType, curBlock, curLine, startPos, i + 1)
		tikzCommands[curImage] += makeTikzRes(endRes, curBlock, curLine, False, i + 1)
		curBlock += 1
		if curBlock % blocksPerImage == 0:
			curImage += 1
			curBlock = 0
	curLine += 1

# write the output files
try:
	for i in range(len(tikzCommands)):
		f = file(outputFile + str(i) + ".tex", "w")
		f.write(tikzCommands[i])
		f.close()
except IOError as error:
	print "Error while writing '" + error.filename + "'."
