#!/usr/bin/python3
#
#       stlChecker.py
#
#       stlファイをチェックし、不具合があれば修正する
#
#   20/12/22    stlのxyzサイズ、表面積の出力を追加
#               面のstlでは、エラーが発生していたので、修正
#   22/08/17    face向きのcheck方法を大幅修正。
#               faceの向き定義方法で隣接faceの向きをチェックする。
#   24/10/01    checkCorrectFaceDirection:"\"文字で警告発生のため、修正。
#

#from operator import truediv
import sys
import numpy as np

nMaxMin = 20            #座標検索時の分割数
tolerance = 1.0e-6      #samePoints確認用toleranceの係数
                        #  実際のtolは、三角形の最小の辺の長さに
                        #  toleranceを掛け算した値で比較する。
correctFlag = "no"      #修正有無flag
outputControl = "allways"  #常時書き出し


# ---- numpy 専用 -----------------------
def normal(vec):
    l = np.linalg.norm(vec, ord=2)
    if l != 0.0:
        nVec = vec / l
    else:
        nVec = np.array([0,0,0])
    return nVec

def vector(loc0, loc1):
    return loc1 - loc0

def length(a,b):
    """ 2点間の距離算出"""
    c = a - b
    l = np.linalg.norm(c, ord=2)
    return l

def tri3length(tri):
    """ 三角形の3辺の長さを取得"""
    l1 = length(tri[0], tri[1])
    l2 = length(tri[1], tri[2])
    l3 = length(tri[2], tri[0])
    return np.array([l1, l2, l3])

def triArea(tri):
    """ 三角形の面積を取得"""
    p1 = tri[0]
    p2 = tri[1]
    p3 = tri[2]
    a = length(p1, p2)
    b = length(p2, p3)
    c = length(p3, p1)
    s = (a + b + c) / 2
    area = np.sqrt(s*(s - a)*(s - b)*(s - c))
    return area

def tet6length(tet):
    """ 四面体の各辺の長さを取得"""
    ls = tri3length(tet[:3])
    l4 = length(tet[0], tet[3])
    l5 = length(tet[1], tet[3])
    l6 = length(tet[2], tet[3])
    return np.append(ls, [l4, l5, l6])

def faceNormal(face):
    [a,b,c] = face
    vec1 = vector(a, b)
    vec2 = vector(b, c)
    vec = np.cross(vec1, vec2)
    faceNormal = normal(vec)
    return faceNormal

# ----------------------------------------------


def readStlFile(fileName):
    f = open(fileName)
    lines = f.readlines()
    f.close()
    #file内容を取得
    locs = []
    tris = []
    nVecs = []
    flag = 0
    for line in lines:
        words = line.split()
        if words[0] == "facet":
            xyz = list(map(float, words[2:5]))
            nVecs.append(xyz)
        elif words[0] == "outer":
            flag = 1
            tri = []
        elif flag == 1:
            if words[0] == "vertex":
                loc = np.array(list(map(float, words[1:4])))
                tri.append(loc)
                locs.append(loc)
            else:
                flag = 0
                tris.append(tri)
    xmin = min(locs, key=lambda x: x[0])[0]
    ymin = min(locs, key=lambda y: y[1])[1]
    zmin = min(locs, key=lambda z: z[2])[2]
    xmax = max(locs, key=lambda x: x[0])[0]
    ymax = max(locs, key=lambda y: y[1])[1]
    zmax = max(locs, key=lambda z: z[2])[2]
    xyzMinMaxTrue = [[xmin, ymin, zmin], [xmax, ymax, zmax]]
    xmin = xmin - (xmax-xmin)/nMaxMin
    ymin = ymin - (ymax-ymin)/nMaxMin
    zmin = zmin - (zmax-zmin)/nMaxMin
    xmax = xmax + (xmax-xmin)/nMaxMin
    ymax = ymax + (ymax-ymin)/nMaxMin
    zmax = zmax + (zmax-zmin)/nMaxMin
    xyzMinMax = [[xmin, ymin, zmin], [xmax, ymax, zmax]]
    return tris, nVecs, xyzMinMax, xyzMinMaxTrue


def getArrayID(p, minMax):
    xi, yi, zi = 0, 0, 0
    if minMax[1][0] != minMax[0][0]:
        xi = int((p[0] - minMax[0][0]) / (minMax[1][0] - minMax[0][0]) * nMaxMin)
    if minMax[1][1] != minMax[0][1]:
        yi = int((p[1] - minMax[0][1]) / (minMax[1][1] - minMax[0][1]) * nMaxMin)
    if minMax[1][2] != minMax[0][2]:
        zi = int((p[2] - minMax[0][2]) / (minMax[1][2] - minMax[0][2]) * nMaxMin)
    return xi, yi, zi


def isInSamePoint(points, p, tol, pointsArray, boxPoints):
    flag = False
    for pNo in boxPoints:
        point = points[pNo]
        dx = abs(point[0] - p[0])
        if dx < tol:
            dy = abs(point[1] - p[1])
            if dy < tol:
                dz = abs(point[2] - p[2])
                if dz < tol:
                    flag = True
                    break
    if flag == False:
        pNo = len(points)
    return flag, pNo


def getTriPoints(tris, xyzMinMax):
    global tolerance
    #配列を準備(各方向20分割)
    pointsArray = [[[[] for iii in range(nMaxMin)] for ii in range(nMaxMin)] for i in range(nMaxMin)]
    #処理開始
    points = []
    triNos = []
    dlMax = 0.0
    for tri in tris:
        triNo = []
        tol = min(tri3length(tri)) * tolerance
        p0 = tri[0]
        p1 = tri[1]
        p2 = tri[2]
        xi, yi, zi = getArrayID(p0, xyzMinMax)
        boxPoints = pointsArray[xi][yi][zi]
        flag, p = isInSamePoint(points, p0, tol, pointsArray, boxPoints)
        if flag == False:
            points.append(p0)
            pointsArray[xi][yi][zi] += [p]
        triNo.append(p)

        xi, yi, zi = getArrayID(p1, xyzMinMax)
        boxPoints = pointsArray[xi][yi][zi]
        flag, p = isInSamePoint(points, p1, tol, pointsArray, boxPoints)
        if flag == False:
            points.append(p1)
            pointsArray[xi][yi][zi] += [p]
        triNo.append(p)

        xi, yi, zi = getArrayID(p2, xyzMinMax)
        boxPoints = pointsArray[xi][yi][zi]
        flag, p = isInSamePoint(points, p2, tol, pointsArray, boxPoints)
        if flag == False:
            points.append(p2)
            pointsArray[xi][yi][zi] += [p]
        triNo.append(p)

        triNos.append(triNo)
    return points, triNos


def getAreaAllTriangles(tris):
    """ 表面積を求める"""
    sumArea = 0.0
    for tri in tris:
        sumArea += triArea(tri)
    return sumArea


def createStlFiles(outFileName, points, surfaces):
    lines = []
    line = "solid box\n"
    lines.append(line)
    for surface in surfaces:
        locs = list(map(lambda x: points[x], surface))
        nVec = faceNormal(locs)
        line = " facet normal " + " ".join(list(map(str, nVec))) + "\n"
        lines.append(line)
        line = "   outer loop\n"
        lines.append(line)
        for loc in locs:
            line = "     vertex " + " ".join(list(map(str, loc))) + "\n"
            lines.append(line)
        line = "   endloop\n"
        lines.append(line)
        line = " endfacet\n"
        lines.append(line)
    line = "endsolid box\n"
    lines.append(line)
    f = open(outFileName, "w")
    for line in lines:
        f.write(line)
    f.close()


def deleteSameFaces(triNos):
    faceDict = {}
    for triNo in triNos:
        key = tuple(sorted(triNo))
        if key in faceDict.keys():
            faceDict[key].append(triNo)
        else:
            faceDict[key] = [triNo]
    newTriNos = []
    for key in faceDict.keys():
        triNo = faceDict[key][0]
        newTriNos.append(triNo)
    return newTriNos

def deleteNoNeibourFaces(triNos):
    edgeDict = {}
    for i in range(len(triNos)):
        tri = triNos[i]
        edge0 = tuple(sorted([tri[0], tri[1]]))
        edge1 = tuple(sorted([tri[1], tri[2]]))
        edge2 = tuple(sorted([tri[2], tri[0]]))
        edges = [edge0, edge1, edge2]
        for edge in edges:
            if edge in edgeDict.keys():
                edgeDict[edge].append(i)
            else:
                edgeDict[edge] = [i]
    newTriNos = []
    for i in range(len(triNos)):
        tri = triNos[i]
        edge0 = tuple(sorted([tri[0], tri[1]]))
        edge1 = tuple(sorted([tri[1], tri[2]]))
        edge2 = tuple(sorted([tri[2], tri[0]]))
        edges = [edge0, edge1, edge2]
        n = 0
        for edge in edges:
            if len(edgeDict[edge]) >= 2:
                n += 1
        if n == 3:
            newTriNos.append(tri)
    return newTriNos

def isClosedStl(triNos):
    edgeDict = {}
    for i in range(len(triNos)):
        tri = triNos[i]
        edge0 = tuple(sorted([tri[0], tri[1]]))
        edge1 = tuple(sorted([tri[1], tri[2]]))
        edge2 = tuple(sorted([tri[2], tri[0]]))
        edges = [edge0, edge1, edge2]
        for edge in edges:
            if edge in edgeDict.keys():
                edgeDict[edge].append(i)
            else:
                edgeDict[edge] = [i]
    flag = True
    for i in range(len(triNos)):
        tri = triNos[i]
        edge0 = tuple(sorted([tri[0], tri[1]]))
        edge1 = tuple(sorted([tri[1], tri[2]]))
        edge2 = tuple(sorted([tri[2], tri[0]]))
        edges = [edge0, edge1, edge2]
        n = 0
        for edge in edges:
            if len(edgeDict[edge]) == 2:
                n += 1
        if n < 3:
            flag = False
            break
    return flag


def checkCorrectFaceDirection(points, triNos):
    """ faceの向きをチェックして、修正する
              0
           // | \\        三角形0-1-2と三角形0-2-3の向きは
        1 //  |  \\3      共有edgeの向きが
          \\  |  //           三角形0-1-2     三角形0-2-3
           \\ | //            edge 2-0        edge 0-2
              2          の様に反対向きなら三角形の向きは同じ
    """
    
    def getFacesHaveEdge(facei, edge, faces, triNos):
        """ faces内からedgeを持つfaceを取得して返す"""
        neibFaces = []
        for ii in faces:
            Nos = triNos[ii]
            if edge[0] in Nos and edge[1] in Nos:
                if ii != facei:
                    neibFaces.append(ii)
        return neibFaces
    
    def createNeighbourFacesDict(points, triNos):
        """ 自身のfaceに隣接するfaceが取得できる辞書を作成"""
        #  nodeDict[pointNo]=[faces]の辞書作成
        nodeDict = {}
        facei = 0
        for tri in triNos:
            for nNo in tri:
                if nNo in nodeDict.keys():
                    nodeDict[nNo].append(facei)
                else:
                    nodeDict[nNo] = [facei]
            facei += 1
        #  neibFaceDict[facei] = [faces]の辞書作成
        neibFacesDict = {}
        facei = 0
        for tri in triNos:
            faces = []
            for nNo in tri:
                faceList = nodeDict[nNo]
                faces += faceList
            #edge0のfaceを確認
            edge = [tri[0], tri[1]]
            neibFaces = getFacesHaveEdge(facei, edge, faces, triNos)
            #edge1のfaceを確認
            edge = [tri[1], tri[2]]
            neibFaces += getFacesHaveEdge(facei, edge, faces, triNos)
            #edge2のfaceを取得
            edge = [tri[2], tri[0]]
            neibFaces += getFacesHaveEdge(facei, edge, faces, triNos)
            #faceiの隣接faceを取得
            neibFaces = list(set(neibFaces))
            neibFacesDict[facei] = neibFaces
            facei += 1
        return neibFacesDict

    def getEdgeDirection(pointNos, edge):
        """ edgeの方向を取得して返す"""
        idx0 = pointNos.index(edge[0])
        idx1 = pointNos.index(edge[1])
        if abs(idx0-idx1) == 1:
            if idx0 < idx1:
                edgeDir = [edge[0], edge[1]]
            else:
                edgeDir = [edge[1], edge[0]]
        else:
            if idx0 < idx1:
                edgeDir = [edge[1], edge[0]]
            else:
                edgeDir = [edge[0], edge[1]]
        return edgeDir

    def isSameDirection(baseFace, neibFace, points, triNos):
        """ baseFaceとneibfaceの方向は、同じ方向か"""
        basePointNos = triNos[baseFace]
        neibPointNos = triNos[neibFace]
        edge = list(set(basePointNos) & set(neibPointNos))
        edgeDirBase = getEdgeDirection(basePointNos, edge)
        edgeDirNeib = getEdgeDirection(neibPointNos, edge)
        if edgeDirBase == edgeDirNeib:
            ans = False
        else:
            ans = True
        return ans

    def correctFaceDirection(neibFace, triNos):
        """ neibFaceの方向を反転させる"""
        pointNos = triNos[neibFace]
        newPointNos = [pointNos[1], pointNos[0], pointNos[2]]
        triNos[neibFace] = newPointNos
        return triNos

    def checkNeighbourFacesDir(baseFace, neibFaces, checkedFaces,
                               points, triNos):
        """ basefaceに隣接するneibFacesのfaceDirをcheck修正する"""
        checkedFlag = 0
        correctFlag = 0
        count = 0
        for neibFace in neibFaces:
            if checkedFaces[neibFace] == 0:
                checkedFlag = 1
                checkedFaces[neibFace] = 1
                if isSameDirection(baseFace, neibFace, points, triNos) == False:
                    triNos = correctFaceDirection(neibFace, triNos)
                    count += 1
        return checkedFlag, count, checkedFaces, triNos

    #自身のfaceに隣接するfaceを取得する辞書を作成
    neibFacesDict =createNeighbourFacesDict(points, triNos)
    #check有無flagを定義
    checkedFaces = [ 0 for i in range(len(triNos))]
    checkedFaces[0] = 1
    correctedFaces = 0
    loop = True
    while loop:
        checkedFlag = 0
        for facei in range(len(triNos)):
            #未checkのface？
            if checkedFaces[facei] == 0:
                #隣接face中にcheck済face有り？
                baseFace = -1
                faces = neibFacesDict[facei]
                for faceNo in faces:
                    if checkedFaces[faceNo] == 1:
                        baseFace = faceNo
                        break
                if baseFace >= 0:
                    #baseFaceに隣接するfaceをcheck
                    neibFaces = neibFacesDict[baseFace]
                    (checkedFlag,
                     count,
                     checkedFaces,
                     triNos) = checkNeighbourFacesDir(
                                        baseFace,
                                        neibFaces, checkedFaces,
                                        points, triNos)
                    #修正したfaceをcount
                    correctedFaces += count
        #faceの方向をcheckしたか？
        if checkedFlag == 0:
            #未checkのfaceが存在するか？
            flag = 0
            for facei in range(len(triNos)):
                if checkedFaces[facei] == 0:
                    #未checkFaceが有る場合
                    checkedFaces[facei] = 1
                    flag = 1
                    break
            if flag == 0:
                #未checkFaceがない場合、loop終了
                loop = False
    return correctedFaces, points, triNos   


def stlChecker(fileName, outFileName):
    global correctFlag, writeIfCorrect

    #file読み込み
    print()
    print("reading '" + fileName + "' file...")
    tris, nVecs, xyzMinMax, xyzMinMaxTrue = readStlFile(fileName)
    print("  number of triangles")
    print("    " + str(len(tris)) + " triangles.")

    #points, triNosを取得
    print("  number of points")
    points, triNos = getTriPoints(tris, xyzMinMax)
    print("    " + str(len(points)) + " points.")
    #sizeを取得
    print("  stl size")
    mins = xyzMinMaxTrue[0]
    maxs = xyzMinMaxTrue[1]
    x = maxs[0] - mins[0]
    y = maxs[1] - mins[1]
    z = maxs[2] - mins[2]
    xyz = list(map(str, [x,y,z]))
    print("    xyz: " + " ".join(xyz))
    #表面積を取得
    print("  area of all triangles")
    area = getAreaAllTriangles(tris)
    print("    area: " + str(area))

    #同じfaceを2重に定義していないか
    print("deleting same faces...")
    n = len(triNos)
    triNos = deleteSameFaces(triNos)
    print("  deleted " + str(n-len(triNos)) + " faces.")
    if n != len(triNos):
        correctFlag = "yes"

    #faceの方向をチェック
    print("checking face direction...")
    correctedFaces, points, triNos = checkCorrectFaceDirection(points, triNos)
    print("  changed " + str(correctedFaces) + " faces directions.")
    if correctedFaces != 0:
        correctFlag = "yes"

    #閉じたstlか
    print("checking stl, closed or not...")
    if isClosedStl(triNos) == False:
        print("  stlFile is not closed!!")
        if correctFlag == "yes" or outputControl == "allways":
            createStlFiles(outFileName, points, triNos)
            print("--> checked stl file was saved at '" + outFileName + "'.")
        if correctFlag == "no":
            print("--> '" + fileName + "' is good stlFile!!")
        return
    else:
        if len(triNos) == 0:
            print("  no triangles, after deleted faces")
        else:
            print("  stl file is closed.")

    #書き出し
    if correctFlag == "yes" or outputControl == "allways":
        print("creating stl file...")
        createStlFiles(outFileName, points, triNos)
        print("--> checked stl file was saved at '" + outFileName + "'.")
    if correctFlag == "no":
        print("--> '" + fileName + "' is good stlFile!!")

def printHelp():
    cont = """
----------- stlChecker.py --------------------------------------------
stlファイルをチェックする。
チェックに伴い、triangleの数、節点の数、stlのxyzサイズ、stlの表面積
を出力する。
stlのチェック内容
・同じfaceが裏表で重複していないか
  重複している場合、片側を削除
・faceの向きが統一されているか
  faceの向きを統一する。（最初のfaceの向きに合わせる）
・閉じたstlかどうか
stlFileを修正した場合は、outputFileに修正結果を保存する。
修正する必要がない場合は、出力しない。

＜使用方法＞
stlChecker.py -i <stlFile> [-o <outputFile>] [-t <tolerance>] [-oc <yes/no>]
    -i          入力ファイル名(stlFile)
    -input        ↑
    -o          出力ファイル名。
                省略時は、「入力ファイル名_out.stl」に設定される
    -output       ↑
    -oc         outputの制御(<allways/onlyCorrect>)
                省略時は「allways」（常に出力する。） 
    -outputControl   ↑
    -t          samePointsのtolerance（デフォルト:1e-6）
                判定値は、(三角形の3辺の合計*tolerance)
    -tolerance    ↑
    -h          helpを表示
    -help         ↑
"""
    print(cont)

def getOption(args, fileName, outFileName):
    global tolerance, outputControl
    i = 1
    loop = True
    while loop == True:
        if args[i] == "-i" or args[i] == "-input":
            i += 1
            fileName = args[i]
        elif args[i] == "-o" or args[i] == "-output":
            i += 1
            outFileName = args[i]
        elif args[i] == "-oc" or args[i] == "-outputControl":
            i += 1
            outputControl = args[i]
        elif args[i] == "-t" or args[i] == "-tolerance":
            i += 1
            tolerance = float(args[i])
        elif args[i] == "-h" or args[i] == "-help":
            printHelp()
            fileName = ""
            loop = False
        else:
            printHelp()
            fileName = ""
            loop = False
        i += 1
        if i >= len(args):
            loop = False
    if fileName != "" and fileName.split(".")[-1] != "stl":
        fileName += ".stl"
    if outFileName == "":
        outFileName = ".".join(fileName.split(".")[:-1]) + "_out.stl"
    else:
        if outFileName.split(".")[-1] != "stl":
            outFileName += ".stl"
    return fileName, outFileName


if __name__ == "__main__":
    #fileName = "testba.stl"        #細かい
    #outFileName = "box_test.stl"
    fileName = ""
    outFileName = ""

    args = sys.argv
    fileName, outFileName = getOption(args, fileName, outFileName)
    if fileName != "":
        stlChecker(fileName, outFileName)
