Prim

📖
El algoritmo de Prim sirve para crear el arbol de expansión mínima (ARM) de un grafo. Esto quiere decir que intentará crear un subgrafo con las aristas de menor peso donde cada nodo se visita al menos una vez.
 
Si el grafo solo podrá realizar ARM de la componente conexa del nodo inicial
ℹ️
Pasos
Inicialización
  • Se establece un conjunto de visitados , un array de distancias y el array solución donde se incluirán la aristas que forman parte del ARM resultante (adicionalmente se podrá ir tomando nota de los pesos acumulados, pero eso se puede calcular luego)
  • Se selecciona un vértice inicial (el que se quiera)
Bucle iterativo
Se iterarán los siguientes puntos hasta que la función de selección de nodos devuelva un nodo inválido
  1. Visitar → Se ejecuta la función de selección de nodos y se marca el nodo como visitado
    1. Si el nodo no es válido → Salir del bucle iterativo
  1. Explorar vecinos → Se actualizarán las distancias de los vecinos del nodo seleccionado que no hayan sido visitados
 
Función de selección
Recorreremos el array de distancias, y anotaremos cual es el nodo, no visitado, cuya distancia sea menor en el array. Si todos los nodos ya han sido visitados, o los que quedan tiene distancia infinita (no pertenecen a la misma componente conexa) entonces esta función devolverá un nodo invñalido (None)
💡
Este algoritmo, crece a partir de un nodo actual. Se expandirá al nodo más cercano que no pertenezca todavía a la misma componente conexa. Es como si imaginamos un ejército que cada día conquista el pueblo más cercano a cualquiera de sus fronteras
##############################################
# Prim's algorithm for minimum spanning tree #
##############################################

def getBest(candidatesW, visited):
    """
    This function will return the vertex with the best weight and the weight itself.
    If there is no vertex to return, it will return None and 0 as the weight.
    """
    vertex = None
    weight = float("inf")
    n = len(visited)
    
    for node in range(n):
        # Ensures that the vertex is not visited (factible) and that the weight is less than the current best (greedy)
        if not visited[node] and candidatesW[node] < weight:
            vertex = node
            weight = candidatesW[node]
            
    if vertex is None:
        return None, 0
    else:
        return vertex, weight

def prim(g, initial = 1):
    """
    Given a graph, represented as a list of tuples, where each tuple is (weight, src, dest).
    This function will return the total weight of the minimum spanning tree and the ordered edges that are part of the minimum spanning tree.
    Non connected components and the initial node appear as None in the edges list.
    """
    cumWeight = 0
    n = len(g)
    visited = [False] * n
    candidatesWeight = [float('inf')] * n
    edges = [None] * n
    
    currentVertex = initial
    for _ in range(n - 1):
        # Mark the current vertex as visited
        visited[currentVertex] = True
        # Update the candidates best weight for each vertex besides the conex component
        for weight, src, dest in g[currentVertex]:
            if not visited[dest]:
                minWeight = min(weight, candidatesWeight[dest])
                candidatesWeight[dest] = minWeight
                edges[dest] = (minWeight, currentVertex, dest)
                
        # Get the best candidate (factible, no loops) and update the solution
        currentVertex, newWeight = getBest(candidatesWeight, visited)
        
        # Exit condition
        if currentVertex is None:  # No more candidates
            break
        else:  # Update the solution
            cumWeight += newWeight
        
    return cumWeight, edges
    



##################################################################
def formatGraph(g):
    newG = []
    for i in g:
        newG.append([])
        for src, dst, w in i:
            newG[-1].append((w, src, dst))
    
    return newG

### MAIN ###
if __name__ == "__main__":
    g = [
        [],
        [(1, 3, 1), (1, 4, 2), (1, 7, 6)],
        [(2, 5, 2), (2, 6, 4), (2, 7, 7)],
        [(3, 1, 1), (3, 4, 3), (3, 7, 5)],
        [(4, 1, 2), (4, 3, 3), (4, 5, 1), (4, 6, 9)],
        [(5, 2, 2), (5, 4, 1), (5, 7, 8)],
        [(6, 2, 4), (6, 4, 9)],
        [(7, 1, 6), (7, 2, 7), (7, 3, 5), (7, 5, 8)]
    ]   

    sol, edges = prim(formatGraph(g), 1)
    print(sol, edges)