#!/usr/bin/env python
#
#############################################################################
#
# MODULE:      i.plr.py
# AUTHOR(S):   Georg Kaspar
# PURPOSE:     Probabilistic label relaxation, postclassification filter
# COPYRIGHT:   (C) 2009
#
#              This program is free software under the GNU General Public
#              License (>=v2). Read the file COPYING that comes with GRASS
#              for details.
#
#############################################################################

#%Module
#% description: Probabilistic label relaxation, postclassification filter
#%End
#%option
#% key: group
#% type: string
#% description: image group to be used
#% required : yes
#%end
#%option
#% key: subgroup
#% type: string
#% description: image subgroup to be used
#% required : yes
#%end
#%option
#% key: sigfile
#% type: string
#% description: Path to sigfile
#% required : yes
#%end
#%option
#% key: output
#% type: string
#% description: Name for resulting raster file
#% required : yes
#%end
#%option
#% key: iterations
#% type: integer
#% description: Number of iterations (1 by default)
#% required : no
#%end
#%option
#% key: ntype
#% type: integer
#% description: type of neighbourhood (4(default) or 8)
#% required : no
#%end

import sys
import os
import numpy
import grass.script as grass
from osgeo import gdal, gdalnumeric, gdal_array
from osgeo.gdalconst import GDT_Byte

def getMetadata():
    env = grass.gisenv()
    global GISDBASE 
    global LOCATION_NAME 
    global MAPSET 
    GISDBASE = env['GISDBASE']
    LOCATION_NAME = env['LOCATION_NAME']
    MAPSET = env['MAPSET']
        
def splitSignatures(path, sigfile):
    # split signature file into subfiles with 1 signature each
    stream_in = open(path + sigfile, "r")
    stream_in.next() # skip first line
    counter = 0
    stream_out = open(path + "plr_foo.sig", "w")
    for line in stream_in: 
        if line[0] == "#":
            stream_out.close()
            counter += 1     
            stream_out = open(path + "plr_" + str(counter) + ".sig", "w")
            stream_out.write("#produced by i.plr\n")
        stream_out.write(line)
    stream_out.close()
    stream_in.close()
    return counter
    
def normalizeProbabilities(counter):
    arg = ""
    for i in range(counter):
        arg = arg + "+plr_rej_" + str(i)
    arg = arg.strip('+')
    print "calculating multiplicands, arg=" + arg
    grass.run_command("r.mapcalc", multiplicand = "1./(" + arg + ")")
    for i in range(counter):
        print "normalizing probabilities for class " + str(i)
        grass.run_command("r.mapcalc", plr_rej_norm = "plr_rej_" + str(i) + "*multiplicand")
        grass.run_command("g.rename", rast="plr_rej_norm,plr_rej_norm_" + str(i))
        
def getProbability(a,b):
    # TODO: Implement this!!!
    if a == b:
        return 1
    else:
        return 0.5
    
def cleanUp(path):
    os.system("rm " + path + "/plr_*.*")
    os.system("rm /tmp/plr_*.*")
    grass.run_command("g.mremove", flags="f", quiet=True, rast="plr_*")
    
def plr_filter(probabilities, width, height, classes, type):
    # create an empty n-dimesional array containing results
    results = numpy.ones((classes,height,width))            
    print "starting label relaxation"
    progress = 0
    # for each pixel (except border)
    for y in range(height):
        p = int(float(y)/height * 100)
        if p > progress:
            print '\r' + str(p) + '%'
            progress = p
        for x in range(width):
            # for all classes create neighbourhood and extract probabilities
            for c in range(1, classes+1):
                if (x == 0) or (x == width-1) or (y == 0) or (y == height-1):
                    results[c-1,y,x] = probabilities[c-1,y,x]
                else:
                    if type == 8:
                        q = neighbourhoodFunction8(probabilities, x, y, c, classes)
                    else:
                        q = neighbourhoodFunction4(probabilities, x, y, c, classes)
                    p = probabilities[c-1, y, x]
                    # resulting cell contains the product of class probability and
                    # neighbourhood-function
                    results[c-1,y,x] = p * q         
    print ""
    return results
           
def neighbourhoodFunction4(probabilities, x, y, z, classes):           
    n = []
    neighbours = [[x-1,y],[x,y],[x+1,y],[x,y-1],[x,y+1]]
    for i in neighbours:
        l = []
        # for each possible class
        for c in range(1, classes+1):
            l.append(getProbability(z, c) * float(probabilities[c-1,i[1],i[0]]))
        n.append(sum(l))
    return sum(n)
        
def neighbourhoodFunction8(probabilities, x, y, z, classes):
    n = []    
    for j in range(y-1, y+2):
        for i in range(x-1, x+2):
            l = []
            # for each possible class
            for c in range(1, classes+1):
                l.append(getProbability(z, c) * float(probabilities[c-1,j,i]))
            n.append(sum(l))
    return sum(n)
    
def createMap(probabilities, width, height, classes):
    print "retrieving class labels"
    results = numpy.ones((height,width))
    progress = 0
    for y in range(height):
        p = int(float(y)/height * 100)
        if p > progress:
            print '\r' + str(p) + '%'
            progress = p
        for x in range(width):
            max_class = 1
            max_val = probabilities[0,y,x]
            for c in range(2, classes+1):
                current_val = probabilities[c-1,y,x]
                if current_val > max_val:
                    max_val = current_val
                    max_class = c
            results[y,x] = max_class
    #results = probabilities.max(0)
    return results;
    
def export(array, trans, proj):    
    driver = gdal.GetDriverByName('GTiff')    
    out = driver.Create('/tmp/plr_results.tif', array.shape[1], array.shape[0], 1, GDT_Byte)
    out.SetGeoTransform(trans)
    out.SetProjection(proj)
    gdal_array.BandWriteArray(out.GetRasterBand(1), array)


def main():
    # fetch parameters
    group = options['group']
    subgroup = options['subgroup']
    sigfile = options['sigfile']
    output = options['output']
    iterations = options['iterations']
    ntype = options['ntype']
    
    if iterations == "":
        iterations = 1
    iterations = int(iterations)
        
    if ntype == "":
        ntype = 4
    ntype = int(ntype)
    
    # fetch Metadata
    getMetadata()
    
    # split sigfiles
    sigpath = GISDBASE + "/" + LOCATION_NAME + "/" + MAPSET + "/group/" + group + "/subgroup/" + subgroup + "/sig/"
    counter = splitSignatures(sigpath, sigfile)
    print "found " + str(counter) + " signatures"
    
    l = []
    for i in range(1, counter+1):
        # extract probabilities
        print "extracting probabilities for class " + str(i)
        grass.run_command("i.maxlik", group=group, subgroup=subgroup, sigfile="plr_" + str(i) + ".sig", clas="plr_class" + str(i), reject="plr_rej_" + str(i))
        # export from GRASS
        print "exporting probabilities for class " + str(i) + " to /tmp"
        grass.run_command("r.out.gdal", inp="plr_rej_" + str(i), out="/tmp/plr_rej_" + str(i) + ".tif")
        
        # import via gdal
        print "reading file"
        tif = gdal.Open("/tmp/plr_rej_" + str(i) + ".tif")    
        l.append(tif.ReadAsArray())
        
        if i == 1:
            width = l[0].shape[1]
            height = l[0].shape[0]
            trans = tif.GetGeoTransform()
            proj = tif.GetProjection()
    
    # create n-dimensional array
    print "creating 3D-array"
    probabilities = numpy.array(l)
    
    print "Image size: " + str(width) + "x" + str(height)
    print "using " + str(ntype) + "-neighbourhood"
    # invoke relaxation process
    results = probabilities.copy()
    for i in range(int(iterations)):
        print str(int(iterations)-i) + " iteration(s) to go..."
        results = plr_filter(results, width, height, counter, ntype)
    labels = createMap(results, width, height, counter)
    
    # exporting results
    print "exporting results to /tmp"
    export(labels, trans, proj)
    
    # import via gdal into GRASS
    print "reading results"
    grass.run_command("r.in.gdal", inp="/tmp/plr_results.tif", out=output)
        
    # clean up
    print "removing temporary files"
    cleanUp(sigpath)
    
if __name__ == "__main__":
    options, flags = grass.parser()
    main()
