#!/usr/bin/python
import subprocess
import sys
import os
import bisect
from lxml import etree as ET

colon = "39584ac56e6d0976692778b80915f94b61767814a8704982"

def hexa(v):
	if v <= 9:
		return str(v)
	else:
		return hex(v)


# from renouveau.xml
# the rules.xml names suck
known_objclasses = {
	0x0001 : "NV01_ROOT",
	0x0002 : "NV01_CONTEXT_DMA",
	0x0003 : "NV01_DEVICE",
	0x0004 : "NV01_TIMER",
	0x0012 : "NV01_CONTEXT_BETA1",
	0x0017 : "NV01_CONTEXT_COLOR_KEY",
	0x0057 : "NV04_CONTEXT_COLOR_KEY",
	0x0018 : "NV01_CONTEXT_PATTERN",
	0x0019 : "NV01_CONTEXT_CLIP_RECTANGLE",
	0x001c : "NV01_RENDER_SOLID_LINE",
	0x005c : "NV04_RENDER_SOLID_LINE",
	0x001d : "NV01_RENDER_SOLID_TRIANGLE",
	0x005d : "NV04_RENDER_SOLID_TRIANGLE",
	0x001e : "NV01_RENDER_SOLID_RECTANGLE",
	0x005e : "NV04_RENDER_SOLID_RECTANGLE",
	0x001f : "NV01_IMAGE_BLIT",
	0x005f : "NV04_IMAGE_BLIT",
	0x009f : "NV12_IMAGE_BLIT",
	0x0021 : "NV01_IMAGE_FROM_CPU",
	0x0061 : "NV04_IMAGE_FROM_CPU",
	0x0065 : "NV05_IMAGE_FROM_CPU",
	0x008a : "NV10_IMAGE_FROM_CPU",
	0x038a : "NV30_IMAGE_FROM_CPU",
	0x308a : "NV40_IMAGE_FROM_CPU",
	0x0030 : "NV01_NULL",
	0x0036 : "NV03_STRETCHED_IMAGE_FROM_CPU",
	0x0076 : "NV04_STRETCHED_IMAGE_FROM_CPU",
	0x0066 : "NV05_STRETCHED_IMAGE_FROM_CPU",
	0x0366 : "NV30_STRETCHED_IMAGE_FROM_CPU",
	0x3066 : "NV40_STRETCHED_IMAGE_FROM_CPU",
	0x0037 : "NV03_SCALED_IMAGE_FROM_MEMORY",
	0x0077 : "NV04_SCALED_IMAGE_FROM_MEMORY",
	0x0063 : "NV05_SCALED_IMAGE_FROM_MEMORY",
	0x0089 : "NV10_SCALED_IMAGE_FROM_MEMORY",
	0x0389 : "NV30_SCALED_IMAGE_FROM_MEMORY",
	0x3089 : "NV40_SCALED_IMAGE_FROM_MEMORY",
	0x0038 : "NV04_DVD_SUBPICTURE",
	0x0088 : "NV10_DVD_SUBPICTURE",
	0x0039 : "NV04_MEMORY_TO_MEMORY_FORMAT",
	0x5039 : "NV50_MEMORY_TO_MEMORY_FORMAT",
	0x003d : "NV01_MEMORY_LOCAL_BANKED",
#	0x003e : "NV01_MAPPING_SYSTEM",
	0x003f : "NV03_MEMORY_LOCAL_CURSOR",
	0x0040 : "NV01_MEMORY_LOCAL_LINEAR",
	0x0041 : "NV01_MAPPING_LOCAL",
	0x0042 : "NV04_CONTEXT_SURFACES_2D",
	0x0062 : "NV10_CONTEXT_SURFACES_2D",
	0x0362 : "NV30_CONTEXT_SURFACES_2D",
	0x3062 : "NV40_CONTEXT_SURFACES_2D",
	0x0043 : "NV03_CONTEXT_ROP",
	0x0044 : "NV04_IMAGE_PATTERN",
	0x0046 : "NV03_VIDEO_LUT_CURSOR_DAC",
	0x0048 : "NV03_TEXTURED_TRIANGLE",
	0x004a : "NV04_GDI_RECTANGLE_TEXT",
	0x004b : "NV03_GDI_RECTANGLE_TEXT",
	0x0052 : "NV04_SWIZZLED_SURFACE",
	0x009e : "NV20_SWIZZLED_SURFACE",
	0x039e : "NV30_SWIZZLED_SURFACE",
	0x309e : "NV40_SWIZZLED_SURFACE",
	0x0053 : "NV04_CONTEXT_SURFACES_3D",
	0x0093 : "NV10_CONTEXT_SURFACES_3D",
	0x0054 : "NV04_TEXTURED_TRIANGLE",
	0x0094 : "NV10_TEXTURED_TRIANGLE",
	0x0055 : "NV04_MULTITEX_TRIANGLE",
	0x0095 : "NV10_MULTITEX_TRIANGLE",
	0x0056 : "NV10TCL",
	0x0096 : "NV11TCL",
	0x0099 : "NV17TCL",
	0x0058 : "NV03_CONTEXT_SURFACES_2D",
	0x005a : "NV03_CONTEXT_SURFACES_3D",
	0x0060 : "NV04_INDEXED_IMAGE_FROM_CPU",
	0x0064 : "NV05_INDEXED_IMAGE_FROM_CPU",
	0x006a : "NV03_CHANNEL_PIO",
	0x006b : "NV03_CHANNEL_DMA",
	0x0072 : "NV04_BETA_SOLID",
	0x007b : "NV10_TEXTURE_FROM_CPU",
	0x037b : "NV30_TEXTURE_FROM_CPU",
	0x307b : "NV40_TEXTURE_FROM_CPU",
	0x007c : "NV10_VIDEO_DISPLAY",
	0x0097 : "NV20TCL",
	0x0597 : "NV25TCL",
	0x0397 : "NV30TCL",
	0x0497 : "NV35TCL",
	0x0697 : "NV34TCL",
	0x4097 : "NV40TCL",
	0x4497 : "NV44TCL",
	0x502d : "NV50_2D",
	0x5097 : "NV50TCL",
	0x8297 : "NV84TCL",
	0x8397 : "NVA0TCL",
	0x8597 : "NVA8TCL",
	0x50c0 : "NV50_COMPUTE",
	0x0010 : "NV_IMAGE_STENCIL",
	0x0011 : "NV_IMAGE_BLEND_AND",
	0x0013 : "NV_IMAGE_ROP_AND",
	0x0015 : "NV_IMAGE_COLOR_KEY",
	0x0064 : "NV01_IMAGE_SRCCOPY_AND",
	0x0065 : "NV03_IMAGE_SRCCOPY",
	0x0066 : "NV04_IMAGE_SRCCOPY_PREMULT",
	0x0067 : "NV04_IMAGE_BLEND_PREMULT",
}

tree = ET.parse("rules.xml")
root = tree.getroot();
outdir = "rules"
try:
	os.mkdir(outdir)
except:
	pass
for subdir in ("mmio", "objects", "misc", "strange"):
	try:
		os.mkdir(outdir + "/" + subdir)
	except:
		pass

devobjs = []

address_stripeobjs = {}
objsuper_stripeobjs = {}
stripeobjs = []

class Stripe(object):
	def __init__(self):
		self.strange_name = None
		self.objsuper = None
		self.names = set()
		self.members = {}

class Device(object):
	pass

addendum = '<rule-database><device name="NV_PRMIO" start="0x00007100" end="0x0000710f" /><device name="NV_PCRTC" start="0x00600000" end="0x00600fff" /></rule-database>'
addendum_root = ET.fromstring(addendum)
for device in addendum_root.iterchildren(tag=ET.Element):
	addendum_root.remove(device)
	root.append(device)

for device in root.iterchildren(tag=ET.Element):
	devobj = Device()
	devobj.device = device
	devobj.name = device.attrib["name"]
	devobj.offset = int(device.attrib["start"], 0)
	devobj.end = int(device.attrib["end"], 0) + 1

	objclasses = {}

	for reglike in device.iterchildren(tag=ET.Element):
		if reglike.tag == "field-value":
			if "name" in reglike.attrib and "value" in reglike.attrib:
				objclasses[int(reglike.attrib["value"], 0)] = reglike.attrib["name"]

	for objclass in objclasses:
		if objclass in known_objclasses:
			objclasses[objclass] = known_objclasses[objclass]

	objclasses = [(name, num) for num, name in objclasses.items()]
	objclasses.sort()

	if not objclasses:
		if devobj.name == "NV_UTOMEM":
			objclasses = [("NV_IMAGE_TO_MEMORY", None)]

	devobj.objclasses = objclasses

	devobjs.append(devobj)

	if devobj.objclasses:
		objsuper = devobj.objclasses[0][0]
		us = objsuper.find("_")
		objsuper = objsuper[us + 1:]
		stripeobj = objsuper_stripeobjs.get(objsuper)
	else:
		stripeobj = address_stripeobjs.get(devobj.offset)

	if not stripeobj:
		stripeobj = Stripe()
		stripeobj.offset = devobj.offset
		stripeobj.end = devobj.end
		stripeobjs.append(stripeobj)

		# don't merge stuff at offset 0, since they usually aren't in the MMIO domain
		if not devobj.objclasses:
			if devobj.offset != 0:
				address_stripeobjs[devobj.offset] = stripeobj
		else:
			stripeobj.objsuper = objsuper
			objsuper_stripeobjs[objsuper] = stripeobj
	else:
		if devobj.end > stripeobj.end:
			stripeobj.end = devobj.end
	stripeobj.names.add(devobj.name)
	devobj.stripeobj = stripeobj

strange_stripeobjs = {}

def common_prefix_len(a, b):
	i = 0
	while a[i:i+1] == b[i:i+1]:
		i += 1
	return i

field_name_to_type_map = {"HANDLE" : "object", "STYLE" : "notify_style"}
uninteresting_field_names = ["VAL", "VALUE", "PARAMETER", "ADDRESS", "RESERVED", "DATA", "BLAH", "COUNT", "CNT", "ARGUMENT", "BITMAP", "OFFSET", "FIELD", "MODE", "LE"]
uninteresting_field_names += field_name_to_type_map.keys()
uninteresting_field_names = frozenset(uninteresting_field_names)

uninteresting_reglike_names = frozenset(["NOP", "NOTIFY", "SET_NOTIFY", "CTX_SWITCH"])

def process_values(field, rnn_field):
	true_value = None
	false_value = None
	common_prefix = None

	rnn_values = []
	strong_bool = False

	for field_value in field.iterchildren(tag=ET.Element):
		if field_value.tag == "field-value":
			field_value_name = field_value.attrib["name"]
			rnn_value_name = field_value_name
			if rnn_value_name.startswith(field_name + "_"):
				rnn_value_name = field_value_name[len(field_name)+1:]
			rnn_value = ET.Element("value")
			rnn_value.attrib["name"] = rnn_value_name
			rnn_value.attrib["value"] = hexa(int(field_value.attrib["value"], 0))
			rnn_values.append(rnn_value)
			rnn_field.append(rnn_value)
			if rnn_value_name in ("FALSE", "OFF", "DISABLE", "DISABLED", "NO", "0"):
				false_value = rnn_value
				if rnn_value_name != "0":
					strong_bool = True
			if rnn_value_name in ("TRUE", "ON", "ENABLE", "ENABLED", "YES", "1"):
				true_value = rnn_value
				if rnn_value_name != "0":
					strong_bool = True

	# avoid reglikes with just "0 = 0"
	if len(rnn_values) == 1 and rnn_values[0].attrib["name"] == "0" and rnn_values[0].attrib["value"] == "0":
		rnn_field.remove(rnn_values[0])
		rnn_values = []
	elif len(rnn_values) and len(rnn_values) == ((1 if true_value is not None else 0) + (1 if false_value is not None else 0)):
		if false_value is not None:
			rnn_field.remove(false_value)
		if true_value is not None:
			rnn_field.remove(true_value)
		rnn_field.attrib["type"] = "boolean"

	if len(rnn_values) > 1 and rnn_field.attrib["name"] not in ("NOTIFY",):
		for rnn_value in rnn_values:
			rnn_value_name = rnn_value.attrib["name"]
			if rnn_value_name == "INVALID" or rnn_value_name[0:1].isdigit():
				pass
			elif common_prefix is None:
				common_prefix = rnn_value_name
			else:
				common_prefix = common_prefix[:common_prefix_len(common_prefix, rnn_value_name)]

	if common_prefix:
		us = common_prefix.rfind("_")
		if us >= 0:
			common_prefix = common_prefix[:us + 1]
			if common_prefix:
				for rnn_value in rnn_values:
					if rnn_value.attrib["name"].startswith(common_prefix):
						rnn_value.attrib["name"] = rnn_value.attrib["name"][len(common_prefix):]
				rnn_field.attrib["name"] += "_" + common_prefix.rstrip("_")

#print ET.tostring(tree)
for devobj in devobjs:
	for reglike in devobj.device.iterchildren(tag=ET.Element):
		if reglike.tag == "register" or reglike.tag == "array":
			address = int(reglike.attrib["address"], 0)

			reglike_name = reglike.attrib["name"]
			cand_devobjs = [(-common_prefix_len(cand_devobj.name, reglike_name), cand_devobj) for cand_devobj in devobjs if address >= cand_devobj.offset and address < cand_devobj.end]
			cand_devobjs.sort()
			cand_devobj = cand_devobjs[0][1]
			stripeobj = cand_devobj.stripeobj
#			if stripeobj.offset == 0:

			if reglike_name.startswith(cand_devobj.name + "_"):
				stripeobj = cand_devobj.stripeobj
				if devobj.objclasses:
					assert cand_devobj.objclasses
				if cand_devobj.objclasses:
					assert devobj.objclasses
				objclasses = devobj.objclasses
			else:
				us = reglike_name.find("_")
				us = reglike_name.find("_", us + 1)
				strange_stripeobj_name = reglike_name[:us]
				stripeobj = strange_stripeobjs.get(strange_stripeobj_name)
				if not stripeobj:
					stripeobj = Stripe()
					stripeobj.strange_name = strange_stripeobj_name
					stripeobj.offset = 0
					stripeobj.end = 0xffffffff
					strange_stripeobjs[strange_stripeobj_name] = stripeobj
					stripeobjs.append(stripeobj)

				objclasses = []

			assert address >= cand_devobj.offset
			assert address < cand_devobj.end
			offset = address - cand_devobj.offset

			if offset == 0 and objclasses:
				continue

			rnn_reglike_name = reglike_name
			for device_name in stripeobj.names:
				if rnn_reglike_name.startswith(device_name + "_"):
					rnn_reglike_name = rnn_reglike_name[len(device_name) + 1:]
					break

			if rnn_reglike_name in uninteresting_reglike_names:
				continue

			rnn_reglike = ET.Element("reg32")

			rnn_reglike.attrib["offset"] = hex(offset)

			if reglike.tag == "array":
				max_size = 0
				for field in reglike.iterchildren(tag=ET.Element):
					if field.tag == "array-size":
						size = int(field.attrib["size"], 0)
						if size > max_size:
							max_size = size
				rnn_reglike.attrib["length"] = str(max_size)
				rnn_reglike.attrib["stride"] = str(int(reglike.attrib["width"], 0))

			num_bitfields = 0
			full_field_name = None
			for field in reglike.iterchildren(tag=ET.Element):
				if field.tag == "field":
					i = 0 # for "i" variables in the low-bit and high-bit attributes
					low = eval(field.attrib["low-bit"])
					high = eval(field.attrib["high-bit"])
					field_name = field.attrib["name"]
					rnn_field_name = field_name
					if rnn_field_name.startswith(reglike_name + "_"):
						rnn_field_name = rnn_field_name[len(reglike_name)+1:]

					if low == 0 and high == 31:
						if rnn_field_name not in uninteresting_field_names and not rnn_field_name.startswith("NV_"):
							full_field_name = rnn_field_name
						if rnn_field_name in field_name_to_type_map:
							rnn_reglike.attrib["type"] = field_name_to_type_map[rnn_field_name]
						bitfield = rnn_reglike
						for elem in field.iterchildren(tag=ET.Element):
							field.remove(elem)
							reglike.append(elem)
					else:
						bitfield = ET.Element("bitfield")
						rnn_reglike.append(bitfield)
						bitfield.attrib["low"] = str(low)
						bitfield.attrib["high"] = str(high)
						bitfield.attrib["name"] = rnn_field_name
						num_bitfields += 1
						process_values(field, bitfield)

			if full_field_name is not None and num_bitfields == 0:
				rnn_reglike_name += "_" + full_field_name
			rnn_reglike.attrib["name"] = rnn_reglike_name
			process_values(reglike, rnn_reglike)

			stripeobj.members.setdefault(offset, {}).setdefault(ET.tostring(rnn_reglike), []).extend(objclasses)

# TODO: maybe the Python stdlib provides this function?
def extract_ranges(univ, cur):
	cur = frozenset(cur)
	i = 0
	first = None
	ranges = []
	for item in univ:
		if item in cur:
			if first is None:
				first = i
		else:
			if first is not None:
				ranges.append((first, i - 1))
				first = None
		i += 1
	if first is not None:
		ranges.append((first, i - 1))
	return ranges

for stripeobj in stripeobjs:
	if not stripeobj.members:
		continue

	stripe = ET.Element("stripe")

	database = ET.Element("database")
	database.attrib["xmlns"] = "http://nouveau.freedesktop.org/"
	database.attrib["xmlns" + colon + "xsi"] = "http://www.w3.org/2001/XMLSchema-instance"
	database.attrib["xsi" + colon + "schemaLocation"] = "http://nouveau.freedesktop.org/ rules-ng.xsd"

	members = stripeobj.members.items()
	members.sort()

	objclasses = set()
	for offset, versions in members:
		for version_objclasses in versions.values():
			objclasses.update(version_objclasses)
	objclasses = list(objclasses)
	objclasses.sort()

	objclass_numbers = set([num for num, name in objclasses])
	# assert no objclass has conflicting names
	# could actually handle this, but it's not necessary
	assert len(objclasses) == len(objclass_numbers)

	for offset, versions in members:
		versions_sorted = [(version_objclasses, xml) for xml, version_objclasses in versions.items()]
		versions_sorted.sort()
		for version_objclasses, xml in versions_sorted:
			elem = ET.fromstring(xml)

			if version_objclasses and set(objclasses) != set(version_objclasses):
				ranges = extract_ranges(objclasses, version_objclasses)
				elem.attrib["variants"]  = " ".join([(objclasses[a][0] if a == b else objclasses[a][0] + "-" + objclasses[b][0]) for a, b in ranges])
			stripe.append(elem)

	if stripeobj.objsuper:
		objenum = ET.Element("enum")
		objenum.attrib["bare"] = "yes"
		objenum.attrib["name"] = "obj-class"

		for objname, objclass in objclasses:
			objvalue = ET.Element("value")
			if objclass is not None:
				objvalue.attrib["value"] = hex(objclass)
			objvalue.attrib["name"] = objname
			objenum.append(objvalue)

		database.append(objenum)

		subchan = ET.Element("domain")
		subchan.attrib["name"] = "NV01_SUBCHAN"
		subchan.attrib["bare"] = "yes"
		subchan.attrib["size"] = "0x2000"

		stripe.attrib["prefix"] = "obj-class"
		#stripe.attrib["varset"] = "obj-class"
		if len(objclasses) == 1:
			stripe.attrib["variants"] = objclasses[0][0]
		else:
			stripe.attrib["variants"] = objclasses[0][0] + "-" + objclasses[-1][0]

		subchan.append(stripe)
		database.append(subchan)
	else:
		name = None
		if stripeobj.names:
			assert len(stripeobj.names) == 1
			name = list(stripeobj.names)[0]

		mmio = ET.Element("domain")
		if stripeobj.offset != 0 or name == "NV_PMC":
			mmio.attrib["name"] = "NV_MMIO"
			mmio.attrib["bare"] = "yes"
			is_mmio = True
		elif name:
			mmio.attrib["name"] = name
			is_mmio = False
		mmio.attrib["prefix"] = "chipset"

		if is_mmio and name:
			stripe.attrib["name"] = name
		stripe.attrib["offset"] = hex(stripeobj.offset)
		stripe.attrib["stride"] = hex(stripeobj.end - stripeobj.offset)
		mmio.append(stripe)
		database.append(mmio)

	if stripeobj.objsuper:
		basename = stripeobj.objsuper.lower()
	else:
		basename = name
		if stripeobj.strange_name:
			basename = stripeobj.strange_name

		if basename.startswith("NV_"):
			basename = basename[3:]
		basename = basename.lower()

	if stripeobj.objsuper:
		subdir = "objects"
	elif stripeobj.strange_name:
		subdir = "strange"
	elif is_mmio:
		subdir = "mmio"
	else:
		subdir = "misc"
	filename = outdir + "/" + subdir + "/" + basename

	outfile = open(filename, "w")
	xml2text = subprocess.Popen("./xml2text", stdin = subprocess.PIPE, stdout = outfile)
	xml2text.communicate(ET.tostring(database).replace(colon, ":"))
	xml2text.wait()
	outfile.close()

