Skip to content

[Bug] Server fails to launch with multi-LoRA + NGRAM Speculative Decoding on 0.5.4.post3 #12726

@vedantjh2

Description

@vedantjh2

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

issue:

server correctly launches on version 0.5.4.post1 but fails to launch on latest0.5.4.post3.

on 0.5.4.post1, Server launches successfully with both LoRAs (fr-lora, de-lora) and NGRAM speculative decoding.
Initialization completes, LoRA weights load correctly using the Triton backend, CUDA graphs capture without issue, and HTTP endpoints respond normally.

on 0.5.4.post3, Server fails to start and crashes during CUDA graph capture with the CSGMV LoRA backend.
The issue occurs immediately after LoRA weights load, during the prepare_lora_batch() call in the CUDA graph capture stage.

trace:

[2025-11-06 01:18:29] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2791, in run_scheduler_process
    scheduler = Scheduler(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 319, in __init__
    self.tp_worker = TpModelWorker(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 237, in __init__
    self._model_runner = ModelRunner(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 322, in __init__
    self.initialize(min_per_gpu_memory)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 479, in initialize
    self.init_device_graphs()
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 1995, in init_device_graphs
    self.graph_runner = graph_runners[self.device](self)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 381, in __init__
    self.capture()
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 500, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 646, in capture_one_batch_size
    self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/lora/lora_manager.py", line 287, in prepare_lora_batch
    self.lora_backend.prepare_lora_batch(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/lora/backend/chunked_backend.py", line 173, in prepare_lora_batch
    permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/lora/backend/chunked_backend.py", line 266, in _get_permutation
    torch.tensor(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/torch/utils/_device.py", line 103, in __torch_function__
    return func(*args, **kwargs)
TypeError: 'NoneType' object cannot be interpreted as an integer

[2025-11-06 01:18:29] Received sigquit from a child process. It usually means the child failed.
Killed

Reproduction

python -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3-4B/1cfa9a7208912126459214e8b04321603b3df60c/ --enable-lora --lora-paths fr-lora=/shared/public/sharing/candidate-evaluation-llm/i18n/experiments/peft/fe27f24518f194fb2b21/2025-11-04T06-45-20/QWEN3_4B_lora_fr de-lora=/shared/public/sharing/candidate-evaluation-llm/i18n/experiments/peft/fe27f24518f194fb2b21/2025-11-04T07-25-08/QWEN3_4B_lora_de --max-loras-per-batch 4 --speculative-algorithm NGRAM --speculative-num-draft-tokens 5 --speculative-ngram-min-match-window-size 2 --speculative-ngram-max-match-window-size 15

post1:

(tsgl) jobuser [ ~ ]$ python -m sglang.launch_server   --model-path /shared/public/elr-models/Qwen/Qwen3-4B/   --enable-lora   --lora-paths       fr-lora=/shared/public/sharing/fr_path       de-lora=/shared/public/sharing/de_path   --max-loras-per-batch 4   --speculative-algorithm NGRAM   --speculative-num-draft-tokens 5   --speculative-ngram-min-match-window-size 2   --speculative-ngram-max-match-window-size 15
[2025-11-06 01:22:00] WARNING server_args.py:1104: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-06 01:22:00] WARNING server_args.py:1518: Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests.
[2025-11-06 01:22:00] WARNING server_args.py:1529: The overlap scheduler and mixed chunked prefill are disabled because of using ngram speculative decoding.
[2025-11-06 01:22:00] INFO trace.py:48: opentelemetry package is not installed, tracing disabled
[2025-11-06 01:22:00] server_args=ServerArgs(model_path='/shared/public/elr-models/Qwen/Qwen3-4B/1cfa9a7208912126459214e8b04321603b3df60c/', tokenizer_path='/shared/public/elr-models/Qwen/Qwen3-4B/1cfa9a7208912126459214e8b04321603b3df60c/', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.835, max_running_requests=48, max_queued_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=207795411, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, oltp_traces_endpoint='localhost:4317', api_key=None, served_model_name='/shared/public/elr-models/Qwen/Qwen3-4B/1cfa9a7208912126459214e8b04321603b3df60c/', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=True, max_lora_rank=None, lora_target_modules=None, lora_paths=[LoRARef(lora_id='0fa260e34ecd45e1b9c499b1acd5fb5c', lora_name='fr-lora', lora_path='/shared/public/sharing/candidate-evaluation-llm/i18n/experiments/peft/fe27f24518f194fb2b21/2025-11-04T06-45-20/QWEN3_4B_lora_fr', pinned=False), LoRARef(lora_id='c3799fe89a5647fb99006933fc94ff7a', lora_name='de-lora', lora_path='/shared/public/sharing/candidate-evaluation-llm/i18n/experiments/peft/fe27f24518f194fb2b21/2025-11-04T07-25-08/QWEN3_4B_lora_de', pinned=False)], max_loaded_loras=None, max_loras_per_batch=4, lora_eviction_policy='lru', lora_backend='triton', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm='NGRAM', speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=10, speculative_num_draft_tokens=5, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_ngram_min_match_window_size=2, speculative_ngram_max_match_window_size=15, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_amx_weight_path=None, kt_amx_method='AMXINT4', kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=256, cuda_graph_bs=[1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 40, 44, 48, 52, 56, 60, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8)
[2025-11-06 01:22:00] Using default HuggingFace chat template with detected content format: string
[2025-11-06 01:22:06] INFO trace.py:48: opentelemetry package is not installed, tracing disabled
[2025-11-06 01:22:06] INFO trace.py:48: opentelemetry package is not installed, tracing disabled
[2025-11-06 01:22:07] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-06 01:22:07] Init torch distributed ends. mem usage=0.00 GB
[2025-11-06 01:22:07] MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected
[2025-11-06 01:22:08] Load weight begin. avail mem=78.58 GB
Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:00<00:00,  2.30it/s]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:00<00:00,  3.80it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:01<00:00,  2.90it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:01<00:00,  2.94it/s]

[2025-11-06 01:22:10] Load weight end. type=Qwen3ForCausalLM, dtype=torch.bfloat16, avail mem=70.86 GB, mem usage=7.72 GB.
[2025-11-06 01:22:10] Using triton as backend of LoRA kernels.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 125.17it/s]

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 138.44it/s]

[2025-11-06 01:22:10] Using KV cache dtype: torch.bfloat16
[2025-11-06 01:22:10] KV Cache is allocated. #tokens: 421472, K size: 28.94 GB, V size: 28.94 GB
[2025-11-06 01:22:10] Memory pool end. avail mem=12.89 GB
[2025-11-06 01:22:10] Capture cuda graph begin. This can take up to several minutes. avail mem=12.79 GB
[2025-11-06 01:22:10] Capture cuda graph bs [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 40, 44, 48]
Capturing batches (bs=1 avail_mem=12.19 GB): 100%|█████████████████████████████████████████████| 23/23 [00:02<00:00,  9.88it/s]
[2025-11-06 01:22:13] Capture cuda graph end. Time elapsed: 2.89 s. mem usage=0.61 GB. avail mem=12.18 GB.
[2025-11-06 01:22:14] max_total_num_tokens=421472, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=48, context_len=40960, available_gpu_mem=12.18 GB
[2025-11-06 01:22:14] INFO:     Started server process [191490]
[2025-11-06 01:22:14] INFO:     Waiting for application startup.
[2025-11-06 01:22:14] Using default chat sampling params from model generation config: {'repetition_penalty': 1.0, 'temperature': 0.6, 'top_k': 20, 'top_p': 0.95}
[2025-11-06 01:22:14] Using default chat sampling params from model generation config: {'repetition_penalty': 1.0, 'temperature': 0.6, 'top_k': 20, 'top_p': 0.95}
[2025-11-06 01:22:14] INFO:     Application startup complete.
[2025-11-06 01:22:14] INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
[2025-11-06 01:22:15] INFO:     127.0.0.1:56380 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-11-06 01:22:15] Prefill batch, #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-06 01:22:15] INFO:     127.0.0.1:56388 - "POST /generate HTTP/1.1" 200 OK
[2025-11-06 01:22:15] The server is fired up and ready to roll!

post3:

(tlatest) jobuser [ ~ ]$ python -m sglang.launch_server   --model-path /shared/public/elr-models/Qwen/Qwen3-4B/   --enable-lora   --lora-paths       fr-lora=/shared/public/sharing/fr_path       de-lora=/shared/public/sharing/de_path   --max-loras-per-batch 4   --speculative-algorithm NGRAM   --speculative-num-draft-tokens 5   --speculative-ngram-min-match-window-size 2   --speculative-ngram-max-match-window-size 15
[2025-11-06 01:18:19] WARNING server_args.py:1165: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-06 01:18:19] WARNING server_args.py:1594: Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests.
[2025-11-06 01:18:19] WARNING server_args.py:1605: The overlap scheduler and mixed chunked prefill are disabled because of using ngram speculative decoding.
[2025-11-06 01:18:19] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
[2025-11-06 01:18:19] server_args=ServerArgs(model_path='/shared/public/elr-models/Qwen/Qwen3-4B/1cfa9a7208912126459214e8b04321603b3df60c/', tokenizer_path='/shared/public/elr-models/Qwen/Qwen3-4B/1cfa9a7208912126459214e8b04321603b3df60c/', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.835, max_running_requests=48, max_queued_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=563319386, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', api_key=None, served_model_name='/shared/public/elr-models/Qwen/Qwen3-4B/1cfa9a7208912126459214e8b04321603b3df60c/', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=True, max_lora_rank=None, lora_target_modules=None, lora_paths=[LoRARef(lora_id='f19969c706c94295baadf7d6147d7cbc', lora_name='fr-lora', lora_path='/shared/public/sharing/candidate-evaluation-llm/i18n/experiments/peft/fe27f24518f194fb2b21/2025-11-04T06-45-20/QWEN3_4B_lora_fr', pinned=False), LoRARef(lora_id='0423df616c6e4685ba03b9ecb7cbdab9', lora_name='de-lora', lora_path='/shared/public/sharing/candidate-evaluation-llm/i18n/experiments/peft/fe27f24518f194fb2b21/2025-11-04T07-25-08/QWEN3_4B_lora_de', pinned=False)], max_loaded_loras=None, max_loras_per_batch=4, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm='NGRAM', speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=10, speculative_num_draft_tokens=5, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=2, speculative_ngram_max_match_window_size=15, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_amx_weight_path=None, kt_amx_method='AMXINT4', kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=256, cuda_graph_bs=[1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=-1, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, decrypted_config_file=None, decrypted_draft_config_file=None)
[2025-11-06 01:18:19] Using default HuggingFace chat template with detected content format: string
[2025-11-06 01:18:25] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
[2025-11-06 01:18:25] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
[2025-11-06 01:18:26] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-06 01:18:26] Init torch distributed ends. mem usage=0.00 GB
[2025-11-06 01:18:26] MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected
[2025-11-06 01:18:27] Load weight begin. avail mem=11.31 GB
Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:00<00:00,  2.32it/s]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:00<00:00,  3.83it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:01<00:00,  2.91it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:01<00:00,  2.96it/s]

[2025-11-06 01:18:29] Load weight end. type=Qwen3ForCausalLM, dtype=torch.bfloat16, avail mem=3.58 GB, mem usage=7.72 GB.
[2025-11-06 01:18:29] Using csgmv as backend of LoRA kernels.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 147.41it/s]

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 149.70it/s]

[2025-11-06 01:18:29] Using KV cache dtype: torch.bfloat16
[2025-11-06 01:18:29] KV Cache is allocated. #tokens: 12405, K size: 0.85 GB, V size: 0.85 GB
[2025-11-06 01:18:29] Memory pool end. avail mem=1.72 GB
[2025-11-06 01:18:29] Capture cuda graph begin. This can take up to several minutes. avail mem=1.63 GB
[2025-11-06 01:18:29] Capture cuda graph bs [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 40, 44, 48]
Capturing batches (bs=48 avail_mem=1.43 GB):   0%|                                                      | 0/23 [00:00<?, ?it/s]
[2025-11-06 01:18:29] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2791, in run_scheduler_process
    scheduler = Scheduler(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 319, in __init__
    self.tp_worker = TpModelWorker(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 237, in __init__
    self._model_runner = ModelRunner(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 322, in __init__
    self.initialize(min_per_gpu_memory)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 479, in initialize
    self.init_device_graphs()
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 1995, in init_device_graphs
    self.graph_runner = graph_runners[self.device](self)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 381, in __init__
    self.capture()
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 500, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 646, in capture_one_batch_size
    self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/lora/lora_manager.py", line 287, in prepare_lora_batch
    self.lora_backend.prepare_lora_batch(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/lora/backend/chunked_backend.py", line 173, in prepare_lora_batch
    permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/sglang/srt/lora/backend/chunked_backend.py", line 266, in _get_permutation
    torch.tensor(
  File "/home/jobuser/tlatest/lib/python3.10/site-packages/torch/utils/_device.py", line 103, in __torch_function__
    return func(*args, **kwargs)
TypeError: 'NoneType' object cannot be interpreted as an integer

[2025-11-06 01:18:29] Received sigquit from a child process. It usually means the child failed.
Killed

Environment

0.5.4.post1

(tsgl) jobuser [ ~ ]$ python -m sglang.check_env
Python: 3.10.14 (main, Jul 14 2024, 22:24:12) [GCC 11.2.0]
CUDA available: True
GPU 0: NVIDIA H100 80GB HBM3
GPU 0 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.6, V12.6.77
CUDA Driver Version: 550.163.01
PyTorch: 2.8.0+cu128
sglang: 0.5.4.post1
sgl_kernel: 0.3.16.post4
flashinfer_python: 0.4.1
triton: 3.4.0
transformers: 4.57.1
torchao: 0.9.0
numpy: 2.2.6
aiohttp: 3.13.2
fastapi: 0.121.0
hf_transfer: 0.1.9
huggingface_hub: 0.36.0
interegular: 0.3.3
modelscope: 1.31.0
orjson: 3.11.4
outlines: 0.1.11
packaging: 25.0
psutil: 7.1.3
pydantic: 2.12.4
python-multipart: 0.0.20
pyzmq: 27.1.0
uvicorn: 0.38.0
uvloop: 0.22.1
vllm: Module Not Found
xgrammar: 0.1.25
openai: 1.99.1
tiktoken: 0.12.0
anthropic: 0.72.0
litellm: Module Not Found
decord2: 2.0.0
NVIDIA Topology: 
        GPU0    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    PHB     PHB     NODE    SYS     SYS     0-63,128-191    0               N/A
NIC0    NODE     X      NODE    NODE    NODE    SYS     SYS
NIC1    PHB     NODE     X      PIX     NODE    SYS     SYS
NIC2    PHB     NODE    PIX      X      NODE    SYS     SYS
NIC3    NODE    NODE    NODE    NODE     X      SYS     SYS
NIC4    SYS     SYS     SYS     SYS     SYS      X      NODE
NIC5    SYS     SYS     SYS     SYS     SYS     NODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5


ulimit soft: 10000000

0.5.4.post3

python -m sglang.check_env
Python: 3.10.14 (main, Jul 14 2024, 22:24:12) [GCC 11.2.0]
CUDA available: True
GPU 0: NVIDIA H100 80GB HBM3
GPU 0 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.6, V12.6.77
CUDA Driver Version: 550.163.01
PyTorch: 2.8.0+cu128
sglang: 0.5.4.post3
sgl_kernel: 0.3.16.post4
flashinfer_python: 0.5.0
flashinfer_cubin: 0.5.0
flashinfer_jit_cache: Module Not Found
triton: 3.4.0
transformers: 4.57.1
torchao: 0.9.0
numpy: 2.2.6
aiohttp: 3.13.2
fastapi: 0.121.0
hf_transfer: 0.1.9
huggingface_hub: 0.36.0
interegular: 0.3.3
modelscope: 1.31.0
orjson: 3.11.4
outlines: 0.1.11
packaging: 25.0
psutil: 7.1.3
pydantic: 2.12.4
python-multipart: 0.0.20
pyzmq: 27.1.0
uvicorn: 0.38.0
uvloop: 0.22.1
vllm: Module Not Found
xgrammar: 0.1.25
openai: 2.6.1
tiktoken: 0.12.0
anthropic: 0.72.0
litellm: Module Not Found
decord2: 2.0.0
NVIDIA Topology: 
        GPU0    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    PHB     PHB     NODE    SYS     SYS     0-63,128-191    0               N/A
NIC0    NODE     X      NODE    NODE    NODE    SYS     SYS
NIC1    PHB     NODE     X      PIX     NODE    SYS     SYS
NIC2    PHB     NODE    PIX      X      NODE    SYS     SYS
NIC3    NODE    NODE    NODE    NODE     X      SYS     SYS
NIC4    SYS     SYS     SYS     SYS     SYS      X      NODE
NIC5    SYS     SYS     SYS     SYS     SYS     NODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5


ulimit soft: 10000000

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions