# === UTILITIES ===

import math

def choose(n,r): # nCr
	f = math.factorial
	return f(n) / f(r) / f(n-r)

# === EXCEPTIONS ===

class VectorError (Exception):
	pass

class DimensionMismatchError (VectorError):
	def __init__(self,dim1,dim2):
		self.dim1 = dim1
		self.dim2 = dim2

# === VECTOR CLASS ===

class multivector:
	"""Represents a linear combination of blades in hyperspace. Only argument to initializer is 'd', the number of dimensions."""
	
	def __init__(self, d=4, fill=0):
		"""Initializes dimensions to d and data to an appropriately-sized list of zeroes."""
		self.dim = d
		if fill == 'empty': # Special to override initialization if necessary
			self.data = []
		elif fill == 'seq':
			self.data = range(2**d)
		else:
			self.data = [fill]*(2**d)
	
	def unit_antiscalar(d):
		out = multivector(d)
		out.data[len(out.data)-1] = 1 # Put a one into the last slot, the antiscalar
		return out
	
	def __str__(self):
		"""Prints fancy string representation of blade or multivector."""
		parts = []
		highest = -1
		for i, part in enumerate(self.components()):
			if part and not all(v==0 for (u,v) in part):
				a = prefix[i]+'vector: < '
				a += ', '.join(axes(i)+' '+str(val) for (i, val) in part)
				a += ' >'
				parts.append(a)
				highest = i
		if not parts:
			return '<empty '+str(self.dim)+'-multivector>'
		elif len(parts) == 1:
			return str(self.dim)+'D '+parts[0]
		else:
			title = "MULTIVECTOR"
			if len(parts) == 2 and self.scalar():
				if highest == self.dim-1:
					title = "COMPLEX"
				elif highest == 1:
					title = "PARAVECTOR"
			return '\n'.join(
				[str(self.dim)+'-'+title] +
				['\t'+part for part in parts] )
	
	def vectorize(vec,dimset=-1):
		"""Creates a multivector from some sort of input.
		
		A two-dimensional list is taken as a list of blades.
		A one-dimensional list is taken as a vector.
		Anything else is taken as a scalar.
		
		Optional second argument forces a certain number of dimensions, which is necessary for a scalar."""
		if not vec: # Don't try to deal with empty lists
			return None
		
		if hasattr(vec, '__iter__'): # It's a list of some sort
			if hasattr(vec[0], '__iter__'): # 2D list
				# First, check if it's actually valid.
				# Is it a scalar or smaller vector in disguise?
				if len(vec) == 1:
					return multivector.vectorize(vec[0])
				
				row = [ len(i) for i in vec ]
				dim = row[1] # First element should be total number of dimensions
				if len(vec) != dim+1: # 0, 1..dim-1, dim
					raise DimensionMismatchError(len(vec)+1, dim)
				if dimset != -1 and dimset != dim: # If we're forcing a dimension number...
					raise DimensionMismatchError(dimset, dim)
				
				for i, val in enumerate(row): # Check to see if it actually matches the triangle row
					if val != choose(dim, i):
						raise DimensionMismatchError(val, choose(dim, i))
				
				# Now transfer it.
				output = multivector(dim)
				v2 = [v[:] for v in vec] # Make a copy because we're going to be slashing this one apart
				for i in range(2 ** dim):
					output.data[i] = v2[hamming(i)].pop(0) # Take the relevant element from the relevant list.
				return output
				
			else: # This is a 1D list which should become a vector.
				dim = len(vec)
				if dimset != -1 and dimset != dim: # If we're forcing a dimension it has to match
					raise DimensionMismatchError(dim, dimset)
				
				output = multivector(dim)
				for i in range(dim): # Put vector into one-dimensional components
					output.data[1<<i] = vec[i]
				return output
				
		else: # Must be a scalar - if not, nonsense will result later!
			if dimset == -1: # Need a dimension specified to make this work
				raise DimensionMismatchError(-1, 0)
			output = multivector(dimset)
			output.data[0] = vec
			return output

	def clone(self):
		"""Returns a deep copy of this multivector."""
		output = multivector(self.dim)
		output.data = self.data.copy()
		return output
	
	def component(self, d):
		"""Returns the d-blade component of this multivector, as a list."""
		if d > self.dim:
			return None
		out = []
		for i, val in enumerate(self.data):
			if hamming(i) == d: # Is this a component we want?
				out.append((i, val))
		return out
	
	def icomponent(self, d):
		"""Returns the d-blade component of this multivector, as a blade."""
		if d > self.dim or d < 0:
			return None
		out = multivector(self.dim)
		for i, val in enumerate(self.data):
			if hamming(i) == d: # Is this a component we want?
				out.data[i] = val
		return out
	
	def components(self): # Get all components at once
		"""Returns a list of blades comprising this multivector, as lists of numbers."""
		out = [ [] for i in range(self.dim+1) ]
		for i, val in enumerate(self.data):
			out[hamming(i)].append((i,val))
		return out
	
	def __getitem__(self, n): # Return a particular component
		if n < 0 or n > self.dim:
			raise IndexError(n)
		return self.icomponent(n)
	
	def __setitem__(self, n, v): # Assign to a particular component
		if n < 0 or n > self.dim:
			raise IndexError(n)
		for i in range(2**self.dim):
			if hamming(i) == n:
				self.data[i] = v.pop(0)
	
	def scalar(self): # Get only scalar part, a bit faster than calling component(0)
		return self.data[0]
	
	def vector(self): # Ditto for haplovector part
		return [self.data[2**i] for i in range(self.dim)]
	
	def invert(self):
		"""Flips all components numerically."""
		for i, val in enumerate(self.data):
			self.data[i] = -val
		return self
	
	def __neg__(self):
		"""Unary - operator; flips all components numerically and returns a copy."""
		output = self.clone()
		output.invert()
		return output
	
	def __add__(self, other):
		"""Binary + operator; component-wise addition with multivector or scalar."""
		isvec, output = is_multi(other)
		if isvec:
			return self.mvec_add(output)
		else:
			return self.simple_add(output)
	
	def __radd__(self, other):
		return self + other # Addition is commutative
	
	def mvec_add(self, other):
		"""Adds two multivectors component-wise and returns the result."""
		if self.dim != other.dim:
			raise DimensionMismatchError(self.dim, other.dim)
		output = multivector(self.dim, fill='empty') # Don't want zeroes in this one
		for (val1, val2) in zip(self.data, other.data):
			output.data.append(val1 + val2)
		return output
	
	def simple_add(self, other):
		"""Adds a scalar to a multivector using a faster method than vectorize()."""
		output = self.clone()
		output.data[0] += other
		return output
	
	def __sub__(self, other):
		"""Binary - operator; component-wise subtraction with multivector or scalar."""
		discard, output = is_multi(other)
		return self + (-output)
	
	def __rsub__(self, other): # Subtraction is anticommutative
		return other + (-self)
	
	def __iadd__(self, other):
		isvec, output = is_multi(other)
		if not isvec:
			output = multivector.vectorize(output)
		if self.dim != output.dim:
			raise DimensionMismatchError(self.dim, output.dim)
		for i, val in enumerate(output.data):
			self.data[i] += val
		return self
	
	def __isub__(self, other):
		isvec, output = is_multi(other)
		if not isvec:
			output = multivector.vectorize(output)
		if self.dim != output.dim:
			raise DimensionMismatchError(self.dim, other.dim)
		for i, val in enumerate(output.data):
			self.data[i] -= val
		return self
	
	def inner_product(self, other):
		if self.dim != other.dim:
			raise DimensionMismatchError(self.dim, other.dim)
	#	return sum([a*b for (a, b) in zip(self.data, other.data)]) # Only works if all are spacelike
		# This only works for vectors. Using the slower but more effective formula.
		return self.contract(other, 'inner')
	
	def geom_outer_product(self, other, outer):
		if self.dim != other.dim:
			raise DimensionMismatchError(self.dim, other.dim)
		out = multivector(self.dim)
		for i, a in enumerate(self.data):
			for j, b in enumerate(other.data):
				if bool(i&j) and outer: # When taking the outer product, any parts that are not fully independent (i.e. they share at least one basis) will become zero
					continue # So we can ignore them
				k = i^j # New index is xor of input indices, since two of the same cancel out
				s = multivector.ordering_sign(i, j) # Do we need to flip this? (Anticommutativity!)
				out.data[k] += s*a*b # Now add the product (with the right sign) to the specified component
		return out
	
	def geom_product(self, other):
		return self.geom_outer_product(other, False)
	
	def outer_product(self, other):
		return self.geom_outer_product(other, True)
	
	def contract(self, other, dir='left'):
		if self.dim != other.dim:
			raise DimensionMismatchError(self.dim, other.dim)
		out = multivector(self.dim)
		for i in range(self.dim):
			for j in range(other.dim):
				if dir == 'left':
					k = j-i
				elif dir == 'right':
					k = i-j
				elif dir == 'inner':
					k = abs(i-j)
				if k >= 0:
					out += (self[i] @ other[j]) [k]
		return out
	
	def scalar_product(self, other):
		output = self.clone()
		output.data[:] = [other*i for i in output.data]
		return output
	
	def lie_product(self, other): # Commutator product, isomorphic to Lie bracket over bivectors
		return (self@other - other@self) / 2 # Can be used to find e.g. intersection of planes
	
	def ordering_sign(i, j): # We may need to flip indices to get the result into canonical order: XY => XY (0), YX => -XY (1), ZYX => -ZXY => XZY => -XYZ (3). Every time we do this we need to flip the sign as well because the wedge product is anticommutative.
	#	if i == j: return 1 # Special case can be made faster	# No it can't—that only works for odd grades!
		if not i or not j: return 1 # This case, on the other hand, can
		flip = False # No flip necessary right away
		while i:
			i >>= 1 # The least-significant-bit should be 'first', so we shift the first value toward it
			if parity(i&j): # Now count how many times we 'ran into' something while shifting. We want i (the first value) to have all the less significant bits, so every time it 'collides' with a bit in j (the second value), we need to flip something.
				flip = not flip # Because -1*-1=1 we only care whether this happened an even or odd number of times
		return -1 if flip else 1 # Return a number which can be multiplied by the result
	
	def __mul__(self, other): # Scalar product or inner/dot product
		isvec, output = is_multi(other)
		if isvec:
			return self.inner_product(output)
		else:
			return self.scalar_product(other)
	
	def __rmul__(self, other): # Reverse version of scalar or inner product
		isvec, output = is_multi(other)
		if isvec:
			return other.inner_product(self)
		else:
			return self.scalar_product(other)
	
	def __mod__(self, other): # Outer/wedge product
		isvec, output = is_multi(other)
		if not isvec: # Can't cross a vector with a scalar!
		#	raise DimensionMismatchError(self.dim, 0)
			return self.scalar_product(other)
		return self.outer_product(output)
	
	def __rmod__(self, other): # Outer/wedge product in reverse
		isvec, output = is_multi(other)
		if not isvec: # Can't cross a vector with a scalar!
		#	raise DimensionMismatchError(0, self.dim)
			return self.scalar_product(other)
		return output.outer_product(self)
	
	def __matmul__(self, other): # Geometric/direct product
		isvec, output = is_multi(other)
		if not isvec: # Can't combine a vector with a scalar!
		#	raise DimensionMismatchError(self.dim, 0)
			return self.scalar_product(other)
		return self.geom_product(output)
	
	def __rmatmul__(self, other): # Geometric/direct product in reverse
		isvec, output = is_multi(other)
		if not isvec: # Can't combine a vector with a scalar!
		#	raise DimensionMismatchError(self.dim, 0)
			return self.scalar_product(other)
		return output.geom_product(self)
	
	def __lshift__(self, other): # Left contraction
		isvec, output = is_multi(other)
		if not isvec:
			raise DimensionMismatchError(self.dim, 0)
		return self.contract(output, 'left')
	
	def __rlshift__(self, other): # Left contraction in reverse
		isvec, output = is_multi(other)
		if not isvec:
			raise DimensionMismatchError(self.dim, 0)
		return output.contract(self, 'left')
	
	def __rshift__(self, other): # Right contraction
		isvec, output = is_multi(other)
		if not isvec:
			raise DimensionMismatchError(self.dim, 0)
		return self.contract(output, 'right')
	
	def __rrshift__(self, other): # Right contraction in reverse
		isvec, output = is_multi(other)
		if not isvec:
			raise DimensionMismatchError(self.dim, 0)
		return output.contract(self, 'right')
	
	def __truediv__(self, other):
		if not is_scalar(other):
			x = other.dim if isinstance(other, multivector) else len(other) # Get the dimension of the vector
			raise DimensionMismatchError(x, 0)
		out = multivector(self.dim)
		out.data = [i/other for i in self.data]
		return out
	
	def __floordiv__(self, other):
		if not is_scalar(other):
			x = other.dim if isinstance(other, multivector) else len(other) # Get the dimension of the vector
			raise DimensionMismatchError(x, 0)
		out = multivector(self.dim)
		out.data = [i//other for i in self.data]
		return out
	
	def __imul__(self, other): # Inline scalar multiplication
		if not is_scalar(other):
			x = other.dim if isinstance(other, multivector) else len(other) # Get the dimension of the vector
			raise DimensionMismatchError(x, 0)
		self.data[:] = [other*i for i in self.data]
	
	def __itruediv__(self, other): # Inline scalar division
		if not is_scalar(other):
			x = other.dim if isinstance(other, multivector) else len(other) # Get the dimension of the vector
			raise DimensionMismatchError(x, 0)
		self.data[:] = [other/i for i in self.data]
	
	def __ifloordiv__(self, other): # Inline scalar floor division
		if not is_scalar(other):
			x = other.dim if isinstance(other, multivector) else len(other) # Get the dimension of the vector
			raise DimensionMismatchError(x, 0)
		self.data[:] = [other//i for i in self.data]
	
	def reversion(self): # Reversion antiautomorphism
		out = multivector(self.dim, fill='empty')
		for i, a in enumerate(self.data):
			s = (hamming(i) // 2) % 2 # Is the (n+2)th triangular number odd or even? If it's odd, it takes an odd number of flips to reverse the order of the bases, so the result will be inverted. If it's even, it takes an even number, so the sign remains. This pattern goes (0 0 1 1 0 0 1 1...) which is floor(n/2)%2
			out.data.append(-a if s else a) # Append either a or -a depending whether we flipped or not
		return out
	
	def __invert__(self): # Ditto
		return self.reversion()
	
	def conjugate(self): # Clifford/bar conjugation = combination of grade involution and reversion
	#	out = multivector(self.dim, empty=True)
	#	for i, a in enumerate(self.data):
	#		s = (hamming(i) // 2) % 2
	#		out.data.append(a if s else -a)
	#	return out
		# TODO
		return self
	
	def magnsqr(self): # Square of the magnitude
		prod = self @ ~self # Geometric product of B and revert-B
		return prod.scalar()
	
	def magn(self): # Magnitude
		return math.sqrt(self.magnsqr()) # Square root of result = "magnitude"
	
	def __pos__(self): # Ditto
		return self.magn()
	
	def rotate(self, v): # Rotate a multivector using this bivector as a rotor
		if self.dim != v.dim: # I don't actually check if it's a pure bivector but I have no idea what would happen if it weren't
			raise DimensionMismatchError(self.dim, v.dim)
		return self @ v @ ~self
	
	def oneover(self): # Multiplicative inverse
	#	return ~self / (self @ ~self) # Currently works only for bivectors and other mvs for which B@~B is a scalar! Need to extend...
		return self.conjugate() / self.magnsqr() # Slightly faster way to compute
	
	def normalize(self): # Normalize to a unit multivector
		return self / +self
	
	def exp(self): # e^B
		theta = +self # Magnitude
		if theta == 0: # Don't try to normalize in this case!
			return multivector(self.dim) + 1
		i = self / theta # Convert B to iθ for some scalar θ and unit bivector i
		return math.cos(theta) + i*math.sin(theta) # Euler's formula!
	
	def makerotor(self, theta, prenorm=False): # Rotor by angle theta with my plane and direction
		theta /= 2 # Remove extraneous factor of 2
		if prenorm:
			i = self
		else:
			i = self.normalize()
		return math.cos(theta) + i*math.sin(theta)
	
	def hodge(self): # Hodge dual = geometric product with unit antiscalar
		return (multivector.unit_antiscalar(self.dim) @ self)
	
	def project_down(v, method='orthographic'):
		if isinstance(v, multivector):
			v = v.vector() # Reduce to haplovector
		if method == 'orthographic':
			return [i for i in v[:-1]]
		elif method == 'perspective':
			alpha = abs(v[-1])
			f = 1 # Focal length for the projection - might want to actually calculate this? Nah...
			return [i*f/alpha for i in v[:-1]]
		else:
			raise NotImplementedError("'"+method+"' projection method")
	
	def projection(self, target, method='orthographic'):
		if self.dim <= target:
			add = [0] * (target - self.dim)
			return self.vector() + add
		
		v = self
		if isinstance(v, multivector):
			v = v.vector()
		for _ in range(self.dim - target):
			v = multivector.project_down(v, method)
		return v
	
	def is_null(self):
		return all(not a for a in self.data)

# === UTILITIES ===

def hamming(n): # Hamming weight
#	i = 0
#	while n:
#		n &= n-1
#		i += 1
#	return i
	return bin(n).count('1')

def parity(n): # Whether the number of 'on' bits is odd
	return bool(hamming(n) & 1)

def axes(n):
	return 'S' if n==0 else ''.join(axisname[i] for i, d in enumerate(reversed(bin(n))) if d == '1') # Go through the binary representation in reverse order, putting the axes together into a string. Special case if the input is 0.

axisname = ['X','Y','Z','W','A','B','C','D','E','F','G']

prefix = ['ceno','haplo','bi','tri','tetra','penta','hexa','septi','octo','ennea','deca']

def is_multi(input): # Return a tuple of (whether this is a multivector, its multivector representation)
	if isinstance(input, multivector): # It's a multivector
		return (True, input)
	elif hasattr(input, '__iter__'): # It's a vector
		return (True, multivector.vectorize(input))
	else: # It's a scalar
		return (False, input)

def is_scalar(input):
	if isinstance(input, multivector):
		return False
	elif hasattr(input, '__iter__'):
		return False
	return True

# === TEST ===

# from random import randint

# a = multivector()
# print(a)
# print()

# a.data[3] = 42
# print(a)
# print()

# for i in range(16):
	# a.data[i] = randint(1,100)
# print(a)