#!/usr/bin/env python

import sparse
import sys
import os
import java2doc

class PythonDocParser(sparse.SimpleParser, object):
  def __init__(self, commentdict):
    """sets up the PythonParser"""
    sparse.SimpleParser.__init__(self, includewhitespacetokens=1)
    self.commentdict = commentdict
    self.startcommenttokens = ["#", '"""', "'''"]
    self.endcommenttokens = {"#": "\n", '"""': '"""', "'''": "'''"}
    self.standardtokenizers = [self.commenttokenize, self.stringtokenize, \
        self.removewhitespace, self.separatetokens]

  def keeptogether(self, input):
    """checks whether a token should be kept together"""
    # don't retokenize strings or comments
    return sparse.SimpleParser.keeptogether(self, input) or self.iscommenttoken(input)

  def iscommenttoken(self, input):
    """returns whether the given token is a comment token"""
    for startcommenttoken in self.startcommenttokens:
      if input[:len(startcommenttoken)] == startcommenttoken:
        return True
    return False

  def stringtokenize(self, input):
    """makes strings in input into tokens... but keeps comment tokens together"""
    if self.iscommenttoken(input):
      return [input]
    return sparse.SimpleParser.stringtokenize(self, input)

  def commenttokenize(self, input):
    """makes comment in input into tokens"""
    if sparse.SimpleParser.keeptogether(self, input): return [input]
    tokens = []
    incomment = False
    laststart = 0
    endcommenttoken = None
    for pos in range(len(input)):
      if incomment:
        if input[pos:pos+len(endcommenttoken)] == endcommenttoken:
          pos += len(endcommenttoken)
          if pos > laststart: tokens.append(input[laststart:pos])
          incomment, laststart = False, pos
      else:
        for startcommenttoken in self.startcommenttokens:
          if input[pos:pos+len(startcommenttoken)] == startcommenttoken:
            if pos > laststart: tokens.append(input[laststart:pos])
            incomment, laststart = True, pos
            endcommenttoken = self.endcommenttokens[startcommenttoken]
    if laststart < len(input): tokens.append(input[laststart:])
    return tokens

  def parse(self, source):
    """tokenizes and analyzes the java source"""
    self.tokenize(source)
    lastcomment = None
    lastend = 0
    expectingclass = False
    currentclass = None
    varinit = False
    self.insertdoc = {}
    for pos, token in enumerate(self.tokens):
      if pos in self.insertdoc:
        sys.stdout.write(self.insertdoc[pos])
      sys.stdout.write(token)
      if expectingclass and not token.isspace():
        currentclass = token
        comment = self.commentdict.get(currentclass, None)
        self.adddoc(pos, comment)
        expectingclass = False
      elif token == "class":
        expectingclass = True
      if token == "def":
        functionname = self.tokens[pos+2]
        if functionname == "__init__":
          comment = self.commentdict.get((currentclass, currentclass), None)
        else:
          comment = self.commentdict.get((currentclass, functionname), None)
        self.adddoc(pos, comment)

  def adddoc(self, pos, comment):
    if comment is None: return
    declstart = pos
    declend = declstart
    while self.tokens[declend] != ':':
      declend += 1
    nextindent = self.tokens[declend+1]
    if "\n" in nextindent:
      self.insertdoc[declend+1] = nextindent + '"""' + comment + '"""'
    else:
      previousindent = self.tokens[declstart-1]
      nextindent = previousindent + nextindent
      self.insertdoc[declend+1] = nextindent + '"""' + comment + '"""' + previousindent

if __name__ == "__main__":
  allcomments = []
  filenames = []
  for filename in sys.argv[1:]:
    if os.path.isdir(filename):
      for root, dirs, files in os.walk(filename):
        javasrcfiles = [os.path.join(root, j) for j in files if j.endswith(".java")]
        filenames.extend(javasrcfiles)
    elif os.path.exists(filename):
      filenames.append(filename)
    else:
      print >>sys.stderr, "could not find", filename
  sys.stderr.write("reading %d java files:\n" % len(filenames))
  for filename in filenames:
    sys.stderr.write(".")
    source = open(filename, 'r').read()
    javaparser = java2doc.JavaDocParser()
    javaparser.parse(source)
    comments = javaparser.comments
    allcomments.extend(comments)
  sys.stderr.write(" finished\n")
  sys.stderr.write("found %d comments\n" % len(allcomments))
  commentdict = {}
  for key, value in allcomments:
    if key in commentdict:
      commentdict[key] += "\n" + (value or "")
    else:
      commentdict[key] = value or ""
  sys.stderr.write("combined to %d named comments\n" % len(commentdict))
  parser = PythonDocParser(commentdict)
  source = sys.stdin.read()
  parser.parse(source)

