#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

import fcntl, logging, optparse, os, struct, sys, termios, traceback, types
from fnmatch import fnmatchcase as match
from getopt import GetoptError
from logging import getLogger, StreamHandler, Formatter, Filter, \
    WARN, DEBUG, ERROR
from qpid.util import URL

levels = {
  "DEBUG": DEBUG,
  "WARN": WARN,
  "ERROR": ERROR
  }

sorted_levels = [(v, k) for k, v in levels.items()]
sorted_levels.sort()
sorted_levels = [v for k, v in sorted_levels]

parser = optparse.OptionParser(usage="usage: %prog [options] PATTERN ...",
                               description="Run tests matching the specified PATTERNs.")
parser.add_option("-l", "--list", action="store_true", default=False,
                  help="list tests instead of executing them")
parser.add_option("-b", "--broker", default="localhost",
                  help="run tests against BROKER (default %default)")
parser.add_option("-f", "--log-file", metavar="FILE", help="log output to FILE")
parser.add_option("-v", "--log-level", metavar="LEVEL", default="WARN",
                  help="only display log messages of LEVEL or higher severity: "
                  "%s (default %%default)" % ", ".join(sorted_levels))
parser.add_option("-c", "--log-category", metavar="CATEGORY", action="append",
                  dest="log_categories", default=[],
                  help="log only categories matching CATEGORY pattern")
parser.add_option("-i", "--ignore", action="append", default=[],
                  help="ignore tests matching IGNORE pattern")
parser.add_option("-I", "--ignore-file", metavar="IFILE", action="append",
                  default=[],
                  help="ignore tests matching patterns in IFILE")
parser.add_option("-D", "--define", metavar="DEFINE", dest="defines",
                  action="append", default=[], help="define test parameters")

class Config:

  def __init__(self):
    self.broker = URL("localhost")
    self.defines = {}
    self.log_file = None
    self.log_level = WARN
    self.log_categories = []

opts, args = parser.parse_args()

includes = []
excludes = ["*__*__"]
config = Config()
list_only = opts.list
config.broker = URL(opts.broker)
for d in opts.defines:
  try:
    idx = d.index("=")
    name = d[:idx]
    value = d[idx+1:]
    config.defines[name] = value
  except ValueError:
    config.defines[d] = None
config.log_file = opts.log_file
config.log_level = levels[opts.log_level.upper()]
config.log_categories = opts.log_categories
excludes.extend([v.strip() for v in opts.ignore])
for v in opts.ignore_file:
  f = open(v)
  for line in f:
    line = line.strip()
    if line.startswith("#"):
      continue
    excludes.append(line)
  f.close()

for a in args:
  includes.append(a.strip())

if not includes:
  includes.append("*")

def included(path):
  for p in excludes:
    if match(path, p):
      return False
  for p in includes:
    if match(path, p):
      return True
  return False

def width():
  if sys.stdout.isatty():
    s = struct.pack("HHHH", 0, 0, 0, 0)
    fd_stdout = sys.stdout.fileno()
    x = fcntl.ioctl(fd_stdout, termios.TIOCGWINSZ, s)
    rows, cols, xpx, ypx = struct.unpack("HHHH", x)
    return cols
  else:
    return 80

WIDTH = width()

def resize(sig, frm):
  global WIDTH
  WIDTH = width()

import signal
signal.signal(signal.SIGWINCH, resize)

def vt100_attrs(*attrs):
  return "\x1B[%sm" % ";".join(map(str, attrs))

vt100_reset = vt100_attrs(0)

KEYWORDS = {"pass": (32,),
            "fail": (31,),
            "start": (34,),
            "total": (34,)}

COLORIZE = sys.stdout.isatty()

def colorize_word(word, text=None):
  if text is None:
    text = word
  return colorize(text, *KEYWORDS.get(word, ()))

def colorize(text, *attrs):
  if attrs and COLORIZE:
    return "%s%s%s" % (vt100_attrs(*attrs), text, vt100_reset)
  else:
    return text

def indent(text):
  lines = text.split("\n")
  return "  %s" % "\n  ".join(lines)

from qpid.testlib import testrunner
testrunner.setBroker(str(config.broker))

class Interceptor:

  def __init__(self):
    self.newline = False
    self.indent = False
    self.passthrough = True
    self.dirty = False
    self.last = None

  def begin(self):
    self.newline = True
    self.indent = True
    self.passthrough = False
    self.dirty = False
    self.last = None

  def reset(self):
    self.newline = False
    self.indent = False
    self.passthrough = True

class StreamWrapper:

  def __init__(self, interceptor, stream, prefix="  "):
    self.interceptor = interceptor
    self.stream = stream
    self.prefix = prefix

  def fileno(self):
    return self.stream.fileno()

  def isatty(self):
    return self.stream.isatty()

  def write(self, s):
    if self.interceptor.passthrough:
      self.stream.write(s)
      return

    if s:
      self.interceptor.dirty = True

    if self.interceptor.newline:
      self.interceptor.newline = False
      self.stream.write(" %s\n" % colorize_word("start"))
      self.interceptor.indent = True
    if self.interceptor.indent:
      self.stream.write(self.prefix)
    if s.endswith("\n"):
      s = s.replace("\n", "\n%s" % self.prefix)[:-2]
      self.interceptor.indent = True
    else:
      s = s.replace("\n", "\n%s" % self.prefix)
      self.interceptor.indent = False
    self.stream.write(s)

    if s:
      self.interceptor.last = s[-1]

  def flush(self):
    self.stream.flush()

interceptor = Interceptor()

out_wrp = StreamWrapper(interceptor, sys.stdout)
err_wrp = StreamWrapper(interceptor, sys.stderr)

out = sys.stdout
err = sys.stderr
sys.stdout = out_wrp
sys.stderr = err_wrp

class PatternFilter(Filter):

  def __init__(self, *patterns):
    Filter.__init__(self, patterns)
    self.patterns = patterns

  def filter(self, record):
    if not self.patterns:
      return True
    for p in self.patterns:
      if match(record.name, p):
        return True
    return False

root = getLogger()
handler = StreamHandler(sys.stdout)
filter = PatternFilter(*config.log_categories)
handler.addFilter(filter)
handler.setFormatter(Formatter("%(asctime)s %(levelname)s %(message)s"))
root.addHandler(handler)
root.setLevel(WARN)

log = getLogger("qpid.test")

class Runner:

  def __init__(self):
    self.exceptions = []

  def passed(self):
    return not self.exceptions

  def failed(self):
    return self.exceptions

  def run(self, name, phase):
    try:
      phase()
    except KeyboardInterrupt:
      raise
    except:
      self.exceptions.append((name, sys.exc_info()))

  def status(self):
    if self.passed():
      return "pass"
    else:
      return "fail"

  def print_exceptions(self):
    for name, info in self.exceptions:
      print "Error during %s:" % name
      print indent("".join(traceback.format_exception(*info))).rstrip()

ST_WIDTH = 8

def run_test(name, test, config):
  patterns = filter.patterns
  level = root.level
  filter.patterns = config.log_categories
  root.setLevel(config.log_level)

  parts = name.split(".")
  line = None
  output = ""
  for part in parts:
    if line:
      if len(line) + len(part) >= (WIDTH - ST_WIDTH - 1):
        output += "%s. \\\n" % line
        line = "    %s" % part
      else:
        line = "%s.%s" % (line, part)
    else:
      line = part

  if line:
    output += "%s %s" % (line, (((WIDTH - ST_WIDTH) - len(line))*"."))
  sys.stdout.write(output)
  sys.stdout.flush()
  interceptor.begin()
  try:
    runner = test()
  finally:
    interceptor.reset()
  if interceptor.dirty:
    if interceptor.last != "\n":
      sys.stdout.write("\n")
    sys.stdout.write(output)
  print " %s" % colorize_word(runner.status())
  if runner.failed():
    runner.print_exceptions()
  root.setLevel(level)
  filter.patterns = patterns
  return runner.passed()

class FunctionTest:

  def __init__(self, test):
    self.test = test

  def name(self):
    return "%s.%s" % (self.test.__module__, self.test.__name__)

  def run(self):
    return run_test(self.name(), self._run, config)

  def _run(self):
    runner = Runner()
    runner.run("test", lambda: self.test(config))
    return runner

  def __repr__(self):
    return "FunctionTest(%r)" % self.test

class MethodTest:

  def __init__(self, cls, method):
    self.cls = cls
    self.method = method

  def name(self):
    return "%s.%s.%s" % (self.cls.__module__, self.cls.__name__, self.method)

  def run(self):
    return run_test(self.name(), self._run, config)

  def _run(self):
    runner = Runner()
    inst = self.cls(self.method)
    test = getattr(inst, self.method)

    if hasattr(inst, "configure"):
      runner.run("configure", lambda: inst.configure(config))
      if runner.failed(): return runner
    if hasattr(inst, "setUp"):
      runner.run("setup", inst.setUp)
      if runner.failed(): return runner
    elif hasattr(inst, "setup"):
      runner.run("setup", inst.setup)
      if runner.failed(): return runner

    runner.run("test", test)

    if hasattr(inst, "tearDown"):
      runner.run("teardown", inst.tearDown)
    elif hasattr(inst, "teardown"):
      runner.run("teardown", inst.teardown)

    return runner

  def __repr__(self):
    return "MethodTest(%r, %r)" % (self.cls, self.method)

class PatternMatcher:

  def __init__(self, *patterns):
    self.patterns = patterns

  def matches(self, name):
    for p in self.patterns:
      if match(name, p):
        return True
    return False

class FunctionScanner(PatternMatcher):

  def inspect(self, obj):
    return type(obj) == types.FunctionType and self.matches(name)

  def descend(self, func):
    # the None is required for older versions of python
    return; yield None

  def extract(self, func):
    yield FunctionTest(func)

class ClassScanner(PatternMatcher):

  def inspect(self, obj):
    return type(obj) in (types.ClassType, types.TypeType) and self.matches(obj.__name__)

  def descend(self, cls):
    # the None is required for older versions of python
    return; yield None

  def extract(self, cls):
    names = dir(cls)
    names.sort()
    for name in names:
      obj = getattr(cls, name)
      t = type(obj)
      if t == types.MethodType and name.startswith("test"):
        yield MethodTest(cls, name)

class ModuleScanner:

  def inspect(self, obj):
    return type(obj) == types.ModuleType

  def descend(self, obj):
    names = dir(obj)
    names.sort()
    for name in names:
      yield getattr(obj, name)

  def extract(self, obj):
    # the None is required for older versions of python
    return; yield None

class Harness:

  def __init__(self):
    self.scanners = [
      ModuleScanner(),
      ClassScanner("*Test", "*Tests", "*TestCase"),
      FunctionScanner("test_*")
      ]
    self.tests = []
    self.scanned = []

  def scan(self, *roots):
    objects = list(roots)

    while objects:
      obj = objects.pop(0)
      for s in self.scanners:
        if s.inspect(obj):
          self.tests.extend(s.extract(obj))
          for child in s.descend(obj):
            if not (child in self.scanned or child in objects):
              objects.append(child)
      self.scanned.append(obj)

modules = "qpid.tests", "tests", "tests_0-10"
h = Harness()
for name in modules:
  m = __import__(name, None, None, ["dummy"])
  h.scan(m)

filtered = [t for t in h.tests if included(t.name())]
passed = 0
failed = 0
for t in filtered:
  if list_only:
    print t.name()
  else:
    if t.run():
      passed += 1
    else:
      failed += 1

if not list_only:
  if failed:
    outcome = "fail"
  else:
    outcome = "pass"
  print colorize("Totals:", 1), \
      colorize_word("total", "%s tests" % len(filtered)) + ",", \
      colorize_word("pass", "%s passed" % passed) + ",", \
      colorize_word(outcome, "%s failed" % failed)
