refer to code
.
def register_dimension_hooks(model, rank):
if rank != 0:
return
print('\n------------------- Model Structure -------------------')
print("Model type:", type(model))
# Get the actual model through the wrapper layers
if hasattr(model, 'model'):
model = model.model
if hasattr(model, 'model'):
model = model.model
print("Base model type:", type(model))
def make_hook(name, rank):
def hook(module, input, output):
print(f"\n--------------- Hook: {name} ---------------")
if hasattr(module, 'weight'):
weight = module.weight
print(f"GPU {rank} - {name}:")
print(f"Input shape: {input[0].shape}")
if hasattr(weight, '_local_tensor'):
local_weight = weight._local_tensor
print(f"Local weight shape: {local_weight.shape}")
print(f"Global weight shape: {weight.shape}")
if hasattr(weight, 'device_mesh'):
print(f"Device mesh: {weight.device_mesh}")
print(f"Placement: {weight.placements}")
print(f"Output shape: {output.shape}")
print("-" * 50)
return hook
# Register hooks for embedding layer
if hasattr(model, 'embed_tokens'):
print("Found embed_tokens")
model.embed_tokens.register_forward_hook(make_hook('embed_tokens', rank))
# Register hooks for all transformer layers
if hasattr(model, 'layers'):
for i, layer in enumerate(model.layers):
# Attention blocks
layer.self_attn.q_proj.register_forward_hook(
make_hook(f'layer_{i}_q_proj', rank))
layer.self_attn.k_proj.register_forward_hook(
make_hook(f'layer_{i}_k_proj', rank))
layer.self_attn.v_proj.register_forward_hook(
make_hook(f'layer_{i}_v_proj', rank))
layer.self_attn.o_proj.register_forward_hook(
make_hook(f'layer_{i}_o_proj', rank))
# MLP blocks
layer.mlp.gate_proj.register_forward_hook(
make_hook(f'layer_{i}_mlp_gate_proj', rank))
layer.mlp.up_proj.register_forward_hook(
make_hook(f'layer_{i}_mlp_up_proj', rank))
layer.mlp.down_proj.register_forward_hook(
make_hook(f'layer_{i}_mlp_down_proj', rank))
# Layer norms
layer.input_layernorm.register_forward_hook(
make_hook(f'layer_{i}_input_layernorm', rank))
layer.post_attention_layernorm.register_forward_hook(
make_hook(f'layer_{i}_post_attention_layernorm', rank))
# Register hook for final layer norm
if hasattr(model, 'norm'):
model.norm.register_forward_hook(make_hook('final_layernorm', rank))
# Register hook for LM head
if hasattr(model, 'lm_head'):
print("Found lm_head")
model.lm_head.register_forward_hook(make_hook('lm_head', rank))
# Print model structure to debug
print("\nModel attributes:", dir(model))
..
Thank you.
No comments:
Post a Comment