using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Net.Http.Headers; using System.Security.Claims; using System.Text.Json; using System.Threading.Tasks; using Microsoft.AspNetCore.Components.Authorization; namespace Client { public sealed class CustomAuthStateProvider : AuthenticationStateProvider { private static string _token; private readonly HttpClient _http; public CustomAuthStateProvider(HttpClient http) { _http = http; } public override Task GetAuthenticationStateAsync() { return Task.FromResult(AuthenticationStateFromCurrentToken()); } public void MarkUserAsAuthenticated(string token) { _token = token; NotifyAuthenticationStateChanged(Task.FromResult(AuthenticationStateFromCurrentToken())); } private AuthenticationState AuthenticationStateFromCurrentToken() { if (string.IsNullOrWhiteSpace(_token)) { return new AuthenticationState(new ClaimsPrincipal(new ClaimsIdentity())); } _http.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", _token); var principal = new ClaimsPrincipal(new ClaimsIdentity(ParseClaimsFromJwt(_token), "jwt", "unique_name", null)); return new AuthenticationState(principal); } private IEnumerable ParseClaimsFromJwt(string jwt) { var claims = new List(); var payload = jwt.Split('.')[1]; var jsonBytes = ParseBase64WithoutPadding(payload); var keyValuePairs = JsonSerializer.Deserialize>(jsonBytes); keyValuePairs.TryGetValue(ClaimTypes.Role, out object roles); if (roles != null) { if (roles.ToString().Trim().StartsWith("[")) { var parsedRoles = JsonSerializer.Deserialize(roles.ToString()); foreach (var parsedRole in parsedRoles) { claims.Add(new Claim(ClaimTypes.Role, parsedRole)); } } else { claims.Add(new Claim(ClaimTypes.Role, roles.ToString())); } keyValuePairs.Remove(ClaimTypes.Role); } claims.AddRange(keyValuePairs.Select(kvp => new Claim(kvp.Key, kvp.Value.ToString()))); return claims; } private byte[] ParseBase64WithoutPadding(string base64) { switch (base64.Length % 4) { case 2: base64 += "=="; break; case 3: base64 += "="; break; } return Convert.FromBase64String(base64); } } }