Support for Digest Auth. Little "hacky" due to Bug (?) in Mono (for Android)

This commit is contained in:
PhilippC 2013-05-11 22:12:13 +02:00
parent 8238f9f76f
commit 7a065a9fdc

View File

@ -119,13 +119,6 @@ namespace KeePassLib.Serialization
internal static void ConfigureWebClient(WebClient wc) internal static void ConfigureWebClient(WebClient wc)
{ {
// Not implemented and ignored in Mono < 2.10
try
{
wc.CachePolicy = new RequestCachePolicy(RequestCacheLevel.NoCacheNoStore);
}
catch(NotImplementedException) { }
catch(Exception) { Debug.Assert(false); }
try try
{ {
@ -189,7 +182,7 @@ namespace KeePassLib.Serialization
ValidateServerCertificate; ValidateServerCertificate;
} }
private static IOWebClient CreateWebClient(IOConnectionInfo ioc) private static IOWebClient CreateWebClient(IOConnectionInfo ioc, bool digestAuth)
{ {
PrepareWebAccess(); PrepareWebAccess();
@ -197,14 +190,42 @@ namespace KeePassLib.Serialization
ConfigureWebClient(wc); ConfigureWebClient(wc);
if ((ioc.UserName.Length > 0) || (ioc.Password.Length > 0)) if ((ioc.UserName.Length > 0) || (ioc.Password.Length > 0))
{
//set the credentials without a cache (in case the cache below fails:
wc.Credentials = new NetworkCredential(ioc.UserName, ioc.Password); wc.Credentials = new NetworkCredential(ioc.UserName, ioc.Password);
if (digestAuth)
{
//try to use the credential cache to access with Digest support:
try
{
var credentialCache = new CredentialCache();
credentialCache.Add(
new Uri(new Uri(ioc.Path).GetLeftPart(UriPartial.Authority)),
"Digest",
new NetworkCredential(ioc.UserName, ioc.Password)
);
wc.Credentials = credentialCache;
} catch (NotImplementedException e)
{
Android.Util.Log.Debug("DEBUG", e.ToString());
} catch (Exception e)
{
Android.Util.Log.Debug("DEBUG", e.ToString());
Debug.Assert(false);
}
}
}
else if(NativeLib.IsUnix()) // Mono requires credentials else if(NativeLib.IsUnix()) // Mono requires credentials
wc.Credentials = new NetworkCredential("anonymous", string.Empty); wc.Credentials = new NetworkCredential("anonymous", string.Empty);
return wc; return wc;
} }
private static WebRequest CreateWebRequest(IOConnectionInfo ioc) private static WebRequest CreateWebRequest(IOConnectionInfo ioc, bool digestAuth)
{ {
PrepareWebAccess(); PrepareWebAccess();
@ -212,7 +233,21 @@ namespace KeePassLib.Serialization
ConfigureWebRequest(req); ConfigureWebRequest(req);
if((ioc.UserName.Length > 0) || (ioc.Password.Length > 0)) if((ioc.UserName.Length > 0) || (ioc.Password.Length > 0))
{
req.Credentials = new NetworkCredential(ioc.UserName, ioc.Password); req.Credentials = new NetworkCredential(ioc.UserName, ioc.Password);
if (digestAuth)
{
var credentialCache = new CredentialCache();
credentialCache.Add(
new Uri(new Uri(ioc.Path).GetLeftPart(UriPartial.Authority)), // request url's host
"Digest", // authentication type
new NetworkCredential(ioc.UserName, ioc.Password) // credentials
);
req.Credentials = credentialCache;
}
}
else if(NativeLib.IsUnix()) // Mono requires credentials else if(NativeLib.IsUnix()) // Mono requires credentials
req.Credentials = new NetworkCredential("anonymous", string.Empty); req.Credentials = new NetworkCredential("anonymous", string.Empty);
@ -224,12 +259,24 @@ namespace KeePassLib.Serialization
if (StrUtil.IsDataUri(ioc.Path)) if (StrUtil.IsDataUri(ioc.Path))
{ {
byte[] pbData = StrUtil.DataUriToData(ioc.Path); byte[] pbData = StrUtil.DataUriToData(ioc.Path);
if(pbData != null) return new MemoryStream(pbData, false); if (pbData != null)
return new MemoryStream(pbData, false);
} }
if(ioc.IsLocalFile()) return OpenReadLocal(ioc); if (ioc.IsLocalFile())
return OpenReadLocal(ioc);
try
{
return CreateWebClient(ioc, false).OpenRead(new Uri(ioc.Path));
} catch (WebException ex)
{
if ((ex.Response is HttpWebResponse) && (((HttpWebResponse)ex.Response).StatusCode == HttpStatusCode.Unauthorized))
return CreateWebClient(ioc, true).OpenRead(new Uri(ioc.Path));
else
throw ex;
}
return CreateWebClient(ioc).OpenRead(new Uri(ioc.Path));
} }
#else #else
public static Stream OpenRead(IOConnectionInfo ioc) public static Stream OpenRead(IOConnectionInfo ioc)
@ -248,20 +295,20 @@ namespace KeePassLib.Serialization
class UploadOnCloseMemoryStream: MemoryStream class UploadOnCloseMemoryStream: MemoryStream
{ {
System.Net.WebClient webClient; IOConnectionInfo ioc;
string method; string method;
Uri destinationFilePath; Uri destinationFilePath;
public UploadOnCloseMemoryStream(System.Net.WebClient _webClient, string _method, Uri _destinationFilePath) public UploadOnCloseMemoryStream(IOConnectionInfo _ioc, string _method, Uri _destinationFilePath)
{ {
this.webClient = _webClient; ioc = _ioc;
this.method = _method; this.method = _method;
this.destinationFilePath = _destinationFilePath; this.destinationFilePath = _destinationFilePath;
} }
public UploadOnCloseMemoryStream(System.Net.WebClient _webClient, Uri _destinationFilePath) public UploadOnCloseMemoryStream(IOConnectionInfo _ioc, Uri _destinationFilePath)
{ {
this.webClient = _webClient; this.ioc = _ioc;
this.method = null; this.method = null;
this.destinationFilePath = _destinationFilePath; this.destinationFilePath = _destinationFilePath;
} }
@ -269,6 +316,21 @@ namespace KeePassLib.Serialization
public override void Close() public override void Close()
{ {
base.Close(); base.Close();
try
{
uploadData(IOConnection.CreateWebClient(ioc, false));
} catch (WebException ex)
{
if ((ex.Response is HttpWebResponse) && (((HttpWebResponse) ex.Response).StatusCode == HttpStatusCode.Unauthorized))
uploadData(IOConnection.CreateWebClient(ioc, true));
else
throw ex;
}
}
void uploadData(WebClient webClient)
{
if (method != null) if (method != null)
{ {
webClient.UploadData(destinationFilePath, method, this.ToArray()); webClient.UploadData(destinationFilePath, method, this.ToArray());
@ -276,7 +338,6 @@ namespace KeePassLib.Serialization
{ {
webClient.UploadData(destinationFilePath, this.ToArray()); webClient.UploadData(destinationFilePath, this.ToArray());
} }
} }
} }
@ -293,9 +354,9 @@ namespace KeePassLib.Serialization
if(NativeLib.IsUnix() && (uri.Scheme.Equals(Uri.UriSchemeHttp, if(NativeLib.IsUnix() && (uri.Scheme.Equals(Uri.UriSchemeHttp,
StrUtil.CaseIgnoreCmp) || uri.Scheme.Equals(Uri.UriSchemeHttps, StrUtil.CaseIgnoreCmp) || uri.Scheme.Equals(Uri.UriSchemeHttps,
StrUtil.CaseIgnoreCmp))) StrUtil.CaseIgnoreCmp)))
return new UploadOnCloseMemoryStream(CreateWebClient(ioc), WebRequestMethods.Http.Put, uri); return new UploadOnCloseMemoryStream(ioc, WebRequestMethods.Http.Put, uri);
return new UploadOnCloseMemoryStream(CreateWebClient(ioc), uri); return new UploadOnCloseMemoryStream(ioc, uri);
} }
#else #else
public static Stream OpenWrite(IOConnectionInfo ioc) public static Stream OpenWrite(IOConnectionInfo ioc)
@ -352,12 +413,31 @@ namespace KeePassLib.Serialization
return true; return true;
} }
delegate void DoWithRequest(WebRequest req);
static void RepeatWithDigestOnFail(IOConnectionInfo ioc, DoWithRequest f)
{
WebRequest req = CreateWebRequest(ioc, false);
try{
f(req);
}
catch (WebException ex)
{
if ((ex.Response is HttpWebResponse) && (((HttpWebResponse)ex.Response).StatusCode == HttpStatusCode.Unauthorized))
{
req = CreateWebRequest(ioc, true);
f(req);
}
}
}
public static void DeleteFile(IOConnectionInfo ioc) public static void DeleteFile(IOConnectionInfo ioc)
{ {
if(ioc.IsLocalFile()) { File.Delete(ioc.Path); return; } if(ioc.IsLocalFile()) { File.Delete(ioc.Path); return; }
#if !KeePassLibSD #if !KeePassLibSD
WebRequest req = CreateWebRequest(ioc); RepeatWithDigestOnFail(ioc, (WebRequest req) => {
if(req != null) if(req != null)
{ {
if(req is HttpWebRequest) req.Method = "DELETE"; if(req is HttpWebRequest) req.Method = "DELETE";
@ -371,6 +451,7 @@ namespace KeePassLib.Serialization
DisposeResponse(req.GetResponse(), true); DisposeResponse(req.GetResponse(), true);
} }
});
#endif #endif
} }
@ -388,8 +469,7 @@ namespace KeePassLib.Serialization
if(iocFrom.IsLocalFile()) { File.Move(iocFrom.Path, iocTo.Path); return; } if(iocFrom.IsLocalFile()) { File.Move(iocFrom.Path, iocTo.Path); return; }
#if !KeePassLibSD #if !KeePassLibSD
WebRequest req = CreateWebRequest(iocFrom); RepeatWithDigestOnFail(iocFrom, (WebRequest req)=> { if(req != null)
if(req != null)
{ {
if(req is HttpWebRequest) if(req is HttpWebRequest)
{ {
@ -415,6 +495,8 @@ namespace KeePassLib.Serialization
DisposeResponse(req.GetResponse(), true); DisposeResponse(req.GetResponse(), true);
} }
});
#endif #endif
// using(Stream sIn = IOConnection.OpenRead(iocFrom)) // using(Stream sIn = IOConnection.OpenRead(iocFrom))
@ -435,9 +517,11 @@ namespace KeePassLib.Serialization
{ {
try try
{ {
WebRequest req = CreateWebRequest(ioc); RepeatWithDigestOnFail(ioc, (WebRequest req)=> {
req.Method = strMethod; req.Method = strMethod;
DisposeResponse(req.GetResponse(), true); DisposeResponse(req.GetResponse(), true);
});
} }
catch(Exception) { return false; } catch(Exception) { return false; }